Skip to content

Commit 57beea1

Browse files
committed
import fix
1 parent fa00d7a commit 57beea1

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

keras/src/utils/jax_layer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
import numpy as np
77
import string
88

9-
import jax
109
from jax.experimental import jax2tf
11-
import keras
1210
from keras.src import backend
1311
from keras.src import random
1412
from keras.src import tree
@@ -19,8 +17,9 @@
1917
from keras.src.saving import serialization_lib
2018
from keras.src.utils import jax_utils
2119
from keras.src.utils import tracking
20+
from keras.src import ops
2221
from keras.src.utils.module_utils import jax
23-
import tensorflow as tf
22+
from keras.src.utils.module_utils import tensorflow as tf
2423

2524

2625

@@ -484,7 +483,7 @@ def _initialize_weights(self, input_shape):
484483
# Initialize `params` and `state` if needed by calling `init_fn`.
485484
def create_input(shape):
486485
shape = [d if d is not None else 1 for d in shape]
487-
return keras.ops.ones(shape)
486+
return ops.ones(shape)
488487

489488
init_inputs = tree.map_shape_structure(create_input, input_shape)
490489
init_args = []

0 commit comments

Comments
 (0)