|
1 | 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. |
2 | 2 | import chex |
| 3 | +import jax.numpy as jnp |
3 | 4 | import numpy as np |
4 | 5 | import numpy.testing as npt |
5 | 6 | from absl.testing import parameterized |
6 | 7 |
|
7 | | -from jax_scaled_arithmetics.core import pow2_round_down, pow2_round_up |
8 | | -from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa |
| 8 | +from jax_scaled_arithmetics.core import Array, pow2_round_down, pow2_round_up |
| 9 | +from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa, safe_div, safe_reciprocal |
9 | 10 |
|
10 | 11 |
|
11 | 12 | class Pow2RoundingUtilTests(chex.TestCase): |
@@ -54,3 +55,38 @@ def test__pow2_round_up__proper_rounding__multi_dtypes(self, val_exp, dtype): |
54 | 55 | assert pow2_val.dtype == val.dtype |
55 | 56 | assert type(pow2_val) in {type(val), np.ndarray} |
56 | 57 | npt.assert_equal(pow2_val, exp) |
| 58 | + |
| 59 | + |
| 60 | +class SafeDivOpTests(chex.TestCase): |
| 61 | + @parameterized.parameters( |
| 62 | + {"lhs": np.float16(0), "rhs": np.float16(0)}, |
| 63 | + {"lhs": np.float32(0), "rhs": np.float32(0)}, |
| 64 | + {"lhs": np.float16(2), "rhs": np.float16(0)}, |
| 65 | + {"lhs": np.float32(4), "rhs": np.float32(0)}, |
| 66 | + ) |
| 67 | + def test__safe_div__zero_div__numpy_inputs(self, lhs, rhs): |
| 68 | + out = safe_div(lhs, rhs) |
| 69 | + assert isinstance(out, (np.number, np.ndarray)) |
| 70 | + assert out.dtype == lhs.dtype |
| 71 | + npt.assert_equal(out, 0) |
| 72 | + |
| 73 | + @parameterized.parameters( |
| 74 | + {"lhs": np.float16(0), "rhs": jnp.float16(0)}, |
| 75 | + {"lhs": jnp.float32(0), "rhs": np.float32(0)}, |
| 76 | + {"lhs": jnp.float16(2), "rhs": np.float16(0)}, |
| 77 | + {"lhs": np.float32(4), "rhs": jnp.float32(0)}, |
| 78 | + ) |
| 79 | + def test__safe_div__zero_div__jax_inputs(self, lhs, rhs): |
| 80 | + out = safe_div(lhs, rhs) |
| 81 | + assert isinstance(out, Array) |
| 82 | + assert out.dtype == lhs.dtype |
| 83 | + npt.assert_almost_equal(out, 0) |
| 84 | + |
| 85 | + @parameterized.parameters( |
| 86 | + {"val": np.float16(0)}, |
| 87 | + {"val": jnp.float16(0)}, |
| 88 | + ) |
| 89 | + def test__safe_reciprocal__zero_div(self, val): |
| 90 | + out = safe_reciprocal(val) |
| 91 | + assert out.dtype == val.dtype |
| 92 | + npt.assert_almost_equal(out, 0) |
0 commit comments