@@ -574,6 +574,10 @@ def forward(self,
574574 past_observed_targets : Optional [torch .BoolTensor ] = None ,
575575 decoder_observed_values : Optional [torch .Tensor ] = None ,
576576 ) -> ALL_NET_OUTPUT :
577+
578+ if isinstance (past_targets , dict ):
579+ past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
580+
577581 x_past , x_future , x_static , loc , scale , static_context_initial_hidden , _ = self .pre_processing (
578582 past_targets = past_targets ,
579583 past_observed_targets = past_observed_targets ,
@@ -603,6 +607,38 @@ def forward(self,
603607
604608 return self .rescale_output (output , loc , scale , self .device )
605609
610+ def _unwrap_past_targets (
611+ self ,
612+ past_targets : dict
613+ ) -> Tuple [
614+ torch .Tensor ,
615+ Optional [torch .Tensor ],
616+ Optional [torch .Tensor ],
617+ Optional [torch .Tensor ],
618+ Optional [torch .BoolTensor ],
619+ Optional [torch .Tensor ]]:
620+ """
621+ Time series forecasting network requires multiple inputs for the forward pass which is different to how pytorch
622+ networks usually work. SWA's update_bn in line #452 of trainer choice, does not unwrap the dictionary of the
623+ input when running the forward pass. So we need to check for that here.
624+
625+ Args:
626+ past_targets (dict):
627+ Input mistakenly passed to past_targets variable
628+
629+ Returns:
630+ _type_: _description_
631+ """
632+
633+ past_targets_copy = past_targets .copy ()
634+ past_targets = past_targets_copy .pop ('past_targets' )
635+ future_targets = past_targets_copy .pop ('future_targets' , None )
636+ past_features = past_targets_copy .pop ('past_features' , None )
637+ future_features = past_targets_copy .pop ('future_features' , None )
638+ past_observed_targets = past_targets_copy .pop ('past_observed_targets' , None )
639+ decoder_observed_values = past_targets_copy .pop ('decoder_observed_values' , None )
640+ return past_targets ,past_features ,future_features ,past_observed_targets
641+
606642 def pred_from_net_output (self , net_output : ALL_NET_OUTPUT ) -> torch .Tensor :
607643 if self .output_type == 'regression' :
608644 return net_output
@@ -694,6 +730,10 @@ def forward(self,
694730 future_features : Optional [torch .Tensor ] = None ,
695731 past_observed_targets : Optional [torch .BoolTensor ] = None ,
696732 decoder_observed_values : Optional [torch .Tensor ] = None , ) -> ALL_NET_OUTPUT :
733+
734+ if isinstance (past_targets , dict ):
735+ past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
736+
697737 x_past , _ , x_static , loc , scale , static_context_initial_hidden , past_targets = self .pre_processing (
698738 past_targets = past_targets ,
699739 past_observed_targets = past_observed_targets ,
@@ -983,6 +1023,10 @@ def forward(self,
9831023 future_features : Optional [torch .Tensor ] = None ,
9841024 past_observed_targets : Optional [torch .BoolTensor ] = None ,
9851025 decoder_observed_values : Optional [torch .Tensor ] = None , ) -> ALL_NET_OUTPUT :
1026+
1027+ if isinstance (past_targets , dict ):
1028+ past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
1029+
9861030 encode_length = min (self .window_size , past_targets .shape [1 ])
9871031
9881032 if past_observed_targets is None :
@@ -1250,6 +1294,9 @@ def forward(self, # type: ignore[override]
12501294 decoder_observed_values : Optional [torch .Tensor ] = None , ) -> Union [torch .Tensor ,
12511295 Tuple [torch .Tensor , torch .Tensor ]]:
12521296
1297+ if isinstance (past_targets , dict ):
1298+ past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
1299+
12531300 # Unlike other networks, NBEATS network is required to predict both past and future targets.
12541301 # Thereby, we return two tensors for backcast and forecast
12551302 if past_observed_targets is None :
0 commit comments