Skip to content

Commit 4d484d9

Browse files
committed
format
1 parent 3d977fb commit 4d484d9

File tree

1 file changed

+24
-37
lines changed

1 file changed

+24
-37
lines changed

keras/src/utils/jax_layer.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
import inspect
22

3+
4+
5+
6+
import collections
7+
import functools
8+
import itertools
9+
import keras
310
import numpy as np
11+
import string
12+
import tensorflow as tf
413

14+
from jax.experimental import jax2tf
515
from keras.src import backend
16+
from keras.src import random
617
from keras.src import tree
718
from keras.src.api_export import keras_export
819
from keras.src.backend.common.variables import is_float_dtype
@@ -12,22 +23,18 @@
1223
from keras.src.utils import jax_utils
1324
from keras.src.utils import tracking
1425
from keras.src.utils.module_utils import jax
15-
import tensorflow as tf
16-
from jax.experimental import jax2tf
17-
import keras
18-
import itertools
19-
import string
20-
import functools
21-
from keras.src import random
22-
import logging
23-
import collections
26+
27+
28+
2429

2530

2631
def standardize_pytree_collections(pytree):
2732
if isinstance(pytree, collections.abc.Mapping):
28-
return {k: standardize_pytree_collections(v) for k, v in pytree.items()}
33+
return {k: standardize_pytree_collections(v)
34+
for k, v in pytree.items()}
2935
elif isinstance(pytree, collections.abc.Sequence):
30-
return [standardize_pytree_collections(v) for v in pytree]
36+
return [standardize_pytree_collections(v)
37+
for v in pytree]
3138
else:
3239
return pytree
3340

@@ -343,7 +350,6 @@ def get_single_jax2tf_shape(shape):
343350
return "(" + ", ".join(jax2tf_shape) + ")"
344351

345352
res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
346-
logging.info("_get_jax2tf_input_shape res:", res)
347353
return res
348354

349355
def _jax2tf_convert(self, fn, polymorphic_shapes):
@@ -475,9 +481,9 @@ def _initialize_weights(self, input_shape):
475481
if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
476482
# This exception is not actually shown, it is caught and a detailed
477483
# warning about calling 'build' is printed.
478-
raise ValueError("'JaxLayer' cannot be built in tracing scope or inside tf function")
484+
raise ValueError("'JaxLayer' cannot be built in tracing scope"
485+
"or inside tf function")
479486

480-
logging.info("_initialize_weights input_shape:", input_shape)
481487
# Initialize `params` and `state` if needed by calling `init_fn`.
482488
def create_input(shape):
483489
shape = [d if d is not None else 1 for d in shape]
@@ -582,29 +588,10 @@ def call_with_fn(fn):
582588
assign_state_to_variable, new_state, self.state
583589
)
584590
elif backend.backend() == "tensorflow":
585-
# self.state = standardize_pytree_collections(self.state)
586-
print("\nself.state:", self.state)
587-
print("new_state:", new_state)
588-
print("self.state after: ", standardize_pytree_collections(self.state))
589-
print("pytree name", type(self.state).__name__)
590-
print("pytree name", type(new_state).__name__)
591-
jax.tree_util.tree_map(assign_state_to_variable, standardize_pytree_collections(new_state), standardize_pytree_collections(self.state))
592-
# jax.tree_util.tree_map(
593-
# assign_state_to_variable, new_state, self.state
594-
# )
595-
# new_state_leaves = jax.tree_util.tree_leaves(new_state)
596-
# state_leaves = jax.tree_util.tree_leaves(self.state)
597-
# if len(new_state_leaves) != len(state_leaves):
598-
# # This indicates a more fundamental structure divergence.
599-
# raise ValueError(
600-
# "State leaf count mismatch between jax2tf output and layer state: "
601-
# f"{len(new_state_leaves)} vs {len(state_leaves)}. "
602-
# f"new_state structure: {jax.tree_util.tree_structure(new_state)}, "
603-
# f"self.state structure: {jax.tree_util.tree_structure(self.state)}"
604-
# )
605-
# for new_val, state_leaf in zip(new_state_leaves, state_leaves):
606-
# assign_state_to_variable(new_val, state_leaf)
607-
591+
jax.tree_util.tree_map(
592+
assign_state_to_variable,
593+
standardize_pytree_collections(new_state),
594+
standardize_pytree_collections(self.state))
608595
return predictions
609596
else:
610597
return fn(*call_args)

0 commit comments

Comments
 (0)