From c79aa43d91993f1d199f2a1f03fac5864465e590 Mon Sep 17 00:00:00 2001 From: emekaokoli19 Date: Fri, 7 Nov 2025 22:49:28 +0100 Subject: [PATCH 1/2] added filter to scan --- pytensor/scan/views.py | 43 ++++++++++++++++++++++++++++++++++++++++ tests/scan/test_views.py | 20 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/pytensor/scan/views.py b/pytensor/scan/views.py index b86476b330..e3e656f7cc 100644 --- a/pytensor/scan/views.py +++ b/pytensor/scan/views.py @@ -170,3 +170,46 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None) mode=mode, name=name, ) + + +def filter( + fn, + sequences, + non_sequences=None, + go_backwards=False, + mode=None, + name=None, +): + """Construct a `Scan` `Op` that functions like `filter`. + + Parameters + ---------- + fn : callable + Predicate function returning a boolean tensor. + sequences : list + Sequences to filter. + non_sequences : list + Non-iterated arguments passed to `fn`. + go_backwards : bool + Whether to iterate in reverse. + mode : str or None + See ``scan``. + name : str or None + See ``scan``. + """ + mask, _ = scan( + fn=fn, + sequences=sequences, + outputs_info=None, + non_sequences=non_sequences, + go_backwards=go_backwards, + mode=mode, + name=f"{name or ''}_mask", + ) + + if isinstance(sequences, (list, tuple)): + filtered_sequences = [seq[mask] for seq in sequences] + else: + filtered_sequences = sequences[mask] + + return filtered_sequences diff --git a/tests/scan/test_views.py b/tests/scan/test_views.py index 38c9b9cfcd..6717016704 100644 --- a/tests/scan/test_views.py +++ b/tests/scan/test_views.py @@ -133,3 +133,23 @@ def test_foldr_memory_consumption(): gx = grad(o, x) f2 = function([], gx) utt.assert_allclose(f2(), np.ones((10,))) + + +def test_filter(): + import pytensor.tensor as pt + + v = pt.vector("v") + + def fn(x): + return pt.eq(x % 2, 0) + + from pytensor.scan.views import filter as pt_filter + + filtered = pt_filter(fn, v) + f = function([v], filtered, allow_input_downcast=True) + + rng = np.random.default_rng(utt.fetch_seed()) + vals = rng.integers(0, 10, size=(10,)) + expected = vals[vals % 2 == 0] + result = f(vals) + utt.assert_allclose(expected, result) From 5e4ac7620c303cb812fc3d43098e41a9693e1765 Mon Sep 17 00:00:00 2001 From: emekaokoli19 Date: Mon, 10 Nov 2025 11:02:35 +0100 Subject: [PATCH 2/2] Address review comments for filter in scan --- pytensor/scan/views.py | 25 +++++++++++++++++++++---- tests/scan/test_views.py | 28 ++++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/pytensor/scan/views.py b/pytensor/scan/views.py index e3e656f7cc..7d9365bb47 100644 --- a/pytensor/scan/views.py +++ b/pytensor/scan/views.py @@ -196,6 +196,12 @@ def filter( See ``scan``. name : str or None See ``scan``. + + Notes + ----- + If the predicate function `fn` returns multiple boolean masks (one per sequence), + each mask will be applied to its corresponding sequence. If it returns a single mask, + that mask will be broadcast to all sequences. """ mask, _ = scan( fn=fn, @@ -204,12 +210,23 @@ def filter( non_sequences=non_sequences, go_backwards=go_backwards, mode=mode, - name=f"{name or ''}_mask", + name=name, ) - if isinstance(sequences, (list, tuple)): - filtered_sequences = [seq[mask] for seq in sequences] + if isinstance(mask, (list, tuple)): + # One mask per sequence + if not isinstance(sequences, (list, tuple)): + raise TypeError( + "If multiple masks are returned, sequences must be a list or tuple." + ) + if len(mask) != len(sequences): + raise ValueError("Number of masks must match number of sequences.") + filtered_sequences = [seq[m] for seq, m in zip(sequences, mask)] else: - filtered_sequences = sequences[mask] + # Single mask applied to all sequences + if isinstance(sequences, (list, tuple)): + filtered_sequences = [seq[mask] for seq in sequences] + else: + filtered_sequences = sequences[mask] return filtered_sequences diff --git a/tests/scan/test_views.py b/tests/scan/test_views.py index 6717016704..3002f9cd3a 100644 --- a/tests/scan/test_views.py +++ b/tests/scan/test_views.py @@ -3,6 +3,7 @@ import pytensor.tensor as pt from pytensor import config, function, grad, shared from pytensor.compile.mode import FAST_RUN +from pytensor.scan.views import filter as pt_filter from pytensor.scan.views import foldl, foldr from pytensor.scan.views import map as pt_map from pytensor.scan.views import reduce as pt_reduce @@ -136,15 +137,11 @@ def test_foldr_memory_consumption(): def test_filter(): - import pytensor.tensor as pt - v = pt.vector("v") def fn(x): return pt.eq(x % 2, 0) - from pytensor.scan.views import filter as pt_filter - filtered = pt_filter(fn, v) f = function([v], filtered, allow_input_downcast=True) @@ -153,3 +150,26 @@ def fn(x): expected = vals[vals % 2 == 0] result = f(vals) utt.assert_allclose(expected, result) + + +def test_filter_multiple_masks(): + v1 = pt.vector("v1") + v2 = pt.vector("v2") + + def fn(x1, x2): + # Mask v1 for even numbers, mask v2 for numbers > 5 + return pt.eq(x1 % 2, 0), pt.gt(x2, 5) + + filtered_v1, filtered_v2 = pt_filter(fn, [v1, v2]) + f = function([v1, v2], [filtered_v1, filtered_v2], allow_input_downcast=True) + + vals1 = np.arange(10) + vals2 = np.arange(10) + + expected_v1 = vals1[vals1 % 2 == 0] + expected_v2 = vals2[vals2 > 5] + + result_v1, result_v2 = f(vals1, vals2) + + utt.assert_allclose(expected_v1, result_v1) + utt.assert_allclose(expected_v2, result_v2)