Skip to content

Commit 65a903b

Browse files
committed
Validate compatible linker in Scan make_thunk
1 parent 856d412 commit 65a903b

File tree

2 files changed

+129
-45
lines changed

2 files changed

+129
-45
lines changed

pytensor/scan/op.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from pytensor.graph.type import HasShape
7777
from pytensor.graph.utils import InconsistencyError, MissingInputError
7878
from pytensor.link.c.basic import CLinker
79+
from pytensor.link.vm import VMLinker
7980
from pytensor.printing import op_debug_information
8081
from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
8182
from pytensor.tensor.basic import as_tensor_variable
@@ -884,16 +885,24 @@ def tensorConstructor(shape, dtype):
884885
self.nit_sot_arg_offset = (
885886
self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs
886887
)
887-
# XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
888+
# Note: This doesn't include `info.n_nit_sot`s, so it's really a count
888889
# of the number of outputs generated by taps with inputs
889890
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
890891
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
891892

892-
# TODO: These can be moved to thunk/function compilation
893-
(
894-
_,
895-
self.mitmots_preallocated,
896-
) = self._mitmot_preallocations()
893+
# Python and Cython perform methods provide the array location where a mitmot output should be
894+
# stored to the VM as a symbolic update. This helper variable is used in the perform method for validation
895+
mitmots_preallocated = [False] * info.n_mit_mot_outs
896+
if config.scan__allow_output_prealloc:
897+
for mitmot_idx in range(info.n_mit_mot):
898+
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
899+
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
900+
# Figure out the index of the corresponding output
901+
output_idx = sum(
902+
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
903+
) + info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
904+
mitmots_preallocated[output_idx] = True
905+
self.mitmots_preallocated = tuple(mitmots_preallocated)
897906

898907
self.n_outer_inputs = info.n_outer_inputs
899908
self.n_outer_outputs = info.n_outer_outputs
@@ -908,39 +917,6 @@ def tensorConstructor(shape, dtype):
908917
)
909918
self._hash_inner_graph = hash(self._cmodule_key)
910919

911-
def _mitmot_preallocations(self):
912-
if config.scan__allow_output_prealloc:
913-
preallocated_mitmot_outs = []
914-
915-
info = self.info
916-
input_idx = info.n_seqs
917-
for mitmot_idx in range(info.n_mit_mot):
918-
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
919-
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
920-
# Figure out the index of the corresponding output
921-
output_idx = sum(
922-
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
923-
)
924-
output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
925-
preallocated_mitmot_outs.append(output_idx)
926-
927-
input_idx += 1
928-
929-
preallocated_mitmot_outs.sort()
930-
931-
else:
932-
# Output preallocation is not activated. Mark every mitmot output
933-
# tap as not being preallocated
934-
preallocated_mitmot_outs = []
935-
936-
# Store the list of mitmot output taps that have been altered so they
937-
# can be preallocated
938-
mitmots_preallocated = [
939-
i in preallocated_mitmot_outs for i in range(info.n_mit_mot_outs)
940-
]
941-
942-
return preallocated_mitmot_outs, mitmots_preallocated
943-
944920
def __setstate__(self, d):
945921
self.__dict__.update(d)
946922
# Ensure that the graph associated with the inner function is valid.
@@ -1483,11 +1459,26 @@ def fn(self):
14831459

14841460
# Clone mode_instance, altering "allow_gc" for the linker,
14851461
# and adding a message if we profile
1486-
mode_instance = get_mode(self.mode).clone(
1487-
link_kwargs=dict(allow_gc=self.allow_gc),
1488-
message=f"{self.name or 'Scan'} sub profile",
1489-
)
1490-
1462+
mode = self.mode
1463+
if mode in (None, "FAST_RUN"):
1464+
mode_instance = Mode("cvm", "fast_run")
1465+
elif mode == "FAST_COMPILE":
1466+
mode_instance = Mode(
1467+
VMLinker(use_cloop=False, c_thunks=False), "fast_compile"
1468+
)
1469+
else:
1470+
mode_instance = get_mode(mode).clone(
1471+
link_kwargs=dict(allow_gc=self.allow_gc),
1472+
message=f"{self.name or 'Scan'} sub profile",
1473+
)
1474+
# Scan python and cython perform relies on the VM being able to set updates for preallocated MIT-MOT,
1475+
# which only the VMs produced by VMLinker do
1476+
if any(self.mitmots_preallocated) and not isinstance(
1477+
mode_instance.linker, VMLinker
1478+
):
1479+
raise NotImplementedError(
1480+
f"Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker, got {mode_instance.linker}"
1481+
)
14911482
self._fn = pfunc(
14921483
wrapped_inputs,
14931484
wrapped_outputs,
@@ -2007,6 +1998,9 @@ def perform(self, node, inputs, output_storage):
20071998
new_var = inner_input_storage[inner_inp_idx].storage[0]
20081999
if old_var is new_var:
20092000
old_data = old_mitmot_input_data[mitmot_inp_idx]
2001+
# This check is only valid if the VM performs updates
2002+
# Otherwise the output value may remain the same as the input,
2003+
# but doesn't mean that it has been setup correctly
20102004
same_data = new_var.data == old_data
20112005
else:
20122006
same_data = False

tests/scan/test_basic.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434
from pytensor.graph.rewriting.basic import MergeOptimizer
3535
from pytensor.graph.traversal import ancestors
3636
from pytensor.graph.utils import MissingInputError
37+
from pytensor.link.vm import VMLinker
3738
from pytensor.raise_op import assert_op
3839
from pytensor.scan.basic import scan
39-
from pytensor.scan.op import Scan
40+
from pytensor.scan.op import Scan, ScanInfo
4041
from pytensor.scan.utils import until
42+
from pytensor.tensor import as_tensor
4143
from pytensor.tensor.math import all as pt_all
4244
from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh
4345
from pytensor.tensor.math import sum as pt_sum
@@ -4308,3 +4310,91 @@ def test_return_updates_api_change():
43084310

43094311
with pytest.raises(ValueError, match=err_msg):
43104312
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)
4313+
4314+
4315+
@pytest.mark.parametrize(
4316+
"scan_mode",
4317+
[
4318+
None,
4319+
"FAST_RUN",
4320+
"FAST_COMPILE",
4321+
Mode("cvm", optimizer=None),
4322+
Mode("vm", optimizer=None),
4323+
Mode("c", optimizer=None),
4324+
Mode("py", optimizer=None),
4325+
],
4326+
)
4327+
def test_scan_mode_compatibility(scan_mode):
4328+
# Regression test for case where using Scan with a non-updating VM failed
4329+
4330+
# Build a scan with one sequence and two MIT-MOTs
4331+
info = ScanInfo(
4332+
n_seqs=1,
4333+
mit_mot_in_slices=((0, 1), (0, 1)),
4334+
mit_mot_out_slices=((1,), (1,)),
4335+
mit_sot_in_slices=(),
4336+
sit_sot_in_slices=(),
4337+
n_nit_sot=0,
4338+
n_untraced_sit_sot_outs=0,
4339+
n_non_seqs=0,
4340+
as_while=False,
4341+
)
4342+
bool_seq = pt.scalar(dtype="bool")
4343+
mitmot_A0, mitmot_A1, mitmot_B0, mitmot_B1 = [
4344+
pt.matrix(shape=(2, 2)) for i in range(4)
4345+
]
4346+
inputs = [
4347+
bool_seq,
4348+
mitmot_A0,
4349+
mitmot_A1,
4350+
mitmot_B0,
4351+
mitmot_B1,
4352+
]
4353+
outputs = [
4354+
pt.add(bool_seq + mitmot_A0, mitmot_A1),
4355+
pt.add(bool_seq * mitmot_B0, mitmot_B1),
4356+
]
4357+
4358+
scan_op = Scan(
4359+
inputs,
4360+
outputs,
4361+
info=info,
4362+
mode=scan_mode,
4363+
)
4364+
4365+
n_steps = 5
4366+
numerical_inputs = [
4367+
np.array(n_steps, dtype="int64"),
4368+
np.array([1, 1, 0, 1, 0], dtype="bool"),
4369+
np.zeros(n_steps + 1)[:, None, None] * np.eye(2),
4370+
np.arange(n_steps + 1)[:, None, None] * np.eye(2),
4371+
]
4372+
tensor_inputs = [as_tensor(inp, dtype=inp.dtype).type() for inp in numerical_inputs]
4373+
tensor_outputs = [o.sum() for o in scan_op(*tensor_inputs)]
4374+
4375+
no_opt_mode = Mode(linker="py", optimizer=None)
4376+
# NotImplementedError should only be triggered when we try to compile the function
4377+
if (
4378+
# Abstract modes should never fail
4379+
scan_mode not in (None, "FAST_RUN", "FAST_COMPILE")
4380+
# Only if the user tries something specific and incompatible
4381+
and not isinstance(get_mode(scan_mode).linker, VMLinker)
4382+
):
4383+
with pytest.raises(
4384+
NotImplementedError,
4385+
match="Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker",
4386+
):
4387+
function(tensor_inputs, tensor_outputs, mode=no_opt_mode)
4388+
return
4389+
4390+
fn = function(tensor_inputs, tensor_outputs, mode=no_opt_mode)
4391+
4392+
# Check we have the expected Scan in the compiled function
4393+
[fn_scan_op] = [
4394+
node.op for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
4395+
]
4396+
assert fn_scan_op.info == info
4397+
assert fn_scan_op.mitmots_preallocated == (True, True)
4398+
4399+
# Expected value computed by running correct Scan once
4400+
np.testing.assert_allclose(fn(*numerical_inputs), [44, 38])

0 commit comments

Comments
 (0)