Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6328350
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and resto…
amitsrivastava78 Oct 21, 2025
ca71da6
Fix unused variable in orbax checkpoint test
amitsrivastava78 Oct 22, 2025
4dfa903
fixed failing cases
amitsrivastava78 Oct 22, 2025
7742139
fixed review comments
amitsrivastava78 Oct 22, 2025
822396f
Improve OrbaxCheckpoint implementation
amitsrivastava78 Oct 24, 2025
61bd5e6
Fix code formatting and remove unused variable
amitsrivastava78 Oct 24, 2025
19d2495
Add OrbaxCheckpoint callback with conditional exports and improved te…
amitsrivastava78 Oct 24, 2025
b56dc7b
Improve OrbaxCheckpoint: preserve nested structures, enhance tests
amitsrivastava78 Oct 28, 2025
7722e30
Fixed review comments
amitsrivastava78 Oct 31, 2025
eb7855d
Migration to Orbax V1
amitsrivastava78 Nov 5, 2025
aaf6e20
Fix sklearn wrapper CI tests by marking pipeline consistency checks a…
amitsrivastava78 Nov 10, 2025
cd881dd
made distributed structure proper
amitsrivastava78 Nov 10, 2025
9417027
Fixed sav decision between keras and orbax
amitsrivastava78 Nov 11, 2025
b7a0dff
Optimize Orbax checkpoint for JAX backend
amitsrivastava78 Nov 11, 2025
33f4e66
Optimize Orbax checkpoint for JAX backend with compatibility check
amitsrivastava78 Nov 11, 2025
d7884ef
added checkpointer.wait()
amitsrivastava78 Nov 12, 2025
13aec2e
Improve OrbaxCheckpoint callback with optimizations and cleanup
amitsrivastava78 Nov 13, 2025
a2938ea
Simplify OrbaxCheckpoint API to match ModelCheckpoint parity
amitsrivastava78 Nov 13, 2025
4d659f4
Removed the experimental import
amitsrivastava78 Nov 13, 2025
ce30b36
Add comprehensive OrbaxCheckpoint tests with loading verification
amitsrivastava78 Nov 14, 2025
be35fdd
Improve OrbaxCheckpoint: complete state preservation, cross-backend c…
amitsrivastava78 Nov 24, 2025
e6c54e2
Add back try-except fallback for wait() method to support older Orbax…
amitsrivastava78 Nov 24, 2025
b876e11
Use hasattr check instead of try-except for wait() method compatibility
amitsrivastava78 Nov 24, 2025
124142c
Add JAX monitoring compatibility: mock jax.monitoring.record_scalar w…
amitsrivastava78 Nov 24, 2025
98cff1a
Re-run CI
amitsrivastava78 Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
3 changes: 3 additions & 0 deletions keras/api/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
35 changes: 35 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,38 @@ class name_scope(backend_name_scope):
@keras_export("keras.device")
def device(device_name):
return device_scope(device_name) # noqa: F405


def get_process_index():
"""Get the index of the current process in a distributed setup.

Returns:
int: The process index (0 for primary process, >0 for others).
Returns 0 if not in a distributed setup.
"""
backend_name = backend()
if backend_name == "jax":
try:
import jax

return jax.process_index()
except (ImportError, AttributeError):
return 0
elif backend_name == "tensorflow":
try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
except (ImportError, AttributeError, RuntimeError):
return 0
elif backend_name == "torch":
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
except (ImportError, AttributeError):
return 0
else:
return 0
2 changes: 1 addition & 1 deletion keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax import core
from keras.src.backend.jax import distribution_lib
from keras.src.backend.jax import image
from keras.src.backend.jax import linalg
from keras.src.backend.jax import math
Expand All @@ -25,6 +24,7 @@
from keras.src.backend.jax.core import shape
from keras.src.backend.jax.core import stop_gradient
from keras.src.backend.jax.core import vectorized_map
from keras.src.backend.jax.distribution_lib import process_id
from keras.src.backend.jax.rnn import cudnn_ok
from keras.src.backend.jax.rnn import gru
from keras.src.backend.jax.rnn import lstm
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras.src.backend.numpy.core import random_seed_dtype
from keras.src.backend.numpy.core import shape
from keras.src.backend.numpy.core import vectorized_map
from keras.src.backend.numpy.distribution_lib import process_id
from keras.src.backend.numpy.rnn import cudnn_ok
from keras.src.backend.numpy.rnn import gru
from keras.src.backend.numpy.rnn import lstm
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/numpy/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Utilities for distribution strategy with NumPy backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
return 0
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.openvino import core
from keras.src.backend.openvino import distribution_lib
from keras.src.backend.openvino import image
from keras.src.backend.openvino import linalg
from keras.src.backend.openvino import math
Expand All @@ -19,6 +20,7 @@
from keras.src.backend.openvino.core import random_seed_dtype
from keras.src.backend.openvino.core import shape
from keras.src.backend.openvino.core import vectorized_map
from keras.src.backend.openvino.distribution_lib import process_id
from keras.src.backend.openvino.rnn import cudnn_ok
from keras.src.backend.openvino.rnn import gru
from keras.src.backend.openvino.rnn import lstm
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Utilities for distribution strategy with OpenVINO backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
return 0
2 changes: 1 addition & 1 deletion keras/src/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src.backend.tensorflow import core
from keras.src.backend.tensorflow import distribution_lib
from keras.src.backend.tensorflow import image
from keras.src.backend.tensorflow import linalg
from keras.src.backend.tensorflow import math
Expand All @@ -24,6 +23,7 @@
from keras.src.backend.tensorflow.core import shape
from keras.src.backend.tensorflow.core import stop_gradient
from keras.src.backend.tensorflow.core import vectorized_map
from keras.src.backend.tensorflow.distribution_lib import process_id
from keras.src.backend.tensorflow.rnn import cudnn_ok
from keras.src.backend.tensorflow.rnn import gru
from keras.src.backend.tensorflow.rnn import lstm
Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/tensorflow/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,13 @@ def _to_backend_layout(tensor_layout):
]
dtensor_mesh = tensor_layout.device_mesh.backend_mesh
return dtensor.Layout(sharding_specs=sharding_specs, mesh=dtensor_mesh)


def process_id():
"""Return the current process ID for the distribution setting."""
try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
except (ImportError, AttributeError, RuntimeError):
return 0
1 change: 1 addition & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from keras.src.backend.torch.core import stop_gradient
from keras.src.backend.torch.core import to_torch_dtype
from keras.src.backend.torch.core import vectorized_map
from keras.src.backend.torch.distribution_lib import process_id
from keras.src.backend.torch.rnn import cudnn_ok
from keras.src.backend.torch.rnn import gru
from keras.src.backend.torch.rnn import lstm
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/torch/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Utilities for distribution strategy with PyTorch backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
except (ImportError, AttributeError):
return 0
1 change: 1 addition & 0 deletions keras/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
from keras.src.callbacks.monitor_callback import MonitorCallback
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
from keras.src.callbacks.progbar_logger import ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from keras.src.callbacks.remote_monitor import RemoteMonitor
Expand Down
Loading