Skip to content

Commit 35b3c31

Browse files
Bycobmergify[bot]
authored andcommitted
feat(nbeats): add parameter coefficient to backcast loss
1 parent 6597b53 commit 35b3c31

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

src/backends/torch/native/native_factory.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,18 @@ namespace dd
5050
if (tdef.find("nbeats") != std::string::npos)
5151
{
5252
std::vector<std::string> p;
53+
double bc_loss_coef = 1;
5354
if (template_params.has("stackdef"))
5455
{
5556
p = template_params.get("stackdef")
5657
.get<std::vector<std::string>>();
5758
}
58-
return new NBeats(inputc, p);
59+
if (template_params.has("backcast_loss_coef"))
60+
{
61+
bc_loss_coef
62+
= template_params.get("backcast_loss_coef").get<double>();
63+
}
64+
return new NBeats(inputc, p, bc_loss_coef);
5965
}
6066
else if (tdef.find("ttransformer") != std::string::npos)
6167
return new TTransformer(inputc, template_params, logger);

src/backends/torch/native/templates/nbeats.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ namespace dd
297297
}
298298

299299
NBeats(const CSVTSTorchInputFileConn &inputc,
300-
std::vector<std::string> stackdef,
300+
std::vector<std::string> stackdef, double backcast_loss_coef = 1,
301301
std::vector<BlockType> stackTypes = NBEATS_DEFAULT_STACK_TYPES,
302302
int nb_blocks_per_stack = NBEATS_DEFAULT_NB_BLOCKS,
303303
int data_size = NBEATS_DEFAULT_DATA_SIZE,
@@ -311,7 +311,8 @@ namespace dd
311311
_hidden_layer_units(hidden_layer_units),
312312
_nb_blocks_per_stack(nb_blocks_per_stack),
313313
_share_weights_in_stack(share_weights_in_stack),
314-
_stack_types(stackTypes), _thetas_dims(thetas_dims)
314+
_stack_types(stackTypes), _thetas_dims(thetas_dims),
315+
_backcast_loss_coef(backcast_loss_coef)
315316
{
316317
parse_stackdef(stackdef);
317318
update_params(inputc);
@@ -413,12 +414,14 @@ namespace dd
413414
torch::Tensor y_pred = torch::slice(output, 1, _backcast_length,
414415
_backcast_length + _forecast_length);
415416
torch::Tensor input_zeros = torch::zeros_like(input_real);
417+
416418
if (loss.empty() || loss == "L1" || loss == "l1")
417419
return torch::l1_loss(y_pred, target)
418-
+ torch::l1_loss(x_pred, input_zeros);
420+
+ torch::l1_loss(x_pred, input_zeros) * _backcast_loss_coef;
419421
if (loss == "L2" || loss == "l2" || loss == "eucl")
420422
return torch::mse_loss(y_pred, target)
421-
+ torch::mse_loss(x_pred, input_zeros);
423+
+ torch::mse_loss(x_pred, input_zeros) * _backcast_loss_coef;
424+
422425
throw MLLibBadParamException("unknown loss " + loss);
423426
}
424427

@@ -435,6 +438,8 @@ namespace dd
435438
bool _share_weights_in_stack = NBEATS_DEFAULT_SHARE_WEIGHTS;
436439
std::vector<BlockType> _stack_types = NBEATS_DEFAULT_STACK_TYPES;
437440
std::vector<int> _thetas_dims = NBEATS_DEFAULT_THETAS;
441+
double _backcast_loss_coef
442+
= 1; /** < Coefficient applied to backcast loss */
438443

439444
std::vector<Stack> _stacks;
440445
torch::nn::Linear _fcn{ nullptr };

0 commit comments

Comments
 (0)