1212from keras .src .utils import jax_utils
1313from keras .src .utils import tracking
1414from keras .src .utils .module_utils import jax
15-
15+ import tensorflow as tf
16+ from jax .experimental import jax2tf
17+ import keras
18+ import itertools
19+ import string
20+ import functools
21+ from keras .src import random
22+ import logging
23+ # from flax.core import FrozenDict, DictWrapper, ListWrapper
1624
1725@keras_export ("keras.layers.JaxLayer" )
1826class JaxLayer (Layer ):
@@ -196,6 +204,9 @@ def my_haiku_module_fn(inputs, training):
196204 init_fn: the function to call to initialize the model. See description
197205 above for the list of arguments it takes and the outputs it returns.
198206 If `None`, then `params` and/or `state` must be provided.
207+ compute_output_shape_fn: Function that takes the input shape
208+ (a tuple or nested structure of tuples) and returns the output
209+ shape (a tuple or nested structure of tuples).
199210 params: A `PyTree` containing all the model trainable parameters. This
200211 allows passing trained parameters or controlling the initialization.
201212 If both `params` and `state` are `None`, `init_fn` is called at
@@ -214,14 +225,15 @@ def __init__(
214225 self ,
215226 call_fn ,
216227 init_fn = None ,
228+ compute_output_shape_fn = None ,
217229 params = None ,
218230 state = None ,
219231 seed = None ,
220232 ** kwargs ,
221233 ):
222- if backend .backend () != "jax" :
234+ if backend .backend () not in [ "jax" , "tensorflow" ] :
223235 raise ValueError (
224- "JaxLayer is only supported with the JAX backend. Current "
236+ "JaxLayer is only supported with the JAX or Tensorflow backend. Current "
225237 f"backend: { backend .backend ()} "
226238 )
227239
@@ -233,7 +245,10 @@ def __init__(
233245 super ().__init__ (** kwargs )
234246 self .call_fn = call_fn
235247 self .init_fn = init_fn
236- self .seed_generator = backend .random .SeedGenerator (seed )
248+ self .compute_output_shape_fn = compute_output_shape_fn
249+ if seed is None :
250+ seed = random .seed_generator .make_default_seed ()
251+ self .jax_rng = jax .random .PRNGKey (seed )
237252 self .tracked_params = self ._create_variables (params , trainable = True )
238253 self .tracked_state = self ._create_variables (state , trainable = False )
239254 if self .params is not None or self .state is not None :
@@ -251,7 +266,12 @@ def __init__(
251266 self .init_fn_arguments = self ._validate_signature (
252267 init_fn , "init_fn" , {"rng" , "inputs" , "training" }, {"inputs" }
253268 )
269+
270+ # Attributes for jax2tf functions
271+ self .jax2tf_training_false_fn = None
272+ self .jax2tf_training_true_fn = None
254273
274+
255275 def _validate_signature (self , fn , fn_name , allowed , required ):
256276 fn_parameters = inspect .signature (fn ).parameters
257277 for parameter_name in required :
@@ -271,6 +291,78 @@ def _validate_signature(self, fn, fn_name, allowed, required):
271291 parameter_names .append (parameter .name )
272292
273293 return parameter_names
294+
295+ def _get_jax2tf_input_shape (self , input_shape ):
296+ """Convert input shape in a format suitable for `jax2tf`.
297+
298+ `jax2tf` expects a letter for each unknown dimension, which allows
299+ correlated dimensions. Since correlated dimensions are not supported by
300+ Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
301+ however use 'batch' for dimension 0 if not defined to correlate the
302+ batch size across inputs.
303+
304+ Example (spaces added for readability):
305+ ```
306+ input_shape: (None , 4 , None, None, 5 )
307+ result: "(batch, 4 , a , b , 5 )"
308+ ```
309+
310+ Args:
311+ input_shape: a single shape or a structure of shapes for the inputs.
312+ Returns:
313+ the shape or shapes structure in the `jax2tf` format as strings.
314+ """
315+ dim_names = itertools .chain (
316+ string .ascii_lowercase , # a, b, ... z
317+ itertools .starmap ( # aa, ab, ... az, ba, bb, ... zz
318+ lambda a , b : a + b ,
319+ itertools .product (string .ascii_lowercase , repeat = 2 ),
320+ ),
321+ )
322+
323+ def get_single_jax2tf_shape (shape ):
324+ jax2tf_shape = []
325+
326+ for index , dim in enumerate (shape ):
327+ if dim is not None :
328+ jax2tf_shape .append (str (dim ))
329+ elif index == 0 :
330+ jax2tf_shape .append ("batch" )
331+ else :
332+ jax2tf_shape .append (next (dim_names ))
333+
334+ return "(" + ", " .join (jax2tf_shape ) + ")"
335+
336+ res = tree .map_shape_structure (get_single_jax2tf_shape , input_shape )
337+ logging .info ("_get_jax2tf_input_shape res:" , res )
338+ return res
339+
340+ def _jax2tf_convert (self , fn , polymorphic_shapes ):
341+ converted_fn = jax2tf .convert (fn , polymorphic_shapes = polymorphic_shapes )
342+ # Autograph won't work with the output of jax2tf.
343+ converted_fn = tf .autograph .experimental .do_not_convert (converted_fn )
344+ return converted_fn
345+
346+ def _partial_with_positional (self , fn , index , value ):
347+ """Return a new partial with one positional argument set to a value.
348+
349+ This is needed because `jax2tf` only supports positional arguments and
350+ `functools.partial` only supports setting positional arguments starting
351+ from the left. Our use case is the `training` argument which is
352+ typically the righmost argument.
353+
354+ Args:
355+ fn: the function to wrap.
356+ index: the index of the positional argument to set to `value`.
357+ value: the value for the positional argument at `index`.
358+ """
359+
360+ @functools .wraps (fn )
361+ def wrapper (* args ):
362+ args = args [0 :index ] + (value ,) + args [index :]
363+ return fn (* args )
364+
365+ return wrapper
274366
275367 @tracking .no_automatic_dependency_tracking
276368 def _create_variables (self , values , trainable ):
@@ -296,14 +388,14 @@ def _create_variables(self, values, trainable):
296388
297389 def create_variable (value ):
298390 if backend .is_tensor (value ) or isinstance (
299- value , (np .ndarray , np .generic )
391+ value , (np .ndarray , np .generic , jax . Array )
300392 ):
301393 dtype = value .dtype
302394 if is_float_dtype (dtype ):
303395 dtype = None # Use the layer dtype policy
304396 return self .add_weight (
305397 value .shape ,
306- initializer = value ,
398+ initializer = backend . convert_to_tensor ( value ) if value is not None else None ,
307399 dtype = dtype ,
308400 trainable = trainable ,
309401 )
@@ -328,8 +420,15 @@ def create_variable(value):
328420 else :
329421 self .state = variables
330422
331- flat_variables , _ = jax .tree_util .tree_flatten (variables )
332- return flat_variables
423+ if backend .backend () == "jax" :
424+ flat_variables , _ = jax .tree_util .tree_flatten (variables )
425+ return flat_variables
426+ elif backend .backend () == "tensorflow" :
427+ return variables
428+
429+ def _split_jax_rng (self ):
430+ self .jax_rng , subkey = jax .random .split (self .jax_rng )
431+ return subkey
333432
334433 def _get_init_rng (self ):
335434 """
@@ -343,7 +442,7 @@ def _get_init_rng(self):
343442 a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
344443 the `rng` argument of `init_fn`.
345444 """
346- return self .seed_generator . next ()
445+ return self ._split_jax_rng ()
347446
348447 def _get_call_rng (self , training ):
349448 """
@@ -359,24 +458,22 @@ def _get_call_rng(self, training):
359458 the `rng` argument of `call_fn`.
360459 """
361460 if training :
362- return self .seed_generator . next ()
461+ return self ._split_jax_rng ()
363462 else :
364463 return None
365464
366- def build (self , input_shape ):
367- if self .params is not None or self .state is not None :
368- return
369-
370- if jax_utils .is_in_jax_tracing_scope ():
465+ def _initialize_weights (self , input_shape ):
466+ if jax_utils .is_in_jax_tracing_scope () or tf .inside_function ():
371467 # This exception is not actually shown, it is caught and a detailed
372468 # warning about calling 'build' is printed.
373- raise ValueError ("'JaxLayer' cannot be built in tracing scope" )
469+ raise ValueError ("'JaxLayer' cannot be built in tracing scope or inside tf function " )
374470
471+ logging .info ("_initialize_weights input_shape:" , input_shape )
375472 # Initialize `params` and `state` if needed by calling `init_fn`.
376473 def create_input (shape ):
377474 shape = [d if d is not None else 1 for d in shape ]
378- return jax . numpy .ones (shape )
379-
475+ return keras . ops .ones (shape )
476+
380477 init_inputs = tree .map_shape_structure (create_input , input_shape )
381478 init_args = []
382479 for argument_name in self .init_fn_arguments :
@@ -398,6 +495,44 @@ def create_input(shape):
398495 )
399496 self .tracked_state = self ._create_variables (init_state , trainable = False )
400497
498+
499+ def build (self , input_shape ):
500+ if self .params is None and self .state is None :
501+ self ._initialize_weights (input_shape )
502+
503+ if backend .backend () == "tensorflow" :
504+ polymorphic_shapes = []
505+ for argument in self .call_fn_arguments :
506+ if argument == "inputs" :
507+ polymorphic_shapes .append (
508+ self ._get_jax2tf_input_shape (input_shape )
509+ )
510+ elif argument != "training" :
511+ # params, state, rng
512+ polymorphic_shapes .append ("..." )
513+
514+ if "training" in self .call_fn_arguments :
515+ training_argument_index = self .call_fn_arguments .index ("training" )
516+ self .jax2tf_training_false_fn = self ._jax2tf_convert (
517+ self ._partial_with_positional (
518+ self .call_fn , training_argument_index , False
519+ ),
520+ polymorphic_shapes ,
521+ )
522+ self .jax2tf_training_true_fn = self ._jax2tf_convert (
523+ self ._partial_with_positional (
524+ self .call_fn , training_argument_index , True
525+ ),
526+ polymorphic_shapes ,
527+ )
528+ else :
529+ self .jax2tf_training_false_fn = self ._jax2tf_convert (
530+ self .call_fn ,
531+ polymorphic_shapes ,
532+ )
533+ self .jax2tf_training_true_fn = None
534+ super ().build (input_shape )
535+
401536 def call (self , inputs , training = False ):
402537 def unwrap_variable (variable ):
403538 return None if variable is None else variable .value
@@ -417,7 +552,8 @@ def unwrap_variable(variable):
417552 elif argument_name == "inputs" :
418553 call_args .append (inputs )
419554 elif argument_name == "training" :
420- call_args .append (training )
555+ if backend .backend () == "jax" :
556+ call_args .append (training )
421557
422558 def assign_state_to_variable (value , variable ):
423559 # This exists only to make debugging this error case easier.
@@ -429,14 +565,50 @@ def assign_state_to_variable(value, variable):
429565 )
430566 variable .assign (value )
431567
432- if self .has_state :
433- predictions , new_state = self .call_fn (* call_args )
434- jax .tree_util .tree_map (
435- assign_state_to_variable , new_state , self .state
436- )
437- return predictions
568+ def call_with_fn (fn ):
569+ if self .has_state :
570+ predictions , new_state = fn (* call_args )
571+ if backend .backend () == "jax" :
572+ jax .tree_util .tree_map (
573+ assign_state_to_variable , new_state , self .state
574+ )
575+ elif backend .backend () == "tensorflow" :
576+ # tf.nest.map_structure(
577+ # assign_state_to_variable, new_state, self.state
578+ # )
579+ new_state_leaves = jax .tree_util .tree_leaves (new_state )
580+ state_leaves = jax .tree_util .tree_leaves (self .state )
581+ if len (new_state_leaves ) != len (state_leaves ):
582+ # This indicates a more fundamental structure divergence.
583+ raise ValueError (
584+ "State leaf count mismatch between jax2tf output and layer state: "
585+ f"{ len (new_state_leaves )} vs { len (state_leaves )} . "
586+ f"new_state structure: { jax .tree_util .tree_structure (new_state )} , "
587+ f"self.state structure: { jax .tree_util .tree_structure (self .state )} "
588+ )
589+ for new_val , state_leaf in zip (new_state_leaves , state_leaves ):
590+ assign_state_to_variable (new_val , state_leaf )
591+
592+ return predictions
593+ else :
594+ return fn (* call_args )
595+ if backend .backend () == "jax" :
596+ return call_with_fn (self .call_fn )
597+ elif backend .backend () == "tensorflow" :
598+ if self .jax2tf_training_true_fn is None :
599+ return call_with_fn (self .jax2tf_training_false_fn )
600+ else :
601+ if training :
602+ return call_with_fn (self .jax2tf_training_true_fn )
603+ else :
604+ return call_with_fn (self .jax2tf_training_false_fn )
605+
606+ def compute_output_shape (self , input_shape ):
607+ if self .compute_output_shape_fn :
608+ return self .compute_output_shape_fn (input_shape )
438609 else :
439- return self .call_fn (* call_args )
610+ return super ().compute_output_shape (input_shape )
611+
440612
441613 def get_config (self ):
442614 config = {
@@ -549,19 +721,14 @@ def my_flax_module_wrapper(module, inputs, training):
549721 def __init__ (
550722 self ,
551723 module ,
724+ compute_output_shape_fn = None ,
552725 method = None ,
553726 variables = None ,
554727 ** kwargs ,
555728 ):
556729 # Late import to only require Flax when this is used.
557730 from flax .core import scope as flax_scope
558731
559- if backend .backend () != "jax" :
560- raise ValueError (
561- "FlaxLayer is only supported with the JAX backend. Current "
562- f"backend: { backend .backend ()} "
563- )
564-
565732 self .module = module
566733 self .method = method
567734
@@ -618,6 +785,7 @@ def init_without_training(rng, inputs):
618785 super ().__init__ (
619786 call_fn = call_fn ,
620787 init_fn = init_fn ,
788+ compute_output_shape_fn = compute_output_shape_fn ,
621789 params = params ,
622790 state = state ,
623791 ** kwargs ,
@@ -650,13 +818,13 @@ def _variables_to_params_and_state(self, variables):
650818
651819 def _get_init_rng (self ):
652820 return {
653- "params" : self .seed_generator . next (),
654- "dropout" : self .seed_generator . next (),
821+ "params" : self ._split_jax_rng (),
822+ "dropout" : self ._split_jax_rng (),
655823 }
656824
657825 def _get_call_rng (self , training ):
658826 if training :
659- return {"dropout" : self .seed_generator . next ()}
827+ return {"dropout" : self ._split_jax_rng ()}
660828 else :
661829 return {}
662830
0 commit comments