@@ -35,7 +35,7 @@ def print_mean_std(name, v):
3535 # Always use np.float32, to avoid floating errors in descaling + stats.
3636 v = jsa .asarray (data , dtype = np .float32 )
3737 m , s = np .mean (v ), np .std (v )
38- print (f"{ name } : MEAN({ m :.4f} ) / STD({ s :.4f} ) / SCALE({ scale :.4f } )" )
38+ print (f"{ name } : MEAN({ m :.4f} ) / STD({ s :.4f} ) / SCALE({ scale . dtype } )" )
3939
4040
4141def init_random_params (scale , layer_sizes , rng = npr .RandomState (0 )):
@@ -45,40 +45,43 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
4545def predict (params , inputs ):
4646 activations = inputs
4747 for w , b in params [:- 1 ]:
48- # jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49- # jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
50- # (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
48+ jsa .ops .debug_callback (partial (print_mean_std , "W" ), w )
49+ jsa .ops .debug_callback (partial (print_mean_std , "B" ), b )
50+ (w ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "WGrad" ), w )
5151
5252 # Matmul + relu
5353 outputs = jnp .dot (activations , w ) + b
5454 activations = jnp .maximum (outputs , 0 )
55+ jsa .ops .debug_callback (partial (print_mean_std , "Act" ), activations )
5556 # activations = jsa.ops.dynamic_rescale_l2_grad(activations)
5657
5758 final_w , final_b = params [- 1 ]
58- logits = jnp .dot (activations , final_w ) + final_b
59+ logits = jnp .dot (activations , final_w )
60+ jsa .ops .debug_callback (partial (print_mean_std , "Logits0" ), logits )
61+ logits = logits + final_b
5962
60- jsa .ops .debug_callback (partial (print_mean_std , "Logits " ), logits )
61- # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
63+ jsa .ops .debug_callback (partial (print_mean_std , "Logits1 " ), logits )
64+ (logits ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "LogitsGrad" ), logits )
6265
6366 logits = jsa .ops .dynamic_rescale_l2_grad (logits )
6467 # logits = logits.astype(np.float32)
65- # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
68+ (logits ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "LogitsGrad" ), logits )
6669
6770 logits = logits - logsumexp (logits , axis = 1 , keepdims = True )
6871 jsa .ops .debug_callback (partial (print_mean_std , "Logits2" ), logits )
69- (logits ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "LogitsGrad" ), logits )
72+ # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
7073 return logits
7174
7275
7376def loss (params , batch ):
7477 inputs , targets = batch
7578 preds = predict (params , inputs )
76- jsa .ops .debug_callback (partial (print_mean_std , "Preds" ), preds )
79+ # jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds)
7780 loss = jnp .sum (preds * targets , axis = 1 )
7881 # loss = jsa.ops.dynamic_rescale_l2(loss)
79- jsa .ops .debug_callback (partial (print_mean_std , "LOSS1" ), loss )
82+ # jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss)
8083 loss = - jnp .mean (loss )
81- jsa .ops .debug_callback (partial (print_mean_std , "LOSS2" ), loss )
84+ # jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss)
8285 return loss
8386 return - jnp .mean (jnp .sum (preds * targets , axis = 1 ))
8487
@@ -92,7 +95,7 @@ def accuracy(params, batch):
9295
9396if __name__ == "__main__" :
9497 layer_sizes = [784 , 1024 , 1024 , 10 ]
95- param_scale = 1 .0
98+ param_scale = 2 .0
9699 step_size = 0.001
97100 num_epochs = 10
98101 batch_size = 128
@@ -114,7 +117,7 @@ def data_stream():
114117 batches = data_stream ()
115118 params = init_random_params (param_scale , layer_sizes )
116119 # Transform parameters to `ScaledArray` and proper dtype.
117- params = jsa .as_scaled_array (params )
120+ params = jsa .as_scaled_array (params , scale = np . float32 ( 2 ) )
118121 params = jax .tree_map (lambda v : v .astype (training_dtype ), params )
119122
120123 @jit
@@ -127,8 +130,8 @@ def update(params, batch):
127130 for (w , b ), (dw , db ) in zip (params , grads )
128131 ]
129132
130- # num_batches = 4
131- # num_epochs = 2
133+ num_batches = 1
134+ num_epochs = 1
132135 for epoch in range (num_epochs ):
133136 # print("EPOCH:", epoch)
134137 start_time = time .time ()
0 commit comments