Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ jobs:
install-mlx: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes the tensor/test_inplace.py and reorders the ignore / parts to be more readible. They ignore show up in the same order where each part is that reintroduced.

May still change after the PR runs to rebalance the workload so they take more or less the same time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jobs look well balanced

- "tests/scan"
- "tests/tensor --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/conv --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
- "tests/tensor/rewriting"
- "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting"
- "tests/tensor/test_basic.py tests/tensor/test_elemwise.py"
- "tests/tensor/test_math.py"
- "tests/tensor/test_basic.py tests/tensor/test_inplace.py tests/tensor/conv"
- "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
- "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv"
- "tests/tensor/rewriting"
exclude:
- python-version: "3.11"
fast-compile: 1
Expand Down Expand Up @@ -167,7 +167,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"

steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
Expand Down
58 changes: 5 additions & 53 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,30 +1101,6 @@ def same_out_float_only(type) -> tuple[ScalarType]:
return (type,)


class transfer_type(MetaObject):
__props__ = ("transfer",)

def __init__(self, *transfer):
assert all(isinstance(x, int | str) or x is None for x in transfer)
self.transfer = transfer

def __str__(self):
return f"transfer_type{self.transfer}"

def __call__(self, *types):
upcast = upcast_out(*types)
retval = []
for i in self.transfer:
if i is None:
retval += [upcast]
elif isinstance(i, str):
retval += [i]
else:
retval += [types[i]]
return retval
# return [upcast if i is None else types[i] for i in self.transfer]


class specific_out(MetaObject):
__props__ = ("spec",)

Expand Down Expand Up @@ -2446,6 +2422,10 @@ def handle_int(v):


class Second(BinaryScalarOp):
@staticmethod
def output_types_preference(_first_type, second_type):
return [second_type]

def impl(self, x, y):
return y

Expand Down Expand Up @@ -2474,7 +2454,7 @@ def grad(self, inputs, gout):
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)


second = Second(transfer_type(1), name="second")
second = Second(name="second")


class Identity(UnaryScalarOp):
Expand Down Expand Up @@ -2515,18 +2495,6 @@ def clone_float32(self):
return convert_to_float32
return self

def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This breaks the insert_inplace_optimizer optimization.
This function is a fix to patch this, by ignoring the
output_types_preference passed by the optimization, and replacing it
by the current output type. This should only be triggered when
both input and output have the same dtype anyway.

"""
return self.__class__(self.o_type, name)

def impl(self, input):
return self.ctor(input)

Expand Down Expand Up @@ -4322,22 +4290,6 @@ def __str__(self):

return self._name

def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.

"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(Composite, out).__init__(output_types_preference, name)
return out

@property
def fgraph(self):
if hasattr(self, "_fgraph"):
Expand Down
3 changes: 0 additions & 3 deletions pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ def clone(self, name=None, **kwargs):
def fn(self):
raise NotImplementedError

def make_new_inplace(self, output_types_preference=None, name=None):
return self.clone(output_types_preference=output_types_preference, name=name)

def make_node(self, n_steps, *inputs):
assert len(inputs) == self.nin - 1

Expand Down
17 changes: 6 additions & 11 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import int64, transfer_type, upcast
from pytensor.scalar.basic import int64, upcast
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
Expand Down Expand Up @@ -1634,17 +1634,12 @@ def construct(symbol):
symbolname = symbolname or symbol.__name__

if symbolname.endswith("_inplace"):
base_symbol_name = symbolname[: -len("_inplace")]
scalar_op = getattr(scalar, base_symbol_name)
inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise(
inplace_scalar_op,
{0: 0},
nfunc_spec=(nfunc and (nfunc, nin, nout)),
raise ValueError(
"Creation of automatic inplace elemwise operations deprecated"
)
else:
scalar_op = getattr(scalar, symbolname)
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))

scalar_op = getattr(scalar, symbolname)
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))

if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__
Expand Down
Loading
Loading