Skip to content

Commit 9722dfc

Browse files
committed
Remove predefined inplace Elemwise Ops and redundant tests
1 parent a5fb911 commit 9722dfc

File tree

10 files changed

+161
-1205
lines changed

10 files changed

+161
-1205
lines changed

.github/workflows/test.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ jobs:
8484
install-mlx: [0]
8585
install-xarray: [0]
8686
part:
87-
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
87+
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor"
8888
- "tests/scan"
89-
- "tests/tensor --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/conv --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
90-
- "tests/tensor/rewriting"
89+
- "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting"
90+
- "tests/tensor/test_basic.py tests/tensor/test_elemwise.py"
9191
- "tests/tensor/test_math.py"
92-
- "tests/tensor/test_basic.py tests/tensor/test_inplace.py tests/tensor/conv"
93-
- "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
92+
- "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv"
93+
- "tests/tensor/rewriting"
9494
exclude:
9595
- python-version: "3.11"
9696
fast-compile: 1
@@ -167,7 +167,7 @@ jobs:
167167
install-numba: 0
168168
install-jax: 0
169169
install-torch: 0
170-
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
170+
part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"
171171

172172
steps:
173173
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0

pytensor/tensor/elemwise.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytensor.printing import Printer, pprint
2121
from pytensor.scalar import get_scalar_type
2222
from pytensor.scalar.basic import identity as scalar_identity
23-
from pytensor.scalar.basic import int64, transfer_type, upcast
23+
from pytensor.scalar.basic import int64, upcast
2424
from pytensor.tensor import elemwise_cgen as cgen
2525
from pytensor.tensor import get_vector_length
2626
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
@@ -1634,17 +1634,12 @@ def construct(symbol):
16341634
symbolname = symbolname or symbol.__name__
16351635

16361636
if symbolname.endswith("_inplace"):
1637-
base_symbol_name = symbolname[: -len("_inplace")]
1638-
scalar_op = getattr(scalar, base_symbol_name)
1639-
inplace_scalar_op = scalar_op.__class__(transfer_type(0))
1640-
rval = Elemwise(
1641-
inplace_scalar_op,
1642-
{0: 0},
1643-
nfunc_spec=(nfunc and (nfunc, nin, nout)),
1637+
raise ValueError(
1638+
"Creation of automatic inplace elemwise operations deprecated"
16441639
)
1645-
else:
1646-
scalar_op = getattr(scalar, symbolname)
1647-
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
1640+
1641+
scalar_op = getattr(scalar, symbolname)
1642+
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
16481643

16491644
if getattr(symbol, "__doc__"):
16501645
rval.__doc__ = symbol.__doc__

0 commit comments

Comments
 (0)