From f003c39a89d511ee2a942bb742eadfe364f12ff3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Nov 2025 13:28:53 +0100 Subject: [PATCH 1/2] Remove predefined inplace Elemwise Ops and redundant tests --- .github/workflows/test.yml | 12 +- pytensor/tensor/elemwise.py | 17 +- pytensor/tensor/inplace.py | 427 ------------------------- tests/link/numba/test_elemwise.py | 8 +- tests/tensor/rewriting/test_math.py | 79 +++-- tests/tensor/test_blas.py | 10 +- tests/tensor/test_elemwise.py | 93 +++++- tests/tensor/test_inplace.py | 465 ---------------------------- tests/tensor/test_math_scipy.py | 232 +------------- tests/tensor/utils.py | 35 ++- 10 files changed, 172 insertions(+), 1206 deletions(-) delete mode 100644 pytensor/tensor/inplace.py delete mode 100644 tests/tensor/test_inplace.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2c44c4d44f..a7ec2f6e64 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -84,13 +84,13 @@ jobs: install-mlx: [0] install-xarray: [0] part: - - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor" + - "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor" - "tests/scan" - - "tests/tensor --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/conv --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py" - - "tests/tensor/rewriting" + - "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting" + - "tests/tensor/test_basic.py tests/tensor/test_elemwise.py" - "tests/tensor/test_math.py" - - "tests/tensor/test_basic.py tests/tensor/test_inplace.py tests/tensor/conv" - - "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" + - "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv" + - "tests/tensor/rewriting" exclude: - python-version: "3.11" fast-compile: 1 @@ -167,7 +167,7 @@ jobs: install-numba: 0 install-jax: 0 install-torch: 0 - part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" + part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py" steps: - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 4388c110c8..f1d8bc09df 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -20,7 +20,7 @@ from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type from pytensor.scalar.basic import identity as scalar_identity -from pytensor.scalar.basic import int64, transfer_type, upcast +from pytensor.scalar.basic import int64, upcast from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable @@ -1634,17 +1634,12 @@ def construct(symbol): symbolname = symbolname or symbol.__name__ if symbolname.endswith("_inplace"): - base_symbol_name = symbolname[: -len("_inplace")] - scalar_op = getattr(scalar, base_symbol_name) - inplace_scalar_op = scalar_op.__class__(transfer_type(0)) - rval = Elemwise( - inplace_scalar_op, - {0: 0}, - nfunc_spec=(nfunc and (nfunc, nin, nout)), + raise ValueError( + "Creation of automatic inplace elemwise operations deprecated" ) - else: - scalar_op = getattr(scalar, symbolname) - rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout))) + + scalar_op = getattr(scalar, symbolname) + rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout))) if getattr(symbol, "__doc__"): rval.__doc__ = symbol.__doc__ diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py deleted file mode 100644 index 8c0df0e2e0..0000000000 --- a/pytensor/tensor/inplace.py +++ /dev/null @@ -1,427 +0,0 @@ -from pytensor import printing -from pytensor.printing import pprint -from pytensor.tensor.elemwise import scalar_elemwise - - -@scalar_elemwise -def lt_inplace(a, b): - """a < b (inplace on a)""" - - -@scalar_elemwise -def gt_inplace(a, b): - """a > b (inplace on a)""" - - -@scalar_elemwise -def le_inplace(a, b): - """a <= b (inplace on a)""" - - -@scalar_elemwise -def ge_inplace(a, b): - """a >= b (inplace on a)""" - - -@scalar_elemwise -def eq_inplace(a, b): - """a == b (inplace on a)""" - - -@scalar_elemwise -def neq_inplace(a, b): - """a != b (inplace on a)""" - - -@scalar_elemwise -def and__inplace(a, b): - """bitwise a & b (inplace on a)""" - - -@scalar_elemwise -def or__inplace(a, b): - """bitwise a | b (inplace on a)""" - - -@scalar_elemwise -def xor_inplace(a, b): - """bitwise a ^ b (inplace on a)""" - - -@scalar_elemwise -def invert_inplace(a): - """bitwise ~a (inplace on a)""" - - -@scalar_elemwise -def abs_inplace(a): - """|`a`| (inplace on `a`)""" - - -@scalar_elemwise -def exp_inplace(a): - """e^`a` (inplace on `a`)""" - - -@scalar_elemwise -def exp2_inplace(a): - """2^`a` (inplace on `a`)""" - - -@scalar_elemwise -def expm1_inplace(a): - """e^`a` - 1 (inplace on `a`)""" - - -@scalar_elemwise -def neg_inplace(a): - """-a (inplace on a)""" - - -@scalar_elemwise -def reciprocal_inplace(a): - """1.0/a (inplace on a)""" - - -@scalar_elemwise -def log_inplace(a): - """base e logarithm of a (inplace on a)""" - - -@scalar_elemwise -def log1p_inplace(a): - """log(1+a)""" - - -@scalar_elemwise -def log2_inplace(a): - """base 2 logarithm of a (inplace on a)""" - - -@scalar_elemwise -def log10_inplace(a): - """base 10 logarithm of a (inplace on a)""" - - -@scalar_elemwise -def sign_inplace(a): - """sign of `a` (inplace on `a`)""" - - -@scalar_elemwise -def ceil_inplace(a): - """ceil of `a` (inplace on `a`)""" - - -@scalar_elemwise -def floor_inplace(a): - """floor of `a` (inplace on `a`)""" - - -@scalar_elemwise -def trunc_inplace(a): - """trunc of `a` (inplace on `a`)""" - - -@scalar_elemwise -def round_half_to_even_inplace(a): - """round_half_to_even_inplace(a) (inplace on `a`)""" - - -@scalar_elemwise -def round_half_away_from_zero_inplace(a): - """round_half_away_from_zero_inplace(a) (inplace on `a`)""" - - -@scalar_elemwise -def sqr_inplace(a): - """square of `a` (inplace on `a`)""" - - -@scalar_elemwise -def sqrt_inplace(a): - """square root of `a` (inplace on `a`)""" - - -@scalar_elemwise -def deg2rad_inplace(a): - """convert degree `a` to radian(inplace on `a`)""" - - -@scalar_elemwise -def rad2deg_inplace(a): - """convert radian `a` to degree(inplace on `a`)""" - - -@scalar_elemwise -def cos_inplace(a): - """cosine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arccos_inplace(a): - """arccosine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def sin_inplace(a): - """sine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arcsin_inplace(a): - """arcsine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def tan_inplace(a): - """tangent of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arctan_inplace(a): - """arctangent of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arctan2_inplace(a, b): - """arctangent of `a` / `b` (inplace on `a`)""" - - -@scalar_elemwise -def cosh_inplace(a): - """hyperbolic cosine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arccosh_inplace(a): - """hyperbolic arc cosine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def sinh_inplace(a): - """hyperbolic sine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arcsinh_inplace(a): - """hyperbolic arc sine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def tanh_inplace(a): - """hyperbolic tangent of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arctanh_inplace(a): - """hyperbolic arc tangent of `a` (inplace on `a`)""" - - -@scalar_elemwise -def erf_inplace(a): - """error function""" - - -@scalar_elemwise -def erfc_inplace(a): - """complementary error function""" - - -@scalar_elemwise -def erfcx_inplace(a): - """scaled complementary error function""" - - -@scalar_elemwise -def owens_t_inplace(h, a): - """owens t function""" - - -@scalar_elemwise -def gamma_inplace(a): - """gamma function""" - - -@scalar_elemwise -def gammaln_inplace(a): - """log gamma function""" - - -@scalar_elemwise -def psi_inplace(a): - """derivative of log gamma function""" - - -@scalar_elemwise -def tri_gamma_inplace(a): - """second derivative of the log gamma function""" - - -@scalar_elemwise -def gammainc_inplace(k, x): - """regularized lower gamma function (P)""" - - -@scalar_elemwise -def gammaincc_inplace(k, x): - """regularized upper gamma function (Q)""" - - -@scalar_elemwise -def gammau_inplace(k, x): - """upper incomplete gamma function""" - - -@scalar_elemwise -def gammal_inplace(k, x): - """lower incomplete gamma function""" - - -@scalar_elemwise -def gammaincinv_inplace(k, x): - """Inverse to the regularized lower incomplete gamma function""" - - -@scalar_elemwise -def gammainccinv_inplace(k, x): - """Inverse of the regularized upper incomplete gamma function""" - - -@scalar_elemwise -def j0_inplace(x): - """Bessel function of the first kind of order 0.""" - - -@scalar_elemwise -def j1_inplace(x): - """Bessel function of the first kind of order 1.""" - - -@scalar_elemwise -def jv_inplace(v, x): - """Bessel function of the first kind of order v (real).""" - - -@scalar_elemwise -def i0_inplace(x): - """Modified Bessel function of the first kind of order 0.""" - - -@scalar_elemwise -def i1_inplace(x): - """Modified Bessel function of the first kind of order 1.""" - - -@scalar_elemwise -def iv_inplace(v, x): - """Modified Bessel function of the first kind of order v (real).""" - - -@scalar_elemwise -def ive_inplace(v, x): - """Exponentially scaled modified Bessel function of the first kind of order v (real).""" - - -@scalar_elemwise -def sigmoid_inplace(x): - """Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit""" - - -@scalar_elemwise -def softplus_inplace(x): - """Compute log(1 + exp(x)), also known as softplus or log1pexp""" - - -@scalar_elemwise -def log1mexp_inplace(x): - """Compute log(1 - exp(x)), also known as log1mexp""" - - -@scalar_elemwise -def betainc_inplace(a, b, x): - """Regularized incomplete beta function""" - - -@scalar_elemwise -def betaincinv_inplace(a, b, x): - """Inverse of the regularized incomplete beta function""" - - -@scalar_elemwise -def second_inplace(a): - """Fill `a` with `b`""" - - -fill_inplace = second_inplace -pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="])) - - -@scalar_elemwise -def maximum_inplace(a, b): - """elementwise addition (inplace on `a`)""" - - -@scalar_elemwise -def minimum_inplace(a, b): - """elementwise addition (inplace on `a`)""" - - -@scalar_elemwise -def add_inplace(a, b): - """elementwise addition (inplace on `a`)""" - - -@scalar_elemwise -def sub_inplace(a, b): - """elementwise subtraction (inplace on `a`)""" - - -@scalar_elemwise -def mul_inplace(a, b): - """elementwise multiplication (inplace on `a`)""" - - -@scalar_elemwise -def true_div_inplace(a, b): - """elementwise division (inplace on `a`)""" - - -@scalar_elemwise -def int_div_inplace(a, b): - """elementwise division (inplace on `a`)""" - - -@scalar_elemwise -def mod_inplace(a, b): - """elementwise modulo (inplace on `a`)""" - - -@scalar_elemwise -def pow_inplace(a, b): - """elementwise power (inplace on `a`)""" - - -@scalar_elemwise -def conj_inplace(a): - """elementwise conjugate (inplace on `a`)""" - - -@scalar_elemwise -def hyp2f1_inplace(a, b, c, z): - """gaussian hypergeometric function""" - - -pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either")) -pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either")) -pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left")) -pprint.assign(neg_inplace, printing.OperatorPrinter("-=", 0, "either")) -pprint.assign(true_div_inplace, printing.OperatorPrinter("/=", -1, "left")) -pprint.assign(int_div_inplace, printing.OperatorPrinter("//=", -1, "left")) -pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right")) - - -def transpose_inplace(x, **kwargs): - "Perform a transpose on a tensor without copying the underlying storage" - dims = list(range(x.ndim - 1, -1, -1)) - return x.dimshuffle(dims) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 954656cebe..a068335d5b 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -6,13 +6,13 @@ import pytensor import pytensor.tensor as pt -import pytensor.tensor.inplace as pti import pytensor.tensor.math as ptm from pytensor import config, function from pytensor.compile import get_mode from pytensor.compile.ops import deep_copy_op from pytensor.gradient import grad from pytensor.scalar import Composite, float64 +from pytensor.scalar import add as scalar_add from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum @@ -30,6 +30,8 @@ rng = np.random.default_rng(42849) +add_inplace = Elemwise(scalar_add, {0: 0}) + @pytest.mark.parametrize( "inputs, input_vals, output_fn", @@ -80,7 +82,7 @@ np.array(1.0, dtype=config.floatX), np.array(1.0, dtype=config.floatX), ], - lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), + lambda x, y: add_inplace(deep_copy_op(x), deep_copy_op(y)), ), ( [pt.vector(), pt.vector()], @@ -88,7 +90,7 @@ rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX), ], - lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), + lambda x, y: add_inplace(deep_copy_op(x), deep_copy_op(y)), ), ( [pt.vector(), pt.vector()], diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 303cf970d4..0380f997e3 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -31,7 +31,6 @@ from pytensor.graph.traversal import ancestors from pytensor.printing import debugprint from pytensor.scalar import PolyGamma, Psi, TriGamma -from pytensor.tensor import inplace from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv @@ -1134,15 +1133,15 @@ def test_log1p(): f = function([x], log(1 + (x)), mode=m) assert [node.op for node in f.maker.fgraph.toposort()] == [log1p] f = function([x], log(1 + (-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] == [ - neg, - inplace.log1p_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [ + ps.neg, + ps.log1p, ] f = function([x], -log(1 + (-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] == [ - neg, - inplace.log1p_inplace, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [ + ps.neg, + ps.log1p, + ps.neg, ] # check trickier cases (and use different dtype) @@ -4035,27 +4034,27 @@ def test_exp_over_1_plus_exp(self): # todo: solve issue #4589 first # assert check_stack_trace( # f, ops_to_check=[sigmoid, neg_inplace]) - assert [node.op for node in f.maker.fgraph.toposort()] == [ - sigmoid, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [ + ps.sigmoid, + ps.neg, ] f(data) f = pytensor.function([x], pt.fill(x, -1.0) / (1 - exp(-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.neg, ] f(data) f = pytensor.function([x], pt.fill(x, -1.0) / (2 + exp(-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.neg, ] f(data) f = pytensor.function([x], pt.fill(x, -1.1) / (1 + exp(-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.neg, ] f(data) @@ -4077,10 +4076,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.1) * exp(x)) / ((1 + exp(x)) * (1 + exp(-x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) f = pytensor.function( @@ -4088,10 +4087,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.0) * exp(x)) / ((2 + exp(x)) * (1 + exp(-x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) f = pytensor.function( @@ -4099,10 +4098,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) f = pytensor.function( @@ -4110,10 +4109,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (1 + exp(x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) f = pytensor.function( @@ -4121,10 +4120,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 6d1e843a9e..60592d1b31 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -17,7 +17,6 @@ from pytensor.gradient import grad from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import InconsistencyError -from pytensor.tensor import inplace from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.blas import ( BatchedDot, @@ -40,6 +39,7 @@ ger, ger_destructive, ) +from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger from pytensor.tensor.type import ( @@ -258,16 +258,20 @@ def test_destroy_map1(self): rng = np.random.default_rng(seed=utt.fetch_seed()) Z = as_tensor_variable(rng.random((2, 2))) A = as_tensor_variable(rng.random((2, 2))) + Zt = Z.transpose() + assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]} with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq): - gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0) + gemm_inplace(Z, 1.0, A, Zt, 1.0) def test_destroy_map2(self): # test that only first input can be overwritten. rng = np.random.default_rng(seed=utt.fetch_seed()) Z = as_tensor_variable(rng.random((2, 2))) A = as_tensor_variable(rng.random((2, 2))) + Zt = Z.transpose() + assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]} with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq): - gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0) + gemm_inplace(Z, 1.0, Zt, A, 1.0) def test_destroy_map3(self): # test that only first input can be overwritten diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 5d20bf837b..c7fd040cfb 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -20,6 +20,9 @@ from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.scalar import ScalarOp, float32, float64, int32, int64 +from pytensor.scalar import add as scalar_add +from pytensor.scalar import exp as scalar_exp +from pytensor.scalar import xor as scalar_xor from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -43,6 +46,16 @@ ) from tests import unittest_tools from tests.link.test_link import make_function +from tests.tensor.utils import ( + _bad_runtime_broadcast_binary_normal, + inplace_func, + integers, + integers_uint16, + integers_uint32, + makeBroadcastTester, + random, + random_complex, +) def reduce_bitwise_and(x, axis=-1, dtype="int8"): @@ -334,7 +347,7 @@ def with_linker_inplace(self, linker, op, type, rand_val): x = x_type("x") y = y_type("y") - e = op(ps.Add(ps.transfer_type(0)), {0: 0})(x, y) + e = op(ps.add, {0: 0})(x, y) f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) xv = rand_val(xsh) yv = rand_val(ysh) @@ -348,7 +361,7 @@ def with_linker_inplace(self, linker, op, type, rand_val): if isinstance(linker, PerformLinker): x = x_type("x") y = y_type("y") - e = op(ps.Add(ps.transfer_type(0)), {0: 0})(x, y) + e = op(ps.add, {0: 0})(x, y) f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape]))) xv = rand_val(xsh) yv = rand_val(ysh) @@ -390,7 +403,10 @@ def test_fill(self): ): x = t(pytensor.config.floatX, shape=(None, None))("x") y = t(pytensor.config.floatX, shape=(1, 1))("y") - e = op(ps.Second(ps.transfer_type(0)), {0: 0})(x, y) + op1 = op(ps.second, {0: 0}) + op2 = op(ps.second, {0: 0}) + assert op1 == op2 + e = op(ps.Second(), {0: 0})(x, y) f = make_function(linker().accept(FunctionGraph([x, y], [e]))) xv = rval((5, 5)) yv = rval((1, 1)) @@ -1113,3 +1129,74 @@ def test_numpy_warning_suppressed(): y = pt.log(x) fn = pytensor.function([x], y, mode=Mode(linker="py")) assert fn(0) == -np.inf + + +rng = np.random.default_rng(18) +_good_add_inplace = dict( + same_shapes=(random(2, 3, rng=rng), random(2, 3, rng=rng)), + not_same_dimensions=(random(2, 2, rng=rng), random(2, rng=rng)), + scalar=(random(2, 3, rng=rng), random(1, 1, rng=rng)), + row=(random(2, 3, rng=rng), random(1, 3, rng=rng)), + column=(random(2, 3, rng=rng), random(2, 1, rng=rng)), + integers=(integers(2, 3, rng=rng), integers(2, 3, rng=rng)), + uint32=(integers_uint32(2, 3, rng=rng), integers_uint32(2, 3, rng=rng)), + uint16=(integers_uint16(2, 3, rng=rng), integers_uint16(2, 3, rng=rng)), + # (float32, >int16) upcasts to float64 by default + dtype_valid_mixup=( + random(2, 3, rng=rng), + integers(2, 3, rng=rng).astype( + "int16" if config.floatX == "float32" else "int64" + ), + ), + complex1=(random_complex(2, 3, rng=rng), random_complex(2, 3, rng=rng)), + complex2=(random_complex(2, 3, rng=rng), random(2, 3, rng=rng)), + empty=(np.asarray([], dtype=config.floatX), np.asarray([1], dtype=config.floatX)), +) +TestAddInplaceBroadcast = makeBroadcastTester( + op=Elemwise(scalar_add, {0: 0}), + expected=lambda x, y: x + y, + good=_good_add_inplace, + # Cannot inplace on first input if it doesn't match output dtype (upcast of inputs) + bad_build=dict(dtype_invalid_mixup=_good_add_inplace["dtype_valid_mixup"][::-1]), + bad_runtime=_bad_runtime_broadcast_binary_normal, + inplace=True, +) + + +@pytest.mark.xfail( + config.cycle_detection == "fast" and config.mode != "FAST_COMPILE", + reason="Cycle detection is fast and mode is FAST_COMPILE", +) +def test_exp_inplace_grad_1(): + utt.verify_grad( + Elemwise(scalar_exp, {0: 0}), + [ + np.asarray( + [ + [1.5089518, 1.48439076, -4.7820262], + [2.04832468, 0.50791564, -1.58892269], + ] + ) + ], + ) + + +def test_XOR_inplace(): + dtype = [ + "int8", + "int16", + "int32", + "int64", + ] + xor_inplace = Elemwise(scalar_xor, {0: 0}) + + for dtype in dtype: + x, y = vector(dtype=dtype), vector(dtype=dtype) + l = np.asarray([0, 0, 1, 1], dtype=dtype) + r = np.asarray([0, 1, 0, 1], dtype=dtype) + ix = x + ix = xor_inplace(ix, y) + gn = inplace_func([x, y], ix) + _ = gn(l, r) + # test the in-place stuff + assert np.all(l == np.asarray([0, 1, 1, 0])), l diff --git a/tests/tensor/test_inplace.py b/tests/tensor/test_inplace.py deleted file mode 100644 index a31a26df07..0000000000 --- a/tests/tensor/test_inplace.py +++ /dev/null @@ -1,465 +0,0 @@ -import numpy as np -import pytest - -from pytensor import config -from pytensor.scalar.basic import round_half_away_from_zero_vec, upcast -from pytensor.tensor.inplace import ( - abs_inplace, - add_inplace, - arccos_inplace, - arccosh_inplace, - arcsin_inplace, - arcsinh_inplace, - arctan2_inplace, - arctan_inplace, - arctanh_inplace, - ceil_inplace, - conj_inplace, - cos_inplace, - cosh_inplace, - deg2rad_inplace, - exp2_inplace, - exp_inplace, - expm1_inplace, - floor_inplace, - int_div_inplace, - log1p_inplace, - log2_inplace, - log10_inplace, - log_inplace, - maximum_inplace, - minimum_inplace, - mod_inplace, - mul_inplace, - neg_inplace, - pow_inplace, - rad2deg_inplace, - reciprocal_inplace, - round_half_away_from_zero_inplace, - round_half_to_even_inplace, - sign_inplace, - sin_inplace, - sinh_inplace, - sqr_inplace, - sqrt_inplace, - sub_inplace, - tan_inplace, - tanh_inplace, - true_div_inplace, - trunc_inplace, - xor_inplace, -) -from pytensor.tensor.type import vector -from tests import unittest_tools as utt -from tests.tensor.utils import ( - _bad_build_broadcast_binary_normal, - _bad_runtime_broadcast_binary_normal, - _bad_runtime_reciprocal, - _good_broadcast_binary_arctan2, - _good_broadcast_binary_normal, - _good_broadcast_div_mod_normal_float_inplace, - _good_broadcast_pow_normal_float_pow, - _good_broadcast_unary_arccosh, - _good_broadcast_unary_arcsin_float, - _good_broadcast_unary_arctanh, - _good_broadcast_unary_normal, - _good_broadcast_unary_normal_abs, - _good_broadcast_unary_normal_float, - _good_broadcast_unary_normal_float_no_complex, - _good_broadcast_unary_normal_float_no_empty_no_complex, - _good_broadcast_unary_normal_no_complex, - _good_broadcast_unary_positive_float, - _good_broadcast_unary_tan, - _good_broadcast_unary_wide_float, - _good_reciprocal_inplace, - _numpy_true_div, - angle_eps, - check_floatX, - copymod, - div_grad_rtol, - ignore_isfinite_mode, - inplace_func, - makeBroadcastTester, - upcast_float16_ufunc, -) - - -TestAddInplaceBroadcast = makeBroadcastTester( - op=add_inplace, - expected=lambda x, y: x + y, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestSubInplaceBroadcast = makeBroadcastTester( - op=sub_inplace, - expected=lambda x, y: x - y, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestMaximumInplaceBroadcast = makeBroadcastTester( - op=maximum_inplace, - expected=np.maximum, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestMinimumInplaceBroadcast = makeBroadcastTester( - op=minimum_inplace, - expected=np.minimum, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestMulInplaceBroadcast = makeBroadcastTester( - op=mul_inplace, - expected=lambda x, y: x * y, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestTrueDivInplaceBroadcast = makeBroadcastTester( - op=true_div_inplace, - expected=_numpy_true_div, - good=copymod( - _good_broadcast_div_mod_normal_float_inplace, - # The output is now in float, we cannot work inplace on an int. - without=["integer", "uint8", "uint16", "int8"], - ), - grad_rtol=div_grad_rtol, - inplace=True, -) - -TestReciprocalInplaceBroadcast = makeBroadcastTester( - op=reciprocal_inplace, - expected=lambda x: _numpy_true_div(np.int8(1), x), - good=_good_reciprocal_inplace, - bad_runtime=_bad_runtime_reciprocal, - grad_rtol=div_grad_rtol, - inplace=True, -) - -TestModInplaceBroadcast = makeBroadcastTester( - op=mod_inplace, - expected=lambda x, y: np.asarray(x % y, dtype=upcast(x.dtype, y.dtype)), - good=copymod( - _good_broadcast_div_mod_normal_float_inplace, ["complex1", "complex2"] - ), - grad_eps=1e-5, - inplace=True, -) - -TestPowInplaceBroadcast = makeBroadcastTester( - op=pow_inplace, - expected=lambda x, y: x**y, - good=_good_broadcast_pow_normal_float_pow, - inplace=True, - mode=ignore_isfinite_mode, -) - -TestNegInplaceBroadcast = makeBroadcastTester( - op=neg_inplace, - expected=lambda x: -x, - good=_good_broadcast_unary_normal, - inplace=True, -) - -TestSgnInplaceBroadcast = makeBroadcastTester( - op=sign_inplace, - expected=np.sign, - good=_good_broadcast_unary_normal_no_complex, - inplace=True, -) - -TestAbsInplaceBroadcast = makeBroadcastTester( - op=abs_inplace, - expected=lambda x: np.abs(x), - good=_good_broadcast_unary_normal_abs, - inplace=True, -) - -TestIntDivInplaceBroadcast = makeBroadcastTester( - op=int_div_inplace, - expected=lambda x, y: check_floatX((x, y), x // y), - good=_good_broadcast_div_mod_normal_float_inplace, - # I don't test the grad as the output is always an integer - # (this is not a continuous output). - # grad=_grad_broadcast_div_mod_normal, - inplace=True, -) - -TestCeilInplaceBroadcast = makeBroadcastTester( - op=ceil_inplace, - expected=upcast_float16_ufunc(np.ceil), - good=copymod( - _good_broadcast_unary_normal_no_complex, - without=["integers", "int8", "uint8", "uint16"], - ), - # corner cases includes a lot of integers: points where Ceil is not - # continuous (not differentiable) - inplace=True, -) - -TestFloorInplaceBroadcast = makeBroadcastTester( - op=floor_inplace, - expected=upcast_float16_ufunc(np.floor), - good=copymod( - _good_broadcast_unary_normal_no_complex, - without=["integers", "int8", "uint8", "uint16"], - ), - inplace=True, -) - -TestTruncInplaceBroadcast = makeBroadcastTester( - op=trunc_inplace, - expected=upcast_float16_ufunc(np.trunc), - good=_good_broadcast_unary_normal_no_complex, - inplace=True, -) - -TestRoundHalfToEvenInplaceBroadcast = makeBroadcastTester( - op=round_half_to_even_inplace, - expected=np.round, - good=_good_broadcast_unary_normal_float_no_complex, - inplace=True, -) - -TestRoundHalfAwayFromZeroInplaceBroadcast = makeBroadcastTester( - op=round_half_away_from_zero_inplace, - expected=lambda a: round_half_away_from_zero_vec(a), - good=_good_broadcast_unary_normal_float_no_empty_no_complex, - inplace=True, -) - -TestSqrInplaceBroadcast = makeBroadcastTester( - op=sqr_inplace, - expected=np.square, - good=_good_broadcast_unary_normal, - inplace=True, -) - -TestExpInplaceBroadcast = makeBroadcastTester( - op=exp_inplace, - expected=np.exp, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestExp2InplaceBroadcast = makeBroadcastTester( - op=exp2_inplace, - expected=np.exp2, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestExpm1InplaceBroadcast = makeBroadcastTester( - op=expm1_inplace, - expected=np.expm1, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestLogInplaceBroadcast = makeBroadcastTester( - op=log_inplace, - expected=np.log, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestLog2InplaceBroadcast = makeBroadcastTester( - op=log2_inplace, - expected=np.log2, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestLog10InplaceBroadcast = makeBroadcastTester( - op=log10_inplace, - expected=np.log10, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestLog1pInplaceBroadcast = makeBroadcastTester( - op=log1p_inplace, - expected=np.log1p, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestSqrtInplaceBroadcast = makeBroadcastTester( - op=sqrt_inplace, - expected=np.sqrt, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestDeg2radInplaceBroadcast = makeBroadcastTester( - op=deg2rad_inplace, - expected=np.deg2rad, - good=_good_broadcast_unary_normal_float_no_complex, - inplace=True, - eps=angle_eps, -) - -TestRad2degInplaceBroadcast = makeBroadcastTester( - op=rad2deg_inplace, - expected=np.rad2deg, - good=_good_broadcast_unary_normal_float_no_complex, - inplace=True, - eps=angle_eps, -) - -TestSinInplaceBroadcast = makeBroadcastTester( - op=sin_inplace, - expected=np.sin, - good=_good_broadcast_unary_wide_float, - inplace=True, -) - -TestArcsinInplaceBroadcast = makeBroadcastTester( - op=arcsin_inplace, - expected=np.arcsin, - good=_good_broadcast_unary_arcsin_float, - inplace=True, -) - -TestCosInplaceBroadcast = makeBroadcastTester( - op=cos_inplace, - expected=np.cos, - good=_good_broadcast_unary_wide_float, - inplace=True, -) - -TestArccosInplaceBroadcast = makeBroadcastTester( - op=arccos_inplace, - expected=np.arccos, - good=_good_broadcast_unary_arcsin_float, - inplace=True, -) - -TestTanInplaceBroadcast = makeBroadcastTester( - op=tan_inplace, - expected=np.tan, - good=copymod( - _good_broadcast_unary_tan, without=["integers", "int8", "uint8", "uint16"] - ), - inplace=True, -) - -TestArctanInplaceBroadcast = makeBroadcastTester( - op=arctan_inplace, - expected=np.arctan, - good=_good_broadcast_unary_wide_float, - inplace=True, -) - -TestArctan2InplaceBroadcast = makeBroadcastTester( - op=arctan2_inplace, - expected=np.arctan2, - good=copymod( - _good_broadcast_binary_arctan2, - without=["integers", "int8", "uint8", "uint16", "dtype_mixup_2"], - ), - inplace=True, -) - -TestCoshInplaceBroadcast = makeBroadcastTester( - op=cosh_inplace, - expected=np.cosh, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestArccoshInplaceBroadcast = makeBroadcastTester( - op=arccosh_inplace, - expected=np.arccosh, - good=copymod(_good_broadcast_unary_arccosh, without=["integers", "uint8"]), - inplace=True, -) - -TestSinhInplaceBroadcast = makeBroadcastTester( - op=sinh_inplace, - expected=np.sinh, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestArcsinhInplaceBroadcast = makeBroadcastTester( - op=arcsinh_inplace, - expected=np.arcsinh, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestTanhInplaceBroadcast = makeBroadcastTester( - op=tanh_inplace, - expected=np.tanh, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestArctanhInplaceBroadcast = makeBroadcastTester( - op=arctanh_inplace, - expected=np.arctanh, - good=copymod( - _good_broadcast_unary_arctanh, without=["integers", "int8", "uint8", "uint16"] - ), - inplace=True, -) - -TestConjInplaceBroadcast = makeBroadcastTester( - op=conj_inplace, - expected=np.conj, - good=_good_broadcast_unary_normal, - inplace=True, -) - - -@pytest.mark.xfail( - config.cycle_detection == "fast" and config.mode != "FAST_COMPILE", - reason="Cycle detection is fast and mode is FAST_COMPILE", -) -def test_exp_inplace_grad_1(): - utt.verify_grad( - exp_inplace, - [ - np.asarray( - [ - [1.5089518, 1.48439076, -4.7820262], - [2.04832468, 0.50791564, -1.58892269], - ] - ) - ], - ) - - -def test_XOR_inplace(): - dtype = [ - "int8", - "int16", - "int32", - "int64", - ] - - for dtype in dtype: - x, y = vector(dtype=dtype), vector(dtype=dtype) - l = np.asarray([0, 0, 1, 1], dtype=dtype) - r = np.asarray([0, 1, 0, 1], dtype=dtype) - ix = x - ix = xor_inplace(ix, y) - gn = inplace_func([x, y], ix) - _ = gn(l, r) - # test the in-place stuff - assert np.all(l == np.asarray([0, 1, 1, 0])), l diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index fbfa5fb77e..d4c2e3463f 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -12,13 +12,12 @@ from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, verify_grad from pytensor.scalar import ScalarLoop -from pytensor.tensor import gammaincc, inplace, kn, kv, kve, vector +from pytensor.tensor import gammaincc, kn, kv, kve, vector from pytensor.tensor.elemwise import Elemwise from tests import unittest_tools as utt from tests.tensor.utils import ( _good_broadcast_unary_chi2sf, _good_broadcast_unary_normal, - _good_broadcast_unary_normal_float, _good_broadcast_unary_normal_float_no_complex, _good_broadcast_unary_normal_float_no_complex_small_neg_range, _good_broadcast_unary_normal_no_complex, @@ -85,14 +84,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestErfInplaceBroadcast = makeBroadcastTester( - op=inplace.erf_inplace, - expected=expected_erf, - good=_good_broadcast_unary_normal_float, - mode=mode_no_scipy, - eps=2e-10, - inplace=True, -) TestErfcBroadcast = makeBroadcastTester( op=pt.erfc, @@ -102,14 +93,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestErfcInplaceBroadcast = makeBroadcastTester( - op=inplace.erfc_inplace, - expected=expected_erfc, - good=_good_broadcast_unary_normal_float_no_complex, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) TestErfcxBroadcast = makeBroadcastTester( op=pt.erfcx, @@ -119,14 +102,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestErfcxInplaceBroadcast = makeBroadcastTester( - op=inplace.erfcx_inplace, - expected=expected_erfcx, - good=_good_broadcast_unary_normal_float_no_complex_small_neg_range, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) TestErfinvBroadcast = makeBroadcastTester( op=pt.erfinv, @@ -192,14 +167,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestOwensTInplaceBroadcast = makeBroadcastTester( - op=inplace.owens_t_inplace, - expected=expected_owenst, - good=_good_broadcast_binary_owenst, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_gammaln = dict( @@ -223,14 +190,6 @@ def scipy_special_gammal(k, x): mode=mode_no_scipy, eps=1e-5, ) -TestGammaInplaceBroadcast = makeBroadcastTester( - op=inplace.gamma_inplace, - expected=expected_gamma, - good=_good_broadcast_unary_gammaln, - mode=mode_no_scipy, - eps=1e-5, - inplace=True, -) TestGammalnBroadcast = makeBroadcastTester( op=pt.gammaln, @@ -240,14 +199,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestGammalnInplaceBroadcast = makeBroadcastTester( - op=inplace.gammaln_inplace, - expected=expected_gammaln, - good=_good_broadcast_unary_gammaln, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_psi = dict( @@ -265,14 +216,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestPsiInplaceBroadcast = makeBroadcastTester( - op=inplace.psi_inplace, - expected=expected_psi, - good=_good_broadcast_unary_psi, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) _good_broadcast_unary_tri_gamma = _good_broadcast_unary_psi @@ -283,14 +226,6 @@ def scipy_special_gammal(k, x): eps=2e-8, mode=mode_no_scipy, ) -TestTriGammaInplaceBroadcast = makeBroadcastTester( - op=inplace.tri_gamma_inplace, - expected=expected_tri_gamma, - good=_good_broadcast_unary_tri_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) TestChi2SFBroadcast = makeBroadcastTester( op=pt.chi2sf, @@ -343,15 +278,6 @@ def scipy_special_gammal(k, x): mode=mode_no_scipy, ) -TestGammaIncInplaceBroadcast = makeBroadcastTester( - op=inplace.gammainc_inplace, - expected=expected_gammainc, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - TestGammaInccBroadcast = makeBroadcastTester( op=pt.gammaincc, expected=expected_gammaincc, @@ -361,15 +287,6 @@ def scipy_special_gammal(k, x): mode=mode_no_scipy, ) -TestGammaInccInplaceBroadcast = makeBroadcastTester( - op=inplace.gammaincc_inplace, - expected=expected_gammaincc, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - def test_gammainc_ddk_tabulated_values(): # This test replicates part of the old STAN test: @@ -447,15 +364,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestGammaUInplaceBroadcast = makeBroadcastTester( - op=inplace.gammau_inplace, - expected=expected_gammau, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - TestGammaLBroadcast = makeBroadcastTester( op=pt.gammal, expected=expected_gammal, @@ -464,15 +372,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestGammaLInplaceBroadcast = makeBroadcastTester( - op=inplace.gammal_inplace, - expected=expected_gammal, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_binary_gamma = dict( normal=( @@ -490,15 +389,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestGammaIncInvInplaceBroadcast = makeBroadcastTester( - op=inplace.gammaincinv_inplace, - expected=expected_gammaincinv, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - TestGammaInccInvBroadcast = makeBroadcastTester( op=pt.gammainccinv, expected=expected_gammainccinv, @@ -507,15 +397,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestGammaInccInvInplaceBroadcast = makeBroadcastTester( - op=inplace.gammainccinv_inplace, - expected=expected_gammainccinv, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_bessel = dict( normal=(random_ranged(-10, 10, (2, 3), rng=rng),), @@ -562,15 +443,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestJ0InplaceBroadcast = makeBroadcastTester( - op=inplace.j0_inplace, - expected=expected_j0, - good=_good_broadcast_unary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestJ1Broadcast = makeBroadcastTester( op=pt.j1, expected=expected_j1, @@ -580,15 +452,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestJ1InplaceBroadcast = makeBroadcastTester( - op=inplace.j1_inplace, - expected=expected_j1, - good=_good_broadcast_unary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestJvBroadcast = makeBroadcastTester( op=pt.jv, expected=expected_jv, @@ -597,15 +460,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestJvInplaceBroadcast = makeBroadcastTester( - op=inplace.jv_inplace, - expected=expected_jv, - good=_good_broadcast_binary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - def test_verify_jv_grad(): # Verify Jv gradient. @@ -628,15 +482,6 @@ def fixed_first_input_jv(x): mode=mode_no_scipy, ) -TestI0InplaceBroadcast = makeBroadcastTester( - op=inplace.i0_inplace, - expected=expected_i0, - good=_good_broadcast_unary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestI1Broadcast = makeBroadcastTester( op=pt.i1, expected=expected_i1, @@ -646,15 +491,6 @@ def fixed_first_input_jv(x): mode=mode_no_scipy, ) -TestI1InplaceBroadcast = makeBroadcastTester( - op=inplace.i1_inplace, - expected=expected_i1, - good=_good_broadcast_unary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestIvBroadcast = makeBroadcastTester( op=pt.iv, expected=expected_iv, @@ -663,15 +499,6 @@ def fixed_first_input_jv(x): mode=mode_no_scipy, ) -TestIvInplaceBroadcast = makeBroadcastTester( - op=inplace.iv_inplace, - expected=expected_iv, - good=_good_broadcast_binary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestIveBroadcast = makeBroadcastTester( op=pt.ive, expected=expected_ive, @@ -680,15 +507,6 @@ def fixed_first_input_jv(x): mode=mode_no_scipy, ) -TestIveInplaceBroadcast = makeBroadcastTester( - op=inplace.ive_inplace, - expected=expected_ive, - good=_good_broadcast_binary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - def test_verify_iv_grad(): # Verify Iv gradient. @@ -721,15 +539,6 @@ def fixed_first_input_ive(x): eps=1e-8, ) -TestSigmoidInplaceBroadcast = makeBroadcastTester( - op=inplace.sigmoid_inplace, - expected=expected_sigmoid, - good=_good_broadcast_unary_normal_no_complex, - grad=_grad_broadcast_unary_normal, - eps=1e-8, - inplace=True, -) - class TestSigmoid: def test_elemwise(self): @@ -758,15 +567,6 @@ def test_elemwise(self): eps=1e-8, ) -TestSoftplusInplaceBroadcast = makeBroadcastTester( - op=inplace.softplus_inplace, - expected=expected_sofplus, - good=_good_broadcast_unary_softplus, - grad=_grad_broadcast_unary_normal, - eps=1e-8, - inplace=True, -) - class TestSoftplus: def test_elemwise(self): @@ -805,14 +605,6 @@ def expected_log1mexp(x): eps=1e-8, ) -TestLog1mexpInplaceBroadcast = makeBroadcastTester( - op=inplace.log1mexp_inplace, - expected=expected_log1mexp, - good=_good_broadcast_unary_log1mexp, - eps=1e-8, - inplace=True, -) - _good_broadcast_ternary_betainc = dict( normal=( random_ranged(0, 1000, (2, 3)), @@ -828,14 +620,6 @@ def expected_log1mexp(x): grad=_good_broadcast_ternary_betainc, ) -TestBetaincInplaceBroadcast = makeBroadcastTester( - op=inplace.betainc_inplace, - expected=special.betainc, - good=_good_broadcast_ternary_betainc, - grad=_good_broadcast_ternary_betainc, - inplace=True, -) - class TestBetaIncGrad: def test_stan_grad_partial(self): @@ -926,13 +710,6 @@ def test_beta_inc_stan_grad_combined(self): good=_good_broadcast_ternary_betaincinv, ) -TestBetaincinvInplaceBroadcast = makeBroadcastTester( - op=inplace.betaincinv_inplace, - expected=special.betaincinv, - good=_good_broadcast_ternary_betaincinv, - inplace=True, -) - _good_broadcast_quaternary_hyp2f1 = dict( normal=( random_ranged(0, 20, (2, 3)), @@ -949,13 +726,6 @@ def test_beta_inc_stan_grad_combined(self): grad=_good_broadcast_quaternary_hyp2f1, ) -TestHyp2F1InplaceBroadcast = makeBroadcastTester( - op=inplace.hyp2f1_inplace, - expected=expected_hyp2f1, - good=_good_broadcast_quaternary_hyp2f1, - inplace=True, -) - class TestHyp2F1Grad: few_iters_case = ( diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 1a8b2455ec..8ebf25a1d9 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -672,7 +672,9 @@ def test_grad_none(self): return Checker -def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs): +def makeBroadcastTester( + op, expected, checks=None, name=None, *, inplace=False, **kwargs +): if checks is None: checks = {} if name is None: @@ -695,22 +697,20 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs): # cases we need to add it manually. if not name.endswith("Tester"): name += "Tester" - if "inplace" in kwargs: - if kwargs["inplace"]: - _expected = expected - if not isinstance(_expected, dict): - - def expected(*inputs): - return np.array(_expected(*inputs), dtype=inputs[0].dtype) - - def inplace_check(inputs, outputs): - # this used to be inputs[0] is output[0] - # I changed it so that it was easier to satisfy by the - # DebugMode - return np.all(inputs[0] == outputs[0]) - - checks = dict(checks, inplace_check=inplace_check) - del kwargs["inplace"] + if inplace: + _expected = expected + if not isinstance(_expected, dict): + + def expected(*inputs): + return np.array(_expected(*inputs), dtype=inputs[0].dtype) + + def inplace_check(inputs, outputs): + # this used to be inputs[0] is output[0] + # I changed it so that it was easier to satisfy by the + # DebugMode + return np.all(inputs[0] == outputs[0]) + + checks = dict(checks, inplace_check=inplace_check) return makeTester(name, op, expected, checks, **kwargs) @@ -815,6 +815,7 @@ def inplace_check(inputs, outputs): big_scalar=[np.arange(17.0, 29.0, 0.5, dtype=config.floatX)], ) +# FIXME: Why is this empty? _bad_build_broadcast_binary_normal = dict() _bad_runtime_broadcast_binary_normal = dict( From d882c8a2d4086b40d39e37469372f4b593bb43de Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Nov 2025 13:26:20 +0100 Subject: [PATCH 2/2] Do not recreate Scalar Ops with custom TransferType for Elemwise inplacing This helper could arbitrarily override the default output_type from `ScalarOp.make_node` so that the output type matched one of the input types. This can be used to create artificial Op signatures that don't make sense or can't be cleanly implemented in other backends. For instance an Add with signature (int8,int64)->int8. This helper was historically used in: 1. Elemwise inplace rewrite. I assume as a preventive measure. However, regular use should never require changing the ScalarOp signature, as we only try to inplace on inputs that match the output dtype and recreating the same Op with the same input types should always return the same output type. ScalarOp don't have a concept of inplace, only the Elemwise wrapper does, and it shouldn't require recreating/mutating the original Op. 2. SecondOp. Here it makes sense, but a custom static_method works just as well 3. Inplace tests with the inplace variants of `@scalar_elemwise` decorator. These test Classes were removed. It still didn't make sense to test/force Ops to have an artifical signature for the sake of tests. They were removed anyway --- pytensor/scalar/basic.py | 58 +++------------------------ pytensor/scalar/loop.py | 3 -- pytensor/tensor/rewriting/elemwise.py | 26 +++++------- tests/tensor/test_basic.py | 1 - tests/tensor/test_elemwise.py | 25 ++++++++++++ 5 files changed, 40 insertions(+), 73 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 6adc16ec59..af0b0b7173 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1101,30 +1101,6 @@ def same_out_float_only(type) -> tuple[ScalarType]: return (type,) -class transfer_type(MetaObject): - __props__ = ("transfer",) - - def __init__(self, *transfer): - assert all(isinstance(x, int | str) or x is None for x in transfer) - self.transfer = transfer - - def __str__(self): - return f"transfer_type{self.transfer}" - - def __call__(self, *types): - upcast = upcast_out(*types) - retval = [] - for i in self.transfer: - if i is None: - retval += [upcast] - elif isinstance(i, str): - retval += [i] - else: - retval += [types[i]] - return retval - # return [upcast if i is None else types[i] for i in self.transfer] - - class specific_out(MetaObject): __props__ = ("spec",) @@ -2446,6 +2422,10 @@ def handle_int(v): class Second(BinaryScalarOp): + @staticmethod + def output_types_preference(_first_type, second_type): + return [second_type] + def impl(self, x, y): return y @@ -2474,7 +2454,7 @@ def grad(self, inputs, gout): return DisconnectedType()(), y.zeros_like(dtype=config.floatX) -second = Second(transfer_type(1), name="second") +second = Second(name="second") class Identity(UnaryScalarOp): @@ -2515,18 +2495,6 @@ def clone_float32(self): return convert_to_float32 return self - def make_new_inplace(self, output_types_preference=None, name=None): - """ - This op.__init__ fct don't have the same parameter as other scalar op. - This breaks the insert_inplace_optimizer optimization. - This function is a fix to patch this, by ignoring the - output_types_preference passed by the optimization, and replacing it - by the current output type. This should only be triggered when - both input and output have the same dtype anyway. - - """ - return self.__class__(self.o_type, name) - def impl(self, input): return self.ctor(input) @@ -4322,22 +4290,6 @@ def __str__(self): return self._name - def make_new_inplace(self, output_types_preference=None, name=None): - """ - This op.__init__ fct don't have the same parameter as other scalar op. - This break the insert_inplace_optimizer optimization. - This fct allow fix patch this. - - """ - d = {k: getattr(self, k) for k in self.init_param} - out = self.__class__(**d) - if name: - out.name = name - else: - name = out.name - super(Composite, out).__init__(output_types_preference, name) - return out - @property def fgraph(self): if hasattr(self, "_fgraph"): diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index f23c4e1c42..e4bfc871fc 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -136,9 +136,6 @@ def clone(self, name=None, **kwargs): def fn(self): raise NotImplementedError - def make_new_inplace(self, output_types_preference=None, name=None): - return self.clone(output_types_preference=output_types_preference, name=name) - def make_node(self, n_steps, *inputs): assert len(inputs) == self.nin - 1 diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 1bcaa8624d..dc30beedf3 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -35,7 +35,6 @@ Mul, ScalarOp, get_scalar_type, - transfer_type, upcast_out, upgrade_to_float, ) @@ -287,22 +286,17 @@ def create_inplace_node(self, node, inplace_pattern): op = node.op scalar_op = op.scalar_op inplace_pattern = {i: o for i, [o] in inplace_pattern.items()} - if hasattr(scalar_op, "make_new_inplace"): - new_scalar_op = scalar_op.make_new_inplace( - transfer_type( - *[ - inplace_pattern.get(i, o.dtype) - for i, o in enumerate(node.outputs) - ] + try: + return type(op)(scalar_op, inplace_pattern).make_node(*node.inputs) + except TypeError: + # Elemwise raises TypeError if we try to inplace an output on an input of a different dtype + if config.optimizer_verbose: + print( # noqa: T201 + f"InplaceElemwise failed because the output dtype of {node} changed when rebuilt. " + "Perhaps due to a change in config.floatX or config.cast_policy" ) - ) - else: - new_scalar_op = type(scalar_op)( - transfer_type( - *[inplace_pattern.get(i, None) for i in range(len(node.outputs))] - ) - ) - return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs) + # InplaceGraphOptimizer will chug along fine if we return the original node + return node optdb.register( diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 0390fbbac8..46a5e2e4fa 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -2797,7 +2797,6 @@ def test_infer_shape(self, cast_policy): out = arange(start, stop, 1) f = function([start, stop], out.shape, mode=mode) assert len(f.maker.fgraph.toposort()) == 5 - # 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)] if config.cast_policy == "custom": assert out.dtype == "int64" elif config.cast_policy == "numpy+floatX": diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index c7fd040cfb..5a61bf8f8a 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -1200,3 +1200,28 @@ def test_XOR_inplace(): _ = gn(l, r) # test the in-place stuff assert np.all(l == np.asarray([0, 1, 1, 0])), l + + +def test_inplace_dtype_changed(): + with pytensor.config.change_flags(cast_policy="numpy+floatX", floatX="float64"): + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="int32") + with pytensor.config.change_flags(floatX="float32"): + out = pt.add(x, y) + + assert out.dtype == "float32" + with pytensor.config.change_flags(floatX="float32"): + fn32 = pytensor.function( + [In(x, mutable=True), In(y, mutable=True)], + out, + mode="fast_run", + ) + assert fn32.maker.fgraph.outputs[0].owner.op.destroy_map == {0: [0]} + + with pytensor.config.change_flags(floatX="float64"): + fn64 = pytensor.function( + [In(x, mutable=True), In(y, mutable=True)], + out, + mode="fast_run", + ) + assert fn64.maker.fgraph.outputs[0].owner.op.destroy_map == {}