7676from pytensor .graph .type import HasShape
7777from pytensor .graph .utils import InconsistencyError , MissingInputError
7878from pytensor .link .c .basic import CLinker
79+ from pytensor .link .vm import VMLinker
7980from pytensor .printing import op_debug_information
8081from pytensor .scan .utils import ScanProfileStats , Validator , forced_replace , safe_new
8182from pytensor .tensor .basic import as_tensor_variable
@@ -884,16 +885,24 @@ def tensorConstructor(shape, dtype):
884885 self .nit_sot_arg_offset = (
885886 self .untraced_sit_sot_arg_offset + info .n_untraced_sit_sot_outs
886887 )
887- # XXX : This doesn't include `info.n_nit_sot`s, so it's really a count
888+ # Note : This doesn't include `info.n_nit_sot`s, so it's really a count
888889 # of the number of outputs generated by taps with inputs
889890 self .n_outs = info .n_mit_mot + info .n_mit_sot + info .n_sit_sot
890891 self .n_tap_outs = info .n_mit_mot + info .n_mit_sot
891892
892- # TODO: These can be moved to thunk/function compilation
893- (
894- _ ,
895- self .mitmots_preallocated ,
896- ) = self ._mitmot_preallocations ()
893+ # Python and Cython perform methods provide the array location where a mitmot output should be
894+ # stored to the VM as a symbolic update. This helper variable is used in the perform method for validation
895+ mitmots_preallocated = [False ] * info .n_mit_mot_outs
896+ if config .scan__allow_output_prealloc :
897+ for mitmot_idx in range (info .n_mit_mot ):
898+ for inp_tap in info .mit_mot_in_slices [mitmot_idx ]:
899+ if inp_tap in info .mit_mot_out_slices [mitmot_idx ]:
900+ # Figure out the index of the corresponding output
901+ output_idx = sum (
902+ len (m ) for m in info .mit_mot_out_slices [:mitmot_idx ]
903+ ) + info .mit_mot_out_slices [mitmot_idx ].index (inp_tap )
904+ mitmots_preallocated [output_idx ] = True
905+ self .mitmots_preallocated = tuple (mitmots_preallocated )
897906
898907 self .n_outer_inputs = info .n_outer_inputs
899908 self .n_outer_outputs = info .n_outer_outputs
@@ -908,39 +917,6 @@ def tensorConstructor(shape, dtype):
908917 )
909918 self ._hash_inner_graph = hash (self ._cmodule_key )
910919
911- def _mitmot_preallocations (self ):
912- if config .scan__allow_output_prealloc :
913- preallocated_mitmot_outs = []
914-
915- info = self .info
916- input_idx = info .n_seqs
917- for mitmot_idx in range (info .n_mit_mot ):
918- for inp_tap in info .mit_mot_in_slices [mitmot_idx ]:
919- if inp_tap in info .mit_mot_out_slices [mitmot_idx ]:
920- # Figure out the index of the corresponding output
921- output_idx = sum (
922- len (m ) for m in info .mit_mot_out_slices [:mitmot_idx ]
923- )
924- output_idx += info .mit_mot_out_slices [mitmot_idx ].index (inp_tap )
925- preallocated_mitmot_outs .append (output_idx )
926-
927- input_idx += 1
928-
929- preallocated_mitmot_outs .sort ()
930-
931- else :
932- # Output preallocation is not activated. Mark every mitmot output
933- # tap as not being preallocated
934- preallocated_mitmot_outs = []
935-
936- # Store the list of mitmot output taps that have been altered so they
937- # can be preallocated
938- mitmots_preallocated = [
939- i in preallocated_mitmot_outs for i in range (info .n_mit_mot_outs )
940- ]
941-
942- return preallocated_mitmot_outs , mitmots_preallocated
943-
944920 def __setstate__ (self , d ):
945921 self .__dict__ .update (d )
946922 # Ensure that the graph associated with the inner function is valid.
@@ -1483,11 +1459,26 @@ def fn(self):
14831459
14841460 # Clone mode_instance, altering "allow_gc" for the linker,
14851461 # and adding a message if we profile
1486- mode_instance = get_mode (self .mode ).clone (
1487- link_kwargs = dict (allow_gc = self .allow_gc ),
1488- message = f"{ self .name or 'Scan' } sub profile" ,
1489- )
1490-
1462+ mode = self .mode
1463+ if mode in (None , "FAST_RUN" ):
1464+ mode_instance = Mode ("cvm" , "fast_run" )
1465+ elif mode == "FAST_COMPILE" :
1466+ mode_instance = Mode (
1467+ VMLinker (use_cloop = False , c_thunks = False ), "fast_compile"
1468+ )
1469+ else :
1470+ mode_instance = get_mode (mode ).clone (
1471+ link_kwargs = dict (allow_gc = self .allow_gc ),
1472+ message = f"{ self .name or 'Scan' } sub profile" ,
1473+ )
1474+ # Scan python and cython perform relies on the VM being able to set updates for preallocated MIT-MOT,
1475+ # which only the VMs produced by VMLinker do
1476+ if any (self .mitmots_preallocated ) and not isinstance (
1477+ mode_instance .linker , VMLinker
1478+ ):
1479+ raise NotImplementedError (
1480+ f"Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker, got { mode_instance .linker } "
1481+ )
14911482 self ._fn = pfunc (
14921483 wrapped_inputs ,
14931484 wrapped_outputs ,
@@ -2007,6 +1998,9 @@ def perform(self, node, inputs, output_storage):
20071998 new_var = inner_input_storage [inner_inp_idx ].storage [0 ]
20081999 if old_var is new_var :
20092000 old_data = old_mitmot_input_data [mitmot_inp_idx ]
2001+ # This check is only valid if the VM performs updates
2002+ # Otherwise the output value may remain the same as the input,
2003+ # but doesn't mean that it has been setup correctly
20102004 same_data = new_var .data == old_data
20112005 else :
20122006 same_data = False
0 commit comments