diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index 7e78cc6..c3ae1ac 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -16,9 +16,8 @@ The primary aim here is simplicity and minimal dependencies. """ - - import time +from functools import partial import datasets import jax @@ -31,6 +30,15 @@ import jax_scaled_arithmetics as jsa +def print_mean_std(name, v): + data, scale = jsa.lax.get_data_scale(v) + # Always use np.float32, to avoid floating errors in descaling + stats. + v = jsa.asarray(data, dtype=np.float32) + m, s = np.mean(v), np.std(v) + # print(data) + print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") + + def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] @@ -38,20 +46,44 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): def predict(params, inputs): activations = inputs for w, b in params[:-1]: + jsa.ops.debug_callback(partial(print_mean_std, "W"), w) + jsa.ops.debug_callback(partial(print_mean_std, "B"), b) + (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w) + # Matmul + relu outputs = jnp.dot(activations, w) + b activations = jnp.maximum(outputs, 0) + jsa.ops.debug_callback(partial(print_mean_std, "Act"), activations) + # activations = jsa.ops.dynamic_rescale_l2_grad(activations) final_w, final_b = params[-1] - logits = jnp.dot(activations, final_w) + final_b - # Dynamic rescaling of the gradient, as logits gradient not properly scaled. + logits = jnp.dot(activations, final_w) + jsa.ops.debug_callback(partial(print_mean_std, "Logits0"), logits) + logits = logits + final_b + + jsa.ops.debug_callback(partial(print_mean_std, "Logits1"), logits) + (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + logits = jsa.ops.dynamic_rescale_l2_grad(logits) - return logits - logsumexp(logits, axis=1, keepdims=True) + # logits = logits.astype(np.float32) + (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + + logits = logits - logsumexp(logits, axis=1, keepdims=True) + jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits) + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + return logits def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) + # jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds) + loss = jnp.sum(preds * targets, axis=1) + # loss = jsa.ops.dynamic_rescale_l2(loss) + # jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss) + loss = -jnp.mean(loss) + # jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss) + return loss return -jnp.mean(jnp.sum(preds * targets, axis=1)) @@ -64,7 +96,7 @@ def accuracy(params, batch): if __name__ == "__main__": layer_sizes = [784, 1024, 1024, 10] - param_scale = 1.0 + param_scale = 2.0 step_size = 0.001 num_epochs = 10 batch_size = 128 @@ -88,18 +120,26 @@ def data_stream(): batches = data_stream() params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. - params = jsa.as_scaled_array(params, scale=scale_dtype(1)) + params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) - @jit + # @jit @jsa.autoscale def update(params, batch): grads = grad(loss)(params, batch) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] + return [ + (jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db)) + for (w, b), (dw, db) in zip(params, grads) + ] + num_batches = 1 + num_epochs = 1 for epoch in range(num_epochs): + # print("EPOCH:", epoch) start_time = time.time() for _ in range(num_batches): + # print("BATCH...") batch = next(batches) # Scaled micro-batch + training dtype cast. batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) @@ -111,9 +151,9 @@ def update(params, batch): epoch_time = time.time() - start_time # Evaluation in float32, for consistency. - raw_params = jsa.asarray(params, dtype=np.float32) - train_acc = accuracy(raw_params, (train_images, train_labels)) - test_acc = accuracy(raw_params, (test_images, test_labels)) - print(f"Epoch {epoch} in {epoch_time:0.2f} sec") - print(f"Training set accuracy {train_acc:0.5f}") - print(f"Test set accuracy {test_acc:0.5f}") + # raw_params = jsa.asarray(params, dtype=np.float32) + # train_acc = accuracy(raw_params, (train_images, train_labels)) + # test_acc = accuracy(raw_params, (test_images, test_labels)) + # print(f"Epoch {epoch} in {epoch_time:0.2f} sec") + # print(f"Training set accuracy {train_acc:0.5f}") + # print(f"Test set accuracy {test_acc:0.5f}") diff --git a/jax_scaled_arithmetics/ops/rescaling.py b/jax_scaled_arithmetics/ops/rescaling.py index cd7687b..35f6716 100644 --- a/jax_scaled_arithmetics/ops/rescaling.py +++ b/jax_scaled_arithmetics/ops/rescaling.py @@ -2,6 +2,8 @@ from functools import partial import jax + +# import jax.numpy as jnp import numpy as np from jax_scaled_arithmetics.core import ScaledArray, pow2_round @@ -48,7 +50,7 @@ def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray: data_sq = jax.lax.abs(data) axes = tuple(range(data.ndim)) # Get MAX norm + pow2 rounding. - norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) + norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) + np.float32(1e-3) norm = pow2_round(norm.astype(scale.dtype)) # Rebalancing based on norm. return rebalance(arr, norm) @@ -63,7 +65,7 @@ def dynamic_rescale_l1_base(arr: ScaledArray) -> ScaledArray: data_sq = jax.lax.abs(data.astype(np.float32)) axes = tuple(range(data.ndim)) # Get L1 norm + pow2 rounding. - norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size + norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size + np.float32(1e-3) norm = pow2_round(norm.astype(scale.dtype)) # Rebalancing based on norm. return rebalance(arr, norm)