66from autoPyTorch .pipeline .components .base_component import autoPyTorchComponent
77
88
9- class HyperparameterSearchSpaceUpdate ():
9+ class HyperparameterSearchSpaceUpdate :
10+ """
11+ Allows specifying update to the search space of a
12+ particular hyperparameter.
13+
14+ Args:
15+ node_name (str):
16+ The name of the node in the pipeline
17+ hyperparameter (str):
18+ The name of the hyperparameter
19+ value_range (Union[List, Tuple]):
20+ In case of categorical hyperparameter, defines the new categorical choices.
21+ In case of numerical hyperparameter, defines the new range
22+ in the form of (LOWER, UPPER)
23+ default_value (Union[int, float, str]):
24+ New default value for the hyperparameter
25+ log (bool) (default=False):
26+ In case of numerical hyperparameters, whether to sample on a log scale
27+ """
1028 def __init__ (self , node_name : str , hyperparameter : str , value_range : Union [List , Tuple ],
1129 default_value : Union [int , float , str ], log : bool = False ) -> None :
1230 self .node_name = node_name
@@ -16,6 +34,15 @@ def __init__(self, node_name: str, hyperparameter: str, value_range: Union[List,
1634 self .default_value = default_value
1735
1836 def apply (self , pipeline : List [Tuple [str , Union [autoPyTorchComponent , autoPyTorchChoice ]]]) -> None :
37+ """
38+ Applies the update to the appropriate hyperparameter of the pipeline
39+ Args:
40+ pipeline (List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]):
41+ The named steps of the current autopytorch pipeline
42+
43+ Returns:
44+ None
45+ """
1946 [node [1 ]._apply_search_space_update (name = self .hyperparameter ,
2047 new_value_range = self .value_range ,
2148 log = self .log ,
@@ -29,30 +56,69 @@ def __str__(self) -> str:
2956 (" log" if self .log else "" ))
3057
3158
32- class HyperparameterSearchSpaceUpdates ():
59+ class HyperparameterSearchSpaceUpdates :
60+ """ Contains a collection of HyperparameterSearchSpaceUpdate """
3361 def __init__ (self , updates : Optional [List [HyperparameterSearchSpaceUpdate ]] = None ) -> None :
3462 self .updates = updates if updates is not None else []
3563
3664 def apply (self , pipeline : List [Tuple [str , Union [autoPyTorchComponent , autoPyTorchChoice ]]]) -> None :
65+ """
66+ Iteratively applies updates to the pipeline
67+
68+ Args:
69+ pipeline: (List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]):
70+ The named steps of the current autoPyTorch pipeline
71+
72+ Returns:
73+ None
74+ """
3775 for update in self .updates :
3876 update .apply (pipeline )
3977
4078 def append (self , node_name : str , hyperparameter : str , value_range : Union [List , Tuple ],
4179 default_value : Union [int , float , str ], log : bool = False ) -> None :
80+ """
81+ Add a new update
82+
83+ Args:
84+ node_name (str):
85+ The name of the node in the pipeline
86+ hyperparameter (str):
87+ The name of the hyperparameter
88+ value_range (Union[List, Tuple]):
89+ In case of categorical hyperparameter, defines the new categorical choices.
90+ In case of numerical hyperparameter, defines the new range
91+ in the form of (LOWER, UPPER)
92+ default_value (Union[int, float, str]):
93+ New default value for the hyperparameter
94+ log (bool) (default=False):
95+ In case of numerical hyperparameters, whether to sample on a log scale
96+
97+ Returns:
98+ None
99+ """
42100 self .updates .append (HyperparameterSearchSpaceUpdate (node_name = node_name ,
43101 hyperparameter = hyperparameter ,
44102 value_range = value_range ,
45103 default_value = default_value ,
46104 log = log ))
47105
48106 def save_as_file (self , path : str ) -> None :
107+ """
108+ Save the updates as a file to reuse later
109+
110+ Args:
111+ path (str): path of the file
112+
113+ Returns:
114+ None
115+ """
49116 with open (path , "w" ) as f :
50- with open (path , "w" ) as f :
51- for update in self .updates :
52- print (update .node_name , update .hyperparameter , # noqa: T001
53- str (update .value_range ), "'{}'" .format (update .default_value )
54- if isinstance (update .default_value , str ) else update .default_value ,
55- (" log" if update .log else "" ), file = f )
117+ for update in self .updates :
118+ print (update .node_name , update .hyperparameter , # noqa: T001
119+ str (update .value_range ), "'{}'" .format (update .default_value )
120+ if isinstance (update .default_value , str ) else update .default_value ,
121+ (" log" if update .log else "" ), file = f )
56122
57123
58124def parse_hyperparameter_search_space_updates (updates_file : Optional [str ]
0 commit comments