Skip to content

Commit 4fc71b2

Browse files
committed
wip
1 parent 7484879 commit 4fc71b2

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131

3232

3333
def print_mean_std(name, v):
34+
_, scale = jsa.lax.get_data_scale(v)
3435
# Always use np.float32, to avoid floating errors in descaling + stats.
3536
v = jsa.asarray(v, dtype=np.float32)
3637
m, s = np.mean(v), np.std(v)
37-
print(name, m, s)
38+
print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})")
3839

3940

4041
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
@@ -44,8 +45,8 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
4445
def predict(params, inputs):
4546
activations = inputs
4647
for w, b in params[:-1]:
47-
jsa.ops.debug_callback(partial(print_mean_std, "W:"), w)
48-
(w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WG:"), w)
48+
jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49+
(w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
4950

5051
# Matmul + relu
5152
outputs = jnp.dot(activations, w) + b
@@ -54,8 +55,8 @@ def predict(params, inputs):
5455
final_w, final_b = params[-1]
5556
logits = jnp.dot(activations, final_w) + final_b
5657

57-
jsa.ops.debug_callback(partial(print_mean_std, "LOGITS:"), logits)
58-
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LOGITSG:"), logits)
58+
jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
59+
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
5960

6061
return logits - logsumexp(logits, axis=1, keepdims=True)
6162

0 commit comments

Comments
 (0)