Skip to content

Commit 16bb84d

Browse files
committed
wip
1 parent 730ba81 commit 16bb84d

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def print_mean_std(name, v):
3535
# Always use np.float32, to avoid floating errors in descaling + stats.
3636
v = jsa.asarray(data, dtype=np.float32)
3737
m, s = np.mean(v), np.std(v)
38-
print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})")
38+
print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale.dtype})")
3939

4040

4141
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
@@ -45,40 +45,43 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
4545
def predict(params, inputs):
4646
activations = inputs
4747
for w, b in params[:-1]:
48-
# jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49-
# jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
50-
# (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
48+
jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49+
jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
50+
(w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
5151

5252
# Matmul + relu
5353
outputs = jnp.dot(activations, w) + b
5454
activations = jnp.maximum(outputs, 0)
55+
jsa.ops.debug_callback(partial(print_mean_std, "Act"), activations)
5556
# activations = jsa.ops.dynamic_rescale_l2_grad(activations)
5657

5758
final_w, final_b = params[-1]
58-
logits = jnp.dot(activations, final_w) + final_b
59+
logits = jnp.dot(activations, final_w)
60+
jsa.ops.debug_callback(partial(print_mean_std, "Logits0"), logits)
61+
logits = logits + final_b
5962

60-
jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
61-
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
63+
jsa.ops.debug_callback(partial(print_mean_std, "Logits1"), logits)
64+
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
6265

6366
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
6467
# logits = logits.astype(np.float32)
65-
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
68+
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
6669

6770
logits = logits - logsumexp(logits, axis=1, keepdims=True)
6871
jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits)
69-
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
72+
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
7073
return logits
7174

7275

7376
def loss(params, batch):
7477
inputs, targets = batch
7578
preds = predict(params, inputs)
76-
jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds)
79+
# jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds)
7780
loss = jnp.sum(preds * targets, axis=1)
7881
# loss = jsa.ops.dynamic_rescale_l2(loss)
79-
jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss)
82+
# jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss)
8083
loss = -jnp.mean(loss)
81-
jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss)
84+
# jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss)
8285
return loss
8386
return -jnp.mean(jnp.sum(preds * targets, axis=1))
8487

@@ -92,7 +95,7 @@ def accuracy(params, batch):
9295

9396
if __name__ == "__main__":
9497
layer_sizes = [784, 1024, 1024, 10]
95-
param_scale = 1.0
98+
param_scale = 2.0
9699
step_size = 0.001
97100
num_epochs = 10
98101
batch_size = 128
@@ -114,7 +117,7 @@ def data_stream():
114117
batches = data_stream()
115118
params = init_random_params(param_scale, layer_sizes)
116119
# Transform parameters to `ScaledArray` and proper dtype.
117-
params = jsa.as_scaled_array(params)
120+
params = jsa.as_scaled_array(params, scale=np.float32(2))
118121
params = jax.tree_map(lambda v: v.astype(training_dtype), params)
119122

120123
@jit
@@ -127,8 +130,8 @@ def update(params, batch):
127130
for (w, b), (dw, db) in zip(params, grads)
128131
]
129132

130-
# num_batches = 4
131-
# num_epochs = 2
133+
num_batches = 1
134+
num_epochs = 1
132135
for epoch in range(num_epochs):
133136
# print("EPOCH:", epoch)
134137
start_time = time.time()

0 commit comments

Comments
 (0)