From f1a9c82da388fc6e5e8c94b3ef32e3c54dd1677f Mon Sep 17 00:00:00 2001 From: Sugat Mahanti Date: Tue, 4 Nov 2025 19:05:13 -0500 Subject: [PATCH 1/3] Modernize metrics module with type hints and generic types --- src/llmcompressor/metrics/logger.py | 265 +++++++++--------- .../metrics/utils/frequency_manager.py | 12 +- 2 files changed, 138 insertions(+), 139 deletions(-) diff --git a/src/llmcompressor/metrics/logger.py b/src/llmcompressor/metrics/logger.py index 10a030cc99..47a09be4c4 100644 --- a/src/llmcompressor/metrics/logger.py +++ b/src/llmcompressor/metrics/logger.py @@ -11,7 +11,7 @@ from datetime import datetime from pathlib import Path from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable from loguru import logger @@ -94,7 +94,7 @@ def enabled(self, value: bool): def __repr__(self): return f"{self.__class__.__name__}(name={self._name}, enabled={self._enabled})" - def log_hyperparams(self, params: Dict[str, float]) -> bool: + def log_hyperparams(self, params: dict[str, float]) -> bool: """ :param params: Each key-value pair in the dictionary is the name of the hyper parameter and it's corresponding value. @@ -106,8 +106,8 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, + step: int | None = None, + wall_time: float | None = None, **kwargs, ) -> bool: """ @@ -123,9 +123,9 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, **kwargs, ) -> bool: """ @@ -142,8 +142,8 @@ def log_string( self, tag: str, string: str, - step: Optional[int] = None, - wall_time: Optional[float] = None, + step: int | None = None, + wall_time: float | None = None, **kwargs, ) -> bool: """ @@ -185,12 +185,12 @@ def __init__( self, lambda_func: Callable[ [ - Optional[str], - Optional[Union[float, str]], - Optional[Dict[str, float]], - Optional[int], - Optional[float], - Optional[int], + str | None, + float | str | None, + dict[str, float] | None, + int | None, + float | None, + int | None, ], bool, ], @@ -206,12 +206,12 @@ def lambda_func( self, ) -> Callable[ [ - Optional[str], - Optional[Union[float, str]], - Optional[Dict[str, float]], - Optional[int], - Optional[float], - Optional[int], + str | None, + float | str | None, + dict[str, float] | None, + int | None, + float | None, + int | None, ], bool, ]: @@ -223,8 +223,8 @@ def lambda_func( def log_hyperparams( self, - params: Dict, - level: Optional[Union[int, str]] = None, + params: dict, + level: int | str | None = None, ) -> bool: """ :param params: Each key-value pair in the dictionary is the name of the @@ -248,9 +248,9 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + level: int | str | None = None, ) -> bool: """ :param tag: identifying tag to log the value with @@ -277,10 +277,10 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, + level: int | str | None = None, ) -> bool: """ :param tag: identifying tag to log the values with @@ -366,12 +366,12 @@ def _create_default_logger(self) -> None: def _log_lambda( self, - tag: Optional[str], - value: Optional[Union[float, str]], - values: Optional[Dict[str, float]], - step: Optional[int], - wall_time: Optional[float], - level: Optional[Union[int, str]] = None, + tag: str | None, + value: float | str | None, + values: dict[str, float] | None, + step: int | None, + wall_time: float | None, + level: int | str | None = None, ) -> bool: """ :param tag: identifying tag to log the values with @@ -386,7 +386,7 @@ def _log_lambda( if not level: level = "DEBUG" - def is_higher_than_debug(lev: Optional[Union[int, str]] = None) -> bool: + def is_higher_than_debug(lev: int | str | None = None) -> bool: """Check if the given level is higher than DEBUG level.""" debug_level_no = logger.level("DEBUG").no if isinstance(lev, int): @@ -417,11 +417,11 @@ def is_higher_than_debug(lev: Optional[Union[int, str]] = None) -> bool: def log_string( self, - tag: Optional[str], - string: Optional[str], - step: Optional[int], - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + tag: str | None, + string: str | None, + step: int | None, + wall_time: float | None = None, + level: int | str | None = None, ) -> bool: """ :param tag: identifying tag to log the values with @@ -513,12 +513,12 @@ def writer(self) -> SummaryWriter: def _log_lambda( self, - tag: Optional[str], - value: Optional[float], - values: Optional[Dict[str, float]], - step: Optional[int], - wall_time: Optional[float], - level: Optional[Union[int, str]] = None, + tag: str | None, + value: float | None, + values: dict[str, float] | None, + step: int | None, + wall_time: float | None, + level: int | str | None = None, ) -> bool: if value is not None: self._writer.add_scalar(tag, value, step, wall_time) @@ -553,10 +553,10 @@ def available() -> bool: def __init__( self, - init_kwargs: Optional[Dict] = None, + init_kwargs: dict | None = None, name: str = "wandb", enabled: bool = True, - wandb_err: Optional[Exception] = wandb_err, + wandb_err: Exception | None = wandb_err, ): if wandb_err: raise wandb_err @@ -587,12 +587,12 @@ def __init__( def _log_lambda( self, - tag: Optional[str], - value: Optional[float], - values: Optional[Dict[str, float]], - step: Optional[int], - wall_time: Optional[float], - level: Optional[Union[int, str]] = None, + tag: str | None, + value: float | None, + values: dict[str, float] | None, + step: int | None, + wall_time: float | None, + level: int | str | None = None, ) -> bool: params = {} @@ -653,27 +653,26 @@ class SparsificationGroupLogger(BaseLogger): def __init__( self, - lambda_func: Optional[ - Callable[ - [ - Optional[str], - Optional[float], - Optional[Dict[str, float]], - Optional[int], - Optional[float], - ], - bool, - ] - ] = None, + lambda_func: Callable[ + [ + str | None, + float | None, + dict[str, float] | None, + int | None, + float | None, + ], + bool, + ] + | None = None, python: bool = False, - python_log_level: Optional[Union[int, str]] = "INFO", - tensorboard: Optional[Union[bool, str, SummaryWriter]] = None, - wandb_: Optional[Union[bool, Dict]] = None, + python_log_level: int | str | None = "INFO", + tensorboard: bool | str | SummaryWriter = None, + wandb_: bool | dict | None = None, name: str = "sparsification", enabled: bool = True, ): super().__init__(name, enabled) - self._loggers: List[BaseLogger] = [] + self._loggers: list[BaseLogger] = [] if lambda_func: self._loggers.append( @@ -703,7 +702,7 @@ def __init__( if wandb_ and WANDBLogger.available(): self._loggers.append( WANDBLogger( - init_kwargs=wandb_ if isinstance(wandb_, Dict) else None, + init_kwargs=wandb_ if isinstance(wandb_, dict) else None, name=name, enabled=enabled, ) @@ -720,13 +719,13 @@ def enabled(self, value: bool): log.enabled = value @property - def loggers(self) -> List[BaseLogger]: + def loggers(self) -> list[BaseLogger]: """ :return: the created metrics sub instances for this metrics """ return self._loggers - def log_hyperparams(self, params: Dict, level: Optional[Union[int, str]] = None): + def log_hyperparams(self, params: dict, level: int | str | None = None): """ :param params: Each key-value pair in the dictionary is the name of the hyper parameter and it's corresponding value. @@ -738,9 +737,9 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + level: int | str | None = None, ): """ :param tag: identifying tag to log the value with @@ -756,10 +755,10 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, + level: int | str | None = None, ): """ :param tag: identifying tag to log the values with @@ -793,8 +792,8 @@ class LoggerManager(ABC): def __init__( self, - loggers: Optional[List[BaseLogger]] = None, - log_frequency: Union[float, None] = 0.1, + loggers: list[BaseLogger] | None = None, + log_frequency: float | None = 0.1, log_python: bool = True, name: str = "manager", mode: LoggingModeType = "exact", @@ -817,7 +816,7 @@ def __init__( log_frequency=log_frequency, ) - self.system = SystemLoggingWraper( + self.system = SystemLoggingWrapper( loggers=self._loggers, frequency_manager=self.frequency_manager ) self.metric = MetricLoggingWrapper( @@ -882,28 +881,28 @@ def epoch_to_step(epoch, steps_per_epoch): return round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch) @property - def loggers(self) -> List[BaseLogger]: + def loggers(self) -> list[BaseLogger]: """ :return: list of loggers assigned to this manager """ return self._loggers @loggers.setter - def loggers(self, value: List[BaseLogger]): + def loggers(self, value: list[BaseLogger]): """ :param value: list of loggers assigned to this manager """ self._loggers = value @property - def log_frequency(self) -> Union[str, float, None]: + def log_frequency(self) -> str | float | None: """ :return: number of epochs or fraction of epochs to wait between logs """ return self.frequency_manager._log_frequency @log_frequency.setter - def log_frequency(self, value: Union[str, float, None]): + def log_frequency(self, value: str | float | None): """ :param value: number of epochs or fraction of epochs to wait between logs """ @@ -917,7 +916,7 @@ def name(self) -> str: return self._name @property - def wandb(self) -> Optional[ModuleType]: + def wandb(self) -> ModuleType | None: """ :return: wandb module if initialized """ @@ -930,10 +929,10 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] | None = ALL_TOKEN, + level: int | str | None = None, ): """ (Note: this method is deprecated and will be removed in a future version, @@ -960,11 +959,11 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] | None = ALL_TOKEN, + level: int | str | None = None, ): """ (Note: this method is deprecated and will be removed in a future version, @@ -990,9 +989,9 @@ def log_scalars( def log_hyperparams( self, - params: Dict, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + params: dict, + log_types: str | list[str] | None = ALL_TOKEN, + level: int | str | None = None, ): """ (Note: this method is deprecated and will be removed in a future version, @@ -1012,10 +1011,10 @@ def log_string( self, tag: str, string: str, - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ (Note: this method is deprecated and will be removed in a future version, @@ -1052,7 +1051,7 @@ def save( log.save(file_path, **kwargs) @contextmanager - def time(self, tag: Optional[str] = None, *args, **kwargs): + def time(self, tag: str | None = None, *args, **kwargs): """ Context manager to log the time it takes to run the block of code @@ -1076,7 +1075,7 @@ class LoggingWrapperBase: Base class that holds a reference to the loggers and frequency manager """ - def __init__(self, loggers: List[BaseLogger], frequency_manager: FrequencyManager): + def __init__(self, loggers: list[BaseLogger], frequency_manager: FrequencyManager): self.loggers = loggers self._frequency_manager = frequency_manager @@ -1087,7 +1086,7 @@ def __repr__(self): ) -class SystemLoggingWraper(LoggingWrapperBase): +class SystemLoggingWrapper(LoggingWrapperBase): """ Wraps utilities and convenience methods for logging strings to the system """ @@ -1096,10 +1095,10 @@ def log_string( self, tag: str, string: str, - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ :param tag: identifying tag to log the values with @@ -1211,9 +1210,9 @@ class MetricLoggingWrapper(LoggingWrapperBase): def log_hyperparams( self, - params: Dict, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + params: dict, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ :param params: Each key-value pair in the dictionary is the name of the @@ -1228,10 +1227,10 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ :param tag: identifying tag to log the value with @@ -1255,11 +1254,11 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ :param tag: identifying tag to log the values with @@ -1284,8 +1283,8 @@ def add_scalar( self, value, tag: str = DEFAULT_TAG, - step: Optional[int] = None, - wall_time: Union[int, float, None] = None, + step: int | None = None, + wall_time: int | float | None = None, **kwargs, ): """ @@ -1302,10 +1301,10 @@ def add_scalar( def add_scalars( self, - values: Dict[str, Any], + values: dict[str, Any], tag: str = DEFAULT_TAG, - step: Optional[int] = None, - wall_time: Union[int, float, None] = None, + step: int | None = None, + wall_time: int | float | None = None, **kwargs, ): """ @@ -1325,9 +1324,9 @@ def add_scalars( def log( self, - data: Dict[str, Any], - step: Optional[int] = None, - tag: Optional[str] = DEFAULT_TAG, + data: dict[str, Any], + step: int | None = None, + tag: str | None = DEFAULT_TAG, **kwargs, ) -> None: """ diff --git a/src/llmcompressor/metrics/utils/frequency_manager.py b/src/llmcompressor/metrics/utils/frequency_manager.py index 2234bb2eb8..ba58d3b01a 100644 --- a/src/llmcompressor/metrics/utils/frequency_manager.py +++ b/src/llmcompressor/metrics/utils/frequency_manager.py @@ -7,7 +7,7 @@ with configurable modes and intervals. """ -from typing import Literal, Optional, Union +from typing import Literal __all__ = [ "FrequencyManager", @@ -17,7 +17,7 @@ "log_ready", ] -LogStepType = Union[int, float, None] +LogStepType = int | float | None LoggingModeType = Literal["on_change", "exact"] FrequencyType = Literal["epoch", "step"] @@ -253,10 +253,10 @@ def _set_frequency_type(self, frequency_type: FrequencyType) -> FrequencyType: def log_ready( - current_log_step: Optional[LogStepType], - last_log_step: Optional[LogStepType], - log_frequency: Optional[LogStepType], - last_model_update_step: Optional[LogStepType] = None, + current_log_step: LogStepType | None, + last_log_step: LogStepType | None, + log_frequency: LogStepType | None, + last_model_update_step: LogStepType | None = None, check_model_update: bool = False, ): """ From 2e24d6daf3c25fedde7c2e3192a81e46ceebc59b Mon Sep 17 00:00:00 2001 From: Sugat Mahanti Date: Sun, 9 Nov 2025 17:02:32 -0500 Subject: [PATCH 2/3] Addressing comments: Enforce and clean up type hints --- src/llmcompressor/metrics/logger.py | 10 +++++----- src/llmcompressor/metrics/utils/frequency_manager.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/metrics/logger.py b/src/llmcompressor/metrics/logger.py index 47a09be4c4..85e9e45123 100644 --- a/src/llmcompressor/metrics/logger.py +++ b/src/llmcompressor/metrics/logger.py @@ -223,7 +223,7 @@ def lambda_func( def log_hyperparams( self, - params: dict, + params: dict[str, float], level: int | str | None = None, ) -> bool: """ @@ -931,7 +931,7 @@ def log_scalar( value: float, step: int | None = None, wall_time: float | None = None, - log_types: str | list[str] | None = ALL_TOKEN, + log_types: str | list[str] = ALL_TOKEN, level: int | str | None = None, ): """ @@ -962,7 +962,7 @@ def log_scalars( values: dict[str, float], step: int | None = None, wall_time: float | None = None, - log_types: str | list[str] | None = ALL_TOKEN, + log_types: str | list[str] = ALL_TOKEN, level: int | str | None = None, ): """ @@ -990,7 +990,7 @@ def log_scalars( def log_hyperparams( self, params: dict, - log_types: str | list[str] | None = ALL_TOKEN, + log_types: str | list[str] = ALL_TOKEN, level: int | str | None = None, ): """ @@ -1210,7 +1210,7 @@ class MetricLoggingWrapper(LoggingWrapperBase): def log_hyperparams( self, - params: dict, + params: dict[str, float], log_types: str | list[str] = ALL_TOKEN, level: int | str | None = None, ): diff --git a/src/llmcompressor/metrics/utils/frequency_manager.py b/src/llmcompressor/metrics/utils/frequency_manager.py index ba58d3b01a..20ddd7b9e7 100644 --- a/src/llmcompressor/metrics/utils/frequency_manager.py +++ b/src/llmcompressor/metrics/utils/frequency_manager.py @@ -256,7 +256,7 @@ def log_ready( current_log_step: LogStepType | None, last_log_step: LogStepType | None, log_frequency: LogStepType | None, - last_model_update_step: LogStepType | None = None, + last_model_update_step: LogStepType = None, check_model_update: bool = False, ): """ From 2637edf20042e4cd793edbf99989c17fde8262f0 Mon Sep 17 00:00:00 2001 From: Sugat Mahanti Date: Sun, 9 Nov 2025 17:11:06 -0500 Subject: [PATCH 3/3] Using match-case instead of if-else. Also fixed a bug --- src/llmcompressor/metrics/logger.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/metrics/logger.py b/src/llmcompressor/metrics/logger.py index 85e9e45123..fd6fcc0200 100644 --- a/src/llmcompressor/metrics/logger.py +++ b/src/llmcompressor/metrics/logger.py @@ -389,11 +389,13 @@ def _log_lambda( def is_higher_than_debug(lev: int | str | None = None) -> bool: """Check if the given level is higher than DEBUG level.""" debug_level_no = logger.level("DEBUG").no - if isinstance(lev, int): - return level > debug_level_no - elif isinstance(lev, str): - return logger.level(lev).no > debug_level_no - return False + match lev: + case int(): + return lev > debug_level_no + case str(): + return logger.level(lev).no > debug_level_no + case _: + return False if is_higher_than_debug(level): if step is not None: