Skip to content

Commit ac51f01

Browse files
committed
Implement extensible deepcopy in numba
1 parent 6956753 commit ac51f01

File tree

3 files changed

+58
-23
lines changed

3 files changed

+58
-23
lines changed

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from copy import deepcopy
12
from hashlib import sha256
23

4+
import numba
35
import numpy as np
46

57
from pytensor.compile.builders import OpFromGraph
@@ -15,7 +17,34 @@
1517
register_funcify_default_op_cache_key,
1618
)
1719
from pytensor.raise_op import CheckAndRaise
18-
from pytensor.tensor.type import TensorType
20+
21+
22+
def numba_deepcopy(x):
23+
return deepcopy(x)
24+
25+
26+
@numba.extending.overload(numba_deepcopy)
27+
def numba_deepcopy_tensor(x):
28+
if isinstance(x, numba.types.Number):
29+
30+
def number_deepcopy(x):
31+
return x
32+
33+
return number_deepcopy
34+
35+
if isinstance(x, numba.types.Array):
36+
37+
def array_deepcopy(x):
38+
return np.copy(x)
39+
40+
return array_deepcopy
41+
42+
if isinstance(x, numba.types.UnicodeType):
43+
44+
def string_deepcopy(x):
45+
return x
46+
47+
return string_deepcopy
1948

2049

2150
@register_funcify_and_cache_key(OpFromGraph)
@@ -64,19 +93,11 @@ def identity(x):
6493

6594
@register_funcify_default_op_cache_key(DeepCopyOp)
6695
def numba_funcify_DeepCopyOp(op, node, **kwargs):
67-
if isinstance(node.inputs[0].type, TensorType):
68-
69-
@numba_basic.numba_njit
70-
def deepcopy(x):
71-
return np.copy(x)
72-
73-
else:
74-
75-
@numba_basic.numba_njit
76-
def deepcopy(x):
77-
return x
96+
@numba_basic.numba_njit
97+
def deepcopy(x):
98+
return numba_deepcopy(x)
7899

79-
return deepcopy
100+
return deepcopy, 1
80101

81102

82103
@register_funcify_default_op_cache_key(IfElse)

pytensor/link/numba/dispatch/random.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable
2-
from copy import copy, deepcopy
2+
from copy import deepcopy
33
from functools import singledispatch
44
from hashlib import sha256
55
from textwrap import dedent
@@ -20,6 +20,7 @@
2020
numba_funcify,
2121
register_funcify_and_cache_key,
2222
)
23+
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
2324
from pytensor.link.numba.dispatch.vectorize_codegen import (
2425
_jit_options,
2526
_vectorized,
@@ -35,16 +36,16 @@
3536
from pytensor.tensor.utils import _parse_gufunc_signature
3637

3738

38-
@overload(copy)
39-
def copy_NumPyRandomGenerator(rng):
40-
def impl(rng):
41-
# TODO: Open issue on Numba?
42-
with numba.objmode(new_rng=types.npy_rng):
43-
new_rng = deepcopy(rng)
39+
@numba.extending.overload(numba_deepcopy)
40+
def numba_deepcopy_random_generator(x):
41+
if isinstance(x, numba.types.NumPyRandomGeneratorType):
4442

45-
return new_rng
43+
def random_generator_deepcopy(x):
44+
with numba.objmode(new_rng=types.npy_rng):
45+
new_rng = deepcopy(x)
46+
return new_rng
4647

47-
return impl
48+
return random_generator_deepcopy
4849

4950

5051
@singledispatch
@@ -449,7 +450,7 @@ def random(core_shape, rng, size, *dist_params):
449450
def ov_random(core_shape, rng, size, *dist_params):
450451
def impl(core_shape, rng, size, *dist_params):
451452
if not inplace:
452-
rng = copy(rng)
453+
rng = numba_deepcopy(rng)
453454

454455
draws = _vectorized(
455456
core_op_fn,

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
register_funcify_and_cache_key,
1919
register_funcify_default_op_cache_key,
2020
)
21+
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
2122
from pytensor.tensor import TensorType
2223
from pytensor.tensor.rewriting.subtensor import is_full_slice
2324
from pytensor.tensor.subtensor import (
@@ -104,6 +105,18 @@ def in_seq_empty_tuple(x, y):
104105
enable_slice_boxing()
105106

106107

108+
@numba.extending.overload(numba_deepcopy)
109+
def numba_deepcopy_slice(x):
110+
if isinstance(x, types.SliceType):
111+
112+
def deepcopy_slice(x):
113+
return slice(
114+
numba_deepcopy(x.start), numba_deepcopy(x.stop), numba_deepcopy(x.step)
115+
)
116+
117+
return deepcopy_slice
118+
119+
107120
@register_funcify_default_op_cache_key(MakeSlice)
108121
def numba_funcify_MakeSlice(op, **kwargs):
109122
@numba_basic.numba_njit

0 commit comments

Comments
 (0)