|
12 | 12 |
|
13 | 13 | import numpy as np |
14 | 14 |
|
15 | | -import pynisher |
16 | | - |
17 | 15 | import torch |
18 | 16 | from torch.optim import Optimizer |
19 | 17 | from torch.optim.lr_scheduler import _LRScheduler |
@@ -196,37 +194,16 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom |
196 | 194 | ] if 'logger_port' in X else logging.handlers.DEFAULT_TCP_LOGGING_PORT, |
197 | 195 | ) |
198 | 196 |
|
199 | | - fit_function = self._fit |
200 | | - if X['use_pynisher']: |
201 | | - wall_time_in_s = X['runtime'] if 'runtime' in X else None |
202 | | - memory_limit = X['cpu_memory_limit'] if 'cpu_memory_limit' in X else None |
203 | | - fit_function = pynisher.enforce_limits( |
204 | | - wall_time_in_s=wall_time_in_s, |
205 | | - mem_in_mb=memory_limit, |
206 | | - logger=self.logger |
207 | | - )(self._fit) |
208 | | - |
209 | 197 | # Call the actual fit function. |
210 | | - state_dict = fit_function( |
| 198 | + self._fit( |
211 | 199 | X=X, |
212 | 200 | y=y, |
213 | 201 | **kwargs |
214 | 202 | ) |
215 | 203 |
|
216 | | - if X['use_pynisher']: |
217 | | - # Normally the X[network] is a pointer to the object, so at the |
218 | | - # end, when we train using X, the pipeline network is updated for free |
219 | | - # If we do multiprocessing (because of pynisher) we have to update |
220 | | - # X[network] manually. we do so in a way that every pipeline component |
221 | | - # can see this new network -- via an update, not overwrite of the pointer |
222 | | - state_dict = state_dict.result |
223 | | - X['network'].load_state_dict(state_dict) |
224 | | - |
225 | | - # TODO: when have the optimizer code, the pynisher object might have failed |
226 | | - # We should process this function as Failure if so trough fit_function.exit_status |
227 | 204 | return cast(autoPyTorchComponent, self.choice) |
228 | 205 |
|
229 | | - def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> torch.nn.Module: |
| 206 | + def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoice': |
230 | 207 | """ |
231 | 208 | Fits a component by using an input dictionary with pre-requisites |
232 | 209 |
|
@@ -359,7 +336,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> torch.nn.Modu |
359 | 336 | # Tag as fitted |
360 | 337 | self.fitted_ = True |
361 | 338 |
|
362 | | - return X['network'].state_dict() |
| 339 | + return self |
363 | 340 |
|
364 | 341 | def early_stop_handler(self, X: Dict[str, Any]) -> bool: |
365 | 342 | """ |
@@ -444,10 +421,6 @@ def check_requirements(self, X: Dict[str, Any], y: Any = None) -> None: |
444 | 421 | raise ValueError('Need a backend to provide the working directory, ' |
445 | 422 | "yet 'backend' was not found in the fit dictionary") |
446 | 423 |
|
447 | | - # For resource allocation, we need to know if pynisher is enabled |
448 | | - if 'use_pynisher' not in X: |
449 | | - raise ValueError('To fit a Trainer, expected fit dictionary to have use_pynisher') |
450 | | - |
451 | 424 | # Whether we should evaluate metrics during training or no |
452 | 425 | if 'metrics_during_training' not in X: |
453 | 426 | raise ValueError('Missing metrics_during_training in the fit dictionary') |
|
0 commit comments