Skip to content

Commit 54a9ff2

Browse files
committed
wip
1 parent 8949497 commit 54a9ff2

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
1717
The primary aim here is simplicity and minimal dependencies.
1818
"""
19-
20-
2119
import time
20+
from functools import partial
2221

2322
import datasets
2423
import jax
@@ -31,13 +30,22 @@
3130
import jax_scaled_arithmetics as jsa
3231

3332

33+
def print_mean_std(name, v):
34+
# Always use np.float32, to avoid floating errors in descaling + stats.
35+
v = jsa.asarray(v, dtype=np.float32)
36+
m, s = np.mean(v), np.std(v)
37+
print(name, m, s)
38+
39+
3440
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
3541
return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
3642

3743

3844
def predict(params, inputs):
3945
activations = inputs
4046
for w, b in params[:-1]:
47+
jsa.ops.debug_callback(partial(print_mean_std, "W:"), w)
48+
4149
# Matmul + relu
4250
outputs = jnp.dot(activations, w) + b
4351
activations = jnp.maximum(outputs, 0)
@@ -66,7 +74,7 @@ def accuracy(params, batch):
6674
step_size = 0.001
6775
num_epochs = 10
6876
batch_size = 128
69-
training_dtype = np.float16
77+
training_dtype = np.float32
7078

7179
train_images, train_labels, test_images, test_labels = datasets.mnist()
7280
num_train = train_images.shape[0]
@@ -93,9 +101,13 @@ def update(params, batch):
93101
grads = grad(loss)(params, batch)
94102
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
95103

104+
num_batches = 2
105+
num_epochs = 2
96106
for epoch in range(num_epochs):
107+
print("EPOCH:", epoch)
97108
start_time = time.time()
98109
for _ in range(num_batches):
110+
print("BATCH...")
99111
batch = next(batches)
100112
# Scaled micro-batch + training dtype cast.
101113
batch = jsa.as_scaled_array(batch)
@@ -108,8 +120,8 @@ def update(params, batch):
108120

109121
# Evaluation in float32, for consistency.
110122
raw_params = jsa.asarray(params, dtype=np.float32)
111-
train_acc = accuracy(raw_params, (train_images, train_labels))
112-
test_acc = accuracy(raw_params, (test_images, test_labels))
113-
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
114-
print(f"Training set accuracy {train_acc:0.5f}")
115-
print(f"Test set accuracy {test_acc:0.5f}")
123+
# train_acc = accuracy(raw_params, (train_images, train_labels))
124+
# test_acc = accuracy(raw_params, (test_images, test_labels))
125+
# print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
126+
# print(f"Training set accuracy {train_acc:0.5f}")
127+
# print(f"Test set accuracy {test_acc:0.5f}")

0 commit comments

Comments
 (0)