Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,17 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = (
value.clone().detach()
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device, dtype=_get_default_dtype())
)
if isinstance(value, Tensor):
# Keep tensor on its original device to avoid unnecessary transfers
value = value.clone().detach()
else:
# Place scalar metrics on CPU to avoid CPU-GPU transfer and synchronization.
# `torch.tensor(value, device="cuda")` contains such synchronization, while the metric
# itself is only used on the CPU side. So placing metric on CPU for scalar inputs is more efficient.
# For non-CUDA devices, maintain original behavior
device = "cpu" if self.device.type == "cuda" else self.device
value = torch.tensor(value, device=device, dtype=_get_default_dtype())

if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def metrics(self) -> _METRICS:
"""This function returns either batch or epoch metrics."""
on_step = self._first_loop_iter is not None
assert self.trainer._results is not None
return self.trainer._results.metrics(on_step)
# Only include progress bar metrics if a progress bar callback is present
include_pbar_metrics = self.trainer.progress_bar_callback is not None
return self.trainer._results.metrics(on_step, include_pbar_metrics=include_pbar_metrics)

@property
def callback_metrics(self) -> _OUT_DICT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> tuple[str
forked_name += dataloader_suffix
return name, forked_name

def metrics(self, on_step: bool) -> _METRICS:
def metrics(self, on_step: bool, *, include_pbar_metrics: bool = True) -> _METRICS:
metrics = _METRICS(callback={}, log={}, pbar={})

for _, result_metric in self.valid_items():
Expand All @@ -489,7 +489,7 @@ def metrics(self, on_step: bool) -> _METRICS:
metrics["callback"][forked_name] = value

# populate progress_bar metrics. convert tensors to numbers
if result_metric.meta.prog_bar:
if result_metric.meta.prog_bar and include_pbar_metrics:
metrics["pbar"][forked_name] = convert_tensors_to_scalars(value)

return metrics
Expand Down
Loading