File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change 66import numpy as np
77import string
88
9- import jax
109from jax .experimental import jax2tf
11- import keras
1210from keras .src import backend
1311from keras .src import random
1412from keras .src import tree
1917from keras .src .saving import serialization_lib
2018from keras .src .utils import jax_utils
2119from keras .src .utils import tracking
20+ from keras .src import ops
2221from keras .src .utils .module_utils import jax
23- import tensorflow as tf
22+ from keras . src . utils . module_utils import tensorflow as tf
2423
2524
2625
@@ -484,7 +483,7 @@ def _initialize_weights(self, input_shape):
484483 # Initialize `params` and `state` if needed by calling `init_fn`.
485484 def create_input (shape ):
486485 shape = [d if d is not None else 1 for d in shape ]
487- return keras . ops .ones (shape )
486+ return ops .ones (shape )
488487
489488 init_inputs = tree .map_shape_structure (create_input , input_shape )
490489 init_args = []
You can’t perform that action at this time.
0 commit comments