@@ -297,7 +297,7 @@ namespace dd
297297 }
298298
299299 NBeats (const CSVTSTorchInputFileConn &inputc,
300- std::vector<std::string> stackdef,
300+ std::vector<std::string> stackdef, double backcast_loss_coef = 1 ,
301301 std::vector<BlockType> stackTypes = NBEATS_DEFAULT_STACK_TYPES,
302302 int nb_blocks_per_stack = NBEATS_DEFAULT_NB_BLOCKS,
303303 int data_size = NBEATS_DEFAULT_DATA_SIZE,
@@ -311,7 +311,8 @@ namespace dd
311311 _hidden_layer_units(hidden_layer_units),
312312 _nb_blocks_per_stack(nb_blocks_per_stack),
313313 _share_weights_in_stack(share_weights_in_stack),
314- _stack_types(stackTypes), _thetas_dims(thetas_dims)
314+ _stack_types(stackTypes), _thetas_dims(thetas_dims),
315+ _backcast_loss_coef(backcast_loss_coef)
315316 {
316317 parse_stackdef (stackdef);
317318 update_params (inputc);
@@ -413,12 +414,14 @@ namespace dd
413414 torch::Tensor y_pred = torch::slice (output, 1 , _backcast_length,
414415 _backcast_length + _forecast_length);
415416 torch::Tensor input_zeros = torch::zeros_like (input_real);
417+
416418 if (loss.empty () || loss == " L1" || loss == " l1" )
417419 return torch::l1_loss (y_pred, target)
418- + torch::l1_loss (x_pred, input_zeros);
420+ + torch::l1_loss (x_pred, input_zeros) * _backcast_loss_coef ;
419421 if (loss == " L2" || loss == " l2" || loss == " eucl" )
420422 return torch::mse_loss (y_pred, target)
421- + torch::mse_loss (x_pred, input_zeros);
423+ + torch::mse_loss (x_pred, input_zeros) * _backcast_loss_coef;
424+
422425 throw MLLibBadParamException (" unknown loss " + loss);
423426 }
424427
@@ -435,6 +438,8 @@ namespace dd
435438 bool _share_weights_in_stack = NBEATS_DEFAULT_SHARE_WEIGHTS;
436439 std::vector<BlockType> _stack_types = NBEATS_DEFAULT_STACK_TYPES;
437440 std::vector<int > _thetas_dims = NBEATS_DEFAULT_THETAS;
441+ double _backcast_loss_coef
442+ = 1 ; /* * < Coefficient applied to backcast loss */
438443
439444 std::vector<Stack> _stacks;
440445 torch::nn::Linear _fcn{ nullptr };
0 commit comments