diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1b9ca10ccb61..ed77e1c79384 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3004,7 +3004,7 @@ def _maybe_log_save_evaluate( # reset tr_loss to zero tr_loss -= tr_loss - logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged) if grad_norm is not None: logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm if learning_rate is not None: diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 86cb61d18212..6fe3470d0441 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -665,12 +665,12 @@ def on_log(self, args, state, control, logs=None, **kwargs): f"[String too long to display, length: {len(v)} > {self.max_str_len}. " "Consider increasing `max_str_len` if needed.]" ) + if isinstance(v, float): + # Format floats for better readability + shallow_logs[k] = f"{v:.4g}" else: shallow_logs[k] = v _ = shallow_logs.pop("total_flos", None) - # round numbers so that it looks better in console - if "epoch" in shallow_logs: - shallow_logs["epoch"] = round(shallow_logs["epoch"], 2) self.training_bar.write(str(shallow_logs)) def on_train_end(self, args, state, control, **kwargs): @@ -687,6 +687,8 @@ class PrinterCallback(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): _ = logs.pop("total_flos", None) if state.is_local_process_zero: + if logs is not None: + logs = {k: (f"{v:.4g}" if isinstance(v, float) else v) for k, v in logs.items()} print(logs)