@@ -238,11 +238,12 @@ def verify_weights_and_params(layer):
238238 )
239239 model1 .summary ()
240240
241+ verify_weights_and_params (layer1 )
242+
241243 model1 .compile (
242244 loss = "categorical_crossentropy" ,
243245 optimizer = "adam" ,
244246 metrics = [metrics .CategoricalAccuracy ()],
245- run_eagerly = True ,
246247 )
247248
248249 tw1_before_fit = tree .map_structure (
@@ -259,8 +260,6 @@ def verify_weights_and_params(layer):
259260 backend .convert_to_numpy , layer1 .non_trainable_weights
260261 )
261262
262- verify_weights_and_params (layer1 )
263-
264263 # verify both trainable and non-trainable weights did change after fit
265264 for before , after in zip (tw1_before_fit , tw1_after_fit ):
266265 self .assertNotAllClose (before , after )
@@ -357,59 +356,59 @@ def call(self, inputs):
357356 output5 = model5 (x_test )
358357 self .assertNotAllClose (output5 , 0.0 )
359358
360- # @parameterized.named_parameters(
361- # {
362- # "testcase_name": "training_independent",
363- # "init_kwargs": {
364- # "call_fn": jax_stateless_apply,
365- # "init_fn": jax_stateless_init,
366- # },
367- # "trainable_weights": 6,
368- # "trainable_params": 266610,
369- # "non_trainable_weights": 0,
370- # "non_trainable_params": 0,
371- # },
372- # {
373- # "testcase_name": "training_state",
374- # "init_kwargs": {
375- # "call_fn": jax_stateful_apply,
376- # "init_fn": jax_stateful_init,
377- # },
378- # "trainable_weights": 6,
379- # "trainable_params": 266610,
380- # "non_trainable_weights": 1,
381- # "non_trainable_params": 1,
382- # },
383- # {
384- # "testcase_name": "training_state_dtype_policy",
385- # "init_kwargs": {
386- # "call_fn": jax_stateful_apply,
387- # "init_fn": jax_stateful_init,
388- # "dtype": DTypePolicy("mixed_float16"),
389- # },
390- # "trainable_weights": 6,
391- # "trainable_params": 266610,
392- # "non_trainable_weights": 1,
393- # "non_trainable_params": 1,
394- # },
395- # )
396- # def test_jax_layer(
397- # self,
398- # init_kwargs,
399- # trainable_weights,
400- # trainable_params,
401- # non_trainable_weights,
402- # non_trainable_params,
403- # ):
404- # self._test_layer(
405- # init_kwargs["call_fn"].__name__,
406- # JaxLayer,
407- # init_kwargs,
408- # trainable_weights,
409- # trainable_params,
410- # non_trainable_weights,
411- # non_trainable_params,
412- # )
359+ @parameterized .named_parameters (
360+ {
361+ "testcase_name" : "training_independent" ,
362+ "init_kwargs" : {
363+ "call_fn" : jax_stateless_apply ,
364+ "init_fn" : jax_stateless_init ,
365+ },
366+ "trainable_weights" : 6 ,
367+ "trainable_params" : 266610 ,
368+ "non_trainable_weights" : 0 ,
369+ "non_trainable_params" : 0 ,
370+ },
371+ {
372+ "testcase_name" : "training_state" ,
373+ "init_kwargs" : {
374+ "call_fn" : jax_stateful_apply ,
375+ "init_fn" : jax_stateful_init ,
376+ },
377+ "trainable_weights" : 6 ,
378+ "trainable_params" : 266610 ,
379+ "non_trainable_weights" : 1 ,
380+ "non_trainable_params" : 1 ,
381+ },
382+ {
383+ "testcase_name" : "training_state_dtype_policy" ,
384+ "init_kwargs" : {
385+ "call_fn" : jax_stateful_apply ,
386+ "init_fn" : jax_stateful_init ,
387+ "dtype" : DTypePolicy ("mixed_float16" ),
388+ },
389+ "trainable_weights" : 6 ,
390+ "trainable_params" : 266610 ,
391+ "non_trainable_weights" : 1 ,
392+ "non_trainable_params" : 1 ,
393+ },
394+ )
395+ def test_jax_layer (
396+ self ,
397+ init_kwargs ,
398+ trainable_weights ,
399+ trainable_params ,
400+ non_trainable_weights ,
401+ non_trainable_params ,
402+ ):
403+ self ._test_layer (
404+ init_kwargs ["call_fn" ].__name__ ,
405+ JaxLayer ,
406+ init_kwargs ,
407+ trainable_weights ,
408+ trainable_params ,
409+ non_trainable_weights ,
410+ non_trainable_params ,
411+ )
413412
414413 @parameterized .named_parameters (
415414 {
0 commit comments