diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index bc2b004fc..00eabd72b 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -23,6 +23,24 @@ @dataclass class Model: lagged_reg_layers: Optional[List[int]] + quantiles: Optional[List[float]] = None + + def setup_quantiles(self): + # convert quantiles to empty list [] if None + if self.quantiles is None: + self.quantiles = [] + # assert quantiles is a list type + assert isinstance(self.quantiles, list), "Quantiles must be provided as list." + # check if quantiles are float values in (0, 1) + assert all( + 0 < quantile < 1 for quantile in self.quantiles + ), "The quantiles specified need to be floats in-between (0, 1)." + # sort the quantiles + self.quantiles.sort() + # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index + self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] + # 0 is the median quantile index + self.quantiles.insert(0, 0.5) @dataclass @@ -92,9 +110,9 @@ class Train: batch_size: Optional[int] loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] optimizer: Union[str, Type[torch.optim.Optimizer]] - quantiles: List[float] = field(default_factory=list) + # quantiles: List[float] = field(default_factory=list) optimizer_args: dict = field(default_factory=dict) - scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None scheduler_args: dict = field(default_factory=dict) newer_samples_weight: float = 1.0 newer_samples_start: float = 0.0 @@ -104,18 +122,21 @@ class Train: n_data: int = field(init=False) loss_func_name: str = field(init=False) lr_finder_args: dict = field(default_factory=dict) + optimizer_state: dict = field(default_factory=dict) + continue_training: bool = False + trainer_config: dict = field(default_factory=dict) def __post_init__(self): - # assert the uncertainty estimation params and then finalize the quantiles - self.set_quantiles() assert self.newer_samples_weight >= 1.0 assert self.newer_samples_start >= 0.0 assert self.newer_samples_start < 1.0 - self.set_loss_func() - self.set_optimizer() - self.set_scheduler() + # self.set_loss_func(self.quantiles) - def set_loss_func(self): + # called in TimeNet configure_optimizers: + # self.set_optimizer() + # self.set_scheduler() + + def set_loss_func(self, quantiles: List[float]): if isinstance(self.loss_func, str): if self.loss_func.lower() in ["smoothl1", "smoothl1loss", "huber"]: # keeping 'huber' for backwards compatiblility, though not identical @@ -135,25 +156,8 @@ def set_loss_func(self): self.loss_func_name = type(self.loss_func).__name__ else: raise NotImplementedError(f"Loss function {self.loss_func} not found") - if len(self.quantiles) > 1: - self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles) - - def set_quantiles(self): - # convert quantiles to empty list [] if None - if self.quantiles is None: - self.quantiles = [] - # assert quantiles is a list type - assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar." - # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index - self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] - # check if quantiles are float values in (0, 1) - assert all( - 0 < quantile < 1 for quantile in self.quantiles - ), "The quantiles specified need to be floats in-between (0, 1)." - # sort the quantiles - self.quantiles.sort() - # 0 is the median quantile index - self.quantiles.insert(0, 0.5) + if len(quantiles) > 1: + self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=quantiles) def set_auto_batch_epoch( self, @@ -182,26 +186,87 @@ def set_optimizer(self): """ Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding torch optimizer. The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet. + + Parameters + ---------- + optimizer_name : int + Object provided to NeuralProphet as optimizer. + optimizer_args : dict + Arguments for the optimizer. + """ - self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config( - self.optimizer, self.optimizer_args - ) + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adamw": + # Tends to overfit, but reliable + self.optimizer = torch.optim.AdamW + self.optimizer_args["weight_decay"] = 1e-3 + elif self.optimizer.lower() == "sgd": + # better validation performance, but diverges sometimes + self.optimizer = torch.optim.SGD + self.optimizer_args["momentum"] = 0.9 + self.optimizer_args["weight_decay"] = 1e-4 + else: + raise ValueError( + f"The optimizer name {self.optimizer} is not supported. Please pass the optimizer class." + ) + elif not issubclass(self.optimizer, torch.optim.Optimizer): + raise ValueError("The provided optimizer is not supported.") def set_scheduler(self): """ - Set the scheduler and scheduler args. + Set the scheduler and scheduler arg depending on the user selection. The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. """ - self.scheduler = torch.optim.lr_scheduler.OneCycleLR - self.scheduler_args.update( - { - "pct_start": 0.3, - "anneal_strategy": "cos", - "div_factor": 10.0, - "final_div_factor": 10.0, - "three_phase": True, - } - ) + if self.continue_training: + if (isinstance(self.scheduler, str) and self.scheduler.lower() == "onecyclelr") or isinstance( + self.scheduler, torch.optim.lr_scheduler.OneCycleLR + ): + log.warning( + "OneCycleLR scheduler is not supported for continued training. Please set another scheduler. Falling back to ExponentialLR scheduler" + ) + self.scheduler = "exponentiallr" + + if self.scheduler is None: + log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.") + self.scheduler = "exponentiallr" + + if isinstance(self.scheduler, str): + if self.scheduler.lower() == "onecyclelr": + self.scheduler = torch.optim.lr_scheduler.OneCycleLR + defaults = { + "pct_start": 0.3, + "anneal_strategy": "cos", + "div_factor": 10.0, + "final_div_factor": 10.0, + "three_phase": True, + } + elif self.scheduler.lower() == "steplr": + self.scheduler = torch.optim.lr_scheduler.StepLR + defaults = { + "step_size": 10, + "gamma": 0.1, + } + elif self.scheduler.lower() == "exponentiallr": + self.scheduler = torch.optim.lr_scheduler.ExponentialLR + defaults = { + "gamma": 0.95, + } + elif self.scheduler.lower() == "cosineannealinglr": + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR + defaults = { + "T_max": 50, + } + else: + raise NotImplementedError( + f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class." + ) + if self.scheduler_args is not None: + defaults.update(self.scheduler_args) + self.scheduler_args = defaults + else: + assert issubclass( + self.scheduler, torch.optim.lr_scheduler.LRScheduler + ), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler" def set_lr_finder_args(self, dataset_size, num_batches): """ @@ -239,6 +304,9 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re delay_weight = 1 return delay_weight + def set_optimizer_state(self, optimizer_state: dict): + self.optimizer_state = optimizer_state + @dataclass class Trend: diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 85939955e..5c1b6d9cd 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -1,7 +1,9 @@ import logging +import math import os import time from collections import OrderedDict +from dataclasses import dataclass, field from typing import Callable, List, Optional, Tuple, Type, Union import matplotlib @@ -299,6 +301,21 @@ class NeuralProphet: >>> # use custorm torchmetrics names >>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError", + scheduler : str, torch.optim.lr_scheduler._LRScheduler + Type of learning rate scheduler to use. + + Options + * (default) ``OneCycleLR``: One Cycle Learning Rate scheduler + * ``StepLR``: Step Learning Rate scheduler + * ``ExponentialLR``: Exponential Learning Rate scheduler + * ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler + + Examples + -------- + >>> from neuralprophet import NeuralProphet + >>> # Step Learning Rate scheduler + >>> m = NeuralProphet(scheduler="StepLR") + COMMENT Uncertainty Estimation COMMENT @@ -432,9 +449,11 @@ def __init__( batch_size: Optional[int] = None, loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] = "SmoothL1Loss", optimizer: Union[str, Type[torch.optim.Optimizer]] = "AdamW", + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = "onecyclelr", + scheduler_args: Optional[dict] = None, newer_samples_weight: float = 2, newer_samples_start: float = 0.0, - quantiles: List[float] = [], + quantiles: Optional[List[float]] = None, impute_missing: bool = True, impute_linear: int = 10, impute_rolling: int = 10, @@ -445,7 +464,7 @@ def __init__( global_time_normalization: bool = True, unknown_data_normalization: bool = False, accelerator: Optional[str] = None, - trainer_config: dict = {}, + trainer_config: Optional[dict] = None, prediction_frequency: Optional[dict] = None, ): self.config = locals() @@ -487,7 +506,11 @@ def __init__( self.max_lags = self.n_lags # Model - self.config_model = configure.Model(lagged_reg_layers=lagged_reg_layers) + self.config_model = configure.Model( + lagged_reg_layers=lagged_reg_layers, + quantiles=quantiles, + ) + self.config_model.setup_quantiles() # Trend self.config_trend = configure.Trend( @@ -502,17 +525,17 @@ def __init__( ) # Training - self.config_train = configure.Train( - quantiles=quantiles, - learning_rate=learning_rate, - epochs=epochs, - batch_size=batch_size, - loss_func=loss_func, - optimizer=optimizer, - newer_samples_weight=newer_samples_weight, - newer_samples_start=newer_samples_start, - trend_reg_threshold=self.config_trend.trend_reg_threshold, - ) + self.learning_rate = learning_rate + self.scheduler = scheduler + self.scheduler_args = scheduler_args + self.epochs = epochs + self.batch_size = batch_size + self.loss_func = loss_func + self.optimizer = optimizer + self.newer_samples_weight = newer_samples_weight + self.newer_samples_start = newer_samples_start + self.trend_reg_threshold = self.config_trend.trend_reg_threshold + self.continue_training = False # Seasonality self.config_seasonality = configure.ConfigSeasonality( @@ -550,7 +573,7 @@ def __init__( # Pytorch Lightning Trainer self.metrics_logger = MetricsLogger(save_dir=os.getcwd()) self.accelerator = accelerator - self.trainer_config = trainer_config + self.trainer_config = trainer_config if trainer_config is not None else {} # set during prediction self.future_periods = None @@ -916,6 +939,9 @@ def fit( continue_training: bool = False, num_workers: int = 0, deterministic: bool = False, + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None, + scheduler_args: Optional[dict] = None, + trainer_config: Optional[dict] = None, ): """Train, and potentially evaluate model. @@ -967,31 +993,57 @@ def fit( Note: using multiple workers and therefore distributed training might significantly increase the training time since each batch needs to be copied to each worker for each epoch. Keeping all data on the main process might be faster for most datasets. + scheduler : str + Type of learning rate scheduler to use for continued training. If None, uses ExponentialLR as + default as specified in the model config. + Options + * ``StepLR``: Step Learning Rate scheduler + * ``ExponentialLR``: Exponential Learning Rate scheduler + * ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler Returns ------- pd.DataFrame metrics with training and potentially evaluation metrics """ - if self.fitted: - raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.") + if minimal: + # overrides these settings: + checkpointing = False + self.metrics = False + progress = None - # Configuration - if epochs is not None: - self.config_train.epochs = epochs + if self.fitted and not continue_training: + raise RuntimeError( + "Model has been fitted already. If you want to continue training please set the flag continue_training." + ) - if batch_size is not None: - self.config_train.batch_size = batch_size + if continue_training: + if epochs is None: + raise ValueError("Continued training requires setting the number of epochs to train for.") - if learning_rate is not None: - self.config_train.learning_rate = learning_rate + if continue_training and self.metrics_logger.checkpoint_path is None: + log.error("Continued training requires checkpointing in model to continue from last epoch.") + + # Configuration + self.config_train = configure.Train( + learning_rate=self.learning_rate if learning_rate is None else learning_rate, + scheduler=self.scheduler if scheduler is None else scheduler, + scheduler_args=self.scheduler_args if scheduler is None else scheduler_args, + epochs=self.epochs if epochs is None else epochs, + batch_size=self.batch_size if batch_size is None else batch_size, + loss_func=self.loss_func, + optimizer=self.optimizer, + newer_samples_weight=self.newer_samples_weight, + newer_samples_start=self.newer_samples_start, + trend_reg_threshold=self.config_trend.trend_reg_threshold, + continue_training=continue_training, + trainer_config=self.trainer_config if trainer_config is None else trainer_config, + ) + self.config_train.set_loss_func(quantiles=self.config_model.quantiles) if early_stopping is not None: self.early_stopping = early_stopping - if metrics is not None: - self.metrics = utils_metrics.get_metrics(metrics) - # Warnings if early_stopping: reg_enabled = utils.check_for_regularization( @@ -1012,18 +1064,16 @@ def fit( number of epochs to train for." ) - if progress == "plot" and metrics is False: - log.info("Progress plot requires metrics to be enabled. Enabling the default metrics.") - metrics = utils_metrics.get_metrics(True) + if metrics: + self.metrics = utils_metrics.get_metrics(metrics) + + if progress == "plot" and not metrics: + log.info("Progress plot requires metrics to be enabled. Disabling progress plot.") + progress = None if not self.config_normalization.global_normalization: log.info("When Global modeling with local normalization, metrics are displayed in normalized scale.") - if minimal: - checkpointing = False - self.metrics = False - progress = None - # Pre-processing # Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled) df, _, _, self.id_list = df_utils.prep_or_copy_df(df) @@ -1060,8 +1110,6 @@ def fit( or any(value != 1 for value in self.num_seasonalities_modelled_dict.values()) ) - if self.fitted is True and not continue_training: - log.error("Model has already been fitted. Re-fitting may break or produce different results.") self.max_lags = df_utils.get_max_num_lags( n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors ) @@ -1191,7 +1239,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a dates=dates, predicted=predicted, n_forecasts=self.n_forecasts, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, components=components, ) if auto_extend and periods_added[df_name] > 0: @@ -1206,7 +1254,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a n_forecasts=self.n_forecasts, max_lags=self.max_lags, freq=self.data_freq, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, config_lagged_regressors=self.config_lagged_regressors, ) if auto_extend and periods_added[df_name] > 0: @@ -1847,7 +1895,7 @@ def predict_trend(self, df: pd.DataFrame, quantile: float = 0.5): else: meta_name_tensor = None - quantile_index = self.config_train.quantiles.index(quantile) + quantile_index = self.config_model.quantiles.index(quantile) trend = self.model.trend(t, meta_name_tensor).detach().numpy()[:, :, quantile_index].squeeze() data_params = self.config_normalization.get_data_params(df_name) @@ -1912,7 +1960,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): for name in self.config_seasonality.periods: features = inputs["seasonalities"][name] - quantile_index = self.config_train.quantiles.index(quantile) + quantile_index = self.config_model.quantiles.index(quantile) y_season = torch.squeeze( self.model.seasonality.compute_fourier(features=features, name=name, meta=meta_name_tensor)[ :, :, quantile_index @@ -2044,7 +2092,7 @@ def plot( log.info(f"Plotting data from ID {df_name}") if forecast_in_focus is None: forecast_in_focus = self.highlight_forecast_step_n - if len(self.config_train.quantiles) > 1: + if len(self.config_model.quantiles) > 1: if (self.highlight_forecast_step_n) is None and ( self.n_forecasts > 1 or self.n_lags > 0 ): # rather query if n_forecasts >1 than n_lags>1 @@ -2084,7 +2132,7 @@ def plot( if plotting_backend.startswith("plotly"): return plot_plotly( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, xlabel=xlabel, ylabel=ylabel, figsize=tuple(x * 70 for x in figsize), @@ -2095,7 +2143,7 @@ def plot( else: return plot( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, ax=ax, xlabel=xlabel, ylabel=ylabel, @@ -2164,7 +2212,7 @@ def get_latest_forecast( elif include_history_data is True: fcst = fcst fcst = utils.fcst_df_to_latest_forecast( - fcst, self.config_train.quantiles, n_last=1 + include_previous_forecasts + fcst, self.config_model.quantiles, n_last=1 + include_previous_forecasts ) return fcst @@ -2233,7 +2281,7 @@ def plot_latest_forecast( else: fcst = fcst[fcst["ID"] == df_name].copy(deep=True) log.info(f"Plotting data from ID {df_name}") - if len(self.config_train.quantiles) > 1: + if len(self.config_model.quantiles) > 1: log.warning( "Plotting latest forecasts when uncertainty estimation enabled" " plots only the median quantile forecasts." @@ -2245,7 +2293,7 @@ def plot_latest_forecast( elif plot_history_data is True: fcst = fcst fcst = utils.fcst_df_to_latest_forecast( - fcst, self.config_train.quantiles, n_last=1 + include_previous_forecasts + fcst, self.config_model.quantiles, n_last=1 + include_previous_forecasts ) # Check whether a local or global plotting backend is set. @@ -2255,7 +2303,7 @@ def plot_latest_forecast( if plotting_backend.startswith("plotly"): return plot_plotly( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, ylabel=ylabel, xlabel=xlabel, figsize=tuple(x * 70 for x in figsize), @@ -2267,7 +2315,7 @@ def plot_latest_forecast( else: return plot( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, ax=ax, ylabel=ylabel, xlabel=xlabel, @@ -2433,7 +2481,7 @@ def plot_components( m=self, fcst=fcst, plot_configuration=valid_plot_configuration, - quantile=self.config_train.quantiles[0], # plot components only for median quantile + quantile=self.config_model.quantiles[0], # plot components only for median quantile figsize=figsize, df_name=df_name, one_period_per_season=one_period_per_season, @@ -2543,11 +2591,11 @@ def plot_parameters( if not (0 < quantile < 1): raise ValueError("The quantile selected needs to be a float in-between (0,1)") # ValueError if selected quantile is out of range - if quantile not in self.config_train.quantiles: + if quantile not in self.config_model.quantiles: raise ValueError("Selected quantile is not specified in the model configuration.") else: # plot parameters for median quantile if not specified - quantile = self.config_train.quantiles[0] + quantile = self.config_model.quantiles[0] # Validate components to be plotted valid_parameters_set = [ @@ -2615,13 +2663,9 @@ def plot_parameters( ) def _init_model(self): - """Build Pytorch model with configured hyperparamters. - - Returns - ------- - TimeNet model - """ + """Build Pytorch model with configured hyperparamters.""" self.model = time_net.TimeNet( + config_model=self.config_model, config_train=self.config_train, config_trend=self.config_trend, config_ar=self.config_ar, @@ -2644,7 +2688,6 @@ def _init_model(self): meta_used_in_model=self.meta_used_in_model, ) log.debug(self.model) - return self.model def _init_train_loader(self, df, num_workers=0): """Executes data preparation steps and initiates training procedure. @@ -2661,23 +2704,23 @@ def _init_train_loader(self, df, num_workers=0): torch DataLoader """ df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be avoided? - # if not self.fitted: - self.config_normalization.init_data_params( - df=df, - config_lagged_regressors=self.config_lagged_regressors, - config_regressors=self.config_regressors, - config_events=self.config_events, - config_seasonality=self.config_seasonality, - ) + if not self.fitted: + self.config_normalization.init_data_params( + df=df, + config_lagged_regressors=self.config_lagged_regressors, + config_regressors=self.config_regressors, + config_events=self.config_events, + config_seasonality=self.config_seasonality, + ) df = _normalize(df=df, config_normalization=self.config_normalization) - # if not self.fitted: - if self.config_trend.changepoints is not None: - # scale user-specified changepoint times - df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) + if not self.fitted: + if self.config_trend.changepoints is not None: + # scale user-specified changepoint times + df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) - df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) - self.config_trend.changepoints = df_normalized["t"].values # type: ignore + df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) + self.config_trend.changepoints = df_normalized["t"].values # type: ignore # df_merged, _ = df_utils.join_dataframes(df) # df_merged = df_merged.sort_values("ds") @@ -2765,21 +2808,33 @@ def _train( # Internal flag to check if validation is enabled validation_enabled = df_val is not None - # Init the model, if not continue from checkpoint + # Load model and optimizer state from checkpoint if continue_training is True if continue_training: - raise NotImplementedError( - "Continuing training from checkpoint is not implemented yet. This feature is planned for one of the \ - upcoming releases." - ) + checkpoint_path = self.metrics_logger.checkpoint_path + checkpoint = torch.load(checkpoint_path) + + checkpoint_epoch = checkpoint["epoch"] if "epoch" in checkpoint else 0 + previous_epoch = max(self.model.current_epoch, checkpoint_epoch) + + # Set continue_training flag in model to update scheduler correctly + self.model.continue_training = True + self.model.start_epoch = previous_epoch + + # Adjust epochs + new_total_epochs = previous_epoch + self.config_train.epochs + self.config_train.epochs = new_total_epochs + + self.config_train.set_optimizer_state(checkpoint["optimizer_states"][0]) + else: - self.model = self._init_model() + self._init_model() self.model.train_loader = train_loader # Init the Trainer self.trainer, checkpoint_callback = utils.configure_trainer( config_train=self.config_train, - config=self.trainer_config, + config=self.config_train.trainer_config, metrics_logger=self.metrics_logger, early_stopping=self.early_stopping, early_stopping_target="Loss_val" if validation_enabled else "Loss", @@ -2852,8 +2907,18 @@ def _train( if not metrics_enabled: return None + # Return metrics collected in logger as dataframe - metrics_df = pd.DataFrame(self.metrics_logger.history) + if self.metrics_logger.history is not None: + # avoid array mismatch when continuing training + history = self.metrics_logger.history + max_length = max(len(lst) for lst in history.values()) + for key in history: + while len(history[key]) < max_length: + history[key].append(None) + metrics_df = pd.DataFrame(history) + else: + metrics_df = pd.DataFrame() return metrics_df def restore_trainer(self, accelerator: Optional[str] = None): @@ -2867,7 +2932,7 @@ def restore_trainer(self, accelerator: Optional[str] = None): """ self.trainer, _ = utils.configure_trainer( config_train=self.config_train, - config=self.trainer_config, + config=self.config_train.trainer_config, metrics_logger=self.metrics_logger, early_stopping=self.early_stopping, accelerator=accelerator, @@ -3072,7 +3137,7 @@ def conformal_predict( alpha=alpha, method=method, n_forecasts=self.n_forecasts, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, ) df_forecast = c.predict(df=df_test, df_cal=df_cal, show_all_PI=show_all_PI) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index a4fbfee3a..8f847d56d 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -42,6 +42,7 @@ class TimeNet(pl.LightningModule): def __init__( self, + config_model: configure.Model, config_seasonality: configure.ConfigSeasonality, config_train: Optional[configure.Train] = None, config_trend: Optional[configure.Trend] = None, @@ -63,6 +64,8 @@ def __init__( num_seasonalities_modelled: int = 1, num_seasonalities_modelled_dict: dict = None, meta_used_in_model: bool = False, + continue_training: bool = False, + start_epoch: int = 0, ): """ Parameters @@ -149,6 +152,7 @@ def __init__( pass # General + self.config_model = config_model self.n_forecasts = n_forecasts # Lightning Config @@ -156,9 +160,16 @@ def __init__( self.config_normalization = config_normalization self.compute_components_flag = compute_components_flag + # Continued training + self.continue_training = continue_training + self.start_epoch = start_epoch + # Optimizer and LR Scheduler - self._optimizer = self.config_train.optimizer - self._scheduler = self.config_train.scheduler + # self.config_train.set_optimizer() + # self.config_train.set_scheduler() + # self._optimizer = self.config_train.optimizer + # self._scheduler = self.config_train.scheduler + # Manual optimization: we are responsible for calling .backward(), .step(), .zero_grad(). self.automatic_optimization = False # Hyperparameters (can be tuned using trainer.tune()) @@ -200,7 +211,7 @@ def __init__( ) # Quantiles - self.quantiles = self.config_train.quantiles + self.quantiles = self.config_model.quantiles # Trend self.config_trend = config_trend @@ -861,16 +872,36 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return prediction, components def configure_optimizers(self): + self.config_train.set_optimizer() + self.config_train.set_scheduler() + self._optimizer = self.config_train.optimizer + self._scheduler = self.config_train.scheduler + # Optimizer optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) + if self.continue_training: + optimizer.load_state_dict(self.config_train.optimizer_state) + + # Update initial learning rate to the last learning rate for continued training + last_lr = float(optimizer.param_groups[0]["lr"]) # Ensure it's a float + + for param_group in optimizer.param_groups: + param_group["initial_lr"] = (last_lr,) + # Scheduler - lr_scheduler = self._scheduler( - optimizer, - max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, - **self.config_train.scheduler_args, - ) + if self._scheduler == torch.optim.lr_scheduler.OneCycleLR: + lr_scheduler = self._scheduler( + optimizer, + max_lr=self.learning_rate, + total_steps=self.trainer.estimated_stepping_batches, + **self.config_train.scheduler_args, + ) + else: + lr_scheduler = self._scheduler( + optimizer, + **self.config_train.scheduler_args, + ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} diff --git a/tests/test_configure.py b/tests/test_configure.py index e5c5e9800..a93539e29 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -1,20 +1,6 @@ import pytest -from neuralprophet.configure import Train - - -def generate_config_train_params(overrides={}): - config_train_params = { - "quantiles": None, - "learning_rate": None, - "epochs": None, - "batch_size": None, - "loss_func": "SmoothL1Loss", - "optimizer": "AdamW", - } - for key, value in overrides.items(): - config_train_params[key] = value - return config_train_params +from neuralprophet import NeuralProphet def test_config_training_quantiles(): @@ -26,24 +12,21 @@ def test_config_training_quantiles(): ({"quantiles": [0.2, 0.8]}, [0.5, 0.2, 0.8]), ({"quantiles": [0.5, 0.8]}, [0.5, 0.8]), ] - for overrides, expected in checks: - config_train_params = generate_config_train_params(overrides) - config = Train(**config_train_params) - assert config.quantiles == expected + model = NeuralProphet(**overrides) + assert model.config_model.quantiles == expected def test_config_training_quantiles_error_invalid_type(): - config_train_params = generate_config_train_params() - config_train_params["quantiles"] = "hello world" with pytest.raises(AssertionError) as err: - Train(**config_train_params) - assert str(err.value) == "Quantiles must be in a list format, not None or scalar." + _ = NeuralProphet(quantiles="hello world") + assert str(err.value) == "Quantiles must be provided as list." def test_config_training_quantiles_error_invalid_scale(): - config_train_params = generate_config_train_params() - config_train_params["quantiles"] = [-1] with pytest.raises(Exception) as err: - Train(**config_train_params) + _ = NeuralProphet(quantiles=[-1]) + assert str(err.value) == "The quantiles specified need to be floats in-between (0, 1)." + with pytest.raises(Exception) as err: + _ = NeuralProphet(quantiles=[1.3]) assert str(err.value) == "The quantiles specified need to be floats in-between (0, 1)." diff --git a/tests/test_train_config.py b/tests/test_train_config.py new file mode 100644 index 000000000..95716365e --- /dev/null +++ b/tests/test_train_config.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import io +import logging +import os +import pathlib + +import pandas as pd +import pytest + +from neuralprophet import NeuralProphet, df_utils, load, save + +log = logging.getLogger("NP.test") +log.setLevel("ERROR") +log.parent.setLevel("ERROR") + +DIR = pathlib.Path(__file__).parent.parent.absolute() +DATA_DIR = os.path.join(DIR, "tests", "test-data") +PEYTON_FILE = os.path.join(DATA_DIR, "wp_log_peyton_manning.csv") +AIR_FILE = os.path.join(DATA_DIR, "air_passengers.csv") +YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") +NROWS = 512 +EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 +LR = 1.0 +BATCH_SIZE = 64 + +PLOT = False + + +def generate_config_train_params(overrides={}): + config_train_params = { + "learning_rate": None, + "epochs": None, + "batch_size": None, + "loss_func": "SmoothL1Loss", + "optimizer": "AdamW", + } + for key, value in overrides.items(): + config_train_params[key] = value + return config_train_params + + +def test_continue_training(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) + assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +def test_continue_training_with_scheduler_selection(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + # Continue training with StepLR + metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") + assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +def test_save_load_continue_training(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + metrics = m.fit(df, checkpointing=True, freq="D") + save(m, "test_model.pt") + m2 = load("test_model.pt") + metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") + assert metrics["Loss"].min() >= metrics2["Loss"].min() diff --git a/tests/test_utils.py b/tests/test_utils.py index a327f3122..8bed33192 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,6 +21,7 @@ YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") NROWS = 512 EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 LR = 1.0 BATCH_SIZE = 64 @@ -99,19 +100,3 @@ def test_save_load_io(): # Check that the forecasts are the same pd.testing.assert_frame_equal(forecast, forecast2) pd.testing.assert_frame_equal(forecast, forecast3) - - -# TODO: add functionality to continue training -# def test_continue_training(): -# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) -# m = NeuralProphet( -# epochs=EPOCHS, -# batch_size=BATCH_SIZE, -# learning_rate=LR, -# n_lags=6, -# n_forecasts=3, -# n_changepoints=0, -# ) -# metrics = m.fit(df, freq="D") -# metrics2 = m.fit(df, freq="D", continue_training=True) -# assert metrics1["Loss"].sum() >= metrics2["Loss"].sum()