Skip to content

Commit 1dfe1d3

Browse files
committed
fix test
1 parent 4d484d9 commit 1dfe1d3

File tree

1 file changed

+55
-56
lines changed

1 file changed

+55
-56
lines changed

keras/src/utils/jax_layer_test.py

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,12 @@ def verify_weights_and_params(layer):
238238
)
239239
model1.summary()
240240

241+
verify_weights_and_params(layer1)
242+
241243
model1.compile(
242244
loss="categorical_crossentropy",
243245
optimizer="adam",
244246
metrics=[metrics.CategoricalAccuracy()],
245-
run_eagerly=True,
246247
)
247248

248249
tw1_before_fit = tree.map_structure(
@@ -259,8 +260,6 @@ def verify_weights_and_params(layer):
259260
backend.convert_to_numpy, layer1.non_trainable_weights
260261
)
261262

262-
verify_weights_and_params(layer1)
263-
264263
# verify both trainable and non-trainable weights did change after fit
265264
for before, after in zip(tw1_before_fit, tw1_after_fit):
266265
self.assertNotAllClose(before, after)
@@ -357,59 +356,59 @@ def call(self, inputs):
357356
output5 = model5(x_test)
358357
self.assertNotAllClose(output5, 0.0)
359358

360-
# @parameterized.named_parameters(
361-
# {
362-
# "testcase_name": "training_independent",
363-
# "init_kwargs": {
364-
# "call_fn": jax_stateless_apply,
365-
# "init_fn": jax_stateless_init,
366-
# },
367-
# "trainable_weights": 6,
368-
# "trainable_params": 266610,
369-
# "non_trainable_weights": 0,
370-
# "non_trainable_params": 0,
371-
# },
372-
# {
373-
# "testcase_name": "training_state",
374-
# "init_kwargs": {
375-
# "call_fn": jax_stateful_apply,
376-
# "init_fn": jax_stateful_init,
377-
# },
378-
# "trainable_weights": 6,
379-
# "trainable_params": 266610,
380-
# "non_trainable_weights": 1,
381-
# "non_trainable_params": 1,
382-
# },
383-
# {
384-
# "testcase_name": "training_state_dtype_policy",
385-
# "init_kwargs": {
386-
# "call_fn": jax_stateful_apply,
387-
# "init_fn": jax_stateful_init,
388-
# "dtype": DTypePolicy("mixed_float16"),
389-
# },
390-
# "trainable_weights": 6,
391-
# "trainable_params": 266610,
392-
# "non_trainable_weights": 1,
393-
# "non_trainable_params": 1,
394-
# },
395-
# )
396-
# def test_jax_layer(
397-
# self,
398-
# init_kwargs,
399-
# trainable_weights,
400-
# trainable_params,
401-
# non_trainable_weights,
402-
# non_trainable_params,
403-
# ):
404-
# self._test_layer(
405-
# init_kwargs["call_fn"].__name__,
406-
# JaxLayer,
407-
# init_kwargs,
408-
# trainable_weights,
409-
# trainable_params,
410-
# non_trainable_weights,
411-
# non_trainable_params,
412-
# )
359+
@parameterized.named_parameters(
360+
{
361+
"testcase_name": "training_independent",
362+
"init_kwargs": {
363+
"call_fn": jax_stateless_apply,
364+
"init_fn": jax_stateless_init,
365+
},
366+
"trainable_weights": 6,
367+
"trainable_params": 266610,
368+
"non_trainable_weights": 0,
369+
"non_trainable_params": 0,
370+
},
371+
{
372+
"testcase_name": "training_state",
373+
"init_kwargs": {
374+
"call_fn": jax_stateful_apply,
375+
"init_fn": jax_stateful_init,
376+
},
377+
"trainable_weights": 6,
378+
"trainable_params": 266610,
379+
"non_trainable_weights": 1,
380+
"non_trainable_params": 1,
381+
},
382+
{
383+
"testcase_name": "training_state_dtype_policy",
384+
"init_kwargs": {
385+
"call_fn": jax_stateful_apply,
386+
"init_fn": jax_stateful_init,
387+
"dtype": DTypePolicy("mixed_float16"),
388+
},
389+
"trainable_weights": 6,
390+
"trainable_params": 266610,
391+
"non_trainable_weights": 1,
392+
"non_trainable_params": 1,
393+
},
394+
)
395+
def test_jax_layer(
396+
self,
397+
init_kwargs,
398+
trainable_weights,
399+
trainable_params,
400+
non_trainable_weights,
401+
non_trainable_params,
402+
):
403+
self._test_layer(
404+
init_kwargs["call_fn"].__name__,
405+
JaxLayer,
406+
init_kwargs,
407+
trainable_weights,
408+
trainable_params,
409+
non_trainable_weights,
410+
non_trainable_params,
411+
)
413412

414413
@parameterized.named_parameters(
415414
{

0 commit comments

Comments
 (0)