@@ -196,6 +196,40 @@ class OrbaxCheckpoint(MonitorCallback):
196196 directory=checkpoint_dir,
197197 save_decision_policy=policy) # Save every 5 epochs
198198
199+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
200+
201+ # JAX-specific features: Sharding and Multi-Host Checkpointing
202+ # Note: These features are only available with JAX backend
203+
204+ # Example with sharding support (JAX only):
205+ from keras.distribution import DeviceMesh, TensorLayout
206+ devices = keras.distribution.list_devices()
207+ device_mesh = DeviceMesh(shape=(len(devices),), axis_names=('x',),
208+ devices=devices)
209+ tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh)
210+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
211+ directory=checkpoint_dir,
212+ sharding=tensor_layout.backend_layout
213+ ) # Enable sharding for distributed arrays
214+
215+ # Example with multi-host checkpointing (JAX only):
216+ # Enables distributed checkpointing where each host writes its data shards
217+ # while the primary process coordinates metadata and finalization
218+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
219+ directory=checkpoint_dir,
220+ multi_host=True) # Enable multi-host checkpointing
221+
222+ # Combined sharding and multi-host (JAX only):
223+ from keras.distribution import DeviceMesh, TensorLayout
224+ devices = keras.distribution.list_devices()
225+ device_mesh = DeviceMesh(shape=(len(devices),), axis_names=('x',),
226+ devices=devices)
227+ tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh)
228+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
229+ directory=checkpoint_dir,
230+ sharding=tensor_layout.backend_layout,
231+ multi_host=True) # Enable both features
232+
199233 model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
200234 ```
201235
@@ -241,6 +275,16 @@ class OrbaxCheckpoint(MonitorCallback):
241275 overrides the default save frequency logic. Defaults to None.
242276 save_interval: Integer, save checkpoints every N steps. If provided,
243277 overrides save_freq. Defaults to None.
278+ sharding: JAX sharding specification for distributed checkpointing.
279+ Only supported with JAX backend. If provided with TensorFlow or
280+ PyTorch backends, will raise an error. Defaults to None.
281+ multi_host: Boolean, whether to enable multi-host checkpointing for
282+ distributed training across multiple processes/hosts. When enabled,
283+ the primary process (rank 0) coordinates the checkpoint operation
284+ while all processes write their data shards in parallel to create a
285+ complete distributed checkpoint. Only supported with JAX backend.
286+ If enabled with TensorFlow or PyTorch backends, will raise an error.
287+ Defaults to False.
244288 """
245289
246290 def __init__ (
@@ -265,6 +309,8 @@ def __init__(
265309 save_transforms = None ,
266310 save_decision_policy = None ,
267311 save_interval = None ,
312+ sharding = None ,
313+ multi_host = False ,
268314 ):
269315 # Ensure orbax is available
270316 ocp .initialize ()
@@ -287,6 +333,18 @@ def __init__(
287333 self .save_transforms = save_transforms
288334 self .save_decision_policy = save_decision_policy
289335 self .save_interval = save_interval
336+
337+ # JAX-specific features validation
338+ self .sharding = sharding
339+ self .multi_host = multi_host
340+
341+ # Validate JAX-only features
342+ if sharding is not None or multi_host :
343+ if backend .backend () != "jax" :
344+ raise ValueError (
345+ "sharding and multi_host parameters are only supported "
346+ "with JAX backend. Current backend: " + backend .backend ()
347+ )
290348 self ._batches_seen_since_last_saving = 0
291349 self ._last_batch_seen = 0
292350 self ._current_epoch = 0 # Keep track of epoch
@@ -326,6 +384,28 @@ def __init__(
326384 should_save_fn = should_save_fn ,
327385 save_decision_policy = save_decision_policy ,
328386 )
387+
388+ # Multi-host setup for JAX
389+ if self .multi_host and backend .backend () == "jax" :
390+ try :
391+ # Enable multi-host checkpointing using Keras distribution API
392+ from keras .src import distribution
393+
394+ distribution .initialize ()
395+ except RuntimeError as e :
396+ # If distributed cannot be initialized (e.g., JAX already
397+ # initialized), continue anyway - the multi_host flag is mainly
398+ # a hint to Orbax
399+ if "must be called before" in str (e ):
400+ pass # This is expected in test environments
401+ else :
402+ raise
403+ # Orbax will automatically handle multi-host coordination:
404+ # - Primary process (rank 0) coordinates and writes
405+ # metadata/manifest
406+ # - All processes write their data shards in parallel to the
407+ # checkpoint directory
408+
329409 # Ensure directory exists (only needed on one process in multi-host)
330410 if backend .get_process_index () == 0 :
331411 os .makedirs (directory , exist_ok = True )
@@ -447,6 +527,16 @@ def _save_checkpoint(self, step, logs=None):
447527 save_args = ocp .args .StandardSave (
448528 composite_state , save_args = self .save_transforms
449529 )
530+
531+ # Apply sharding if specified (JAX only)
532+ if self .sharding is not None and backend .backend () == "jax" :
533+ # For JAX sharding, we need to ensure the data is properly
534+ # sharded
535+ # This is typically handled automatically by Orbax when JAX
536+ # arrays with sharding metadata are saved
537+ if hasattr (save_args , "sharding" ):
538+ save_args .sharding = self .sharding
539+
450540 self .manager .save (step , args = save_args )
451541
452542 def on_train_batch_end (self , batch , logs = None ):
@@ -539,8 +629,15 @@ def load_checkpoint(self, step, model=None):
539629 was successful, False otherwise, and iterator_state is the saved
540630 data iterator state dict if available, None otherwise.
541631 """
542- # In distributed training, only load on primary process
543- if backend .get_process_index () != 0 :
632+ # In multi-host distributed training, all processes participate in
633+ # loading to read their respective data shards in parallel. Only the
634+ # primary process coordinates the metadata reading and broadcasting.
635+ if self .multi_host and backend .backend () == "jax" :
636+ # Multi-host loading: all processes participate
637+ pass # Continue with loading on all processes
638+ elif backend .get_process_index () != 0 :
639+ # Single-host or non-multi-host distributed: only primary
640+ # process loads
544641 return True # Return True to indicate no error, but no loading
545642
546643 if self .verbose > 0 :
@@ -552,6 +649,13 @@ def load_checkpoint(self, step, model=None):
552649 # template
553650 restore_args = ocp .args .StandardRestore ()
554651
652+ # Apply sharding if specified (JAX only)
653+ if self .sharding is not None and backend .backend () == "jax" :
654+ # For JAX sharding, we need to ensure the data is properly restored
655+ # with the same sharding specification used during save
656+ if hasattr (restore_args , "sharding" ):
657+ restore_args .sharding = self .sharding
658+
555659 # Load the checkpoint
556660 checkpoint_data = self .manager .restore (step , args = restore_args )
557661
0 commit comments