@@ -33,41 +33,40 @@ def __init__(self, *args, layout=None, **kwargs):
3333 def set_tensor_layout (self ):
3434 # We can't import the keras/distribution/distribution_lib
3535 # due to circular dependency.
36- if self ._layout is None :
37- distribution = global_state .get_global_attribute ("distribution" )
38- if distribution is not None :
39- tensor_layout = distribution .get_variable_layout (self )
40- from keras .src .distribution import TensorLayout
41-
42- if isinstance (tensor_layout , TensorLayout ):
43- self ._layout = tensor_layout .backend_layout
44- else :
45- self ._layout = tensor_layout
36+ distribution = global_state .get_global_attribute ("distribution" )
37+ if self ._layout is None and distribution is not None :
38+ tensor_layout = distribution .get_variable_layout (self )
39+ from keras .src .distribution import TensorLayout
40+
41+ if isinstance (tensor_layout , TensorLayout ):
42+ self ._layout = tensor_layout .backend_layout
43+ else :
44+ self ._layout = tensor_layout
4645
4746 def _initialize (self , value ):
4847 # Note that variable.shape is needed by distribution_lib
4948 self ._shape = self ._validate_shape (value .shape )
5049 self .set_tensor_layout ()
5150 self ._direct_assign (value )
5251
53- def check_distributed_init (self , initializer ):
52+ def check_distributed_init (self , initializer , init_layout ):
5453 # Check if 'layout' parameter is supported in the initializer call
5554 import inspect
5655
5756 sig = inspect .signature (initializer .__call__ )
5857 layout_supported = "layout" in sig .parameters
5958 # Check if PartitionSpec has any non-None values
60- spec = getattr (self . _layout , "spec" , None )
59+ spec = getattr (init_layout , "spec" , None )
6160 partition_spec = spec if spec is not None else ()
6261 is_partitioned = any (dim is not None for dim in partition_spec )
63- return layout_supported and is_partitioned
62+ return layout_supported and init_layout is not None and is_partitioned
6463
6564 def _initialize_with_initializer (self , initializer ):
66- self .set_tensor_layout ( )
65+ init_layout = get_initialization_layout ( self .path )
6766 # Use layout-aware initialization for distributed embeddings
68- if self .check_distributed_init (initializer ):
67+ if self .check_distributed_init (initializer , init_layout ):
6968 value = self ._convert_to_tensor (
70- initializer (self ._shape , dtype = self ._dtype , layout = self . _layout )
69+ initializer (self ._shape , dtype = self ._dtype , layout = init_layout )
7170 )
7271 else :
7372 value = self ._convert_to_tensor (
@@ -141,6 +140,12 @@ def __init__(
141140 # The real value is now set in self._value, sync it to raw_value
142141 object .__setattr__ (self , "raw_value" , self ._value )
143142
143+ def _initialize_with_initializer (self , initializer ):
144+ value = self ._convert_to_tensor (
145+ initializer (self ._shape , dtype = self ._dtype )
146+ )
147+ self ._initialize (value )
148+
144149 @property
145150 def _value (self ):
146151 if hasattr (self , "raw_value" ):
@@ -264,6 +269,25 @@ def value(self):
264269 Variable = NnxVariable
265270
266271
272+ def get_initialization_layout (path ):
273+ distribution = global_state .get_global_attribute ("distribution" )
274+ if distribution is None :
275+ return None
276+ layout_map = getattr (distribution , "_layout_map" , None )
277+ if layout_map is None :
278+ return None
279+ layout_obj = layout_map .get (path )
280+ if layout_obj is None :
281+ return None
282+ from keras .src .distribution import TensorLayout
283+
284+ if isinstance (layout_obj , TensorLayout ):
285+ layout_obj = layout_obj .backend_layout
286+ if isinstance (layout_obj , jax .sharding .NamedSharding ):
287+ return layout_obj
288+ return None
289+
290+
267291def convert_to_tensor (x , dtype = None , sparse = None , ragged = None ):
268292 if ragged :
269293 raise ValueError ("`ragged=True` is not supported with jax backend" )
0 commit comments