Skip to content

Commit 23a74b1

Browse files
authored
Improve robustness of scaled translation with zero scale inputs. (#75)
Adding `safe_div` and `safe_reciprocal` in order to avoid generating `inf` and `nan` when input scales are zero.
1 parent 8949497 commit 23a74b1

File tree

10 files changed

+128
-13
lines changed

10 files changed

+128
-13
lines changed

jax_scaled_arithmetics/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@
2323
register_scaled_op,
2424
)
2525
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
26-
from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up # noqa: F401
26+
from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up, safe_div, safe_reciprocal # noqa: F401

jax_scaled_arithmetics/core/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from enum import IntEnum
33
from typing import Any, Dict
44

5+
import jax
6+
import jax.numpy as jnp
57
import numpy as np
68
from numpy.typing import NDArray
79

@@ -77,3 +79,26 @@ def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array:
7779
elif mode == Pow2RoundMode.UP:
7880
return pow2_round_up(val)
7981
raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.")
82+
83+
84+
def safe_div(lhs: Array, rhs: Array) -> Array:
85+
"""Safe (scalar) div: if rhs is zero, returns zero."""
86+
assert lhs.shape == ()
87+
assert rhs.shape == ()
88+
# assert lhs.dtype == rhs.dtype
89+
# Numpy inputs => direct computation.
90+
is_npy_inputs = isinstance(lhs, (np.number, np.ndarray)) and isinstance(rhs, (np.number, np.ndarray))
91+
if is_npy_inputs:
92+
return np.divide(lhs, rhs, out=np.array(0, dtype=rhs.dtype), where=rhs != 0)
93+
# JAX general implementation.
94+
return jax.lax.select(rhs == 0, rhs, jnp.divide(lhs, rhs))
95+
96+
97+
def safe_reciprocal(val: Array) -> Array:
98+
"""Safe (scalar) reciprocal: if val is zero, returns zero."""
99+
assert val.shape == ()
100+
# Numpy inputs => direct computation.
101+
if isinstance(val, (np.number, np.ndarray)):
102+
return np.reciprocal(val, out=np.array(0, dtype=val.dtype), where=val != 0)
103+
# JAX general implementation.
104+
return jax.lax.select(val == 0, val, jax.lax.reciprocal(val))

jax_scaled_arithmetics/lax/base_scaling_primitives.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
asarray,
1515
is_static_one_scalar,
1616
register_scaled_op,
17+
safe_div,
18+
safe_reciprocal,
1719
)
1820

1921
set_scaling_p = core.Primitive("set_scaling_p")
@@ -24,6 +26,9 @@
2426
2527
In JAX Scaled Arithmetics/AutoScale mode, it will rebalance the data term to
2628
return a ScaledArray semantically equivalent.
29+
30+
NOTE: there is specific corner case of passing zero to `set_scaling`. In this
31+
situation, the tensor is assumed to be zeroed by the user.
2732
"""
2833

2934

@@ -46,7 +51,7 @@ def set_scaling_impl(values: Array, scale: Array) -> Array:
4651
# Automatic promotion should ensure we always get a scaled scalar here!
4752
scale_value = asarray(scale)
4853
# Rebalancing data tensor using the new scale.
49-
data = values.data * (values.scale / scale_value).astype(values.dtype)
54+
data = values.data * safe_div(values.scale, scale_value).astype(values.dtype)
5055
return ScaledArray(data, scale_value)
5156
# No scaled array => no-op.
5257
return values
@@ -75,9 +80,9 @@ def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray:
7580
scale_value = asarray(scale)
7681
if not isinstance(values, ScaledArray):
7782
# Simple case, with no pre-existing scale.
78-
return ScaledArray(values / scale_value.astype(values.dtype), scale_value)
83+
return ScaledArray(values * safe_reciprocal(scale_value.astype(values.dtype)), scale_value)
7984
# Rebalancing data tensor using the new scale.
80-
data = values.data * (values.scale / scale_value).astype(values.dtype)
85+
data = values.data * safe_div(values.scale, scale_value).astype(values.dtype)
8186
return ScaledArray(data, scale_value)
8287

8388

jax_scaled_arithmetics/lax/scaled_ops_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
as_scaled_array,
1717
get_scale_dtype,
1818
is_static_zero,
19+
safe_div,
1920
)
2021

2122
from .base_scaling_primitives import scaled_set_scaling
@@ -89,7 +90,7 @@ def scaled_concatenate(operands: Sequence[ScaledArray], dimension: int) -> Scale
8990
# TODO: explore alternative strategies?
9091
outdtype = operands[0].dtype
9192
scale_max = jnp.max(scales)
92-
datas = [v.data * (v.scale / scale_max).astype(outdtype) for v in operands]
93+
datas = [v.data * safe_div(v.scale, scale_max).astype(outdtype) for v in operands]
9394
data_concat = lax.concatenate(datas, dimension=dimension)
9495
return ScaledArray(data_concat, scale_max)
9596

@@ -219,8 +220,8 @@ def scaled_minmax(prim: jax.core.Primitive, lhs: ScaledArray, rhs: ScaledArray)
219220
output_scale = lax.max(lhs.scale, rhs.scale)
220221
# TODO: isolate this "binary" rescale logic into separate function.
221222
outdtype = jnp.promote_types(lhs.dtype, rhs.dtype)
222-
lhs_rescale = (lhs.scale / output_scale).astype(outdtype)
223-
rhs_rescale = (rhs.scale / output_scale).astype(outdtype)
223+
lhs_rescale = safe_div(lhs.scale, output_scale).astype(outdtype)
224+
rhs_rescale = safe_div(rhs.scale, output_scale).astype(outdtype)
224225
output_data = prim.bind(lhs_rescale * lhs.data, rhs_rescale * rhs.data)
225226
return ScaledArray(output_data, output_scale)
226227

jax_scaled_arithmetics/lax/scaled_ops_l2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_autoscale_config,
1515
pow2_round,
1616
register_scaled_op,
17+
safe_div,
1718
)
1819

1920
from .scaled_ops_common import check_scalar_scales, promote_scale_types
@@ -32,14 +33,14 @@ def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArra
3233
# More stable than direct L2 norm, to avoid scale overflow.
3334
ABscale_max = lax.max(A.scale, B.scale)
3435
ABscale_min = lax.min(A.scale, B.scale)
35-
ABscale_ratio = ABscale_min / ABscale_max
36+
ABscale_ratio = safe_div(ABscale_min, ABscale_max)
3637
output_scale = ABscale_max * lax.sqrt(1 + ABscale_ratio * ABscale_ratio)
3738
# Transform back to power-of-2
3839
output_scale = pow2_round(output_scale, pow2_rounding_mode)
3940
# Output dtype => promotion of A and B dtypes.
4041
outdtype = jnp.promote_types(A.dtype, B.dtype)
41-
Arescale = (A.scale / output_scale).astype(outdtype)
42-
Brescale = (B.scale / output_scale).astype(outdtype)
42+
Arescale = safe_div(A.scale, output_scale).astype(outdtype)
43+
Brescale = safe_div(B.scale, output_scale).astype(outdtype)
4344
# check correct type output if mismatch between data and scale precision
4445
output_data = binary_op(Arescale * A.data, Brescale * B.data)
4546
return ScaledArray(output_data, output_scale)

jax_scaled_arithmetics/ops/rescaling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray:
7878
data_sq = jax.lax.integer_pow(data.astype(np.float32), 2)
7979
axes = tuple(range(data.ndim))
8080
# Get L2 norm + pow2 rounding.
81-
norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes)) / data.size
81+
norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size)
8282
norm = pow2_round(norm.astype(scale.dtype))
8383
# Rebalancing based on norm.
8484
return rebalance(arr, norm)

tests/core/test_utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
22
import chex
3+
import jax.numpy as jnp
34
import numpy as np
45
import numpy.testing as npt
56
from absl.testing import parameterized
67

7-
from jax_scaled_arithmetics.core import pow2_round_down, pow2_round_up
8-
from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa
8+
from jax_scaled_arithmetics.core import Array, pow2_round_down, pow2_round_up
9+
from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa, safe_div, safe_reciprocal
910

1011

1112
class Pow2RoundingUtilTests(chex.TestCase):
@@ -54,3 +55,38 @@ def test__pow2_round_up__proper_rounding__multi_dtypes(self, val_exp, dtype):
5455
assert pow2_val.dtype == val.dtype
5556
assert type(pow2_val) in {type(val), np.ndarray}
5657
npt.assert_equal(pow2_val, exp)
58+
59+
60+
class SafeDivOpTests(chex.TestCase):
61+
@parameterized.parameters(
62+
{"lhs": np.float16(0), "rhs": np.float16(0)},
63+
{"lhs": np.float32(0), "rhs": np.float32(0)},
64+
{"lhs": np.float16(2), "rhs": np.float16(0)},
65+
{"lhs": np.float32(4), "rhs": np.float32(0)},
66+
)
67+
def test__safe_div__zero_div__numpy_inputs(self, lhs, rhs):
68+
out = safe_div(lhs, rhs)
69+
assert isinstance(out, (np.number, np.ndarray))
70+
assert out.dtype == lhs.dtype
71+
npt.assert_equal(out, 0)
72+
73+
@parameterized.parameters(
74+
{"lhs": np.float16(0), "rhs": jnp.float16(0)},
75+
{"lhs": jnp.float32(0), "rhs": np.float32(0)},
76+
{"lhs": jnp.float16(2), "rhs": np.float16(0)},
77+
{"lhs": np.float32(4), "rhs": jnp.float32(0)},
78+
)
79+
def test__safe_div__zero_div__jax_inputs(self, lhs, rhs):
80+
out = safe_div(lhs, rhs)
81+
assert isinstance(out, Array)
82+
assert out.dtype == lhs.dtype
83+
npt.assert_almost_equal(out, 0)
84+
85+
@parameterized.parameters(
86+
{"val": np.float16(0)},
87+
{"val": jnp.float16(0)},
88+
)
89+
def test__safe_reciprocal__zero_div(self, val):
90+
out = safe_reciprocal(val)
91+
assert out.dtype == val.dtype
92+
npt.assert_almost_equal(out, 0)

tests/lax/test_base_scaling_primitives.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,22 @@ def test__set_scaling_primitive__scaled_array__eager_mode(self, npapi):
3131
npt.assert_equal(output.scale, npapi.float16(4))
3232
npt.assert_array_equal(output, values)
3333

34+
@chex.variants(with_jit=True, without_jit=True)
35+
@parameterized.parameters(
36+
{"arr": np.array([-1.0, 2.0], dtype=np.float32)},
37+
{"arr": scaled_array([-1.0, 2.0], 1.0, dtype=np.float16)},
38+
{"arr": scaled_array([-1.0, 2.0], 0.0, dtype=np.float32)},
39+
)
40+
def test__set_scaling_primitive__zero_scaling(self, arr):
41+
def fn(arr, scale):
42+
return set_scaling(arr, scale)
43+
44+
scale = np.array(0, dtype=arr.dtype)
45+
out = self.variant(autoscale(fn))(arr, scale)
46+
assert isinstance(out, ScaledArray)
47+
npt.assert_array_almost_equal(out.scale, 0)
48+
npt.assert_array_almost_equal(out.data, 0)
49+
3450
@chex.variants(with_jit=True, without_jit=True)
3551
def test__set_scaling_primitive__proper_result_without_autoscale(self):
3652
def fn(arr, scale):

tests/lax/test_scaled_ops_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ def test__scaled_concatenate__proper_scaling(self):
8181
npt.assert_array_equal(z.scale, y.scale)
8282
npt.assert_array_almost_equal(z, lax.concatenate([np.asarray(x), np.asarray(y)], dimension=0))
8383

84+
def test__scaled_concatenate__zero_input_scales(self):
85+
x = scaled_array(self.rs.rand(2, 3), 0.0, dtype=np.float16)
86+
y = scaled_array(self.rs.rand(5, 3), 0.0, dtype=np.float16)
87+
z = scaled_concatenate([x, y], dimension=0)
88+
assert isinstance(z, ScaledArray)
89+
assert z.dtype == x.dtype
90+
npt.assert_array_equal(z.scale, 0)
91+
npt.assert_array_almost_equal(z, lax.concatenate([np.asarray(x), np.asarray(y)], dimension=0))
92+
8493
def test__scaled_convert_element_type__proper_scaling(self):
8594
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
8695
z = scaled_convert_element_type(x, new_dtype=np.float16)

tests/lax/test_scaled_ops_l2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,28 @@ def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype, sdtyp
9292
assert z.scale.dtype == sdtype
9393
npt.assert_array_almost_equal(z, expected_z, decimal=4)
9494

95+
@chex.variants(with_jit=True, without_jit=True)
96+
@parameterized.product(
97+
prim=[lax.add_p, lax.sub_p, lax.mul_p, lax.min_p, lax.max_p],
98+
dtype=[np.float16, np.float32],
99+
sdtype=[np.float16, np.float32],
100+
)
101+
def test__scaled_binary_op__proper_zero_scale_handling(self, prim, dtype, sdtype):
102+
scaled_op, _ = find_registered_scaled_op(prim)
103+
# NOTE: direct construction to avoid weirdity between NumPy array and scalar!
104+
x = ScaledArray(np.array([-1.0, 2.0], dtype), sdtype(0.0))
105+
y = ScaledArray(np.array([1.5, 4.5], dtype), sdtype(0.0))
106+
# Ensure scale factor has the right dtype.
107+
assert x.scale.dtype == sdtype
108+
assert y.scale.dtype == sdtype
109+
110+
z = self.variant(scaled_op)(x, y)
111+
expected_z = prim.bind(np.asarray(x), np.asarray(y))
112+
113+
assert z.dtype == x.dtype
114+
assert z.scale.dtype == sdtype
115+
npt.assert_array_almost_equal(z, expected_z, decimal=4)
116+
95117
@parameterized.parameters(
96118
{"prim": lax.add_p},
97119
{"prim": lax.sub_p},

0 commit comments

Comments
 (0)