3131
3232
3333def print_mean_std (name , v ):
34- _ , scale = jsa .lax .get_data_scale (v )
34+ data , scale = jsa .lax .get_data_scale (v )
3535 # Always use np.float32, to avoid floating errors in descaling + stats.
36- v = jsa .asarray (v , dtype = np .float32 )
36+ v = jsa .asarray (data , dtype = np .float32 )
3737 m , s = np .mean (v ), np .std (v )
3838 print (f"{ name } : MEAN({ m :.4f} ) / STD({ s :.4f} ) / SCALE({ scale :.4f} )" )
3939
@@ -45,19 +45,23 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
4545def predict (params , inputs ):
4646 activations = inputs
4747 for w , b in params [:- 1 ]:
48- jsa .ops .debug_callback (partial (print_mean_std , "W" ), w )
49- jsa .ops .debug_callback (partial (print_mean_std , "B" ), b )
50- (w ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "WGrad" ), w )
48+ # jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49+ # jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
50+ # (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
5151
5252 # Matmul + relu
5353 outputs = jnp .dot (activations , w ) + b
5454 activations = jnp .maximum (outputs , 0 )
55+ # activations = jsa.ops.dynamic_rescale_l2_grad(activations)
5556
5657 final_w , final_b = params [- 1 ]
5758 logits = jnp .dot (activations , final_w ) + final_b
5859
59- jsa .ops .debug_callback (partial (print_mean_std , "Logits" ), logits )
60- (logits ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "LogitsGrad" ), logits )
60+ # jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
61+ # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
62+
63+ logits = jsa .ops .dynamic_rescale_l1_grad (logits )
64+ # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
6165
6266 return logits - logsumexp (logits , axis = 1 , keepdims = True )
6367
@@ -81,7 +85,7 @@ def accuracy(params, batch):
8185 step_size = 0.001
8286 num_epochs = 10
8387 batch_size = 128
84- training_dtype = np .float32
88+ training_dtype = np .float16
8589
8690 train_images , train_labels , test_images , test_labels = datasets .mnist ()
8791 num_train = train_images .shape [0 ]
@@ -106,15 +110,19 @@ def data_stream():
106110 @jsa .autoscale
107111 def update (params , batch ):
108112 grads = grad (loss )(params , batch )
109- return [(w - step_size * dw , b - step_size * db ) for (w , b ), (dw , db ) in zip (params , grads )]
110-
111- num_batches = 4
112- num_epochs = 2
113+ # return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
114+ return [
115+ (jsa .ops .dynamic_rescale_l1 (w - step_size * dw ), jsa .ops .dynamic_rescale_l1 (b - step_size * db ))
116+ for (w , b ), (dw , db ) in zip (params , grads )
117+ ]
118+
119+ # num_batches = 4
120+ # num_epochs = 2
113121 for epoch in range (num_epochs ):
114- print ("EPOCH:" , epoch )
122+ # print("EPOCH:", epoch)
115123 start_time = time .time ()
116124 for _ in range (num_batches ):
117- print ("BATCH..." )
125+ # print("BATCH...")
118126 batch = next (batches )
119127 # Scaled micro-batch + training dtype cast.
120128 batch = jsa .as_scaled_array (batch )
@@ -127,8 +135,8 @@ def update(params, batch):
127135
128136 # Evaluation in float32, for consistency.
129137 raw_params = jsa .asarray (params , dtype = np .float32 )
130- # train_acc = accuracy(raw_params, (train_images, train_labels))
131- # test_acc = accuracy(raw_params, (test_images, test_labels))
132- # print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
133- # print(f"Training set accuracy {train_acc:0.5f}")
134- # print(f"Test set accuracy {test_acc:0.5f}")
138+ train_acc = accuracy (raw_params , (train_images , train_labels ))
139+ test_acc = accuracy (raw_params , (test_images , test_labels ))
140+ print (f"Epoch { epoch } in { epoch_time :0.2f} sec" )
141+ print (f"Training set accuracy { train_acc :0.5f} " )
142+ print (f"Test set accuracy { test_acc :0.5f} " )
0 commit comments