diff --git a/src/llmcompressor/metrics/logger.py b/src/llmcompressor/metrics/logger.py index 10a030cc99..fd6fcc0200 100644 --- a/src/llmcompressor/metrics/logger.py +++ b/src/llmcompressor/metrics/logger.py @@ -11,7 +11,7 @@ from datetime import datetime from pathlib import Path from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable from loguru import logger @@ -94,7 +94,7 @@ def enabled(self, value: bool): def __repr__(self): return f"{self.__class__.__name__}(name={self._name}, enabled={self._enabled})" - def log_hyperparams(self, params: Dict[str, float]) -> bool: + def log_hyperparams(self, params: dict[str, float]) -> bool: """ :param params: Each key-value pair in the dictionary is the name of the hyper parameter and it's corresponding value. @@ -106,8 +106,8 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, + step: int | None = None, + wall_time: float | None = None, **kwargs, ) -> bool: """ @@ -123,9 +123,9 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, **kwargs, ) -> bool: """ @@ -142,8 +142,8 @@ def log_string( self, tag: str, string: str, - step: Optional[int] = None, - wall_time: Optional[float] = None, + step: int | None = None, + wall_time: float | None = None, **kwargs, ) -> bool: """ @@ -185,12 +185,12 @@ def __init__( self, lambda_func: Callable[ [ - Optional[str], - Optional[Union[float, str]], - Optional[Dict[str, float]], - Optional[int], - Optional[float], - Optional[int], + str | None, + float | str | None, + dict[str, float] | None, + int | None, + float | None, + int | None, ], bool, ], @@ -206,12 +206,12 @@ def lambda_func( self, ) -> Callable[ [ - Optional[str], - Optional[Union[float, str]], - Optional[Dict[str, float]], - Optional[int], - Optional[float], - Optional[int], + str | None, + float | str | None, + dict[str, float] | None, + int | None, + float | None, + int | None, ], bool, ]: @@ -223,8 +223,8 @@ def lambda_func( def log_hyperparams( self, - params: Dict, - level: Optional[Union[int, str]] = None, + params: dict[str, float], + level: int | str | None = None, ) -> bool: """ :param params: Each key-value pair in the dictionary is the name of the @@ -248,9 +248,9 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + level: int | str | None = None, ) -> bool: """ :param tag: identifying tag to log the value with @@ -277,10 +277,10 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, + level: int | str | None = None, ) -> bool: """ :param tag: identifying tag to log the values with @@ -366,12 +366,12 @@ def _create_default_logger(self) -> None: def _log_lambda( self, - tag: Optional[str], - value: Optional[Union[float, str]], - values: Optional[Dict[str, float]], - step: Optional[int], - wall_time: Optional[float], - level: Optional[Union[int, str]] = None, + tag: str | None, + value: float | str | None, + values: dict[str, float] | None, + step: int | None, + wall_time: float | None, + level: int | str | None = None, ) -> bool: """ :param tag: identifying tag to log the values with @@ -386,14 +386,16 @@ def _log_lambda( if not level: level = "DEBUG" - def is_higher_than_debug(lev: Optional[Union[int, str]] = None) -> bool: + def is_higher_than_debug(lev: int | str | None = None) -> bool: """Check if the given level is higher than DEBUG level.""" debug_level_no = logger.level("DEBUG").no - if isinstance(lev, int): - return level > debug_level_no - elif isinstance(lev, str): - return logger.level(lev).no > debug_level_no - return False + match lev: + case int(): + return lev > debug_level_no + case str(): + return logger.level(lev).no > debug_level_no + case _: + return False if is_higher_than_debug(level): if step is not None: @@ -417,11 +419,11 @@ def is_higher_than_debug(lev: Optional[Union[int, str]] = None) -> bool: def log_string( self, - tag: Optional[str], - string: Optional[str], - step: Optional[int], - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + tag: str | None, + string: str | None, + step: int | None, + wall_time: float | None = None, + level: int | str | None = None, ) -> bool: """ :param tag: identifying tag to log the values with @@ -513,12 +515,12 @@ def writer(self) -> SummaryWriter: def _log_lambda( self, - tag: Optional[str], - value: Optional[float], - values: Optional[Dict[str, float]], - step: Optional[int], - wall_time: Optional[float], - level: Optional[Union[int, str]] = None, + tag: str | None, + value: float | None, + values: dict[str, float] | None, + step: int | None, + wall_time: float | None, + level: int | str | None = None, ) -> bool: if value is not None: self._writer.add_scalar(tag, value, step, wall_time) @@ -553,10 +555,10 @@ def available() -> bool: def __init__( self, - init_kwargs: Optional[Dict] = None, + init_kwargs: dict | None = None, name: str = "wandb", enabled: bool = True, - wandb_err: Optional[Exception] = wandb_err, + wandb_err: Exception | None = wandb_err, ): if wandb_err: raise wandb_err @@ -587,12 +589,12 @@ def __init__( def _log_lambda( self, - tag: Optional[str], - value: Optional[float], - values: Optional[Dict[str, float]], - step: Optional[int], - wall_time: Optional[float], - level: Optional[Union[int, str]] = None, + tag: str | None, + value: float | None, + values: dict[str, float] | None, + step: int | None, + wall_time: float | None, + level: int | str | None = None, ) -> bool: params = {} @@ -653,27 +655,26 @@ class SparsificationGroupLogger(BaseLogger): def __init__( self, - lambda_func: Optional[ - Callable[ - [ - Optional[str], - Optional[float], - Optional[Dict[str, float]], - Optional[int], - Optional[float], - ], - bool, - ] - ] = None, + lambda_func: Callable[ + [ + str | None, + float | None, + dict[str, float] | None, + int | None, + float | None, + ], + bool, + ] + | None = None, python: bool = False, - python_log_level: Optional[Union[int, str]] = "INFO", - tensorboard: Optional[Union[bool, str, SummaryWriter]] = None, - wandb_: Optional[Union[bool, Dict]] = None, + python_log_level: int | str | None = "INFO", + tensorboard: bool | str | SummaryWriter = None, + wandb_: bool | dict | None = None, name: str = "sparsification", enabled: bool = True, ): super().__init__(name, enabled) - self._loggers: List[BaseLogger] = [] + self._loggers: list[BaseLogger] = [] if lambda_func: self._loggers.append( @@ -703,7 +704,7 @@ def __init__( if wandb_ and WANDBLogger.available(): self._loggers.append( WANDBLogger( - init_kwargs=wandb_ if isinstance(wandb_, Dict) else None, + init_kwargs=wandb_ if isinstance(wandb_, dict) else None, name=name, enabled=enabled, ) @@ -720,13 +721,13 @@ def enabled(self, value: bool): log.enabled = value @property - def loggers(self) -> List[BaseLogger]: + def loggers(self) -> list[BaseLogger]: """ :return: the created metrics sub instances for this metrics """ return self._loggers - def log_hyperparams(self, params: Dict, level: Optional[Union[int, str]] = None): + def log_hyperparams(self, params: dict, level: int | str | None = None): """ :param params: Each key-value pair in the dictionary is the name of the hyper parameter and it's corresponding value. @@ -738,9 +739,9 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + level: int | str | None = None, ): """ :param tag: identifying tag to log the value with @@ -756,10 +757,10 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, - level: Optional[Union[int, str]] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, + level: int | str | None = None, ): """ :param tag: identifying tag to log the values with @@ -793,8 +794,8 @@ class LoggerManager(ABC): def __init__( self, - loggers: Optional[List[BaseLogger]] = None, - log_frequency: Union[float, None] = 0.1, + loggers: list[BaseLogger] | None = None, + log_frequency: float | None = 0.1, log_python: bool = True, name: str = "manager", mode: LoggingModeType = "exact", @@ -817,7 +818,7 @@ def __init__( log_frequency=log_frequency, ) - self.system = SystemLoggingWraper( + self.system = SystemLoggingWrapper( loggers=self._loggers, frequency_manager=self.frequency_manager ) self.metric = MetricLoggingWrapper( @@ -882,28 +883,28 @@ def epoch_to_step(epoch, steps_per_epoch): return round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch) @property - def loggers(self) -> List[BaseLogger]: + def loggers(self) -> list[BaseLogger]: """ :return: list of loggers assigned to this manager """ return self._loggers @loggers.setter - def loggers(self, value: List[BaseLogger]): + def loggers(self, value: list[BaseLogger]): """ :param value: list of loggers assigned to this manager """ self._loggers = value @property - def log_frequency(self) -> Union[str, float, None]: + def log_frequency(self) -> str | float | None: """ :return: number of epochs or fraction of epochs to wait between logs """ return self.frequency_manager._log_frequency @log_frequency.setter - def log_frequency(self, value: Union[str, float, None]): + def log_frequency(self, value: str | float | None): """ :param value: number of epochs or fraction of epochs to wait between logs """ @@ -917,7 +918,7 @@ def name(self) -> str: return self._name @property - def wandb(self) -> Optional[ModuleType]: + def wandb(self) -> ModuleType | None: """ :return: wandb module if initialized """ @@ -930,10 +931,10 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ (Note: this method is deprecated and will be removed in a future version, @@ -960,11 +961,11 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ (Note: this method is deprecated and will be removed in a future version, @@ -990,9 +991,9 @@ def log_scalars( def log_hyperparams( self, - params: Dict, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + params: dict, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ (Note: this method is deprecated and will be removed in a future version, @@ -1012,10 +1013,10 @@ def log_string( self, tag: str, string: str, - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ (Note: this method is deprecated and will be removed in a future version, @@ -1052,7 +1053,7 @@ def save( log.save(file_path, **kwargs) @contextmanager - def time(self, tag: Optional[str] = None, *args, **kwargs): + def time(self, tag: str | None = None, *args, **kwargs): """ Context manager to log the time it takes to run the block of code @@ -1076,7 +1077,7 @@ class LoggingWrapperBase: Base class that holds a reference to the loggers and frequency manager """ - def __init__(self, loggers: List[BaseLogger], frequency_manager: FrequencyManager): + def __init__(self, loggers: list[BaseLogger], frequency_manager: FrequencyManager): self.loggers = loggers self._frequency_manager = frequency_manager @@ -1087,7 +1088,7 @@ def __repr__(self): ) -class SystemLoggingWraper(LoggingWrapperBase): +class SystemLoggingWrapper(LoggingWrapperBase): """ Wraps utilities and convenience methods for logging strings to the system """ @@ -1096,10 +1097,10 @@ def log_string( self, tag: str, string: str, - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ :param tag: identifying tag to log the values with @@ -1211,9 +1212,9 @@ class MetricLoggingWrapper(LoggingWrapperBase): def log_hyperparams( self, - params: Dict, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + params: dict[str, float], + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ :param params: Each key-value pair in the dictionary is the name of the @@ -1228,10 +1229,10 @@ def log_scalar( self, tag: str, value: float, - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ :param tag: identifying tag to log the value with @@ -1255,11 +1256,11 @@ def log_scalar( def log_scalars( self, tag: str, - values: Dict[str, float], - step: Optional[int] = None, - wall_time: Optional[float] = None, - log_types: Union[str, List[str]] = ALL_TOKEN, - level: Optional[Union[int, str]] = None, + values: dict[str, float], + step: int | None = None, + wall_time: float | None = None, + log_types: str | list[str] = ALL_TOKEN, + level: int | str | None = None, ): """ :param tag: identifying tag to log the values with @@ -1284,8 +1285,8 @@ def add_scalar( self, value, tag: str = DEFAULT_TAG, - step: Optional[int] = None, - wall_time: Union[int, float, None] = None, + step: int | None = None, + wall_time: int | float | None = None, **kwargs, ): """ @@ -1302,10 +1303,10 @@ def add_scalar( def add_scalars( self, - values: Dict[str, Any], + values: dict[str, Any], tag: str = DEFAULT_TAG, - step: Optional[int] = None, - wall_time: Union[int, float, None] = None, + step: int | None = None, + wall_time: int | float | None = None, **kwargs, ): """ @@ -1325,9 +1326,9 @@ def add_scalars( def log( self, - data: Dict[str, Any], - step: Optional[int] = None, - tag: Optional[str] = DEFAULT_TAG, + data: dict[str, Any], + step: int | None = None, + tag: str | None = DEFAULT_TAG, **kwargs, ) -> None: """ diff --git a/src/llmcompressor/metrics/utils/frequency_manager.py b/src/llmcompressor/metrics/utils/frequency_manager.py index 2234bb2eb8..20ddd7b9e7 100644 --- a/src/llmcompressor/metrics/utils/frequency_manager.py +++ b/src/llmcompressor/metrics/utils/frequency_manager.py @@ -7,7 +7,7 @@ with configurable modes and intervals. """ -from typing import Literal, Optional, Union +from typing import Literal __all__ = [ "FrequencyManager", @@ -17,7 +17,7 @@ "log_ready", ] -LogStepType = Union[int, float, None] +LogStepType = int | float | None LoggingModeType = Literal["on_change", "exact"] FrequencyType = Literal["epoch", "step"] @@ -253,10 +253,10 @@ def _set_frequency_type(self, frequency_type: FrequencyType) -> FrequencyType: def log_ready( - current_log_step: Optional[LogStepType], - last_log_step: Optional[LogStepType], - log_frequency: Optional[LogStepType], - last_model_update_step: Optional[LogStepType] = None, + current_log_step: LogStepType | None, + last_log_step: LogStepType | None, + log_frequency: LogStepType | None, + last_model_update_step: LogStepType = None, check_model_update: bool = False, ): """