Skip to content

Commit 7484879

Browse files
committed
wip
1 parent 54a9ff2 commit 7484879

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,18 @@ def predict(params, inputs):
4545
activations = inputs
4646
for w, b in params[:-1]:
4747
jsa.ops.debug_callback(partial(print_mean_std, "W:"), w)
48+
(w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WG:"), w)
4849

4950
# Matmul + relu
5051
outputs = jnp.dot(activations, w) + b
5152
activations = jnp.maximum(outputs, 0)
5253

5354
final_w, final_b = params[-1]
5455
logits = jnp.dot(activations, final_w) + final_b
56+
57+
jsa.ops.debug_callback(partial(print_mean_std, "LOGITS:"), logits)
58+
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LOGITSG:"), logits)
59+
5560
return logits - logsumexp(logits, axis=1, keepdims=True)
5661

5762

0 commit comments

Comments
 (0)