Skip to content

Commit 730ba81

Browse files
committed
wip
1 parent 621d85e commit 730ba81

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
1717
The primary aim here is simplicity and minimal dependencies.
1818
"""
19-
20-
2119
import time
20+
from functools import partial
2221

2322
import datasets
2423
import jax
@@ -31,27 +30,56 @@
3130
import jax_scaled_arithmetics as jsa
3231

3332

33+
def print_mean_std(name, v):
34+
data, scale = jsa.lax.get_data_scale(v)
35+
# Always use np.float32, to avoid floating errors in descaling + stats.
36+
v = jsa.asarray(data, dtype=np.float32)
37+
m, s = np.mean(v), np.std(v)
38+
print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})")
39+
40+
3441
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
3542
return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
3643

3744

3845
def predict(params, inputs):
3946
activations = inputs
4047
for w, b in params[:-1]:
48+
# jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49+
# jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
50+
# (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
51+
4152
# Matmul + relu
4253
outputs = jnp.dot(activations, w) + b
4354
activations = jnp.maximum(outputs, 0)
55+
# activations = jsa.ops.dynamic_rescale_l2_grad(activations)
4456

4557
final_w, final_b = params[-1]
4658
logits = jnp.dot(activations, final_w) + final_b
47-
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
59+
60+
jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
61+
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
62+
4863
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
49-
return logits - logsumexp(logits, axis=1, keepdims=True)
64+
# logits = logits.astype(np.float32)
65+
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
66+
67+
logits = logits - logsumexp(logits, axis=1, keepdims=True)
68+
jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits)
69+
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
70+
return logits
5071

5172

5273
def loss(params, batch):
5374
inputs, targets = batch
5475
preds = predict(params, inputs)
76+
jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds)
77+
loss = jnp.sum(preds * targets, axis=1)
78+
# loss = jsa.ops.dynamic_rescale_l2(loss)
79+
jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss)
80+
loss = -jnp.mean(loss)
81+
jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss)
82+
return loss
5583
return -jnp.mean(jnp.sum(preds * targets, axis=1))
5684

5785

@@ -94,10 +122,18 @@ def data_stream():
94122
def update(params, batch):
95123
grads = grad(loss)(params, batch)
96124
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
125+
return [
126+
(jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db))
127+
for (w, b), (dw, db) in zip(params, grads)
128+
]
97129

130+
# num_batches = 4
131+
# num_epochs = 2
98132
for epoch in range(num_epochs):
133+
# print("EPOCH:", epoch)
99134
start_time = time.time()
100135
for _ in range(num_batches):
136+
# print("BATCH...")
101137
batch = next(batches)
102138
# Scaled micro-batch + training dtype cast.
103139
batch = jsa.as_scaled_array(batch)

jax_scaled_arithmetics/ops/rescaling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33

44
import jax
5+
6+
# import jax.numpy as jnp
57
import numpy as np
68

79
from jax_scaled_arithmetics.core import ScaledArray, pow2_round
@@ -48,7 +50,7 @@ def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray:
4850
data_sq = jax.lax.abs(data)
4951
axes = tuple(range(data.ndim))
5052
# Get MAX norm + pow2 rounding.
51-
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes)
53+
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) + np.float32(1e-3)
5254
norm = pow2_round(norm.astype(scale.dtype))
5355
# Rebalancing based on norm.
5456
return rebalance(arr, norm)
@@ -63,7 +65,7 @@ def dynamic_rescale_l1_base(arr: ScaledArray) -> ScaledArray:
6365
data_sq = jax.lax.abs(data.astype(np.float32))
6466
axes = tuple(range(data.ndim))
6567
# Get L1 norm + pow2 rounding.
66-
norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size
68+
norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size + np.float32(1e-3)
6769
norm = pow2_round(norm.astype(scale.dtype))
6870
# Rebalancing based on norm.
6971
return rebalance(arr, norm)

0 commit comments

Comments
 (0)