Skip to content

Commit baa0739

Browse files
emekaokoli19ricardoV94
authored andcommitted
Faster RNG deepcopy
1 parent 370b172 commit baa0739

File tree

4 files changed

+44
-4
lines changed

4 files changed

+44
-4
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Callable
2-
from copy import deepcopy
32
from functools import singledispatch
43
from hashlib import sha256
54
from textwrap import dedent
@@ -32,6 +31,7 @@
3231
)
3332
from pytensor.tensor import get_vector_length
3433
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
34+
from pytensor.tensor.random.utils import custom_rng_deepcopy
3535
from pytensor.tensor.type_other import NoneTypeT
3636
from pytensor.tensor.utils import _parse_gufunc_signature
3737

@@ -42,7 +42,7 @@ def numba_deepcopy_random_generator(x):
4242

4343
def random_generator_deepcopy(x):
4444
with numba.objmode(new_rng=types.npy_rng):
45-
new_rng = deepcopy(x)
45+
new_rng = custom_rng_deepcopy(x)
4646
return new_rng
4747

4848
return random_generator_deepcopy

pytensor/tensor/random/op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import abc
22
import warnings
33
from collections.abc import Sequence
4-
from copy import deepcopy
54
from typing import Any, cast
65

76
import numpy as np
@@ -23,6 +22,7 @@
2322
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
2423
from pytensor.tensor.random.utils import (
2524
compute_batch_shape,
25+
custom_rng_deepcopy,
2626
explicit_expand_dims,
2727
normalize_size_param,
2828
)
@@ -423,7 +423,7 @@ def perform(self, node, inputs, outputs):
423423

424424
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
425425
if not self.inplace:
426-
rng = deepcopy(rng)
426+
rng = custom_rng_deepcopy(rng)
427427

428428
outputs[0][0] = rng
429429
outputs[1][0] = np.asarray(

pytensor/tensor/random/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections.abc import Callable, Sequence
2+
from copy import deepcopy
23
from functools import wraps
34
from itertools import zip_longest
45
from types import ModuleType
56
from typing import TYPE_CHECKING
67

78
import numpy as np
9+
from numpy.random import Generator
810

911
from pytensor.compile.sharedvalue import shared
1012
from pytensor.graph.basic import Variable
@@ -204,6 +206,16 @@ def normalize_size_param(
204206
return shape
205207

206208

209+
def custom_rng_deepcopy(rng):
210+
# This helper exists because copying numpy.random.Generator via deepcopy is slow.
211+
# NumPy may implement a faster clone/copy API in the future:
212+
# https://github.com/numpy/numpy/issues/24086
213+
old_bitgen = rng.bit_generator
214+
new_bitgen = type(old_bitgen)(deepcopy(old_bitgen._seed_seq))
215+
new_bitgen.state = old_bitgen.state
216+
return Generator(new_bitgen)
217+
218+
207219
class RandomStream:
208220
"""Module component with similar interface to `numpy.random.Generator`.
209221

tests/tensor/random/test_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from copy import deepcopy
2+
13
import numpy as np
24
import pytest
35

@@ -7,6 +9,7 @@
79
from pytensor.tensor.random.utils import (
810
RandomStream,
911
broadcast_params,
12+
custom_rng_deepcopy,
1013
normalize_size_param,
1114
supp_shape_from_ref_param_shape,
1215
)
@@ -348,3 +351,28 @@ def test_normalize_size_param():
348351

349352
sym_tensor_size = tensor(shape=(3,), dtype="int64")
350353
assert normalize_size_param(sym_tensor_size) is sym_tensor_size
354+
355+
356+
def test_custom_rng_deepcopy_matches_deepcopy():
357+
rng = np.random.default_rng(123)
358+
359+
dp = deepcopy(rng).bit_generator
360+
fc = custom_rng_deepcopy(rng).bit_generator
361+
362+
# Same state
363+
assert dp.state == fc.state
364+
# Same seed sequence
365+
assert dp.seed_seq.state == fc.seed_seq.state
366+
367+
368+
def test_custom_rng_deepcopy_output_identical():
369+
rng = np.random.default_rng(123)
370+
371+
rng1 = deepcopy(rng)
372+
rng2 = custom_rng_deepcopy(rng)
373+
374+
# Generate numbers from each
375+
x1 = rng1.normal(size=10)
376+
x2 = rng2.normal(size=10)
377+
378+
assert np.allclose(x1, x2)

0 commit comments

Comments
 (0)