Skip to content

Commit 621d85e

Browse files
authored
Fix MNIST training example with dynamic L2 rescale logits. (#76)
Training set accuracy 0.96670 Test set accuracy 0.93730 Note: test set accuracy can change slightly if using an `eps` in dynamic rescaling.
1 parent 23a74b1 commit 621d85e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def predict(params, inputs):
4444

4545
final_w, final_b = params[-1]
4646
logits = jnp.dot(activations, final_w) + final_b
47+
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
48+
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
4749
return logits - logsumexp(logits, axis=1, keepdims=True)
4850

4951

0 commit comments

Comments
 (0)