|
17 | 17 | The primary aim here is simplicity and minimal dependencies. |
18 | 18 | """ |
19 | 19 | import time |
20 | | -from functools import partial |
21 | 20 |
|
22 | 21 | import datasets |
23 | 22 | import jax |
|
29 | 28 |
|
30 | 29 | import jax_scaled_arithmetics as jsa |
31 | 30 |
|
| 31 | +# from functools import partial |
| 32 | + |
32 | 33 |
|
33 | 34 | def print_mean_std(name, v): |
34 | 35 | data, scale = jsa.lax.get_data_scale(v) |
@@ -60,7 +61,8 @@ def predict(params, inputs): |
60 | 61 | # jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) |
61 | 62 | # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) |
62 | 63 |
|
63 | | - logits = jsa.ops.dynamic_rescale_l1_grad(logits) |
| 64 | + logits = jsa.ops.dynamic_rescale_l2_grad(logits) |
| 65 | + # logits = logits.astype(np.float32) |
64 | 66 | # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) |
65 | 67 |
|
66 | 68 | return logits - logsumexp(logits, axis=1, keepdims=True) |
@@ -110,7 +112,7 @@ def data_stream(): |
110 | 112 | @jsa.autoscale |
111 | 113 | def update(params, batch): |
112 | 114 | grads = grad(loss)(params, batch) |
113 | | - # return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] |
| 115 | + return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] |
114 | 116 | return [ |
115 | 117 | (jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db)) |
116 | 118 | for (w, b), (dw, db) in zip(params, grads) |
|
0 commit comments