Skip to content

Commit 77d1b35

Browse files
Further changes pertaining to Seed Generator
1 parent a9831a8 commit 77d1b35

File tree

8 files changed

+116
-79
lines changed

8 files changed

+116
-79
lines changed

keras/src/backend/jax/core.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,41 +33,40 @@ def __init__(self, *args, layout=None, **kwargs):
3333
def set_tensor_layout(self):
3434
# We can't import the keras/distribution/distribution_lib
3535
# due to circular dependency.
36-
if self._layout is None:
37-
distribution = global_state.get_global_attribute("distribution")
38-
if distribution is not None:
39-
tensor_layout = distribution.get_variable_layout(self)
40-
from keras.src.distribution import TensorLayout
41-
42-
if isinstance(tensor_layout, TensorLayout):
43-
self._layout = tensor_layout.backend_layout
44-
else:
45-
self._layout = tensor_layout
36+
distribution = global_state.get_global_attribute("distribution")
37+
if self._layout is None and distribution is not None:
38+
tensor_layout = distribution.get_variable_layout(self)
39+
from keras.src.distribution import TensorLayout
40+
41+
if isinstance(tensor_layout, TensorLayout):
42+
self._layout = tensor_layout.backend_layout
43+
else:
44+
self._layout = tensor_layout
4645

4746
def _initialize(self, value):
4847
# Note that variable.shape is needed by distribution_lib
4948
self._shape = self._validate_shape(value.shape)
5049
self.set_tensor_layout()
5150
self._direct_assign(value)
5251

53-
def check_distributed_init(self, initializer):
52+
def check_distributed_init(self, initializer, init_layout):
5453
# Check if 'layout' parameter is supported in the initializer call
5554
import inspect
5655

5756
sig = inspect.signature(initializer.__call__)
5857
layout_supported = "layout" in sig.parameters
5958
# Check if PartitionSpec has any non-None values
60-
spec = getattr(self._layout, "spec", None)
59+
spec = getattr(init_layout, "spec", None)
6160
partition_spec = spec if spec is not None else ()
6261
is_partitioned = any(dim is not None for dim in partition_spec)
63-
return layout_supported and is_partitioned
62+
return layout_supported and init_layout is not None and is_partitioned
6463

6564
def _initialize_with_initializer(self, initializer):
66-
self.set_tensor_layout()
65+
init_layout = get_initialization_layout(self.path)
6766
# Use layout-aware initialization for distributed embeddings
68-
if self.check_distributed_init(initializer):
67+
if self.check_distributed_init(initializer, init_layout):
6968
value = self._convert_to_tensor(
70-
initializer(self._shape, dtype=self._dtype, layout=self._layout)
69+
initializer(self._shape, dtype=self._dtype, layout=init_layout)
7170
)
7271
else:
7372
value = self._convert_to_tensor(
@@ -141,6 +140,12 @@ def __init__(
141140
# The real value is now set in self._value, sync it to raw_value
142141
object.__setattr__(self, "raw_value", self._value)
143142

143+
def _initialize_with_initializer(self, initializer):
144+
value = self._convert_to_tensor(
145+
initializer(self._shape, dtype=self._dtype)
146+
)
147+
self._initialize(value)
148+
144149
@property
145150
def _value(self):
146151
if hasattr(self, "raw_value"):
@@ -264,6 +269,25 @@ def value(self):
264269
Variable = NnxVariable
265270

266271

272+
def get_initialization_layout(path):
273+
distribution = global_state.get_global_attribute("distribution")
274+
if distribution is None:
275+
return None
276+
layout_map = getattr(distribution, "_layout_map", None)
277+
if layout_map is None:
278+
return None
279+
layout_obj = layout_map.get(path)
280+
if layout_obj is None:
281+
return None
282+
from keras.src.distribution import TensorLayout
283+
284+
if isinstance(layout_obj, TensorLayout):
285+
layout_obj = layout_obj.backend_layout
286+
if isinstance(layout_obj, jax.sharding.NamedSharding):
287+
return layout_obj
288+
return None
289+
290+
267291
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
268292
if ragged:
269293
raise ValueError("`ragged=True` is not supported with jax backend")

keras/src/backend/jax/distribution_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def _distribute_initializer(
3838
from functools import partial
3939

4040
# Draw seed from the seed generator if seed is not a Jax Array
41+
# It is imperative for seed generation to happen before jit compilation
4142
if seed is None or not isinstance(seed, jax.Array):
42-
jax_compatible_seed = seed_generator.draw_seed(None)
43-
# Convert to JAX PRNG key format (swap counter and seed value)
44-
seed = jax_compatible_seed[::-1]
43+
seed = seed_generator.draw_seed(None)[0]
44+
seed = jax.random.key(seed)
4545

4646
# Validate all required arguments
4747
if init_func is None or init_func.func.__name__ not in [

keras/src/backend/jax/random.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99

1010

1111
def jax_draw_seed(seed):
12-
# Convert to JAX PRNG key format (swap counter and seed value)
1312
if isinstance(seed, jax.Array):
14-
return seed[::-1]
13+
if seed.ndim == 0:
14+
return jax.random.key(seed)
15+
elif seed.ndim == 1 and seed.shape == (2,):
16+
return seed
1517
else:
16-
seed_array = draw_seed(seed)
17-
return seed_array[::-1]
18+
seed = draw_seed(seed)
19+
return seed
1820

1921

2022
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None):
@@ -40,6 +42,35 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None):
4042
return sample * stddev + mean
4143

4244

45+
def truncated_normal(
46+
shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None
47+
):
48+
dtype = dtype or floatx()
49+
seed = jax_draw_seed(seed)
50+
if layout is not None:
51+
from keras.src.backend import distribution_lib
52+
53+
init_func = partial(
54+
jax.random.truncated_normal,
55+
shape=shape,
56+
dtype=dtype,
57+
lower=-2.0,
58+
upper=2.0,
59+
)
60+
return distribution_lib._distribute_initializer(
61+
init_func=init_func,
62+
mean=mean,
63+
stddev=stddev,
64+
seed=seed,
65+
layout=layout,
66+
)
67+
else:
68+
sample = jax.random.truncated_normal(
69+
seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype
70+
)
71+
return sample * stddev + mean
72+
73+
4374
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None):
4475
dtype = dtype or floatx()
4576
seed = jax_draw_seed(seed)
@@ -84,35 +115,6 @@ def randint(shape, minval, maxval, dtype="int32", seed=None):
84115
)
85116

86117

87-
def truncated_normal(
88-
shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None
89-
):
90-
dtype = dtype or floatx()
91-
seed = jax_draw_seed(seed)
92-
if layout is not None:
93-
from keras.src.backend import distribution_lib
94-
95-
init_func = partial(
96-
jax.random.truncated_normal,
97-
shape=shape,
98-
dtype=dtype,
99-
lower=-2.0,
100-
upper=2.0,
101-
)
102-
return distribution_lib._distribute_initializer(
103-
init_func=init_func,
104-
mean=mean,
105-
stddev=stddev,
106-
seed=seed,
107-
layout=layout,
108-
)
109-
else:
110-
sample = jax.random.truncated_normal(
111-
seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype
112-
)
113-
return sample * stddev + mean
114-
115-
116118
def _get_concrete_noise_shape(inputs, noise_shape):
117119
if noise_shape is None:
118120
return inputs.shape

keras/src/initializers/random_initializers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
from keras.src.backend import random
77
from keras.src.initializers.initializer import Initializer
88
from keras.src.saving import serialization_lib
9+
from keras.src.utils import jax_utils
910

1011

1112
class RandomInitializer(Initializer):
1213
def __init__(self, seed=None):
1314
self._init_seed = seed
1415
if seed is None and backend() == "jax":
15-
seed = int(random.draw_seed(None)[0])
16+
seed = jax_utils.get_jax_random_seed(seed)
1617
elif seed is None:
1718
seed = random.make_default_seed()
1819
elif isinstance(seed, dict):
1920
seed = serialization_lib.deserialize_keras_object(seed)
20-
elif not isinstance(seed, (int, random.SeedGenerator)):
21+
elif not isinstance(seed, (random.SeedGenerator, int)):
2122
raise ValueError(
2223
"`seed` argument should be an instance of "
2324
"`keras.random.SeedGenerator()` or an integer. "

keras/src/random/seed_generator.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,29 +118,14 @@ def from_config(cls, config):
118118

119119
def global_seed_generator():
120120
if jax_utils.is_in_jax_tracing_scope():
121-
raise ValueError(
122-
"[JAX RNG] When tracing a JAX function, "
123-
"you should only use seeded random ops, e.g. "
124-
"you should create a `SeedGenerator` instance, attach it "
125-
"to your layer/model, and pass the instance as the `seed` "
126-
"argument when calling random ops. Unseeded random ops "
127-
"would get incorrectly traced by JAX and would become constant "
128-
"after tracing. Example:\n\n"
129-
"```\n"
130-
"# Make sure to set the seed generator as a layer attribute\n"
131-
"self.seed_generator = keras.random.SeedGenerator(seed=1337)\n"
132-
"...\n"
133-
"out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
134-
"```"
135-
)
121+
# When we are in Jax Tracing mode, we provide a lightweight
122+
# object of the shape and dtype expected
123+
return jax_utils.JAXTracingSeedGenerator()
124+
136125
gen = global_state.get_global_attribute("global_seed_generator")
137-
global_seed = global_state.get_global_attribute("global_random_seed")
138-
if gen is None and global_seed is None:
126+
if gen is None:
139127
gen = SeedGenerator()
140128
global_state.set_global_attribute("global_seed_generator", gen)
141-
elif gen is None and global_seed is not None:
142-
gen = SeedGenerator(global_seed)
143-
global_state.set_global_attribute("global_seed_generator", gen)
144129
return gen
145130

146131

keras/src/random/seed_generator_test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,8 @@ def test_jax_tracing_with_global_seed_generator(self):
8484
def traced_function():
8585
return seed_generator.global_seed_generator().next()
8686

87-
with self.assertRaisesRegex(
88-
ValueError,
89-
"When tracing a JAX function, you should only use seeded random",
90-
):
91-
traced_function()
87+
result = traced_function()
88+
self.assertIsNotNone(result)
9289

9390
def test_seed_generator_serialization(self):
9491
random_generator = seed_generator.SeedGenerator(seed=42)

keras/src/utils/jax_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from jax import random
2+
13
from keras.src import backend
24

35

@@ -9,3 +11,24 @@ def is_in_jax_tracing_scope(x=None):
911
if c.__name__ == "Tracer" and c.__module__.startswith("jax"):
1012
return True
1113
return False
14+
15+
16+
def get_jax_random_seed(seed=None):
17+
if is_in_jax_tracing_scope():
18+
# Constant dummy seed for Tracing
19+
seed = 0
20+
else:
21+
# Gathering seed from a seed generator
22+
seed = backend.random.draw_seed(None)[0]
23+
return seed
24+
25+
26+
# Create a lightweight class that only provides shape/dtype info
27+
class JAXTracingSeedGenerator:
28+
def __init__(self):
29+
self._shape = (2,)
30+
self._dtype = "uint32"
31+
32+
def next(self, ordered=False):
33+
# Return a dummy key for tracing
34+
return random.key(0)

keras/src/utils/rng_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from keras.src import backend
66
from keras.src.api_export import keras_export
77
from keras.src.backend.common import global_state
8+
from keras.src.random import seed_generator
89
from keras.src.utils.module_utils import tensorflow as tf
910

1011
GLOBAL_RANDOM_SEED = "global_random_seed"
@@ -60,6 +61,10 @@ def set_random_seed(seed):
6061
import torch
6162

6263
torch.manual_seed(seed)
64+
if backend.backend() == "jax":
65+
# We create a global seed generator using the global random seed
66+
gen = seed_generator.SeedGenerator(seed)
67+
global_state.set_global_attribute("global_seed_generator", gen)
6368

6469

6570
def get_random_seed():

0 commit comments

Comments
 (0)