|
16 | 16 |
|
17 | 17 | The primary aim here is simplicity and minimal dependencies. |
18 | 18 | """ |
19 | | - |
20 | | - |
21 | 19 | import time |
| 20 | +from functools import partial |
22 | 21 |
|
23 | 22 | import datasets |
24 | 23 | import jax |
|
31 | 30 | import jax_scaled_arithmetics as jsa |
32 | 31 |
|
33 | 32 |
|
| 33 | +def print_mean_std(name, v): |
| 34 | + data, scale = jsa.lax.get_data_scale(v) |
| 35 | + # Always use np.float32, to avoid floating errors in descaling + stats. |
| 36 | + v = jsa.asarray(data, dtype=np.float32) |
| 37 | + m, s = np.mean(v), np.std(v) |
| 38 | + print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") |
| 39 | + |
| 40 | + |
34 | 41 | def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): |
35 | 42 | return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] |
36 | 43 |
|
37 | 44 |
|
38 | 45 | def predict(params, inputs): |
39 | 46 | activations = inputs |
40 | 47 | for w, b in params[:-1]: |
| 48 | + # jsa.ops.debug_callback(partial(print_mean_std, "W"), w) |
| 49 | + # jsa.ops.debug_callback(partial(print_mean_std, "B"), b) |
| 50 | + # (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w) |
| 51 | + |
41 | 52 | # Matmul + relu |
42 | 53 | outputs = jnp.dot(activations, w) + b |
43 | 54 | activations = jnp.maximum(outputs, 0) |
| 55 | + # activations = jsa.ops.dynamic_rescale_l2_grad(activations) |
44 | 56 |
|
45 | 57 | final_w, final_b = params[-1] |
46 | 58 | logits = jnp.dot(activations, final_w) + final_b |
47 | | - # Dynamic rescaling of the gradient, as logits gradient not properly scaled. |
| 59 | + |
| 60 | + jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) |
| 61 | + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) |
| 62 | + |
48 | 63 | logits = jsa.ops.dynamic_rescale_l2_grad(logits) |
49 | | - return logits - logsumexp(logits, axis=1, keepdims=True) |
| 64 | + # logits = logits.astype(np.float32) |
| 65 | + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) |
| 66 | + |
| 67 | + logits = logits - logsumexp(logits, axis=1, keepdims=True) |
| 68 | + jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits) |
| 69 | + (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) |
| 70 | + return logits |
50 | 71 |
|
51 | 72 |
|
52 | 73 | def loss(params, batch): |
53 | 74 | inputs, targets = batch |
54 | 75 | preds = predict(params, inputs) |
| 76 | + jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds) |
| 77 | + loss = jnp.sum(preds * targets, axis=1) |
| 78 | + # loss = jsa.ops.dynamic_rescale_l2(loss) |
| 79 | + jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss) |
| 80 | + loss = -jnp.mean(loss) |
| 81 | + jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss) |
| 82 | + return loss |
55 | 83 | return -jnp.mean(jnp.sum(preds * targets, axis=1)) |
56 | 84 |
|
57 | 85 |
|
@@ -94,10 +122,18 @@ def data_stream(): |
94 | 122 | def update(params, batch): |
95 | 123 | grads = grad(loss)(params, batch) |
96 | 124 | return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] |
| 125 | + return [ |
| 126 | + (jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db)) |
| 127 | + for (w, b), (dw, db) in zip(params, grads) |
| 128 | + ] |
97 | 129 |
|
| 130 | + # num_batches = 4 |
| 131 | + # num_epochs = 2 |
98 | 132 | for epoch in range(num_epochs): |
| 133 | + # print("EPOCH:", epoch) |
99 | 134 | start_time = time.time() |
100 | 135 | for _ in range(num_batches): |
| 136 | + # print("BATCH...") |
101 | 137 | batch = next(batches) |
102 | 138 | # Scaled micro-batch + training dtype cast. |
103 | 139 | batch = jsa.as_scaled_array(batch) |
|
0 commit comments