diff --git a/keras/api/_tf_keras/keras/callbacks/__init__.py b/keras/api/_tf_keras/keras/callbacks/__init__.py index 4e165cddb6a8..ce5f900d80f5 100644 --- a/keras/api/_tf_keras/keras/callbacks/__init__.py +++ b/keras/api/_tf_keras/keras/callbacks/__init__.py @@ -19,6 +19,9 @@ from keras.src.callbacks.model_checkpoint import ( ModelCheckpoint as ModelCheckpoint, ) +from keras.src.callbacks.orbax_checkpoint import ( + OrbaxCheckpoint as OrbaxCheckpoint, +) from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger from keras.src.callbacks.reduce_lr_on_plateau import ( ReduceLROnPlateau as ReduceLROnPlateau, diff --git a/keras/api/callbacks/__init__.py b/keras/api/callbacks/__init__.py index 4e165cddb6a8..ce5f900d80f5 100644 --- a/keras/api/callbacks/__init__.py +++ b/keras/api/callbacks/__init__.py @@ -19,6 +19,9 @@ from keras.src.callbacks.model_checkpoint import ( ModelCheckpoint as ModelCheckpoint, ) +from keras.src.callbacks.orbax_checkpoint import ( + OrbaxCheckpoint as OrbaxCheckpoint, +) from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger from keras.src.callbacks.reduce_lr_on_plateau import ( ReduceLROnPlateau as ReduceLROnPlateau, diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..6a4879098197 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -75,3 +75,38 @@ class name_scope(backend_name_scope): @keras_export("keras.device") def device(device_name): return device_scope(device_name) # noqa: F405 + + +def get_process_index(): + """Get the index of the current process in a distributed setup. + + Returns: + int: The process index (0 for primary process, >0 for others). + Returns 0 if not in a distributed setup. + """ + backend_name = backend() + if backend_name == "jax": + try: + import jax + + return jax.process_index() + except (ImportError, AttributeError): + return 0 + elif backend_name == "tensorflow": + try: + import tensorflow as tf + + return tf.distribute.get_replica_context().replica_id_in_sync_group + except (ImportError, AttributeError, RuntimeError): + return 0 + elif backend_name == "torch": + try: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + except (ImportError, AttributeError): + return 0 + else: + return 0 diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 89ac0fa71c8c..9050723c0546 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,6 +1,5 @@ from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core -from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image from keras.src.backend.jax import linalg from keras.src.backend.jax import math @@ -25,6 +24,7 @@ from keras.src.backend.jax.core import shape from keras.src.backend.jax.core import stop_gradient from keras.src.backend.jax.core import vectorized_map +from keras.src.backend.jax.distribution_lib import process_id from keras.src.backend.jax.rnn import cudnn_ok from keras.src.backend.jax.rnn import gru from keras.src.backend.jax.rnn import lstm diff --git a/keras/src/backend/numpy/__init__.py b/keras/src/backend/numpy/__init__.py index 1a9d8eeb7916..8eadb54d77fb 100644 --- a/keras/src/backend/numpy/__init__.py +++ b/keras/src/backend/numpy/__init__.py @@ -20,6 +20,7 @@ from keras.src.backend.numpy.core import random_seed_dtype from keras.src.backend.numpy.core import shape from keras.src.backend.numpy.core import vectorized_map +from keras.src.backend.numpy.distribution_lib import process_id from keras.src.backend.numpy.rnn import cudnn_ok from keras.src.backend.numpy.rnn import gru from keras.src.backend.numpy.rnn import lstm diff --git a/keras/src/backend/numpy/distribution_lib.py b/keras/src/backend/numpy/distribution_lib.py new file mode 100644 index 000000000000..ea04795255ee --- /dev/null +++ b/keras/src/backend/numpy/distribution_lib.py @@ -0,0 +1,6 @@ +"""Utilities for distribution strategy with NumPy backend.""" + + +def process_id(): + """Return the current process ID for the distribution setting.""" + return 0 diff --git a/keras/src/backend/openvino/__init__.py b/keras/src/backend/openvino/__init__.py index 0612260452ea..2282d65e80cf 100644 --- a/keras/src/backend/openvino/__init__.py +++ b/keras/src/backend/openvino/__init__.py @@ -1,5 +1,6 @@ from keras.src.backend.common.name_scope import name_scope from keras.src.backend.openvino import core +from keras.src.backend.openvino import distribution_lib from keras.src.backend.openvino import image from keras.src.backend.openvino import linalg from keras.src.backend.openvino import math @@ -19,6 +20,7 @@ from keras.src.backend.openvino.core import random_seed_dtype from keras.src.backend.openvino.core import shape from keras.src.backend.openvino.core import vectorized_map +from keras.src.backend.openvino.distribution_lib import process_id from keras.src.backend.openvino.rnn import cudnn_ok from keras.src.backend.openvino.rnn import gru from keras.src.backend.openvino.rnn import lstm diff --git a/keras/src/backend/openvino/distribution_lib.py b/keras/src/backend/openvino/distribution_lib.py new file mode 100644 index 000000000000..3307d371682b --- /dev/null +++ b/keras/src/backend/openvino/distribution_lib.py @@ -0,0 +1,6 @@ +"""Utilities for distribution strategy with OpenVINO backend.""" + + +def process_id(): + """Return the current process ID for the distribution setting.""" + return 0 diff --git a/keras/src/backend/tensorflow/__init__.py b/keras/src/backend/tensorflow/__init__.py index ea4eed39b8da..31c55e87b2cc 100644 --- a/keras/src/backend/tensorflow/__init__.py +++ b/keras/src/backend/tensorflow/__init__.py @@ -1,5 +1,4 @@ from keras.src.backend.tensorflow import core -from keras.src.backend.tensorflow import distribution_lib from keras.src.backend.tensorflow import image from keras.src.backend.tensorflow import linalg from keras.src.backend.tensorflow import math @@ -24,6 +23,7 @@ from keras.src.backend.tensorflow.core import shape from keras.src.backend.tensorflow.core import stop_gradient from keras.src.backend.tensorflow.core import vectorized_map +from keras.src.backend.tensorflow.distribution_lib import process_id from keras.src.backend.tensorflow.rnn import cudnn_ok from keras.src.backend.tensorflow.rnn import gru from keras.src.backend.tensorflow.rnn import lstm diff --git a/keras/src/backend/tensorflow/distribution_lib.py b/keras/src/backend/tensorflow/distribution_lib.py index b306fd07dd0e..37a14f2c019c 100644 --- a/keras/src/backend/tensorflow/distribution_lib.py +++ b/keras/src/backend/tensorflow/distribution_lib.py @@ -85,3 +85,13 @@ def _to_backend_layout(tensor_layout): ] dtensor_mesh = tensor_layout.device_mesh.backend_mesh return dtensor.Layout(sharding_specs=sharding_specs, mesh=dtensor_mesh) + + +def process_id(): + """Return the current process ID for the distribution setting.""" + try: + import tensorflow as tf + + return tf.distribute.get_replica_context().replica_id_in_sync_group + except (ImportError, AttributeError, RuntimeError): + return 0 diff --git a/keras/src/backend/torch/__init__.py b/keras/src/backend/torch/__init__.py index 371a62cd0f52..3b3bc16cf1de 100644 --- a/keras/src/backend/torch/__init__.py +++ b/keras/src/backend/torch/__init__.py @@ -39,6 +39,7 @@ from keras.src.backend.torch.core import stop_gradient from keras.src.backend.torch.core import to_torch_dtype from keras.src.backend.torch.core import vectorized_map +from keras.src.backend.torch.distribution_lib import process_id from keras.src.backend.torch.rnn import cudnn_ok from keras.src.backend.torch.rnn import gru from keras.src.backend.torch.rnn import lstm diff --git a/keras/src/backend/torch/distribution_lib.py b/keras/src/backend/torch/distribution_lib.py new file mode 100644 index 000000000000..7043cc9b3540 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib.py @@ -0,0 +1,13 @@ +"""Utilities for distribution strategy with PyTorch backend.""" + + +def process_id(): + """Return the current process ID for the distribution setting.""" + try: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + except (ImportError, AttributeError): + return 0 diff --git a/keras/src/callbacks/__init__.py b/keras/src/callbacks/__init__.py index 427c4f6da95f..c62aed69ee63 100644 --- a/keras/src/callbacks/__init__.py +++ b/keras/src/callbacks/__init__.py @@ -8,6 +8,7 @@ from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler from keras.src.callbacks.model_checkpoint import ModelCheckpoint from keras.src.callbacks.monitor_callback import MonitorCallback +from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint from keras.src.callbacks.progbar_logger import ProgbarLogger from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau from keras.src.callbacks.remote_monitor import RemoteMonitor diff --git a/keras/src/callbacks/orbax_checkpoint.py b/keras/src/callbacks/orbax_checkpoint.py new file mode 100644 index 000000000000..0097fa5ce3a5 --- /dev/null +++ b/keras/src/callbacks/orbax_checkpoint.py @@ -0,0 +1,810 @@ +import os +import warnings + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.callbacks.monitor_callback import ( + MonitorCallback, # For metric monitoring logic +) +from keras.src.utils.io_utils import print_msg +from keras.src.utils.module_utils import LazyModule + +ocp = LazyModule( + "orbax.checkpoint", + pip_name="orbax-checkpoint", + import_error_msg=( + "OrbaxCheckpoint requires the 'orbax-checkpoint' package. " + "Install it with: pip install orbax-checkpoint" + ), +) + +# Note: Advanced Orbax functionality is available through the ocp LazyModule +# Users can access it via: from keras.src.utils.module_utils import LazyModule +# ocp = LazyModule("orbax.checkpoint"); ocp.CheckpointManager + + +def _get_state_tree(model): + """Get the complete model state as a nested tree structure.""" + state_tree = model.get_state_tree(value_format="numpy_array") + + # Convert numpy scalar types to Python types for Orbax compatibility + def convert_scalars(obj): + if isinstance(obj, np.ndarray) and obj.ndim == 0: + # Convert 0-dimensional numpy arrays (scalars) to Python types + return obj.item() + elif isinstance(obj, np.generic): + # Convert numpy scalar types (like np.float32) to Python types + return obj.item() + elif isinstance(obj, dict): + return {k: convert_scalars(v) for k, v in obj.items()} + else: + return obj + + return convert_scalars(state_tree) + + +def _flatten_state_tree_values(state_tree): + """Flatten nested state tree into a list of values in consistent order.""" + values = [] + + def _flatten(obj): + if isinstance(obj, dict): + for key in sorted(obj.keys()): # Sort for consistent ordering + _flatten(obj[key]) + else: + # Save any non-dict value (numpy arrays, lists, scalars, etc.) + values.append(obj) + + _flatten(state_tree) + return values + + +def _reconstruct_state_tree_with_values(structure, values): + """Reconstruct state tree structure with provided values.""" + value_iter = iter(values) + + def _reconstruct(obj): + if isinstance(obj, dict): + new_dict = {} + for key in sorted(obj.keys()): + new_dict[key] = _reconstruct(obj[key]) + return new_dict + else: + value = next(value_iter) + # Handle different cases for value conversion + if isinstance(obj, np.generic): + # obj is a numpy scalar (0-dimensional) + if isinstance(value, (int, float)): + # Convert Python scalar to numpy scalar + return np.array(value, dtype=obj.dtype) + elif isinstance(value, np.ndarray): + # value is a numpy array, convert to scalar if needed + if value.ndim == 0: + return np.array(value.item(), dtype=obj.dtype) + elif value.ndim == 1 and value.size == 1: + return np.array(value.item(), dtype=obj.dtype) + else: + return value.astype(obj.dtype).reshape(obj.shape) + else: + return np.array(value, dtype=obj.dtype) + elif isinstance(obj, np.ndarray): + # obj is a numpy array + if isinstance(value, np.ndarray): + return value.astype(obj.dtype).reshape(obj.shape) + else: + return np.array(value, dtype=obj.dtype).reshape(obj.shape) + else: + return value + + return _reconstruct(structure) + + +def _restore_legacy_format( + checkpoint_data, target_model, save_optimizer_state, save_metrics_state +): + """Restore from the old flat format for backward compatibility.""" + # Restore model weights + if "model_weights" in checkpoint_data: + model_weights_np = checkpoint_data["model_weights"] + # Convert NumPy arrays back to backend tensors and assign to + # model + for i, weight_np in enumerate(model_weights_np): + # Convert numpy array back to appropriate backend tensor + weight_tensor = ops.convert_to_tensor(weight_np) + target_model.weights[i].assign(weight_tensor) + + # Restore optimizer state if available + if "optimizer_state" in checkpoint_data and save_optimizer_state: + optimizer_vars_np = checkpoint_data["optimizer_state"] + # Only restore if the variable counts match + if len(optimizer_vars_np) == len(target_model.optimizer.variables): + # Convert NumPy arrays back to backend tensors and assign to + # optimizer + for i, var_np in enumerate(optimizer_vars_np): + var_tensor = ops.convert_to_tensor(var_np) + target_model.optimizer.variables[i].assign(var_tensor) + + # Restore metrics state if available + if ( + "metrics_state" in checkpoint_data + and save_metrics_state + and hasattr(target_model, "metrics") + ): + metrics_vars_np = checkpoint_data["metrics_state"] + metric_idx = 0 + for metric in target_model.metrics: + if ( + hasattr(metric, "variables") + and metric.variables + and metric_idx < len(metrics_vars_np) + ): + metric_vars_np = metrics_vars_np[metric_idx] + # Restore metric variables + for i, var_np in enumerate(metric_vars_np): + if i < len(metric.variables): + var_tensor = ops.convert_to_tensor(var_np) + metric.variables[i].assign(var_tensor) + metric_idx += 1 + + +@keras_export("keras.callbacks.OrbaxCheckpoint") +class OrbaxCheckpoint(MonitorCallback): + """Callback to save and load model state using Orbax with a similar API to + ModelCheckpoint. + + This callback saves the model's weights and optimizer state asynchronously + using Orbax, allowing training to continue without blocking for I/O. + It also provides methods to load checkpoints for resuming training or + inference. + It supports policies for keeping checkpoints and deciding when to save. + + Example: + + ```python + model.compile(loss=..., optimizer=..., + metrics=['accuracy']) + + EPOCHS = 10 + checkpoint_dir = '/tmp/ckpt' + orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( + directory=checkpoint_dir, + monitor='val_accuracy', + mode='max', + save_best_only=True) + + # Model is saved at the end of every epoch, if it's the best seen so far. + model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) + + # The model can be loaded from a specific checkpoint step as - + checkpoint = keras.callbacks.OrbaxCheckpoint(directory=checkpoint_dir) + checkpoint.load_checkpoint(step=5, model=model) # Load from step 5 + + # Alternatively, save checkpoints every N batches - + orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq=100) # Save every 100 batches + + model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) + + # Or use a SaveDecisionPolicy for more control - + from orbax.checkpoint import checkpoint_managers + policy = checkpoint_managers.FixedIntervalPolicy(interval=5) + orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( + directory=checkpoint_dir, + save_decision_policy=policy) # Save every 5 epochs + + model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) + + # JAX-specific features: Sharding and Multi-Host Checkpointing + # Note: These features are only available with JAX backend + + # Example with sharding support (JAX only): + from keras.distribution import DeviceMesh, TensorLayout + devices = keras.distribution.list_devices() + device_mesh = DeviceMesh(shape=(len(devices),), axis_names=('x',), + devices=devices) + tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh) + orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( + directory=checkpoint_dir, + sharding=tensor_layout.backend_layout + ) # Enable sharding for distributed arrays + + # Example with multi-host checkpointing (JAX only): + # Enables distributed checkpointing where each host writes its data shards + # while the primary process coordinates metadata and finalization + orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( + directory=checkpoint_dir, + multi_host=True) # Enable multi-host checkpointing + + # Combined sharding and multi-host (JAX only): + from keras.distribution import DeviceMesh, TensorLayout + devices = keras.distribution.list_devices() + device_mesh = DeviceMesh(shape=(len(devices),), axis_names=('x',), + devices=devices) + tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh) + orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( + directory=checkpoint_dir, + sharding=tensor_layout.backend_layout, + multi_host=True) # Enable both features + + model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) + ``` + + Args: + directory: string, path to the directory where to save the checkpoints. + monitor: The metric name to monitor (e.g., 'val_loss'). + verbose: Verbosity mode, 0 or 1. + save_best_only: if `save_best_only=True`, it only saves when the model + is considered the "best" based on the monitored quantity. + mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`. + save_freq: `'epoch'` or integer. Frequency to save checkpoints. + max_to_keep: Integer, maximum number of recent checkpoints to keep. + If None, keeps all. Defaults to 5. + keep_period: Integer, keep one checkpoint every `keep_period` saves. + Useful for keeping checkpoints less frequently over long runs. + initial_value_threshold: Floating point initial "best" value for the + monitor, used with `save_best_only`. + save_optimizer_state: Boolean, whether to include optimizer variables + in the checkpoint. Defaults to True. + save_on_background: Boolean, whether to save asynchronously in the + background. Defaults to True. + save_metadata: Dict or callable, additional metadata to save with each + checkpoint. If callable, it will be called with (epoch, logs) and + should return a dict. Defaults to None. + save_data_iterator: Dict or callable, data iterator state to save with + each checkpoint. If callable, it will be called with (epoch, logs) + and should return a dict with serializable iterator state. + Defaults to None. + save_metrics_state: Boolean, whether to include stateful metrics + variables in the checkpoint. Defaults to False. + async_timeout_secs: Integer, timeout in seconds for async checkpointing + operations. Defaults to 600 (10 minutes). + enable_background_delete: Boolean, whether to delete old checkpoints in + the background. Defaults to False. + post_finalization_callback: Callable, function to call after async + checkpointing operations complete. Defaults to None. + save_transforms: Dict of orbax.checkpoint.Transform objects to apply + during saving. Keys should match composite_state keys (e.g., + 'model_weights', 'optimizer_state'). Defaults to None. + save_decision_policy: orbax.checkpoint.SaveDecisionPolicy object to + control when checkpoints are saved. Currently supports + FixedIntervalPolicy for saving at regular intervals. If provided, + overrides the default save frequency logic. Defaults to None. + save_interval: Integer, save checkpoints every N steps. If provided, + overrides save_freq. Defaults to None. + sharding: JAX sharding specification for distributed checkpointing. + Only supported with JAX backend. If provided with TensorFlow or + PyTorch backends, will raise an error. Defaults to None. + multi_host: Boolean, whether to enable multi-host checkpointing for + distributed training across multiple processes/hosts. When enabled, + the primary process (rank 0) coordinates the checkpoint operation + while all processes write their data shards in parallel to create a + complete distributed checkpoint. Only supported with JAX backend. + If enabled with TensorFlow or PyTorch backends, will raise an error. + Defaults to False. + """ + + def __init__( + self, + directory, + monitor="val_loss", + verbose=0, + save_best_only=False, + mode="auto", + save_freq="epoch", + max_to_keep=5, + keep_period=None, + initial_value_threshold=None, + save_optimizer_state=True, + save_on_background=True, + save_metadata=None, + save_data_iterator=None, + save_metrics_state=False, + async_timeout_secs=600, + enable_background_delete=False, + post_finalization_callback=None, + save_transforms=None, + save_decision_policy=None, + save_interval=None, + sharding=None, + multi_host=False, + ): + # Ensure orbax is available + ocp.initialize() + + # Initialize MonitorCallback for handling 'monitor', 'mode', 'best' + # logic + super().__init__(monitor, mode, initial_value_threshold) + + self.directory = directory + self.verbose = verbose + self.save_best_only = save_best_only + self.save_freq = save_freq + self.save_optimizer_state = save_optimizer_state + self.save_metadata = save_metadata + self.save_data_iterator = save_data_iterator + self.save_metrics_state = save_metrics_state + self.async_timeout_secs = async_timeout_secs + self.enable_background_delete = enable_background_delete + self.post_finalization_callback = post_finalization_callback + self.save_transforms = save_transforms + self.save_decision_policy = save_decision_policy + self.save_interval = save_interval + + # JAX-specific features validation + self.sharding = sharding + self.multi_host = multi_host + + # Validate JAX-only features + if sharding is not None or multi_host: + if backend.backend() != "jax": + raise ValueError( + "sharding and multi_host parameters are only supported " + "with JAX backend. Current backend: " + backend.backend() + ) + + # Validate sharding object type + if sharding is not None and backend.backend() == "jax": + # Basic validation: sharding should not be a string or other + # primitive type + if isinstance(sharding, (str, int, float, bool)): + raise TypeError( + f"sharding parameter must be a valid JAX sharding object, " + f"got {type(sharding).__name__}: {sharding}" + ) + self._batches_seen_since_last_saving = 0 + self._last_batch_seen = 0 + self._current_epoch = 0 # Keep track of epoch + self._total_batches_seen = 0 # Global batch counter for step tracking + + if self.save_freq != "epoch" and not isinstance(self.save_freq, int): + raise ValueError("Unrecognized save_freq") + + # Create should_save_fn from save_decision_policy or save_interval + # if provided + should_save_fn = None + if save_decision_policy is not None: + # When using save_decision_policy, let Orbax handle + # should_save_fn internally + # Don't override should_save_fn + pass + elif save_interval is not None: + # Create should_save_fn that saves every N steps + should_save_fn = ( + lambda step, prev_step=None: step % save_interval == 0 + ) + + # --- Orbax CheckpointManager Setup --- + from orbax.checkpoint import AsyncOptions + + async_options = AsyncOptions( + timeout_secs=self.async_timeout_secs, + post_finalization_callback=self.post_finalization_callback, + ) + + options = ocp.CheckpointManagerOptions( + max_to_keep=max_to_keep, + keep_period=keep_period, + enable_async_checkpointing=save_on_background, + enable_background_delete=self.enable_background_delete, + async_options=async_options, + should_save_fn=should_save_fn, + save_decision_policy=save_decision_policy, + ) + + # Multi-host setup for JAX + if self.multi_host and backend.backend() == "jax": + try: + # Enable multi-host checkpointing using Keras distribution API + from keras.src import distribution + + distribution.initialize() + except RuntimeError as e: + # If distributed cannot be initialized (e.g., JAX already + # initialized), continue anyway - the multi_host flag is mainly + # a hint to Orbax. + # We check for messages related to initialization state. + error_str = str(e).lower() + if ( + "already been initialized" in error_str + or "must be called before" in error_str + ): + pass # This is expected in some environments. + else: + raise + # Orbax will automatically handle multi-host coordination: + # - Primary process (rank 0) coordinates and writes + # metadata/manifest + # - All processes write their data shards in parallel to the + # checkpoint directory + + # Ensure directory exists (only needed on one process in multi-host) + if backend.get_process_index() == 0: + os.makedirs(directory, exist_ok=True) + + # Create the CheckpointManager + self.manager = ocp.CheckpointManager( + directory=directory, + options=options, + ) + + def set_model(self, model): + self._model = model + + def _should_save_on_batch(self, batch): + """Check if we should save on this batch.""" + if self.save_freq == "epoch": + return False + + if batch <= self._last_batch_seen: # New epoch. + add_batches = batch + 1 + else: + add_batches = batch - self._last_batch_seen + self._batches_seen_since_last_saving += add_batches + self._last_batch_seen = batch + self._total_batches_seen += add_batches + + if self._batches_seen_since_last_saving >= self.save_freq: + self._batches_seen_since_last_saving = 0 + return True + return False + + def _get_current_step(self): + # A reliable way to get a global step count + # Using optimizer iterations is common + if hasattr(self.model, "optimizer") and hasattr( + self.model.optimizer, "iterations" + ): + # Convert potential backend tensor to int + return int( + backend.convert_to_numpy(self.model.optimizer.iterations) + ) + else: + # Fallback: use global batch count + return self._total_batches_seen + + def _save_checkpoint(self, step, logs=None): + """Save a checkpoint at the given step.""" + if self.model is None: + return + + # --- Prepare Composite State (Backend-Agnostic) --- + state_tree = _get_state_tree(self.model) + + if state_tree is None: + if self.verbose > 0: + print_msg( + "OrbaxCheckpoint: Skipping save due to state tree error" + ) + return + + # Flatten the trainable variables values for cross-model compatibility + trainable_values = _flatten_state_tree_values( + state_tree["trainable_variables"] + ) + + # Save optimizer and metrics state if requested + optimizer_values = None + if self.save_optimizer_state and "optimizer_variables" in state_tree: + optimizer_values = _flatten_state_tree_values( + state_tree["optimizer_variables"] + ) + + metrics_values = None + if self.save_metrics_state and "metrics_variables" in state_tree: + metrics_values = _flatten_state_tree_values( + state_tree["metrics_variables"] + ) + + composite_state = { + "model_weights": trainable_values, + } + + if optimizer_values is not None: + composite_state["optimizer_state"] = optimizer_values + if metrics_values is not None: + composite_state["metrics_variables"] = metrics_values + + # Add metadata if specified + if self.save_metadata is not None: + if callable(self.save_metadata): + metadata = self.save_metadata(self._current_epoch, logs) + else: + metadata = self.save_metadata + if metadata: + composite_state["metadata"] = metadata + + # Add data iterator state if specified + if self.save_data_iterator is not None: + if callable(self.save_data_iterator): + iterator_state = self.save_data_iterator( + self._current_epoch, logs + ) + else: + iterator_state = self.save_data_iterator + if iterator_state: + composite_state["data_iterator"] = iterator_state + + # --- Save Logic --- + # Only save on the primary process (rank 0) in distributed setups + is_primary_host = backend.get_process_index() == 0 + + if is_primary_host: + if self.verbose > 0: + print_msg( + f"OrbaxCheckpoint: Triggering async save for step {step}..." + ) + + # Save the checkpoint + save_args = ocp.args.StandardSave( + composite_state, save_args=self.save_transforms + ) + + # Apply sharding if specified (JAX only) + # Note: Sharding is handled automatically by Orbax when saving + # sharded JAX arrays. No explicit sharding parameter needed. + self.manager.save(step, args=save_args) + + def on_train_batch_end(self, batch, logs=None): + if self._should_save_on_batch(batch): + # Handle save_best_only logic for batch-level saving + should_save = True + if self.save_best_only: + current = logs.get(self.monitor) if logs else None + if current is None: + warnings.warn( + f"Can save best model only with {self.monitor} " + f"available, skipping save at batch {batch}.", + stacklevel=2, + ) + should_save = False + elif not self._is_improvement(current, self.best): + should_save = False + else: + # Update best value when there's improvement + self.best = current + + if should_save: + # Use step number (e.g., optimizer iterations) for Orbax save + # step + step = self._get_current_step() + self._save_checkpoint(step=step, logs=logs) + + def on_epoch_end(self, epoch, logs=None): + self._current_epoch = epoch + if self.monitor_op is None: + self._set_monitor_op() # From MonitorCallback + + should_save = False + if self.save_decision_policy is not None: + # Handle FixedIntervalPolicy by extracting its interval + from orbax.checkpoint import checkpoint_managers + + if isinstance( + self.save_decision_policy, + checkpoint_managers.FixedIntervalPolicy, + ): + should_save = epoch % self.save_decision_policy.interval == 0 + else: + # For other policies, fall back to saving every epoch + # TODO: Implement full support for other SaveDecisionPolicy + # types + should_save = True + elif self.save_interval is not None: + # Save every N epochs + should_save = epoch % self.save_interval == 0 + elif self.save_freq == "epoch": + should_save = True + + # Handle save_best_only logic + if should_save and self.save_best_only: + current = logs.get(self.monitor) if logs else None + if current is None: + warnings.warn( + f"Can save best model only with {self.monitor} available, " + f"skipping save at epoch {epoch}.", + stacklevel=2, + ) + should_save = False + elif not self._is_improvement(current, self.best): + should_save = False + else: + # Update best value when there's improvement + self.best = current + + if should_save: + # Use epoch number as the step for Orbax save + self._save_checkpoint(step=epoch, logs=logs) + + def on_train_end(self, logs=None): + if self.verbose > 0: + print_msg("OrbaxCheckpoint: Waiting for final saves to complete...") + self.manager.wait_until_finished() + if self.verbose > 0: + print_msg("OrbaxCheckpoint: All saves finalized.") + + def load_checkpoint(self, step, model=None): + """Load model and optimizer state from a specific checkpoint step. + + Args: + step: The checkpoint step to load from. + model: Optional model to load into. If None, loads into self.model. + + Returns: + tuple: (success, iterator_state) where success is True if loading + was successful, False otherwise, and iterator_state is the saved + data iterator state dict if available, None otherwise. + """ + # In multi-host distributed training, all processes participate in + # loading to read their respective data shards in parallel. Only the + # primary process coordinates the metadata reading and broadcasting. + if self.multi_host and backend.backend() == "jax": + # Multi-host loading: all processes participate + pass # Continue with loading on all processes + elif backend.get_process_index() != 0: + # Single-host or non-multi-host distributed: only primary + # process loads + return True # Return True to indicate no error, but no loading + + if self.verbose > 0: + print_msg( + f"OrbaxCheckpoint: Loading checkpoint from step {step}..." + ) + + # Prepare restore arguments - Orbax can restore without explicit + # template + restore_args = ocp.args.StandardRestore() + + # Apply sharding if specified (JAX only) + # Note: Sharding is handled automatically by Orbax when loading + # sharded JAX arrays. No explicit sharding parameter needed. + checkpoint_data = self.manager.restore(step, args=restore_args) + + # Restore the model state + target_model = model if model is not None else self.model + success = self._restore_model_state(checkpoint_data, target_model) + + # Extract iterator state if available + iterator_state = checkpoint_data.get("data_iterator", None) + + return success, iterator_state + + def load_latest(self, model=None): + """Load the most recent checkpoint. + + Args: + model: Optional model to load into. If None, loads into self.model. + + Returns: + tuple: (success, iterator_state) where success is True if loading + was successful, False otherwise, and iterator_state is the saved + data iterator state dict if available, None otherwise. + """ + # Get the latest step + latest_step = self.manager.latest_step() + if latest_step is None: + raise FileNotFoundError("OrbaxCheckpoint: No checkpoints found") + + return self.load_checkpoint(latest_step, model) + + def _restore_model_state(self, checkpoint_data, model=None): + """Restore model state from checkpoint data. + + Args: + checkpoint_data: The checkpoint data loaded from Orbax. + model: Optional model to restore into. If None, uses self.model. + + Returns: + bool: True if restoration was successful. + """ + target_model = model if model is not None else self.model + + # Check if this is the new flattened format + if "model_weights" in checkpoint_data and isinstance( + checkpoint_data["model_weights"], list + ): + # New format: flattened values + return self._restore_from_flattened_values( + checkpoint_data, target_model + ) + elif "model_state" in checkpoint_data: + # Old format: full state tree (for backward compatibility) + return self._restore_from_state_tree( + checkpoint_data["model_state"], target_model + ) + else: + # Fallback to legacy format + _restore_legacy_format( + checkpoint_data, + target_model, + self.save_optimizer_state, + self.save_metrics_state, + ) + return True + + def _restore_from_flattened_values(self, checkpoint_data, target_model): + """Restore from the new flattened values format.""" + # Get the target model's state tree structure (without convert_scalars) + target_state_tree = target_model.get_state_tree( + value_format="numpy_array" + ) + if target_state_tree is None: + if self.verbose > 0: + print_msg( + "OrbaxCheckpoint: Could not get target model state tree" + ) + return False + + # Reconstruct state tree with saved values + reconstructed_state = {} + + # Restore trainable variables + if "model_weights" in checkpoint_data: + saved_trainable_values = checkpoint_data["model_weights"] + target_trainable_structure = target_state_tree[ + "trainable_variables" + ] + reconstructed_state["trainable_variables"] = ( + _reconstruct_state_tree_with_values( + target_trainable_structure, saved_trainable_values + ) + ) + + # Restore optimizer variables if available + if ( + "optimizer_state" in checkpoint_data + and self.save_optimizer_state + and "optimizer_variables" in target_state_tree + ): + saved_optimizer_values = checkpoint_data["optimizer_state"] + target_optimizer_structure = target_state_tree[ + "optimizer_variables" + ] + reconstructed_state["optimizer_variables"] = ( + _reconstruct_state_tree_with_values( + target_optimizer_structure, saved_optimizer_values + ) + ) + + # Restore metrics variables if available + if ( + "metrics_variables" in checkpoint_data + and self.save_metrics_state + and "metrics_variables" in target_state_tree + ): + saved_metrics_values = checkpoint_data["metrics_variables"] + target_metrics_structure = target_state_tree["metrics_variables"] + reconstructed_state["metrics_variables"] = ( + _reconstruct_state_tree_with_values( + target_metrics_structure, saved_metrics_values + ) + ) + + # Use set_state_tree to restore the reconstructed state + target_model.set_state_tree(reconstructed_state) + + if self.verbose > 0: + print_msg("OrbaxCheckpoint: Successfully restored model state") + return True + + def _restore_from_state_tree(self, state_tree, target_model): + """Restore from the old full state tree format + (for backward compatibility).""" + target_model.set_state_tree(state_tree) + if self.verbose > 0: + print_msg("OrbaxCheckpoint: Successfully restored model state") + return True + + +# Export additional Orbax functionality for advanced users (only if available) +if ocp.available: + CheckpointManager = ocp.CheckpointManager + PyTreeCheckpointer = ocp.PyTreeCheckpointer + SaveArgs = ocp.SaveArgs + StandardRestore = ocp.args.StandardRestore + TypeHandler = ocp.type_handlers.TypeHandler + metadata = ocp.metadata + register_type_handler = ocp.type_handlers.register_type_handler diff --git a/keras/src/callbacks/orbax_checkpoint_test.py b/keras/src/callbacks/orbax_checkpoint_test.py new file mode 100644 index 000000000000..086e3de7a4bb --- /dev/null +++ b/keras/src/callbacks/orbax_checkpoint_test.py @@ -0,0 +1,2058 @@ +import os +import shutil +import tempfile + +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing + +# Import advanced Orbax functionality through the Keras bridge +# These will only be available if orbax-checkpoint is installed +try: + from keras.src.callbacks.orbax_checkpoint import CheckpointManager + from keras.src.callbacks.orbax_checkpoint import PyTreeCheckpointer + from keras.src.callbacks.orbax_checkpoint import SaveArgs + from keras.src.callbacks.orbax_checkpoint import StandardRestore + from keras.src.callbacks.orbax_checkpoint import TypeHandler + from keras.src.callbacks.orbax_checkpoint import metadata + from keras.src.callbacks.orbax_checkpoint import register_type_handler +except ImportError: + # If orbax is not available, these won't be exported + pass + +from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint + +# Import distribution for sharding tests +try: + from keras.src import distribution +except ImportError: + distribution = None + +# Skip the entire test module if orbax-checkpoint is not available +pytest.importorskip("orbax.checkpoint") + + +class OrbaxCheckpointTest(testing.TestCase): + def setUp(self): + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _create_test_model(self): + """Create a simple test model.""" + inputs = layers.Input(shape=(10,)) + x = layers.Dense(5)(inputs) + outputs = layers.Dense(1)(x) + model = models.Model(inputs, outputs) + model.compile(optimizer="adam", loss="mse") + return model + + def _create_dummy_data(self, num_samples=100): + """Create dummy training data.""" + x = np.random.randn(num_samples, 10) + y = np.random.randn(num_samples, 1) + return x, y + + @pytest.mark.requires_trainable_backend + def test_basic_save_and_load(self): + """Test basic save and load functionality.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_basic") + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq="epoch") + + # Train for a few epochs + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Create a new model and load the checkpoint + new_model = self._create_test_model() + success = callback.load_latest(model=new_model) + + self.assertTrue(success, "Loading checkpoint should succeed") + + # Check that weights are loaded (rough check) + original_weights = [w.numpy() for w in model.weights] + loaded_weights = [w.numpy() for w in new_model.weights] + + # Weights should be different initially + self.assertTrue(np.allclose(original_weights[0], loaded_weights[0])) + + @pytest.mark.requires_trainable_backend + def test_save_best_only(self): + """Test save_best_only functionality.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_best_only") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + monitor="loss", # Monitor training loss + save_best_only=True, # Only save when loss improves + mode="min", # Lower loss is better + save_freq="epoch", # Check every epoch + ) + + # Train for a few epochs - losses should generally decrease + model.fit(x, y, epochs=3, callbacks=[callback], verbose=0) + + # Verify checkpoints were saved only when loss improved + # With save_best_only=True, should save on each improvement + # (typically each epoch for decreasing loss) + all_steps = callback.manager.all_steps() + self.assertGreaterEqual( + len(all_steps), + 1, + f"Should save at least 1 checkpoint with save_best_only=True, " + f"got {len(all_steps)}", + ) + # In practice, with decreasing loss, we expect 3 checkpoints + # (one per epoch) but the exact number depends on when + # improvements occur + self.assertLessEqual( + len(all_steps), + 3, + f"Should save at most 3 checkpoints (one per epoch), " + f"got {len(all_steps)}", + ) + + # Verify that checkpoints correspond to valid epoch steps + for step in all_steps: + self.assertGreaterEqual( + step, 0, f"Checkpoint step should be >= 0, got {step}" + ) + self.assertLessEqual( + step, + 2, + f"Checkpoint step should be <= 2 (epochs are 0-indexed), " + f"got {step}", + ) + + @pytest.mark.requires_trainable_backend + def test_save_freq_batch(self): + """Test batch-level saving.""" + model = self._create_test_model() + x, y = self._create_dummy_data(num_samples=50) + + checkpoint_dir = os.path.join(self.temp_dir, "test_batch_freq") + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq=10) + + # Train for one epoch with batch saving + model.fit(x, y, epochs=1, batch_size=5, callbacks=[callback], verbose=0) + + # Should have saved checkpoints + checkpoints = [] + for root, dirs, files in os.walk(checkpoint_dir): + checkpoints.extend(dirs) + + self.assertGreater( + len(checkpoints), + 0, + "Should have saved checkpoints at batch intervals", + ) + + @pytest.mark.requires_trainable_backend + def test_max_to_keep(self): + """Test max_to_keep parameter.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_max_keep") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, save_freq="epoch", max_to_keep=2 + ) + + # Train for more epochs than max_to_keep + model.fit(x, y, epochs=5, callbacks=[callback], verbose=0) + + # Check that max_to_keep is respected + all_steps = callback.manager.all_steps() + self.assertLessEqual( + len(all_steps), + 2, + f"Should keep at most 2 checkpoints, found {len(all_steps)}: " + f"{all_steps}", + ) + + @pytest.mark.requires_trainable_backend + def test_synchronous_checkpointing(self): + """Test synchronous checkpointing (save_on_background=False).""" + + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Test synchronous checkpointing + checkpoint_dir_sync = os.path.join(self.temp_dir, "test_sync") + callback_sync = OrbaxCheckpoint( + directory=checkpoint_dir_sync, + save_freq="epoch", + save_on_background=False, # Synchronous saving + ) + + # Measure time for synchronous saving + model.fit(x, y, epochs=3, callbacks=[callback_sync], verbose=0) + + # Check that checkpoints were saved + all_steps_sync = callback_sync.manager.all_steps() + self.assertEqual( + len(all_steps_sync), + 3, + f"Should have 3 checkpoints, found {len(all_steps_sync)}", + ) + + # Verify we can load the checkpoints immediately (no need to wait) + success = callback_sync.load_latest() + self.assertTrue(success, "Should successfully load latest checkpoint") + + # Test asynchronous checkpointing for comparison + model2 = self._create_test_model() + checkpoint_dir_async = os.path.join(self.temp_dir, "test_async") + callback_async = OrbaxCheckpoint( + directory=checkpoint_dir_async, + save_freq="epoch", + save_on_background=True, # Asynchronous saving (default) + ) + + # Measure time for asynchronous saving + model2.fit(x, y, epochs=3, callbacks=[callback_async], verbose=0) + # async_time = time.time() - start_time + + # For async mode, ensure background operations complete + callback_async.manager.wait_until_finished() + + # Check that checkpoints were saved + all_steps_async = callback_async.manager.all_steps() + self.assertEqual( + len(all_steps_async), + 3, + f"Should have 3 checkpoints, found {len(all_steps_async)}", + ) + + # Verify we can load the checkpoints + success = callback_async.load_latest() + self.assertTrue(success, "Should successfully load latest checkpoint") + + # Both sync and async modes should work correctly + # (async allows training to continue while saving happens in background, + # but in this small test the timing difference may not be measurable) + + @pytest.mark.requires_trainable_backend + def test_keep_period_functionality(self): + """Test keep_period parameter keeps checkpoints every Nth save + plus recent ones.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_keep_period") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + max_to_keep=5, # Keep last 5 checkpoints + keep_period=3, # Keep every 3rd checkpoint + ) + + # Train for 10 epochs + model.fit(x, y, epochs=10, callbacks=[callback], verbose=0) + + # Check that checkpoints follow keep_period pattern + all_steps = sorted(callback.manager.all_steps()) + + # With keep_period=3 and training for 10 epochs (steps 0-9), + # multiples of 3 that should be kept: 0, 3, 6, 9 + expected_periodic_checkpoints = [0, 3, 6, 9] + + # Verify ALL expected periodic checkpoints are kept + for periodic_step in expected_periodic_checkpoints: + self.assertIn( + periodic_step, + all_steps, + f"Periodic checkpoint {periodic_step} " + f"(multiple of keep_period=3) should be kept, " + f"but only found {all_steps}", + ) + + # Verify that some recent checkpoints are also kept + # (the most recent ones within max_to_keep limit) + recent_steps = [step for step in all_steps if step >= 5] # steps 5-9 + self.assertGreater( + len(recent_steps), + 0, + f"Should keep some recent checkpoints, found {all_steps}", + ) + + # The total should be reasonable (periodic + recent, but may exceed + # max_to_keep) + # In this case, we expect at least the 4 periodic + some recent = + # at least 5 + self.assertGreaterEqual( + len(all_steps), + 4, # At minimum, all periodic checkpoints + f"Should keep at least periodic checkpoints, found " + f"{len(all_steps)}: {all_steps}", + ) + + @pytest.mark.requires_trainable_backend + def test_keep_period_vs_no_keep_period(self): + """Test that keep_period preserves periodic checkpoints that would + otherwise be deleted.""" + # First, test WITHOUT keep_period + model1 = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir_no_period = os.path.join(self.temp_dir, "test_no_period") + callback_no_period = OrbaxCheckpoint( + directory=checkpoint_dir_no_period, + save_freq="epoch", + max_to_keep=3, # Keep only last 3 checkpoints + ) + + # Train for 10 epochs + model1.fit(x, y, epochs=10, callbacks=[callback_no_period], verbose=0) + steps_no_period = sorted(callback_no_period.manager.all_steps()) + + # Without keep_period, should keep only the most recent max_to_keep=3 + expected_recent_only = [7, 8, 9] # Last 3 epochs (0-indexed) + self.assertEqual( + steps_no_period, + expected_recent_only, + f"Without keep_period, should keep only recent checkpoints: " + f"{expected_recent_only}, got {steps_no_period}", + ) + + # Now test WITH keep_period + model2 = self._create_test_model() + checkpoint_dir_with_period = os.path.join( + self.temp_dir, "test_with_period" + ) + callback_with_period = OrbaxCheckpoint( + directory=checkpoint_dir_with_period, + save_freq="epoch", + max_to_keep=3, # Same max_to_keep + keep_period=4, # Keep every 4th checkpoint + ) + + # Train for 10 epochs + model2.fit(x, y, epochs=10, callbacks=[callback_with_period], verbose=0) + steps_with_period = sorted(callback_with_period.manager.all_steps()) + + # With keep_period=4, should keep multiples of 4: 0, 4, 8 + # Plus recent ones within max_to_keep limit + periodic_checkpoints = [0, 4, 8] + for periodic_step in periodic_checkpoints: + self.assertIn( + periodic_step, + steps_with_period, + f"Periodic checkpoint {periodic_step} should be kept with " + f"keep_period=4, found {steps_with_period}", + ) + + # Should have more checkpoints than without keep_period + self.assertGreater( + len(steps_with_period), + len(steps_no_period), + f"With keep_period should keep more checkpoints than without. " + f"With period: {steps_with_period}, without: {steps_no_period}", + ) + + @pytest.mark.requires_trainable_backend + def test_checkpoint_error_handling(self): + """Test error handling when checkpoint operations fail.""" + x, y = self._create_dummy_data() + + # Test: Try to load from a non-existent checkpoint + checkpoint_dir = os.path.join(self.temp_dir, "test_error_handling") + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq="epoch") + + # Try to load a checkpoint that doesn't exist - should raise exception + with self.assertRaises(Exception): + callback.load_checkpoint(step=999) + + # Test: Try to load latest when no checkpoints exist - + # should raise FileNotFoundError + with self.assertRaises(FileNotFoundError): + callback.load_latest() + + @pytest.mark.requires_trainable_backend + def test_partial_checkpoint_loading(self): + """Test loading individual components from composite checkpoints.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_partial_load") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_metadata={"epoch": 1, "custom_value": 42.5}, + save_data_iterator={"batch_index": 42}, + ) + + # Train for a few epochs to create checkpoints + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Manually load checkpoint data to test partial access + manager = CheckpointManager(directory=checkpoint_dir) + restore_args = StandardRestore() + checkpoint_data = manager.restore(step=1, args=restore_args) + + # Verify we can access individual components + self.assertIn( + "model_weights", + checkpoint_data, + "Model weights should be available", + ) + self.assertIn( + "optimizer_state", + checkpoint_data, + "Optimizer state should be available", + ) + self.assertIn( + "metadata", checkpoint_data, "Metadata should be available" + ) + self.assertIn( + "data_iterator", + checkpoint_data, + "Data iterator should be available", + ) + + # Check metadata content + self.assertEqual(checkpoint_data["metadata"]["epoch"], 1) + self.assertEqual(checkpoint_data["metadata"]["custom_value"], 42.5) + + # Check iterator state content + self.assertEqual(checkpoint_data["data_iterator"]["batch_index"], 42) + + # Verify model weights have the right shape (without loading them) + model_weights = checkpoint_data["model_weights"] + self.assertEqual( + len(model_weights), + len(model.weights), + "Should have weights for all model parameters", + ) + + @pytest.mark.requires_trainable_backend + def test_background_delete_functionality(self): + """Test background deletion of old checkpoints.""" + # Test WITHOUT background deletion (synchronous) + model1 = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir_sync = os.path.join(self.temp_dir, "test_sync_delete") + callback_sync = OrbaxCheckpoint( + directory=checkpoint_dir_sync, + save_freq="epoch", + max_to_keep=2, # Keep only 2 checkpoints + enable_background_delete=False, # Synchronous deletion (default) + ) + + # Train for more epochs than max_to_keep + model1.fit(x, y, epochs=5, callbacks=[callback_sync], verbose=0) + + # Check that max_to_keep is respected + all_steps_sync = sorted(callback_sync.manager.all_steps()) + self.assertLessEqual( + len(all_steps_sync), + 2, + f"Should keep at most 2 checkpoints with sync delete, " + f"found {len(all_steps_sync)}: {all_steps_sync}", + ) + + # Now test WITH background deletion + model2 = self._create_test_model() + checkpoint_dir_async = os.path.join(self.temp_dir, "test_async_delete") + callback_async = OrbaxCheckpoint( + directory=checkpoint_dir_async, + save_freq="epoch", + max_to_keep=2, # Keep only 2 checkpoints + enable_background_delete=True, # Asynchronous background deletion + ) + + # Train for more epochs than max_to_keep + model2.fit(x, y, epochs=5, callbacks=[callback_async], verbose=0) + + # Check that max_to_keep is still respected + all_steps_async = sorted(callback_async.manager.all_steps()) + self.assertLessEqual( + len(all_steps_async), + 2, + f"Should keep at most 2 checkpoints with background delete, " + f"found {len(all_steps_async)}: {all_steps_async}", + ) + + # Wait for background operations to complete + callback_async.manager.wait_until_finished() + + # Both should have the same result (same max_to_keep) + # The difference is that background deletion doesn't block training + self.assertEqual( + len(all_steps_sync), + len(all_steps_async), + f"Both sync and async deletion should keep same number of " + f"checkpoints. Sync: {all_steps_sync}, Async: {all_steps_async}", + ) + + @pytest.mark.requires_trainable_backend + def test_post_finalization_callback(self): + """Test post-finalization callbacks.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + callback_called = [] + + def post_callback(): + callback_called.append(True) + + checkpoint_dir = os.path.join(self.temp_dir, "test_post_callback") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + post_finalization_callback=post_callback, + ) + + # Train for a few epochs + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Wait for async operations to complete + callback.manager.wait_until_finished() + + # Check that the callback was called + self.assertTrue( + len(callback_called) > 0, + "Post-finalization callback should have been called", + ) + + @pytest.mark.requires_trainable_backend + def test_async_with_custom_options(self): + """Test async checkpointing with custom AsyncOptions.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_custom_async") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + async_timeout_secs=1200, # Custom timeout: 20 minutes + enable_background_delete=True, # Enable background delete + ) + + # Train for a few epochs + model.fit(x, y, epochs=3, callbacks=[callback], verbose=0) + + # Verify checkpoints were saved successfully + all_steps = callback.manager.all_steps() + self.assertEqual( + len(all_steps), + 3, + f"Should have 3 checkpoints with custom async options, " + f"found {len(all_steps)}", + ) + + # Wait for all operations to complete + callback.manager.wait_until_finished() + + @pytest.mark.requires_trainable_backend + def test_async_timeout_parameter(self): + """Test that async timeout parameter is properly configured.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_timeout") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + async_timeout_secs=300, # Short timeout: 5 minutes + ) + + # Train for a few epochs + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Verify that the timeout setting doesn't break normal operation + all_steps = callback.manager.all_steps() + self.assertEqual( + len(all_steps), + 2, + f"Should have 2 checkpoints with timeout setting, " + f"found {len(all_steps)}", + ) + + # Wait for completion + callback.manager.wait_until_finished() + + @pytest.mark.requires_trainable_backend + def test_metrics_state_saving(self): + """Test saving and loading of metrics state.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_metrics_state") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_metrics_state=True, + ) + + # Train for a few epochs to update metrics + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Check that metrics have state after training + original_metrics_state = [] + for metric in model.metrics: + if hasattr(metric, "variables") and metric.variables: + original_metrics_state.append( + [var.numpy() for var in metric.variables] + ) + + self.assertGreater( + len(original_metrics_state), 0, "Should have metrics with state" + ) + + # Create new model and load checkpoint + new_model = self._create_test_model() + success, _ = callback.load_latest(model=new_model) + self.assertTrue( + success, "Should successfully load checkpoint with metrics state" + ) + + # Check that metrics state was restored in the new model + for i, original_state in enumerate(original_metrics_state): + if i < len(new_model.metrics): + new_metric = new_model.metrics[i] + if hasattr(new_metric, "variables") and new_metric.variables: + new_state = [var.numpy() for var in new_metric.variables] + # States should match (allowing for some floating point + # differences) + for orig, new in zip(original_state, new_state): + np.testing.assert_allclose(orig, new, rtol=1e-5) + + @pytest.mark.requires_trainable_backend + def test_checkpoint_transformations(self): + """Test applying transformations during checkpoint saving.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_transforms") + + # Train for one step first to initialize optimizer variables + model.fit(x, y, epochs=1, verbose=0) + + # Create save_args that converts float32 to float16 + # Note: save_args structure must match composite_state structure (lists) + save_args = { + "model_weights": [ + SaveArgs(dtype=np.dtype(np.float16)), # weights + SaveArgs(dtype=np.dtype(np.float16)), # bias + SaveArgs(dtype=np.dtype(np.float16)), # output weights + SaveArgs(dtype=np.dtype(np.float16)), # output bias + ], + "optimizer_state": [None] * len(model.optimizer.variables), + } + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_transforms=save_args, + ) + + # Train for one more epoch to trigger save + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + + # Load checkpoint data to verify transformation was applied + checkpoint_data = self._load_checkpoint_data(callback, step=0) + + # Check that model weights were saved in float16 + saved_weights = checkpoint_data["model_weights"] + self.assertEqual( + saved_weights[0].dtype, + np.float16, + "Weights should be saved in float16 due to transform", + ) + + # Verify we can still load the checkpoint normally + new_model = self._create_test_model() + success, _ = callback.load_latest(model=new_model) + self.assertTrue(success, "Should load transformed checkpoint") + + # Check that weights were converted back to original dtype + self.assertEqual( + new_model.weights[0].dtype, + model.weights[0].dtype, + "Loaded weights should be converted back to original dtype", + ) + + @pytest.mark.requires_trainable_backend + def test_save_decision_policy(self): + """Test using save_interval parameter for custom save logic.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_save_policy") + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", # This will be overridden by the save_interval + save_interval=2, # Save every 2 epochs + ) + + # Train for 5 epochs + model.fit(x, y, epochs=5, callbacks=[callback], verbose=0) + + # Should have saved at epochs 0, 2, 4 (every 2 steps, 0-indexed) + all_steps = sorted(callback.manager.all_steps()) + expected_steps = [0, 2, 4] # 0-indexed epochs: 0, 2, 4 + self.assertEqual( + all_steps, + expected_steps, + f"Should save at steps {expected_steps}, got {all_steps}", + ) + + @pytest.mark.skipif( + backend.backend() == "torch", + reason="PyTorch train_on_batch has scalar loss issues", + ) + @pytest.mark.requires_trainable_backend + def test_optimizer_state_saving(self): + """Test that optimizer state is saved and loaded.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_optimizer") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_optimizer_state=True, + ) + + # Train for a few epochs to update optimizer state + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Create new model and load + new_model = self._create_test_model() + success = callback.load_latest() + self.assertTrue(success) + + # Check optimizer iterations (rough check that state was loaded) + # Note: This is a basic check - more sophisticated tests could check + # specific optimizer variables + self.assertGreaterEqual(new_model.optimizer.iterations.numpy(), 0) + + @pytest.mark.requires_trainable_backend + def test_load_specific_checkpoint(self): + """Test loading a specific checkpoint by step.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_specific") + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq="epoch") + + # Train for multiple epochs + model.fit(x, y, epochs=3, callbacks=[callback], verbose=0) + + # Create new model and load specific checkpoint + new_model = self._create_test_model() + success, _ = callback.load_checkpoint(step=1) # Load epoch 1 + + self.assertTrue(success, "Loading specific checkpoint should succeed") + # Verify the model was loaded by checking it has weights + self.assertGreater(len(new_model.weights), 0) + + @pytest.mark.requires_trainable_backend + def test_no_checkpoint_found(self): + """Test behavior when no checkpoints exist.""" + model = self._create_test_model() + + checkpoint_dir = os.path.join(self.temp_dir, "test_empty") + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq="epoch") + + # Try to load from empty directory - should raise FileNotFoundError + with self.assertRaises(FileNotFoundError): + callback.load_latest() + # Verify model still has its original weights (not modified) + self.assertGreater(len(model.weights), 0) + + @pytest.mark.requires_trainable_backend + def test_directory_creation(self): + """Test that checkpoint directory is created if it doesn't exist.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join( + self.temp_dir, "test_create_dir", "subdir" + ) + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq="epoch") + + # Directory should be created during training + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + + self.assertTrue( + os.path.exists(checkpoint_dir), + "Checkpoint directory should be created", + ) + + @pytest.mark.requires_trainable_backend + def test_save_and_load_composite_metadata(self): + """Test saving and loading checkpoints with custom metadata.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_metadata") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_metadata={ + "epoch": 5, + "learning_rate": 0.001, + "metrics": {"loss": 0.5, "accuracy": 0.8}, + }, + ) + + # Train for a few epochs + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Load the checkpoint and get the full data + checkpoint_data = self._load_checkpoint_data(callback, step=1) + + # Verify metadata was saved + self.assertIn("metadata", checkpoint_data) + metadata = checkpoint_data["metadata"] + self.assertEqual(metadata["epoch"], 5) + self.assertEqual(metadata["learning_rate"], 0.001) + self.assertEqual(metadata["metrics"]["loss"], 0.5) + self.assertEqual(metadata["metrics"]["accuracy"], 0.8) + + # Verify model weights are also present + self.assertIn("model_weights", checkpoint_data) + self.assertIn("optimizer_state", checkpoint_data) + + @pytest.mark.requires_trainable_backend + def test_save_metadata_callable(self): + """Test saving metadata using a callable function.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_metadata_callable") + + def metadata_func(epoch, logs): + return { + "epoch": epoch, + "learning_rate": 0.001, + "metrics": logs or {}, + } + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_metadata=metadata_func, + ) + + # Train for a few epochs + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Load checkpoint data + checkpoint_data = self._load_checkpoint_data(callback, step=1) + + # Verify metadata was saved with callable + self.assertIn("metadata", checkpoint_data) + metadata = checkpoint_data["metadata"] + self.assertEqual(metadata["epoch"], 1) # epoch is 1-indexed in callback + self.assertEqual(metadata["learning_rate"], 0.001) + + @pytest.mark.requires_trainable_backend + def test_save_data_iterator_state(self): + """Test saving data iterator state with checkpoints.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_iterator") + + def iterator_state_func(epoch, logs): + return { + "current_position": epoch * 100, + "shuffle_seed": 42, + "batch_size": 32, + "dataset_size": len(x), + } + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_data_iterator=iterator_state_func, + ) + + # Train for a few epochs + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Load checkpoint data + checkpoint_data = self._load_checkpoint_data(callback, step=1) + + # Verify data iterator state was saved + self.assertIn("data_iterator", checkpoint_data) + iterator_state = checkpoint_data["data_iterator"] + self.assertEqual(iterator_state["current_position"], 100) # epoch 1 + self.assertEqual(iterator_state["shuffle_seed"], 42) + self.assertEqual(iterator_state["batch_size"], 32) + self.assertEqual(iterator_state["dataset_size"], len(x)) + + @pytest.mark.requires_trainable_backend + def test_load_checkpoint_with_iterator_state(self): + """Test loading checkpoint returns iterator state for restoration.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_load_iterator") + + def iterator_state_func(epoch, logs): + return { + "current_position": epoch * 100, + "shuffle_seed": 42, + "batch_size": 32, + "dataset_size": len(x), + } + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_data_iterator=iterator_state_func, + ) + + # Train for a few epochs + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Create new model and load checkpoint + success, iterator_state = callback.load_checkpoint(step=1) + + # Verify loading succeeded and iterator state was returned + self.assertTrue(success, "Loading checkpoint should succeed") + self.assertIsNotNone( + iterator_state, "Iterator state should be returned" + ) + self.assertEqual(iterator_state["current_position"], 100) # epoch 1 + self.assertEqual(iterator_state["shuffle_seed"], 42) + self.assertEqual(iterator_state["batch_size"], 32) + self.assertEqual(iterator_state["dataset_size"], len(x)) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="TensorFlow-specific iterator restoration test", + ) + def test_tensorflow_iterator_restoration(self): + """Test iterator restoration with TensorFlow backend.""" + import tensorflow as tf + + # Create simple test data + x, y = self._create_dummy_data(50) # Smaller dataset + + model = self._create_test_model() + checkpoint_dir = os.path.join(self.temp_dir, "test_tf_iterator") + + def tf_iterator_state_func(epoch, logs): + return { + "batches_processed": epoch * 5, # 5 batches per epoch + "shuffle_seed": 42, + "batch_size": 10, + "epoch": epoch, + } + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_data_iterator=tf_iterator_state_func, + ) + + # Train for 2 epochs using model.fit (simpler) + model.fit( + x, y, epochs=2, callbacks=[callback], verbose=0, batch_size=10 + ) + + # Load checkpoint and verify iterator state + success, saved_iterator_state = callback.load_checkpoint(step=1) + + self.assertTrue(success, "Checkpoint loading should succeed") + self.assertIsNotNone( + saved_iterator_state, "Iterator state should be returned" + ) + self.assertEqual(saved_iterator_state["epoch"], 1) + self.assertEqual( + saved_iterator_state["batches_processed"], 5 + ) # epoch 1 * 5 batches + self.assertEqual(saved_iterator_state["batch_size"], 10) + + # Demonstrate iterator restoration + # Create tf.data.Dataset similar to what user would do + dataset = tf.data.Dataset.from_tensor_slices((x, y)) + dataset = dataset.shuffle(saved_iterator_state["shuffle_seed"]) + dataset = dataset.batch(saved_iterator_state["batch_size"]) + + # Create iterator and skip to saved position + iterator = iter(dataset) + for _ in range(saved_iterator_state["batches_processed"]): + try: + next(iterator) + except StopIteration: + break + + # Verify we can get next batch + try: + batch_x, batch_y = next(iterator) + self.assertEqual( + batch_x.shape[0], saved_iterator_state["batch_size"] + ) + except StopIteration: + # End of dataset is also acceptable + pass + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="JAX-specific iterator restoration test", + ) + def test_jax_iterator_restoration(self): + """Test iterator restoration with JAX backend.""" + import jax.numpy as jnp + + # Create simple test data + x, y = self._create_dummy_data(50) + + model = self._create_test_model() + checkpoint_dir = os.path.join(self.temp_dir, "test_jax_iterator") + + def jax_iterator_state_func(epoch, logs): + return { + "batches_processed": epoch * 5, # 5 batches per epoch + "shuffle_seed": 42, + "batch_size": 10, + "epoch": epoch, + } + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_data_iterator=jax_iterator_state_func, + ) + + # Train for 2 epochs using model.fit + model.fit( + x, y, epochs=2, callbacks=[callback], verbose=0, batch_size=10 + ) + + # Load checkpoint and verify iterator state + success, saved_iterator_state = callback.load_checkpoint(step=1) + + self.assertTrue(success, "Checkpoint loading should succeed") + self.assertIsNotNone( + saved_iterator_state, "Iterator state should be returned" + ) + self.assertEqual(saved_iterator_state["epoch"], 1) + self.assertEqual(saved_iterator_state["batches_processed"], 5) + self.assertEqual(saved_iterator_state["batch_size"], 10) + + # Demonstrate iterator restoration for JAX + # Convert to JAX arrays + x_jax = jnp.array(x) + # y_jax = jnp.array(y) # Not used in this test + + # Create shuffled indices (same as during training) + rng = jnp.array( + np.random.RandomState( + saved_iterator_state["shuffle_seed"] + ).permutation(len(x_jax)) + ) + + # Calculate starting position + start_idx = ( + saved_iterator_state["batches_processed"] + * saved_iterator_state["batch_size"] + ) + + # Get remaining data from correct position + remaining_indices = rng[start_idx:] + if len(remaining_indices) >= saved_iterator_state["batch_size"]: + batch_indices = remaining_indices[ + : saved_iterator_state["batch_size"] + ] + batch_x = x_jax[batch_indices] + # batch_y = y_jax[batch_indices] # Not used in assertion + self.assertEqual( + batch_x.shape[0], saved_iterator_state["batch_size"] + ) + + @pytest.mark.skipif( + backend.backend() != "torch", + reason="PyTorch-specific iterator restoration test", + ) + def test_pytorch_iterator_restoration(self): + """Test iterator restoration with PyTorch backend.""" + import torch + + # Create simple test data + x, y = self._create_dummy_data(50) + + model = self._create_test_model() + checkpoint_dir = os.path.join(self.temp_dir, "test_torch_iterator") + + def torch_iterator_state_func(epoch, logs): + return { + "batches_processed": epoch * 5, # 5 batches per epoch + "shuffle_seed": 42, + "batch_size": 10, + "epoch": epoch, + } + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_data_iterator=torch_iterator_state_func, + ) + + # Train for 2 epochs using model.fit + model.fit( + x, y, epochs=2, callbacks=[callback], verbose=0, batch_size=10 + ) + + # Load checkpoint and verify iterator state + success, saved_iterator_state = callback.load_checkpoint(step=1) + + self.assertTrue(success, "Checkpoint loading should succeed") + self.assertIsNotNone( + saved_iterator_state, "Iterator state should be returned" + ) + self.assertEqual(saved_iterator_state["epoch"], 1) + self.assertEqual(saved_iterator_state["batches_processed"], 5) + self.assertEqual(saved_iterator_state["batch_size"], 10) + + # Demonstrate iterator restoration for PyTorch + # Convert to PyTorch tensors + x_torch = torch.tensor(x, dtype=torch.float32) + y_torch = torch.tensor(y, dtype=torch.float32) + + # Create dataset and dataloader (same as during training) + dataset = torch.utils.data.TensorDataset(x_torch, y_torch) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=saved_iterator_state["batch_size"], + shuffle=True, + generator=torch.Generator().manual_seed( + saved_iterator_state["shuffle_seed"] + ), + ) + + # Create iterator and skip to saved position + iterator = iter(dataloader) + for _ in range(saved_iterator_state["batches_processed"]): + try: + next(iterator) + except StopIteration: + break + + # Verify we can get next batch + try: + batch_x, batch_y = next(iterator) + self.assertEqual( + batch_x.shape[0], saved_iterator_state["batch_size"] + ) + except StopIteration: + # End of dataset is also acceptable + pass + + @pytest.mark.requires_trainable_backend + def test_custom_handler_and_registry(self): + """Integration test demonstrating complete training setup with custom + type handlers. + + This test shows how MetadataHandler and ConfigHandler work together in a + real-world training workflow, including integration with model.fit() and + checkpoint/resume functionality. Individual handler tests are in + test_metadata_handler() and test_config_handler(). + """ + import json + import time + from dataclasses import dataclass + + @dataclass + class TrainingMetadata: + """A custom object to hold arbitrary training info.""" + + experiment_id: str + start_time: float + backend: str + notes: str = "" + hyperparameters: dict = None + + @dataclass + class ExperimentConfig: + """Another custom object for experiment configuration.""" + + model_architecture: str + dataset_name: str + batch_size: int + learning_rate: float + optimizer_name: str + + import asyncio + + # Use the classes imported through the Keras bridge + # TypeHandler and metadata are already imported above + + class MetadataHandler(TypeHandler): + """A custom Orbax type handler to save/load the TrainingMetadata + object via JSON.""" + + def typestr(self) -> str: + return "training_metadata" + + async def metadata(self, infos): + """Returns metadata for the parameters.""" + return [ + metadata.Metadata(name=info.name, directory=info.parent_dir) + for info in infos + ] + + async def serialize(self, values, infos, args=None): + """Serializes the dataclass as a JSON dict.""" + futures = [] + for value, info in zip(values, infos): + metadata_obj = value + data = { + "experiment_id": metadata_obj.experiment_id, + "start_time": metadata_obj.start_time, + "backend": metadata_obj.backend, + "notes": metadata_obj.notes, + "hyperparameters": metadata_obj.hyperparameters or {}, + } + # Write to file in the directory + file_path = info.path / "metadata.json" + file_path.parent.mkdir(parents=True, exist_ok=True) + # Create directory + with open(file_path, "w") as f: + json.dump(data, f) + # Return a completed future + future_obj = asyncio.Future() + future_obj.set_result(None) + futures.append(future_obj) + return futures + + async def deserialize(self, infos, args=None): + """Deserializes the JSON dict and reconstructs the dataclass + object.""" + futures = [] + for info in infos: + file_path = info.path / "metadata.json" + with open(file_path, "r") as f: + data = json.load(f) + result = TrainingMetadata(**data) + # Return a completed future with the result + future_obj = asyncio.Future() + future_obj.set_result(result) + futures.append(future_obj) + return futures + + class ConfigHandler(TypeHandler): + """Custom handler for ExperimentConfig objects.""" + + def typestr(self) -> str: + return "experiment_config" + + async def metadata(self, infos): + return [ + metadata.Metadata(name=info.name, directory=info.parent_dir) + for info in infos + ] + + async def serialize(self, values, infos, args=None): + futures = [] + for value, info in zip(values, infos): + config_obj = value + data = { + "model_architecture": config_obj.model_architecture, + "dataset_name": config_obj.dataset_name, + "batch_size": config_obj.batch_size, + "learning_rate": config_obj.learning_rate, + "optimizer_name": config_obj.optimizer_name, + } + file_path = info.path / "config.json" + file_path.parent.mkdir(parents=True, exist_ok=True) + # Create directory + with open(file_path, "w") as f: + json.dump(data, f) + future_obj = asyncio.Future() + future_obj.set_result(None) + futures.append(future_obj) + return futures + + async def deserialize(self, infos, args=None): + futures = [] + for info in infos: + file_path = info.path / "config.json" + with open(file_path, "r") as f: + data = json.load(f) + result = ExperimentConfig(**data) + future_obj = asyncio.Future() + future_obj.set_result(result) + futures.append(future_obj) + return futures + + checkpoint_dir = os.path.join(self.temp_dir, "test_custom_handler") + + # === REAL-WORLD TRAINING SETUP === + + # 1. Create experiment configuration and metadata + experiment_config = ExperimentConfig( + model_architecture="simple_mlp", + dataset_name="dummy_regression", + batch_size=32, + learning_rate=0.001, + optimizer_name="adam", + ) + + training_metadata = TrainingMetadata( + experiment_id="exp_123_complete_training", + start_time=time.time(), + backend=backend.backend(), + notes="Complete training setup with custom handlers", + hyperparameters={ + "epochs": 3, + "validation_split": 0.2, + "early_stopping_patience": 5, + }, + ) + + # 2. Register the type handlers globally + # Note: Each test is self-contained and registers its own handlers. + # The integration test needs both handlers for the complete workflow. + register_type_handler( + ty=TrainingMetadata, handler=MetadataHandler(), override=True + ) + register_type_handler( + ty=ExperimentConfig, handler=ConfigHandler(), override=True + ) + + # 3. Set up the model and training data + model = self._create_test_model() + x, y = self._create_dummy_data(num_samples=200) + + # 4. Create checkpoint callback with standard metadata + # Note: save_metadata should use simple serializable types (numbers, + # booleans) + # Complex objects and strings should be saved separately using + # PyTreeCheckpointer + def metadata_func(epoch, logs): + """Standard metadata function with basic serializable data.""" + return { + "experiment_id": 123, # Use number instead of string + "epoch": epoch + 1, + "loss": float(logs.get("loss", 0.0)) if logs else 0.0, + "val_loss": float(logs.get("val_loss", 0.0)) if logs else 0.0, + "backend_id": ( + 1 if training_metadata.backend == "tensorflow" else 2 + ), + # Use number instead of string for backend identification + "total_epochs": training_metadata.hyperparameters["epochs"], + "validation_split": training_metadata.hyperparameters[ + "validation_split" + ], + } + + training_callback = OrbaxCheckpoint( + directory=os.path.join(checkpoint_dir, "training_checkpoints"), + save_freq="epoch", + save_metadata=metadata_func, # Standard serializable metadata + save_metrics_state=True, + save_optimizer_state=True, + ) + + # 5. Train the model with custom metadata + model.fit( + x, + y, + epochs=3, + batch_size=32, + callbacks=[training_callback], + verbose=0, + validation_split=0.2, + ) + + # 6. Save experiment config separately using PyTreeCheckpointer + config_checkpointer = PyTreeCheckpointer() + config_checkpointer.save( + os.path.join(checkpoint_dir, "experiment_config"), experiment_config + ) + + # 7. Save additional training state separately + final_training_state = { + "config": experiment_config, + "metadata": training_metadata, + "final_epoch": 3, + "total_samples": len(x), + } + + state_checkpointer = PyTreeCheckpointer() + state_checkpointer.save( + os.path.join(checkpoint_dir, "training_state"), final_training_state + ) + + # === VERIFICATION: Load and Resume Training === + + # 8. Load the experiment configuration + loaded_config = config_checkpointer.restore( + os.path.join(checkpoint_dir, "experiment_config") + ) + if hasattr(loaded_config, "result"): + loaded_config = loaded_config.result() + + self.assertIsInstance(loaded_config, ExperimentConfig) + self.assertEqual(loaded_config.model_architecture, "simple_mlp") + self.assertEqual(loaded_config.batch_size, 32) + + # 9. Load the training state + loaded_state = state_checkpointer.restore( + os.path.join(checkpoint_dir, "training_state") + ) + if hasattr(loaded_state, "result"): + loaded_state = loaded_state.result() + + self.assertEqual(loaded_state["final_epoch"], 3) + self.assertEqual(loaded_state["total_samples"], 200) + + # 10. Load checkpoint data directly to check metadata + checkpoint_data = self._load_checkpoint_data(training_callback, step=2) + + # Verify metadata was saved and loaded + self.assertIn("metadata", checkpoint_data) + loaded_metadata = checkpoint_data["metadata"] + + # Verify the loaded standard metadata (dict with basic types) + self.assertIsInstance(loaded_metadata, dict) + self.assertEqual(loaded_metadata["experiment_id"], 123) + # Number instead of string + self.assertEqual(loaded_metadata["epoch"], 3) # 0-indexed epoch + 1 + # backend_id was encoded as 1 for TensorFlow and 2 for Torch. + expected_backend_id = ( + 1 if training_metadata.backend == "tensorflow" else 2 + ) + self.assertEqual( + loaded_metadata["backend_id"], + expected_backend_id, + f"backend_id should match the saved training backend, " + f"got {loaded_metadata['backend_id']}", + ) + self.assertIn("total_epochs", loaded_metadata) + + # 11. Demonstrate resuming training with loaded state + resumed_model = self._create_test_model() + resumed_callback = OrbaxCheckpoint( + directory=os.path.join(checkpoint_dir, "training_checkpoints"), + save_freq="epoch", + save_metadata=metadata_func, + ) + + # Load the latest checkpoint into the new model + success = resumed_callback.load_latest(model=resumed_model) + self.assertTrue(success, "Should successfully resume from checkpoint") + + # Continue training for 1 more epoch + resumed_model.fit( + x, + y, + epochs=1, # Just 1 more epoch + batch_size=32, + callbacks=[resumed_callback], + verbose=0, + validation_split=0.2, + initial_epoch=3, # Start from epoch 3 + ) + + # Verify that standard metadata works seamlessly with model.fit() + # Check what steps are available after resumed training + available_steps = sorted(resumed_callback.manager.all_steps()) + + # Load the latest available checkpoint + if available_steps: + latest_step = available_steps[-1] + final_checkpoint_data = self._load_checkpoint_data( + resumed_callback, step=latest_step + ) + self.assertIn("metadata", final_checkpoint_data) + final_metadata = final_checkpoint_data["metadata"] + self.assertIsInstance(final_metadata, dict) + self.assertIn("loss", final_metadata) + else: + self.fail("No checkpoints found after resumed training") + + def _load_checkpoint_data_from_manager(self, manager, step): + """Helper method to load raw checkpoint data from manager.""" + try: + restore_args = StandardRestore() + return manager.restore(step, args=restore_args) + except Exception as e: + self.fail(f"Failed to load checkpoint data: {e}") + + @pytest.mark.requires_trainable_backend + def test_save_decision_policy_integration(self): + """Test using orbax.checkpoint.SaveDecisionPolicy objects.""" + from orbax.checkpoint import checkpoint_managers + + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.temp_dir, "test_decision_policy") + + # Use FixedIntervalPolicy to save every 3 steps + policy = checkpoint_managers.FixedIntervalPolicy( + interval=3, # Save every 3 steps + ) + + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_decision_policy=policy, + ) + + # Train for 10 epochs (steps 0-9) + model.fit(x, y, epochs=10, callbacks=[callback], verbose=0) + + # Should have saved at steps 0, 3, 6, 9 + all_steps = sorted(callback.manager.all_steps()) + expected_steps = [0, 3, 6, 9] + self.assertEqual( + all_steps, + expected_steps, + f"Should save at steps {expected_steps}, got {all_steps}", + ) + + def _load_checkpoint_data(self, callback, step): + """Helper method to load raw checkpoint data for testing.""" + try: + restore_args = StandardRestore() + return callback.manager.restore(step, args=restore_args) + except Exception as e: + self.fail(f"Failed to load checkpoint data: {e}") + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding tests require JAX backend", + ) + def test_jax_sharding_parameter_acceptance(self): + """Test that sharding parameter is accepted with JAX backend.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest("Sharding test requires at least 2 devices") + + device_mesh = DeviceMesh( + shape=(2,), axis_names=("x",), devices=devices[:2] + ) + tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + # Should not raise an error + callback = OrbaxCheckpoint( + directory=os.path.join(self.temp_dir, "test_sharding_acceptance"), + sharding=sharding, + ) + self.assertIsNotNone(callback.sharding) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding tests require JAX backend", + ) + def test_jax_sharding_with_virtual_devices(self): + """Test sharding functionality with virtual devices setup.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest("Sharding test requires at least 2 devices") + + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Create sharding layout + device_mesh = DeviceMesh( + shape=(2,), axis_names=("x",), devices=devices[:2] + ) + tensor_layout = TensorLayout(axes=("x",), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + checkpoint_dir = os.path.join( + self.temp_dir, "test_sharding_virtual_devices" + ) + callback = OrbaxCheckpoint( + directory=checkpoint_dir, sharding=sharding, save_freq="epoch" + ) + + # Train and save + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + + # Verify checkpoint was saved + self.assertTrue(os.path.exists(checkpoint_dir)) + self.assertIsNotNone(callback.manager.latest_step()) + + # Load and verify + new_model = self._create_test_model() + success, _ = callback.load_latest(model=new_model) + self.assertTrue(success, "Should successfully load sharded checkpoint") + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding tests require JAX backend", + ) + def test_jax_sharding_and_multi_host_combined(self): + """Test combining sharding and multi-host checkpointing.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest("Combined test requires at least 2 devices") + + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Create sharding layout + device_mesh = DeviceMesh( + shape=(2,), axis_names=("x",), devices=devices[:2] + ) + tensor_layout = TensorLayout(axes=("x",), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + checkpoint_dir = os.path.join( + self.temp_dir, "test_sharding_multi_host_combined" + ) + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + sharding=sharding, + multi_host=True, + save_freq="epoch", + ) + + # Train and save + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + + # Verify checkpoint was saved + self.assertTrue(os.path.exists(checkpoint_dir)) + self.assertIsNotNone(callback.manager.latest_step()) + + # Load and verify + new_model = self._create_test_model() + success, _ = callback.load_latest(model=new_model) + self.assertTrue(success, "Should successfully load combined checkpoint") + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding tests require JAX backend", + ) + def test_jax_sharding_parameter_validation(self): + """Test that sharding parameter validation works correctly.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest( + "Sharding validation test requires at least 2 devices" + ) + + device_mesh = DeviceMesh( + shape=(2,), axis_names=("x",), devices=devices[:2] + ) + tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + # Valid sharding should work + callback = OrbaxCheckpoint( + directory=os.path.join(self.temp_dir, "test_valid_sharding"), + sharding=sharding, + ) + self.assertEqual(callback.sharding, sharding) + + # None sharding should work + callback_none = OrbaxCheckpoint( + directory=os.path.join(self.temp_dir, "test_none_sharding") + ) + self.assertIsNone(callback_none.sharding) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding tests require JAX backend", + ) + def test_jax_different_sharding_configurations(self): + """Test different sharding configurations work correctly.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 4: + self.skipTest( + "Different sharding configs test requires at least 4 devices" + ) + + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Test different sharding configurations + configs = [ + # 2-way sharding + {"shape": (2,), "axis_names": ("x",), "axes": ("x",)}, + # 4-way sharding + {"shape": (4,), "axis_names": ("x",), "axes": (None,)}, + ] + + for i, config in enumerate(configs): + device_mesh = DeviceMesh( + shape=config["shape"], + axis_names=config["axis_names"], + devices=devices[: config["shape"][0]], + ) + tensor_layout = TensorLayout( + axes=config["axes"], device_mesh=device_mesh + ) + sharding = tensor_layout.backend_layout + + checkpoint_dir = os.path.join( + self.temp_dir, f"test_sharding_config_{i}" + ) + callback = OrbaxCheckpoint( + directory=checkpoint_dir, sharding=sharding, save_freq="epoch" + ) + + # Train and save + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + + # Verify checkpoint was saved + self.assertTrue(os.path.exists(checkpoint_dir)) + self.assertIsNotNone(callback.manager.latest_step()) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding compatibility tests require JAX backend", + ) + def test_jax_sharding_compatibility_across_save_load(self): + """Test sharding compatibility across save and load operations.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest( + "Sharding compatibility test requires at least 2 devices" + ) + + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Save with sharding + device_mesh = DeviceMesh( + shape=(2,), axis_names=("x",), devices=devices[:2] + ) + tensor_layout = TensorLayout(axes=("x",), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + checkpoint_dir = os.path.join( + self.temp_dir, "test_sharding_compatibility" + ) + save_callback = OrbaxCheckpoint( + directory=checkpoint_dir, sharding=sharding, save_freq="epoch" + ) + + model.fit(x, y, epochs=1, callbacks=[save_callback], verbose=0) + + # Load with same sharding + load_callback = OrbaxCheckpoint( + directory=checkpoint_dir, sharding=sharding + ) + new_model = self._create_test_model() + success, _ = load_callback.load_latest(model=new_model) + self.assertTrue(success, "Should successfully load with same sharding") + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding edge case tests require JAX backend", + ) + def test_jax_single_device_sharding_edge_cases(self): + """Test edge cases for single device sharding scenarios.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest( + "Single device sharding test requires at least 2 devices" + ) + + # Test with single device in mesh (effectively no sharding) + device_mesh = DeviceMesh( + shape=(1,), axis_names=("x",), devices=devices[:1] + ) + tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join( + self.temp_dir, "test_single_device_sharding" + ) + callback = OrbaxCheckpoint( + directory=checkpoint_dir, sharding=sharding, save_freq="epoch" + ) + + # Should work without errors + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + self.assertIsNotNone(callback.manager.latest_step()) + + def test_tensorflow_backend_rejects_sharding(self): + """Test that TensorFlow backend rejects sharding parameter.""" + if backend.backend() == "tensorflow": + with self.assertRaises((ValueError, TypeError)) as cm: + OrbaxCheckpoint( + directory=os.path.join(self.temp_dir, "test_tf_reject"), + sharding="invalid_sharding", # Any non-None value + ) + self.assertIn("JAX backend", str(cm.exception)) + + def test_pytorch_backend_rejects_sharding(self): + """Test that PyTorch backend rejects sharding parameter.""" + if backend.backend() == "torch": + with self.assertRaises((ValueError, TypeError)) as cm: + OrbaxCheckpoint( + directory=os.path.join(self.temp_dir, "test_torch_reject"), + sharding="invalid_sharding", # Any non-None value + ) + self.assertIn("JAX backend", str(cm.exception)) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding functionality validation requires JAX backend", + ) + def test_jax_sharding_functionality_validation(self): + """Comprehensive test of JAX sharding functionality.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest( + "Sharding functionality test requires at least 2 devices" + ) + + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Create sharding + device_mesh = DeviceMesh( + shape=(2,), axis_names=("x",), devices=devices[:2] + ) + tensor_layout = TensorLayout(axes=("x",), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + checkpoint_dir = os.path.join( + self.temp_dir, "test_sharding_functionality" + ) + callback = OrbaxCheckpoint( + directory=checkpoint_dir, sharding=sharding, save_freq="epoch" + ) + + # Train and checkpoint + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Verify multiple checkpoints + all_steps = sorted(callback.manager.all_steps()) + self.assertEqual( + len(all_steps), 2, f"Expected 2 checkpoints, got {len(all_steps)}" + ) + + # Load from specific step + new_model = self._create_test_model() + success, _ = callback.load_checkpoint( + step=all_steps[0], model=new_model + ) + self.assertTrue(success, "Should load from specific step") + + # Load latest + latest_model = self._create_test_model() + success, _ = callback.load_latest(model=latest_model) + self.assertTrue(success, "Should load latest checkpoint") + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Multi-host error handling tests require JAX backend", + ) + def test_multi_host_error_handling_with_invalid_sharding(self): + """Test error handling when combining multi-host with invalid + sharding.""" + # Test that multi_host works with None sharding + callback = OrbaxCheckpoint( + directory=os.path.join(self.temp_dir, "test_multi_host_none"), + multi_host=True, + ) + self.assertTrue(callback.multi_host) + self.assertIsNone(callback.sharding) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding interoperability tests require JAX backend", + ) + def test_restore_sharded_checkpoint_to_unsharded_model(self): + """Test restoring a sharded checkpoint to an unsharded model.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest( + "Sharded to unsharded test requires at least 2 devices" + ) + + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Save with 2-way sharding + device_mesh = DeviceMesh( + shape=(2,), axis_names=("x",), devices=devices[:2] + ) + tensor_layout = TensorLayout(axes=("x",), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + checkpoint_dir = os.path.join( + self.temp_dir, "test_sharded_to_unsharded" + ) + save_callback = OrbaxCheckpoint( + directory=checkpoint_dir, sharding=sharding, save_freq="epoch" + ) + + model.fit(x, y, epochs=1, callbacks=[save_callback], verbose=0) + + # Capture original weights + original_weights = [w.numpy() for w in model.weights] + + # Load with unsharded model (sharding=None) + load_callback = OrbaxCheckpoint(directory=checkpoint_dir) + + new_model = self._create_test_model() + success, _ = load_callback.load_latest(model=new_model) + self.assertTrue( + success, + "Should successfully load sharded checkpoint to unsharded model", + ) + + # Assert: Unsharded weights should match original + restored_weights = [w.numpy() for w in new_model.weights] + for original, restored in zip(original_weights, restored_weights): + np.testing.assert_allclose( + original, + restored, + rtol=1e-5, + atol=1e-6, + err_msg="Unsharded weights should match original after loading " + "sharded checkpoint", + ) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding interoperability tests require JAX backend", + ) + def test_restore_unsharded_checkpoint_to_sharded_model(self): + """Test restoring an unsharded checkpoint to a sharded model.""" + if distribution is None: + self.skipTest("Distribution module not available") + + from keras.src.distribution import DeviceMesh + from keras.src.distribution import TensorLayout + + devices = distribution.list_devices() + if len(devices) < 2: + self.skipTest( + "Unsharded to sharded test requires at least 2 devices" + ) + + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Save with unsharded model + checkpoint_dir = os.path.join( + self.temp_dir, "test_unsharded_to_sharded" + ) + save_callback = OrbaxCheckpoint( + directory=checkpoint_dir, save_freq="epoch" + ) + + model.fit(x, y, epochs=1, callbacks=[save_callback], verbose=0) + + # Capture original weights + original_weights = [w.numpy() for w in model.weights] + + # Load with 2-way sharding + device_mesh = DeviceMesh( + shape=(2,), axis_names=("x",), devices=devices[:2] + ) + tensor_layout = TensorLayout(axes=("x",), device_mesh=device_mesh) + sharding = tensor_layout.backend_layout + + load_callback = OrbaxCheckpoint( + directory=checkpoint_dir, sharding=sharding + ) + + new_model = self._create_test_model() + success, _ = load_callback.load_latest(model=new_model) + self.assertTrue( + success, + "Should successfully load unsharded checkpoint to sharded model", + ) + + # Assert: Sharded weights should match original + restored_weights = [w.numpy() for w in new_model.weights] + for original, restored in zip(original_weights, restored_weights): + np.testing.assert_allclose( + original, + restored, + rtol=1e-5, + atol=1e-6, + err_msg="Sharded weights should match original after loading " + "unsharded checkpoint", + ) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Sharding validation tests require JAX backend", + ) + def test_invalid_sharding_argument_raises_error(self): + """Test that invalid sharding arguments raise TypeError.""" + # Test with string (invalid sharding object) + with self.assertRaises(TypeError): + OrbaxCheckpoint( + directory=os.path.join(self.temp_dir, "test_invalid_sharding"), + sharding="invalid_sharding_string", + )