Skip to content

Commit c2fe0ae

Browse files
committed
Allow explicit rng for CustomDist that only require one
1 parent 41f33ad commit c2fe0ae

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

pymc/distributions/custom.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def rv_op(
259259
size=None,
260260
signature: str,
261261
class_name: str,
262+
rng=None,
262263
):
263264
size = normalize_size_param(size)
264265
# If it's NoneConst, just use that as the dummy
@@ -270,7 +271,8 @@ def rv_op(
270271
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
271272
dummy_params = [dummy_size_param, *dummy_dist_params]
272273
# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
273-
# We retrieve them here. This will also raise if the user forgot to specify some update in a Scan Op
274+
# We retrieve them here. This will also raise if the user forgot to specify some update in an InnerGraphOp (e.g., Scan)
275+
# If the user passed an explicit rng we will respect that later when we instantiate the final rv_op
274276
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
275277

276278
rv_type = type(
@@ -357,6 +359,14 @@ def change_custom_dist_size(op, rv, new_size, expand):
357359
outputs=outputs,
358360
extended_signature=extended_signature,
359361
)
362+
if rng is not None:
363+
# User passed an RNG, use that if the graph only required one, raise otherwise
364+
if len(rngs) != 1:
365+
raise ValueError(
366+
f"CustomDist received an explicit rng but it actually requires {len(rngs)} rngs."
367+
" Please modify your dist function to only use one rng, or don't pass an explicitly rng."
368+
)
369+
rngs = (rng,)
360370
return rv_op(size, *dist_params, *rngs)
361371

362372
@staticmethod

tests/distributions/test_custom.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,3 +708,47 @@ def normal_shifted(mu, size):
708708
observed_logp.eval({latent_vv: latent_vv_test, observed_vv: observed_vv_test}),
709709
expected_logp,
710710
)
711+
712+
def test_explicit_rng(self):
713+
def custom_dist(mu, size):
714+
return Normal.dist(mu, size=size)
715+
716+
x = CustomDist.dist(0, dist=custom_dist)
717+
assert len(x.owner.op.rng_params(x.owner)) == 1 # Rng created by default
718+
719+
explicit_rng = pt.random.type.random_generator_type("rng")
720+
x_explicit = CustomDist.dist(0, dist=custom_dist, rng=explicit_rng)
721+
[used_rng] = x_explicit.owner.op.rng_params(x_explicit.owner)
722+
assert used_rng is explicit_rng
723+
724+
# API for passing multiple explicit RNGs not supported
725+
def custom_dist_multi_rng(mu, size):
726+
return Normal.dist(mu, size=size) + Normal.dist(0, size=size)
727+
728+
x = CustomDist.dist(0, dist=custom_dist_multi_rng)
729+
assert len(x.owner.op.rng_params(x.owner)) == 2
730+
731+
with pytest.raises(
732+
ValueError,
733+
match="CustomDist received an explicit rng but it actually requires 2 rngs",
734+
):
735+
CustomDist.dist(
736+
0,
737+
dist=custom_dist_multi_rng,
738+
rng=explicit_rng,
739+
)
740+
741+
# But it can be done if the custom_dist uses only one RNG internally
742+
def custom_dist_multi_rng_fixed(mu, size):
743+
next_rng, x = Normal.dist(mu, size=size).owner.outputs
744+
return x + Normal.dist(0, size=size, rng=next_rng)
745+
746+
x = CustomDist.dist(0, dist=custom_dist_multi_rng_fixed)
747+
assert len(x.owner.op.rng_params(x.owner)) == 1
748+
x_explicit = CustomDist.dist(
749+
0,
750+
dist=custom_dist_multi_rng_fixed,
751+
rng=explicit_rng,
752+
)
753+
[used_rng] = x_explicit.owner.op.rng_params(x_explicit.owner)
754+
assert used_rng is explicit_rng

0 commit comments

Comments
 (0)