diff --git a/tsa/eval.py b/tsa/eval.py index 50c7583..4c2f2f0 100644 --- a/tsa/eval.py +++ b/tsa/eval.py @@ -44,11 +44,11 @@ def evaluate(test_iter, criterion, model, config, ts): predictions, targets = torch.cat(predictions), torch.cat(targets) if config.general.do_eval: - preds, targets = ts.invert_scale(predictions), ts.invert_scale(targets) + preds, targets_ = ts.invert_scale(predictions), ts.invert_scale(targets) plt.figure() plt.plot(preds, linewidth=.3) - plt.plot(targets, linewidth=.3) + plt.plot(targets_, linewidth=.3) plt.savefig("{}/preds.png".format(config.general.output_dir)) plt.close()