Skip to content

Commit c8aef26

Browse files
committed
support jax2tf in JaxLayer
1 parent bea37c5 commit c8aef26

File tree

4 files changed

+247
-53
lines changed

4 files changed

+247
-53
lines changed

keras/src/layers/layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,10 @@ def compute_output_spec(self, *args, **kwargs):
11451145
call_spec=call_spec,
11461146
class_name=self.__class__.__name__,
11471147
)
1148-
output_shape = self.compute_output_shape(**shapes_dict)
1148+
try:
1149+
output_shape = self.compute_output_shape(**shapes_dict)
1150+
except NotImplementedError as e:
1151+
return super().compute_output_spec(*args, **kwargs)
11491152

11501153
if (
11511154
isinstance(output_shape, list)

keras/src/utils/jax_layer.py

Lines changed: 203 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@
1212
from keras.src.utils import jax_utils
1313
from keras.src.utils import tracking
1414
from keras.src.utils.module_utils import jax
15-
15+
import tensorflow as tf
16+
from jax.experimental import jax2tf
17+
import keras
18+
import itertools
19+
import string
20+
import functools
21+
from keras.src import random
22+
import logging
23+
# from flax.core import FrozenDict, DictWrapper, ListWrapper
1624

1725
@keras_export("keras.layers.JaxLayer")
1826
class JaxLayer(Layer):
@@ -196,6 +204,9 @@ def my_haiku_module_fn(inputs, training):
196204
init_fn: the function to call to initialize the model. See description
197205
above for the list of arguments it takes and the outputs it returns.
198206
If `None`, then `params` and/or `state` must be provided.
207+
compute_output_shape_fn: Function that takes the input shape
208+
(a tuple or nested structure of tuples) and returns the output
209+
shape (a tuple or nested structure of tuples).
199210
params: A `PyTree` containing all the model trainable parameters. This
200211
allows passing trained parameters or controlling the initialization.
201212
If both `params` and `state` are `None`, `init_fn` is called at
@@ -214,14 +225,15 @@ def __init__(
214225
self,
215226
call_fn,
216227
init_fn=None,
228+
compute_output_shape_fn=None,
217229
params=None,
218230
state=None,
219231
seed=None,
220232
**kwargs,
221233
):
222-
if backend.backend() != "jax":
234+
if backend.backend() not in ["jax", "tensorflow"]:
223235
raise ValueError(
224-
"JaxLayer is only supported with the JAX backend. Current "
236+
"JaxLayer is only supported with the JAX or Tensorflow backend. Current "
225237
f"backend: {backend.backend()}"
226238
)
227239

@@ -233,7 +245,10 @@ def __init__(
233245
super().__init__(**kwargs)
234246
self.call_fn = call_fn
235247
self.init_fn = init_fn
236-
self.seed_generator = backend.random.SeedGenerator(seed)
248+
self.compute_output_shape_fn = compute_output_shape_fn
249+
if seed is None:
250+
seed = random.seed_generator.make_default_seed()
251+
self.jax_rng = jax.random.PRNGKey(seed)
237252
self.tracked_params = self._create_variables(params, trainable=True)
238253
self.tracked_state = self._create_variables(state, trainable=False)
239254
if self.params is not None or self.state is not None:
@@ -251,7 +266,12 @@ def __init__(
251266
self.init_fn_arguments = self._validate_signature(
252267
init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
253268
)
269+
270+
# Attributes for jax2tf functions
271+
self.jax2tf_training_false_fn = None
272+
self.jax2tf_training_true_fn = None
254273

274+
255275
def _validate_signature(self, fn, fn_name, allowed, required):
256276
fn_parameters = inspect.signature(fn).parameters
257277
for parameter_name in required:
@@ -271,6 +291,78 @@ def _validate_signature(self, fn, fn_name, allowed, required):
271291
parameter_names.append(parameter.name)
272292

273293
return parameter_names
294+
295+
def _get_jax2tf_input_shape(self, input_shape):
296+
"""Convert input shape in a format suitable for `jax2tf`.
297+
298+
`jax2tf` expects a letter for each unknown dimension, which allows
299+
correlated dimensions. Since correlated dimensions are not supported by
300+
Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
301+
however use 'batch' for dimension 0 if not defined to correlate the
302+
batch size across inputs.
303+
304+
Example (spaces added for readability):
305+
```
306+
input_shape: (None , 4 , None, None, 5 )
307+
result: "(batch, 4 , a , b , 5 )"
308+
```
309+
310+
Args:
311+
input_shape: a single shape or a structure of shapes for the inputs.
312+
Returns:
313+
the shape or shapes structure in the `jax2tf` format as strings.
314+
"""
315+
dim_names = itertools.chain(
316+
string.ascii_lowercase, # a, b, ... z
317+
itertools.starmap( # aa, ab, ... az, ba, bb, ... zz
318+
lambda a, b: a + b,
319+
itertools.product(string.ascii_lowercase, repeat=2),
320+
),
321+
)
322+
323+
def get_single_jax2tf_shape(shape):
324+
jax2tf_shape = []
325+
326+
for index, dim in enumerate(shape):
327+
if dim is not None:
328+
jax2tf_shape.append(str(dim))
329+
elif index == 0:
330+
jax2tf_shape.append("batch")
331+
else:
332+
jax2tf_shape.append(next(dim_names))
333+
334+
return "(" + ", ".join(jax2tf_shape) + ")"
335+
336+
res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
337+
logging.info("_get_jax2tf_input_shape res:", res)
338+
return res
339+
340+
def _jax2tf_convert(self, fn, polymorphic_shapes):
341+
converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)
342+
# Autograph won't work with the output of jax2tf.
343+
converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)
344+
return converted_fn
345+
346+
def _partial_with_positional(self, fn, index, value):
347+
"""Return a new partial with one positional argument set to a value.
348+
349+
This is needed because `jax2tf` only supports positional arguments and
350+
`functools.partial` only supports setting positional arguments starting
351+
from the left. Our use case is the `training` argument which is
352+
typically the righmost argument.
353+
354+
Args:
355+
fn: the function to wrap.
356+
index: the index of the positional argument to set to `value`.
357+
value: the value for the positional argument at `index`.
358+
"""
359+
360+
@functools.wraps(fn)
361+
def wrapper(*args):
362+
args = args[0:index] + (value,) + args[index:]
363+
return fn(*args)
364+
365+
return wrapper
274366

275367
@tracking.no_automatic_dependency_tracking
276368
def _create_variables(self, values, trainable):
@@ -296,14 +388,14 @@ def _create_variables(self, values, trainable):
296388

297389
def create_variable(value):
298390
if backend.is_tensor(value) or isinstance(
299-
value, (np.ndarray, np.generic)
391+
value, (np.ndarray, np.generic, jax.Array)
300392
):
301393
dtype = value.dtype
302394
if is_float_dtype(dtype):
303395
dtype = None # Use the layer dtype policy
304396
return self.add_weight(
305397
value.shape,
306-
initializer=value,
398+
initializer=backend.convert_to_tensor(value) if value is not None else None,
307399
dtype=dtype,
308400
trainable=trainable,
309401
)
@@ -328,8 +420,15 @@ def create_variable(value):
328420
else:
329421
self.state = variables
330422

331-
flat_variables, _ = jax.tree_util.tree_flatten(variables)
332-
return flat_variables
423+
if backend.backend() == "jax":
424+
flat_variables, _ = jax.tree_util.tree_flatten(variables)
425+
return flat_variables
426+
elif backend.backend() == "tensorflow":
427+
return variables
428+
429+
def _split_jax_rng(self):
430+
self.jax_rng, subkey = jax.random.split(self.jax_rng)
431+
return subkey
333432

334433
def _get_init_rng(self):
335434
"""
@@ -343,7 +442,7 @@ def _get_init_rng(self):
343442
a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
344443
the `rng` argument of `init_fn`.
345444
"""
346-
return self.seed_generator.next()
445+
return self._split_jax_rng()
347446

348447
def _get_call_rng(self, training):
349448
"""
@@ -359,24 +458,22 @@ def _get_call_rng(self, training):
359458
the `rng` argument of `call_fn`.
360459
"""
361460
if training:
362-
return self.seed_generator.next()
461+
return self._split_jax_rng()
363462
else:
364463
return None
365464

366-
def build(self, input_shape):
367-
if self.params is not None or self.state is not None:
368-
return
369-
370-
if jax_utils.is_in_jax_tracing_scope():
465+
def _initialize_weights(self, input_shape):
466+
if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
371467
# This exception is not actually shown, it is caught and a detailed
372468
# warning about calling 'build' is printed.
373-
raise ValueError("'JaxLayer' cannot be built in tracing scope")
469+
raise ValueError("'JaxLayer' cannot be built in tracing scope or inside tf function")
374470

471+
logging.info("_initialize_weights input_shape:", input_shape)
375472
# Initialize `params` and `state` if needed by calling `init_fn`.
376473
def create_input(shape):
377474
shape = [d if d is not None else 1 for d in shape]
378-
return jax.numpy.ones(shape)
379-
475+
return keras.ops.ones(shape)
476+
380477
init_inputs = tree.map_shape_structure(create_input, input_shape)
381478
init_args = []
382479
for argument_name in self.init_fn_arguments:
@@ -398,6 +495,44 @@ def create_input(shape):
398495
)
399496
self.tracked_state = self._create_variables(init_state, trainable=False)
400497

498+
499+
def build(self, input_shape):
500+
if self.params is None and self.state is None:
501+
self._initialize_weights(input_shape)
502+
503+
if backend.backend() == "tensorflow":
504+
polymorphic_shapes = []
505+
for argument in self.call_fn_arguments:
506+
if argument == "inputs":
507+
polymorphic_shapes.append(
508+
self._get_jax2tf_input_shape(input_shape)
509+
)
510+
elif argument != "training":
511+
# params, state, rng
512+
polymorphic_shapes.append("...")
513+
514+
if "training" in self.call_fn_arguments:
515+
training_argument_index = self.call_fn_arguments.index("training")
516+
self.jax2tf_training_false_fn = self._jax2tf_convert(
517+
self._partial_with_positional(
518+
self.call_fn, training_argument_index, False
519+
),
520+
polymorphic_shapes,
521+
)
522+
self.jax2tf_training_true_fn = self._jax2tf_convert(
523+
self._partial_with_positional(
524+
self.call_fn, training_argument_index, True
525+
),
526+
polymorphic_shapes,
527+
)
528+
else:
529+
self.jax2tf_training_false_fn = self._jax2tf_convert(
530+
self.call_fn,
531+
polymorphic_shapes,
532+
)
533+
self.jax2tf_training_true_fn = None
534+
super().build(input_shape)
535+
401536
def call(self, inputs, training=False):
402537
def unwrap_variable(variable):
403538
return None if variable is None else variable.value
@@ -417,7 +552,8 @@ def unwrap_variable(variable):
417552
elif argument_name == "inputs":
418553
call_args.append(inputs)
419554
elif argument_name == "training":
420-
call_args.append(training)
555+
if backend.backend() == "jax":
556+
call_args.append(training)
421557

422558
def assign_state_to_variable(value, variable):
423559
# This exists only to make debugging this error case easier.
@@ -429,14 +565,50 @@ def assign_state_to_variable(value, variable):
429565
)
430566
variable.assign(value)
431567

432-
if self.has_state:
433-
predictions, new_state = self.call_fn(*call_args)
434-
jax.tree_util.tree_map(
435-
assign_state_to_variable, new_state, self.state
436-
)
437-
return predictions
568+
def call_with_fn(fn):
569+
if self.has_state:
570+
predictions, new_state = fn(*call_args)
571+
if backend.backend() == "jax":
572+
jax.tree_util.tree_map(
573+
assign_state_to_variable, new_state, self.state
574+
)
575+
elif backend.backend() == "tensorflow":
576+
# tf.nest.map_structure(
577+
# assign_state_to_variable, new_state, self.state
578+
# )
579+
new_state_leaves = jax.tree_util.tree_leaves(new_state)
580+
state_leaves = jax.tree_util.tree_leaves(self.state)
581+
if len(new_state_leaves) != len(state_leaves):
582+
# This indicates a more fundamental structure divergence.
583+
raise ValueError(
584+
"State leaf count mismatch between jax2tf output and layer state: "
585+
f"{len(new_state_leaves)} vs {len(state_leaves)}. "
586+
f"new_state structure: {jax.tree_util.tree_structure(new_state)}, "
587+
f"self.state structure: {jax.tree_util.tree_structure(self.state)}"
588+
)
589+
for new_val, state_leaf in zip(new_state_leaves, state_leaves):
590+
assign_state_to_variable(new_val, state_leaf)
591+
592+
return predictions
593+
else:
594+
return fn(*call_args)
595+
if backend.backend() == "jax":
596+
return call_with_fn(self.call_fn)
597+
elif backend.backend() == "tensorflow":
598+
if self.jax2tf_training_true_fn is None:
599+
return call_with_fn(self.jax2tf_training_false_fn)
600+
else:
601+
if training:
602+
return call_with_fn(self.jax2tf_training_true_fn)
603+
else:
604+
return call_with_fn(self.jax2tf_training_false_fn)
605+
606+
def compute_output_shape(self, input_shape):
607+
if self.compute_output_shape_fn:
608+
return self.compute_output_shape_fn(input_shape)
438609
else:
439-
return self.call_fn(*call_args)
610+
return super().compute_output_shape(input_shape)
611+
440612

441613
def get_config(self):
442614
config = {
@@ -549,19 +721,14 @@ def my_flax_module_wrapper(module, inputs, training):
549721
def __init__(
550722
self,
551723
module,
724+
compute_output_shape_fn=None,
552725
method=None,
553726
variables=None,
554727
**kwargs,
555728
):
556729
# Late import to only require Flax when this is used.
557730
from flax.core import scope as flax_scope
558731

559-
if backend.backend() != "jax":
560-
raise ValueError(
561-
"FlaxLayer is only supported with the JAX backend. Current "
562-
f"backend: {backend.backend()}"
563-
)
564-
565732
self.module = module
566733
self.method = method
567734

@@ -618,6 +785,7 @@ def init_without_training(rng, inputs):
618785
super().__init__(
619786
call_fn=call_fn,
620787
init_fn=init_fn,
788+
compute_output_shape_fn=compute_output_shape_fn,
621789
params=params,
622790
state=state,
623791
**kwargs,
@@ -650,13 +818,13 @@ def _variables_to_params_and_state(self, variables):
650818

651819
def _get_init_rng(self):
652820
return {
653-
"params": self.seed_generator.next(),
654-
"dropout": self.seed_generator.next(),
821+
"params": self._split_jax_rng(),
822+
"dropout": self._split_jax_rng(),
655823
}
656824

657825
def _get_call_rng(self, training):
658826
if training:
659-
return {"dropout": self.seed_generator.next()}
827+
return {"dropout": self._split_jax_rng()}
660828
else:
661829
return {}
662830

0 commit comments

Comments
 (0)