Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ class PerformLinker(LocalLinker):

"""

required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
incompatible_rewrites: tuple[str, ...] = ("cxx",)

def __init__(
self, allow_gc: bool | None = None, schedule: Callable | None = None
) -> None:
Expand Down Expand Up @@ -584,6 +587,9 @@ class JITLinker(PerformLinker):

"""

required_rewrites: tuple[str, ...] = ("minimum_compile",)
incompatible_rewrites: tuple[str, ...] = ()

@abstractmethod
def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
Expand Down
2 changes: 0 additions & 2 deletions pytensor/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,8 +1787,6 @@ class OpWiseCLinker(LocalLinker):

"""

__cache__: dict = {}

def __init__(
self, fallback_on_perform=True, allow_gc=None, nice_errors=True, schedule=None
):
Expand Down
7 changes: 7 additions & 0 deletions pytensor/link/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,10 @@ class VMLinker(LocalLinker):

"""

# We can only set these correctly after `__init__`, as it depends on `c_thunks`
required_rewrites: tuple[str, ...] = ("minimum_compile",)
incompatible_rewrites: tuple[str, ...] = ()

def __init__(
self,
allow_gc=None,
Expand All @@ -834,6 +838,9 @@ def __init__(
self.lazy = lazy
if c_thunks is None:
c_thunks = bool(config.cxx)
if not c_thunks:
self.required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
self.incompatible_rewrites: tuple[str, ...] = ("cxx",)
self.c_thunks = c_thunks
self.allow_partial_eval = allow_partial_eval
self.updated_vars = {}
Expand Down
94 changes: 41 additions & 53 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker
from pytensor.link.vm import VMLinker
from pytensor.printing import op_debug_information
from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from pytensor.tensor.basic import as_tensor_variable
Expand Down Expand Up @@ -884,16 +885,24 @@ def tensorConstructor(shape, dtype):
self.nit_sot_arg_offset = (
self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs
)
# XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
# Note: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot

# TODO: These can be moved to thunk/function compilation
(
_,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
# Python and Cython perform methods provide the array location where a mitmot output should be
# stored to the VM as a symbolic update. This helper variable is used in the perform method for validation
mitmots_preallocated = [False] * info.n_mit_mot_outs
if config.scan__allow_output_prealloc:
for mitmot_idx in range(info.n_mit_mot):
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
) + info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
mitmots_preallocated[output_idx] = True
self.mitmots_preallocated = tuple(mitmots_preallocated)

self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
Expand All @@ -908,39 +917,6 @@ def tensorConstructor(shape, dtype):
)
self._hash_inner_graph = hash(self._cmodule_key)

def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = []

info = self.info
input_idx = info.n_seqs
for mitmot_idx in range(info.n_mit_mot):
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
)
output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx)

input_idx += 1

preallocated_mitmot_outs.sort()

else:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
preallocated_mitmot_outs = []

# Store the list of mitmot output taps that have been altered so they
# can be preallocated
mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(info.n_mit_mot_outs)
]

return preallocated_mitmot_outs, mitmots_preallocated

def __setstate__(self, d):
self.__dict__.update(d)
# Ensure that the graph associated with the inner function is valid.
Expand Down Expand Up @@ -999,8 +975,8 @@ def make_node(self, *inputs):

if n_outer_ins != n_inner_ins:
raise ValueError(
"The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan."
f"The number of inputs given to the inner function of scan {n_inner_ins} "
f"does not match the number of inputs given to scan {n_outer_ins}."
)

# Force the inputs to be on the CPU
Expand Down Expand Up @@ -1483,11 +1459,26 @@ def fn(self):

# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc),
message=f"{self.name or 'Scan'} sub profile",
)

mode = self.mode
if mode in (None, "FAST_RUN"):
mode_instance = Mode("cvm", "fast_run")
elif mode == "FAST_COMPILE":
mode_instance = Mode(
VMLinker(use_cloop=False, c_thunks=False), "fast_compile"
)
else:
mode_instance = get_mode(mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc),
message=f"{self.name or 'Scan'} sub profile",
)
# Scan python and cython perform relies on the VM being able to set updates for preallocated MIT-MOT,
# which only the VMs produced by VMLinker do
if any(self.mitmots_preallocated) and not isinstance(
mode_instance.linker, VMLinker
):
raise NotImplementedError(
f"Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker, got {mode_instance.linker}"
)
self._fn = pfunc(
wrapped_inputs,
wrapped_outputs,
Expand Down Expand Up @@ -2007,6 +1998,9 @@ def perform(self, node, inputs, output_storage):
new_var = inner_input_storage[inner_inp_idx].storage[0]
if old_var is new_var:
old_data = old_mitmot_input_data[mitmot_inp_idx]
# This check is only valid if the VM performs updates
# Otherwise the output value may remain the same as the input,
# but doesn't mean that it has been setup correctly
same_data = new_var.data == old_data
else:
same_data = False
Expand Down Expand Up @@ -2051,14 +2045,8 @@ def perform(self, node, inputs, output_storage):
old_data = old_inner_output_data[offset_out + j]
if old_data is None:
output_reused = False
elif isinstance(
self.fn.maker.fgraph.outputs[offset_out + j], TensorVariable
):
output_reused = new_var.data == old_data
else:
raise RuntimeError(
"FIXME: output_reused = new_var.gpudata == old_data"
)
output_reused = new_var.data == old_data
else:
output_reused = False

Expand Down
3 changes: 2 additions & 1 deletion tests/compile/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ def test_NoOutputFromInplace():

def test_including():
mode = Mode(linker="py", optimizer="merge")
assert set(mode._optimizer.include) == {"minimum_compile", "merge"}
assert set(mode._optimizer.include) == {"minimum_compile", "py_only", "merge"}

new_mode = mode.including("fast_compile")
assert set(new_mode._optimizer.include) == {
"minimum_compile",
"py_only",
"merge",
"fast_compile",
}
Expand Down
92 changes: 91 additions & 1 deletion tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
from pytensor.graph.rewriting.basic import MergeOptimizer
from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker
from pytensor.raise_op import assert_op
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import until
from pytensor.tensor import as_tensor
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh
from pytensor.tensor.math import sum as pt_sum
Expand Down Expand Up @@ -4308,3 +4310,91 @@ def test_return_updates_api_change():

with pytest.raises(ValueError, match=err_msg):
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)


@pytest.mark.parametrize(
"scan_mode",
[
None,
"FAST_RUN",
"FAST_COMPILE",
Mode("cvm", optimizer=None),
Mode("vm", optimizer=None),
Mode("c", optimizer=None),
Mode("py", optimizer=None),
],
)
def test_scan_mode_compatibility(scan_mode):
# Regression test for case where using Scan with a non-updating VM failed

# Build a scan with one sequence and two MIT-MOTs
info = ScanInfo(
n_seqs=1,
mit_mot_in_slices=((0, 1), (0, 1)),
mit_mot_out_slices=((1,), (1,)),
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_nit_sot=0,
n_untraced_sit_sot_outs=0,
n_non_seqs=0,
as_while=False,
)
bool_seq = pt.scalar(dtype="bool")
mitmot_A0, mitmot_A1, mitmot_B0, mitmot_B1 = [
pt.matrix(shape=(2, 2)) for i in range(4)
]
inputs = [
bool_seq,
mitmot_A0,
mitmot_A1,
mitmot_B0,
mitmot_B1,
]
outputs = [
pt.add(bool_seq + mitmot_A0, mitmot_A1),
pt.add(bool_seq * mitmot_B0, mitmot_B1),
]

scan_op = Scan(
inputs,
outputs,
info=info,
mode=scan_mode,
)

n_steps = 5
numerical_inputs = [
np.array(n_steps, dtype="int64"),
np.array([1, 1, 0, 1, 0], dtype="bool"),
np.zeros(n_steps + 1)[:, None, None] * np.eye(2),
np.arange(n_steps + 1)[:, None, None] * np.eye(2),
]
tensor_inputs = [as_tensor(inp, dtype=inp.dtype).type() for inp in numerical_inputs]
tensor_outputs = [o.sum() for o in scan_op(*tensor_inputs)]

no_opt_mode = Mode(linker="py", optimizer=None)
# NotImplementedError should only be triggered when we try to compile the function
if (
# Abstract modes should never fail
scan_mode not in (None, "FAST_RUN", "FAST_COMPILE")
# Only if the user tries something specific and incompatible
and not isinstance(get_mode(scan_mode).linker, VMLinker)
):
with pytest.raises(
NotImplementedError,
match="Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker",
):
function(tensor_inputs, tensor_outputs, mode=no_opt_mode)
return

fn = function(tensor_inputs, tensor_outputs, mode=no_opt_mode)

# Check we have the expected Scan in the compiled function
[fn_scan_op] = [
node.op for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
assert fn_scan_op.info == info
assert fn_scan_op.mitmots_preallocated == (True, True)

# Expected value computed by running correct Scan once
np.testing.assert_allclose(fn(*numerical_inputs), [44, 38])
Loading