From 72ed8f5ca40f8929b0c4e10b3a4b2d67fcead061 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 5 Jan 2024 11:09:32 +0000 Subject: [PATCH 1/4] wip --- .../mnist/mnist_classifier_from_scratch.py | 44 +++++++++++++++++-- jax_scaled_arithmetics/ops/rescaling.py | 6 ++- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index 7e78cc6..d335204 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -16,9 +16,8 @@ The primary aim here is simplicity and minimal dependencies. """ - - import time +from functools import partial import datasets import jax @@ -31,6 +30,14 @@ import jax_scaled_arithmetics as jsa +def print_mean_std(name, v): + data, scale = jsa.lax.get_data_scale(v) + # Always use np.float32, to avoid floating errors in descaling + stats. + v = jsa.asarray(data, dtype=np.float32) + m, s = np.mean(v), np.std(v) + print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") + + def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] @@ -38,20 +45,41 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): def predict(params, inputs): activations = inputs for w, b in params[:-1]: + # jsa.ops.debug_callback(partial(print_mean_std, "W"), w) + # jsa.ops.debug_callback(partial(print_mean_std, "B"), b) + # (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w) + # Matmul + relu outputs = jnp.dot(activations, w) + b activations = jnp.maximum(outputs, 0) + # activations = jsa.ops.dynamic_rescale_l2_grad(activations) final_w, final_b = params[-1] logits = jnp.dot(activations, final_w) + final_b - # Dynamic rescaling of the gradient, as logits gradient not properly scaled. + + jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + logits = jsa.ops.dynamic_rescale_l2_grad(logits) - return logits - logsumexp(logits, axis=1, keepdims=True) + # logits = logits.astype(np.float32) + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + + logits = logits - logsumexp(logits, axis=1, keepdims=True) + jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits) + (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + return logits def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) + jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds) + loss = jnp.sum(preds * targets, axis=1) + # loss = jsa.ops.dynamic_rescale_l2(loss) + jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss) + loss = -jnp.mean(loss) + jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss) + return loss return -jnp.mean(jnp.sum(preds * targets, axis=1)) @@ -96,10 +124,18 @@ def data_stream(): def update(params, batch): grads = grad(loss)(params, batch) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] + return [ + (jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db)) + for (w, b), (dw, db) in zip(params, grads) + ] + # num_batches = 4 + # num_epochs = 2 for epoch in range(num_epochs): + # print("EPOCH:", epoch) start_time = time.time() for _ in range(num_batches): + # print("BATCH...") batch = next(batches) # Scaled micro-batch + training dtype cast. batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) diff --git a/jax_scaled_arithmetics/ops/rescaling.py b/jax_scaled_arithmetics/ops/rescaling.py index cd7687b..35f6716 100644 --- a/jax_scaled_arithmetics/ops/rescaling.py +++ b/jax_scaled_arithmetics/ops/rescaling.py @@ -2,6 +2,8 @@ from functools import partial import jax + +# import jax.numpy as jnp import numpy as np from jax_scaled_arithmetics.core import ScaledArray, pow2_round @@ -48,7 +50,7 @@ def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray: data_sq = jax.lax.abs(data) axes = tuple(range(data.ndim)) # Get MAX norm + pow2 rounding. - norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) + norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) + np.float32(1e-3) norm = pow2_round(norm.astype(scale.dtype)) # Rebalancing based on norm. return rebalance(arr, norm) @@ -63,7 +65,7 @@ def dynamic_rescale_l1_base(arr: ScaledArray) -> ScaledArray: data_sq = jax.lax.abs(data.astype(np.float32)) axes = tuple(range(data.ndim)) # Get L1 norm + pow2 rounding. - norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size + norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size + np.float32(1e-3) norm = pow2_round(norm.astype(scale.dtype)) # Rebalancing based on norm. return rebalance(arr, norm) From 31d1ff4b7f83a3675cacc855130b9030e3d60a8e Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 9 Jan 2024 17:27:45 +0000 Subject: [PATCH 2/4] wip --- .../mnist/mnist_classifier_from_scratch.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index d335204..a3dac38 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -35,7 +35,7 @@ def print_mean_std(name, v): # Always use np.float32, to avoid floating errors in descaling + stats. v = jsa.asarray(data, dtype=np.float32) m, s = np.mean(v), np.std(v) - print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") + print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale.dtype})") def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): @@ -45,40 +45,43 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): def predict(params, inputs): activations = inputs for w, b in params[:-1]: - # jsa.ops.debug_callback(partial(print_mean_std, "W"), w) - # jsa.ops.debug_callback(partial(print_mean_std, "B"), b) - # (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w) + jsa.ops.debug_callback(partial(print_mean_std, "W"), w) + jsa.ops.debug_callback(partial(print_mean_std, "B"), b) + (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w) # Matmul + relu outputs = jnp.dot(activations, w) + b activations = jnp.maximum(outputs, 0) + jsa.ops.debug_callback(partial(print_mean_std, "Act"), activations) # activations = jsa.ops.dynamic_rescale_l2_grad(activations) final_w, final_b = params[-1] - logits = jnp.dot(activations, final_w) + final_b + logits = jnp.dot(activations, final_w) + jsa.ops.debug_callback(partial(print_mean_std, "Logits0"), logits) + logits = logits + final_b - jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) - # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + jsa.ops.debug_callback(partial(print_mean_std, "Logits1"), logits) + (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) logits = jsa.ops.dynamic_rescale_l2_grad(logits) # logits = logits.astype(np.float32) - # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) logits = logits - logsumexp(logits, axis=1, keepdims=True) jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits) - (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) return logits def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) - jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds) + # jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds) loss = jnp.sum(preds * targets, axis=1) # loss = jsa.ops.dynamic_rescale_l2(loss) - jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss) + # jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss) loss = -jnp.mean(loss) - jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss) + # jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss) return loss return -jnp.mean(jnp.sum(preds * targets, axis=1)) @@ -92,7 +95,7 @@ def accuracy(params, batch): if __name__ == "__main__": layer_sizes = [784, 1024, 1024, 10] - param_scale = 1.0 + param_scale = 2.0 step_size = 0.001 num_epochs = 10 batch_size = 128 @@ -129,8 +132,8 @@ def update(params, batch): for (w, b), (dw, db) in zip(params, grads) ] - # num_batches = 4 - # num_epochs = 2 + num_batches = 1 + num_epochs = 1 for epoch in range(num_epochs): # print("EPOCH:", epoch) start_time = time.time() From b1cdf1db45bfb43cf5577eb0380f8e9e84f7d154 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 9 Jan 2024 21:06:11 +0000 Subject: [PATCH 3/4] wip --- experiments/mnist/mnist_classifier_from_scratch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index a3dac38..f55177f 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -35,7 +35,7 @@ def print_mean_std(name, v): # Always use np.float32, to avoid floating errors in descaling + stats. v = jsa.asarray(data, dtype=np.float32) m, s = np.mean(v), np.std(v) - print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale.dtype})") + print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): From b42d63a057819393f901ef06b8467c6b3f44a285 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 11 Jan 2024 17:57:34 +0000 Subject: [PATCH 4/4] wip --- .../mnist/mnist_classifier_from_scratch.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index f55177f..c3ae1ac 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -35,6 +35,7 @@ def print_mean_std(name, v): # Always use np.float32, to avoid floating errors in descaling + stats. v = jsa.asarray(data, dtype=np.float32) m, s = np.mean(v), np.std(v) + # print(data) print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") @@ -119,10 +120,10 @@ def data_stream(): batches = data_stream() params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. - params = jsa.as_scaled_array(params, scale=scale_dtype(1)) + params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) - @jit + # @jit @jsa.autoscale def update(params, batch): grads = grad(loss)(params, batch) @@ -150,9 +151,9 @@ def update(params, batch): epoch_time = time.time() - start_time # Evaluation in float32, for consistency. - raw_params = jsa.asarray(params, dtype=np.float32) - train_acc = accuracy(raw_params, (train_images, train_labels)) - test_acc = accuracy(raw_params, (test_images, test_labels)) - print(f"Epoch {epoch} in {epoch_time:0.2f} sec") - print(f"Training set accuracy {train_acc:0.5f}") - print(f"Test set accuracy {test_acc:0.5f}") + # raw_params = jsa.asarray(params, dtype=np.float32) + # train_acc = accuracy(raw_params, (train_images, train_labels)) + # test_acc = accuracy(raw_params, (test_images, test_labels)) + # print(f"Epoch {epoch} in {epoch_time:0.2f} sec") + # print(f"Training set accuracy {train_acc:0.5f}") + # print(f"Test set accuracy {test_acc:0.5f}")