3131
3232
3333def print_mean_std (name , v ):
34+ _ , scale = jsa .lax .get_data_scale (v )
3435 # Always use np.float32, to avoid floating errors in descaling + stats.
3536 v = jsa .asarray (v , dtype = np .float32 )
3637 m , s = np .mean (v ), np .std (v )
37- print (name , m , s )
38+ print (f" { name } : MEAN( { m :.4f } ) / STD( { s :.4f } ) / SCALE( { scale :.4f } )" )
3839
3940
4041def init_random_params (scale , layer_sizes , rng = npr .RandomState (0 )):
@@ -44,8 +45,8 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
4445def predict (params , inputs ):
4546 activations = inputs
4647 for w , b in params [:- 1 ]:
47- jsa .ops .debug_callback (partial (print_mean_std , "W: " ), w )
48- (w ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "WG: " ), w )
48+ jsa .ops .debug_callback (partial (print_mean_std , "W" ), w )
49+ (w ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "WGrad " ), w )
4950
5051 # Matmul + relu
5152 outputs = jnp .dot (activations , w ) + b
@@ -54,8 +55,8 @@ def predict(params, inputs):
5455 final_w , final_b = params [- 1 ]
5556 logits = jnp .dot (activations , final_w ) + final_b
5657
57- jsa .ops .debug_callback (partial (print_mean_std , "LOGITS: " ), logits )
58- (logits ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "LOGITSG: " ), logits )
58+ jsa .ops .debug_callback (partial (print_mean_std , "Logits " ), logits )
59+ (logits ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "LogitsGrad " ), logits )
5960
6061 return logits - logsumexp (logits , axis = 1 , keepdims = True )
6162
0 commit comments