Skip to content

Commit 8885767

Browse files
committed
Add support for TypedList in numba backend
Note: Numba object mode fallback is not safe with lists
1 parent c54bef1 commit 8885767

File tree

4 files changed

+273
-1
lines changed

4 files changed

+273
-1
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytensor.link.numba.dispatch.sparse
1818
import pytensor.link.numba.dispatch.subtensor
1919
import pytensor.link.numba.dispatch.tensor_basic
20+
import pytensor.link.numba.dispatch.typed_list
2021

2122

2223
# isort: on

pytensor/link/numba/dispatch/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytensor.tensor.random.type import RandomGeneratorType
2424
from pytensor.tensor.type import TensorType
2525
from pytensor.tensor.utils import hash_from_ndarray
26+
from pytensor.typed_list import TypedListType
2627

2728

2829
def _filter_numba_warnings():
@@ -132,6 +133,8 @@ def get_numba_type(
132133
return CSCMatrixType(numba_dtype)
133134
elif isinstance(pytensor_type, RandomGeneratorType):
134135
return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType")
136+
elif isinstance(pytensor_type, TypedListType):
137+
return numba.types.List(get_numba_type(pytensor_type.ttype))
135138
else:
136139
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
137140

@@ -260,7 +263,10 @@ def numba_typify(data, dtype=None, **kwargs):
260263

261264

262265
def generate_fallback_impl(op, node, storage_map=None, **kwargs):
263-
"""Create a Numba compatible function from a Pytensor `Op`."""
266+
"""Create a Numba compatible function from a Pytensor `Op`.
267+
268+
Note limitations: https://numba.pydata.org/numba-doc/dev/user/withobjmode.html#the-objmode-context-manager
269+
"""
264270

265271
warnings.warn(
266272
f"Numba will use object mode to run {op}'s perform method. "
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import numba
2+
import numpy as np
3+
4+
import pytensor.link.numba.dispatch.basic as numba_basic
5+
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
6+
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
7+
from pytensor.tensor.type_other import SliceType
8+
from pytensor.typed_list import (
9+
Append,
10+
Count,
11+
Extend,
12+
GetItem,
13+
Index,
14+
Insert,
15+
Length,
16+
MakeList,
17+
Remove,
18+
Reverse,
19+
)
20+
21+
22+
def numba_all_equal(x, y):
23+
if isinstance(x, np.ndarray) or isinstance(y, np.ndarray):
24+
if not (isinstance(x, np.ndarray) and isinstance(y, np.ndarray)):
25+
return False
26+
return (x == y).all()
27+
if isinstance(x, list) or isinstance(y, list):
28+
if not (isinstance(x, list) and isinstance(y, list)):
29+
return False
30+
if len(x) != len(y):
31+
return False
32+
return all(numba_all_equal(xi, yi) for xi, yi in zip(x, y))
33+
return x == y
34+
35+
36+
@numba.extending.overload(numba_all_equal)
37+
def list_all_equal(x, y):
38+
all_equal = None
39+
40+
if isinstance(x, numba.types.List) and isinstance(y, numba.types.List):
41+
42+
def all_equal(x, y):
43+
if len(x) != len(y):
44+
return False
45+
for xi, yi in zip(x, y):
46+
if not numba_all_equal(xi, yi):
47+
return False
48+
return True
49+
50+
if isinstance(x, numba.types.Array) and isinstance(y, numba.types.Array):
51+
52+
def all_equal(x, y):
53+
return (x == y).all()
54+
55+
if isinstance(x, numba.types.Number) and isinstance(y.numba.types.Number):
56+
57+
def all_equal(x, y):
58+
return x == y
59+
60+
return all_equal
61+
62+
63+
@numba.extending.overload(numba_deepcopy)
64+
def numba_deepcopy_list(x):
65+
if isinstance(x, numba.types.List):
66+
67+
def deepcopy_list(x):
68+
return [numba_deepcopy(xi) for xi in x]
69+
70+
return deepcopy_list
71+
72+
73+
@register_funcify_default_op_cache_key(MakeList)
74+
def numba_funcify_make_list(op, node, **kwargs):
75+
@numba_basic.numba_njit
76+
def make_list(*args):
77+
return [numba_deepcopy(arg) for arg in args]
78+
79+
return make_list
80+
81+
82+
@register_funcify_default_op_cache_key(Length)
83+
def numba_funcify_list_length(op, node, **kwargs):
84+
@numba_basic.numba_njit
85+
def list_length(x):
86+
return np.array(len(x), dtype=np.int64)
87+
88+
return list_length
89+
90+
91+
@register_funcify_default_op_cache_key(GetItem)
92+
def numba_funcify_list_get_item(op, node, **kwargs):
93+
if isinstance(node.inputs[1].type, SliceType):
94+
95+
@numba_basic.numba_njit
96+
def list_get_item_slice(x, index):
97+
return x[index]
98+
99+
return list_get_item_slice
100+
101+
else:
102+
103+
@numba_basic.numba_njit
104+
def list_get_item_index(x, index):
105+
return x[index.item()]
106+
107+
return list_get_item_index
108+
109+
110+
@register_funcify_default_op_cache_key(Reverse)
111+
def numba_funcify_list_reverse(op, node, **kwargs):
112+
inplace = op.inplace
113+
114+
@numba_basic.numba_njit
115+
def list_reverse(x):
116+
if inplace:
117+
z = x
118+
else:
119+
z = numba_deepcopy(x)
120+
z.reverse()
121+
return z
122+
123+
return list_reverse
124+
125+
126+
@register_funcify_default_op_cache_key(Append)
127+
def numba_funcify_list_append(op, node, **kwargs):
128+
inplace = op.inplace
129+
130+
@numba_basic.numba_njit
131+
def list_append(x, to_append):
132+
if inplace:
133+
z = x
134+
else:
135+
z = numba_deepcopy(x)
136+
z.append(numba_deepcopy(to_append))
137+
return z
138+
139+
return list_append
140+
141+
142+
@register_funcify_default_op_cache_key(Extend)
143+
def numba_funcify_list_extend(op, node, **kwargs):
144+
inplace = op.inplace
145+
146+
@numba_basic.numba_njit
147+
def list_extend(x, to_append):
148+
if inplace:
149+
z = x
150+
else:
151+
z = numba_deepcopy(x)
152+
z.extend(numba_deepcopy(to_append))
153+
return z
154+
155+
return list_extend
156+
157+
158+
@register_funcify_default_op_cache_key(Insert)
159+
def numba_funcify_list_insert(op, node, **kwargs):
160+
inplace = op.inplace
161+
162+
@numba_basic.numba_njit
163+
def list_insert(x, index, to_insert):
164+
if inplace:
165+
z = x
166+
else:
167+
z = numba_deepcopy(x)
168+
z.insert(index.item(), numba_deepcopy(to_insert))
169+
return z
170+
171+
return list_insert
172+
173+
174+
@register_funcify_default_op_cache_key(Index)
175+
def numba_funcify_list_index(op, node, **kwargs):
176+
@numba_basic.numba_njit
177+
def list_index(x, elem):
178+
for idx, xi in enumerate(x):
179+
if numba_all_equal(xi, elem):
180+
break
181+
return np.array(idx, dtype=np.int64)
182+
183+
return list_index
184+
185+
186+
@register_funcify_default_op_cache_key(Count)
187+
def numba_funcify_list_count(op, node, **kwargs):
188+
@numba_basic.numba_njit
189+
def list_count(x, elem):
190+
c = 0
191+
for xi in x:
192+
if numba_all_equal(xi, elem):
193+
c += 1
194+
return np.array(c, dtype=np.int64)
195+
196+
return list_count
197+
198+
199+
@register_funcify_default_op_cache_key(Remove)
200+
def numba_funcify_list_remove(op, node, **kwargs):
201+
inplace = op.inplace
202+
203+
@numba_basic.numba_njit
204+
def list_remove(x, to_remove):
205+
if inplace:
206+
z = x
207+
else:
208+
z = numba_deepcopy(x)
209+
index_to_remove = -1
210+
for i, zi in enumerate(z):
211+
if numba_all_equal(zi, to_remove):
212+
index_to_remove = i
213+
break
214+
if index_to_remove == -1:
215+
raise ValueError("list.remove(x): x not in list")
216+
z.pop(index_to_remove)
217+
return z
218+
219+
return list_remove
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
3+
from pytensor.tensor import matrix
4+
from pytensor.typed_list import make_list
5+
from tests.link.numba.test_basic import compare_numba_and_py
6+
7+
8+
def test_list_basic_ops():
9+
x = matrix("x", shape=(3, None), dtype="int64")
10+
l = make_list([x[0], x[2]])
11+
12+
x_test = np.arange(12).reshape(3, 4)
13+
compare_numba_and_py([x], [l, l.length()], [x_test])
14+
15+
# Test nested list
16+
ll = make_list([l, l, l])
17+
compare_numba_and_py([x], [ll, ll.length()], [x_test])
18+
19+
20+
def test_make_list_index_ops():
21+
x = matrix("x", shape=(3, None), dtype="int64")
22+
l = make_list([x[0], x[2]])
23+
24+
x_test = np.arange(12).reshape(3, 4)
25+
compare_numba_and_py([x], [l[-1], l[:-1], l.reverse()], [x_test])
26+
27+
28+
def test_make_list_extend_ops():
29+
x = matrix("x", shape=(3, None), dtype="int64")
30+
l = make_list([x[0], x[2]])
31+
32+
x_test = np.arange(12).reshape(3, 4)
33+
compare_numba_and_py(
34+
[x], [l.append(x[1]), l.extend(l), l.insert(0, x[1])], [x_test]
35+
)
36+
37+
38+
def test_make_list_find_ops():
39+
# Remove requires to first find it
40+
x = matrix("x", shape=(3, None), dtype="int64")
41+
y = x[0].type("y")
42+
l = make_list([x[0], x[2], x[0], x[2]])
43+
44+
x_test = np.arange(12).reshape(3, 4)
45+
test_y = x_test[2]
46+
compare_numba_and_py([x, y], [l.ind(y), l.count(y), l.remove(y)], [x_test, test_y])

0 commit comments

Comments
 (0)