|
1 | 1 | import inspect |
2 | 2 |
|
| 3 | + |
| 4 | + |
| 5 | + |
| 6 | +import collections |
| 7 | +import functools |
| 8 | +import itertools |
| 9 | +import keras |
3 | 10 | import numpy as np |
| 11 | +import string |
| 12 | +import tensorflow as tf |
4 | 13 |
|
| 14 | +from jax.experimental import jax2tf |
5 | 15 | from keras.src import backend |
| 16 | +from keras.src import random |
6 | 17 | from keras.src import tree |
7 | 18 | from keras.src.api_export import keras_export |
8 | 19 | from keras.src.backend.common.variables import is_float_dtype |
|
12 | 23 | from keras.src.utils import jax_utils |
13 | 24 | from keras.src.utils import tracking |
14 | 25 | from keras.src.utils.module_utils import jax |
15 | | -import tensorflow as tf |
16 | | -from jax.experimental import jax2tf |
17 | | -import keras |
18 | | -import itertools |
19 | | -import string |
20 | | -import functools |
21 | | -from keras.src import random |
22 | | -import logging |
23 | | -import collections |
| 26 | + |
| 27 | + |
| 28 | + |
24 | 29 |
|
25 | 30 |
|
26 | 31 | def standardize_pytree_collections(pytree): |
27 | 32 | if isinstance(pytree, collections.abc.Mapping): |
28 | | - return {k: standardize_pytree_collections(v) for k, v in pytree.items()} |
| 33 | + return {k: standardize_pytree_collections(v) |
| 34 | + for k, v in pytree.items()} |
29 | 35 | elif isinstance(pytree, collections.abc.Sequence): |
30 | | - return [standardize_pytree_collections(v) for v in pytree] |
| 36 | + return [standardize_pytree_collections(v) |
| 37 | + for v in pytree] |
31 | 38 | else: |
32 | 39 | return pytree |
33 | 40 |
|
@@ -343,7 +350,6 @@ def get_single_jax2tf_shape(shape): |
343 | 350 | return "(" + ", ".join(jax2tf_shape) + ")" |
344 | 351 |
|
345 | 352 | res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape) |
346 | | - logging.info("_get_jax2tf_input_shape res:", res) |
347 | 353 | return res |
348 | 354 |
|
349 | 355 | def _jax2tf_convert(self, fn, polymorphic_shapes): |
@@ -475,9 +481,9 @@ def _initialize_weights(self, input_shape): |
475 | 481 | if jax_utils.is_in_jax_tracing_scope() or tf.inside_function(): |
476 | 482 | # This exception is not actually shown, it is caught and a detailed |
477 | 483 | # warning about calling 'build' is printed. |
478 | | - raise ValueError("'JaxLayer' cannot be built in tracing scope or inside tf function") |
| 484 | + raise ValueError("'JaxLayer' cannot be built in tracing scope" |
| 485 | + "or inside tf function") |
479 | 486 |
|
480 | | - logging.info("_initialize_weights input_shape:", input_shape) |
481 | 487 | # Initialize `params` and `state` if needed by calling `init_fn`. |
482 | 488 | def create_input(shape): |
483 | 489 | shape = [d if d is not None else 1 for d in shape] |
@@ -582,29 +588,10 @@ def call_with_fn(fn): |
582 | 588 | assign_state_to_variable, new_state, self.state |
583 | 589 | ) |
584 | 590 | elif backend.backend() == "tensorflow": |
585 | | - # self.state = standardize_pytree_collections(self.state) |
586 | | - print("\nself.state:", self.state) |
587 | | - print("new_state:", new_state) |
588 | | - print("self.state after: ", standardize_pytree_collections(self.state)) |
589 | | - print("pytree name", type(self.state).__name__) |
590 | | - print("pytree name", type(new_state).__name__) |
591 | | - jax.tree_util.tree_map(assign_state_to_variable, standardize_pytree_collections(new_state), standardize_pytree_collections(self.state)) |
592 | | - # jax.tree_util.tree_map( |
593 | | - # assign_state_to_variable, new_state, self.state |
594 | | - # ) |
595 | | - # new_state_leaves = jax.tree_util.tree_leaves(new_state) |
596 | | - # state_leaves = jax.tree_util.tree_leaves(self.state) |
597 | | - # if len(new_state_leaves) != len(state_leaves): |
598 | | - # # This indicates a more fundamental structure divergence. |
599 | | - # raise ValueError( |
600 | | - # "State leaf count mismatch between jax2tf output and layer state: " |
601 | | - # f"{len(new_state_leaves)} vs {len(state_leaves)}. " |
602 | | - # f"new_state structure: {jax.tree_util.tree_structure(new_state)}, " |
603 | | - # f"self.state structure: {jax.tree_util.tree_structure(self.state)}" |
604 | | - # ) |
605 | | - # for new_val, state_leaf in zip(new_state_leaves, state_leaves): |
606 | | - # assign_state_to_variable(new_val, state_leaf) |
607 | | - |
| 591 | + jax.tree_util.tree_map( |
| 592 | + assign_state_to_variable, |
| 593 | + standardize_pytree_collections(new_state), |
| 594 | + standardize_pytree_collections(self.state)) |
608 | 595 | return predictions |
609 | 596 | else: |
610 | 597 | return fn(*call_args) |
|
0 commit comments