1616
1717The primary aim here is simplicity and minimal dependencies.
1818"""
19-
20-
2119import time
20+ from functools import partial
2221
2322import datasets
2423import jax
3130import jax_scaled_arithmetics as jsa
3231
3332
33+ def print_mean_std (name , v ):
34+ # Always use np.float32, to avoid floating errors in descaling + stats.
35+ v = jsa .asarray (v , dtype = np .float32 )
36+ m , s = np .mean (v ), np .std (v )
37+ print (name , m , s )
38+
39+
3440def init_random_params (scale , layer_sizes , rng = npr .RandomState (0 )):
3541 return [(scale * rng .randn (m , n ), scale * rng .randn (n )) for m , n , in zip (layer_sizes [:- 1 ], layer_sizes [1 :])]
3642
3743
3844def predict (params , inputs ):
3945 activations = inputs
4046 for w , b in params [:- 1 ]:
47+ jsa .ops .debug_callback (partial (print_mean_std , "W:" ), w )
48+
4149 # Matmul + relu
4250 outputs = jnp .dot (activations , w ) + b
4351 activations = jnp .maximum (outputs , 0 )
@@ -66,7 +74,7 @@ def accuracy(params, batch):
6674 step_size = 0.001
6775 num_epochs = 10
6876 batch_size = 128
69- training_dtype = np .float16
77+ training_dtype = np .float32
7078
7179 train_images , train_labels , test_images , test_labels = datasets .mnist ()
7280 num_train = train_images .shape [0 ]
@@ -93,9 +101,13 @@ def update(params, batch):
93101 grads = grad (loss )(params , batch )
94102 return [(w - step_size * dw , b - step_size * db ) for (w , b ), (dw , db ) in zip (params , grads )]
95103
104+ num_batches = 2
105+ num_epochs = 2
96106 for epoch in range (num_epochs ):
107+ print ("EPOCH:" , epoch )
97108 start_time = time .time ()
98109 for _ in range (num_batches ):
110+ print ("BATCH..." )
99111 batch = next (batches )
100112 # Scaled micro-batch + training dtype cast.
101113 batch = jsa .as_scaled_array (batch )
@@ -108,8 +120,8 @@ def update(params, batch):
108120
109121 # Evaluation in float32, for consistency.
110122 raw_params = jsa .asarray (params , dtype = np .float32 )
111- train_acc = accuracy (raw_params , (train_images , train_labels ))
112- test_acc = accuracy (raw_params , (test_images , test_labels ))
113- print (f"Epoch { epoch } in { epoch_time :0.2f} sec" )
114- print (f"Training set accuracy { train_acc :0.5f} " )
115- print (f"Test set accuracy { test_acc :0.5f} " )
123+ # train_acc = accuracy(raw_params, (train_images, train_labels))
124+ # test_acc = accuracy(raw_params, (test_images, test_labels))
125+ # print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
126+ # print(f"Training set accuracy {train_acc:0.5f}")
127+ # print(f"Test set accuracy {test_acc:0.5f}")
0 commit comments