@@ -708,3 +708,47 @@ def normal_shifted(mu, size):
708708 observed_logp .eval ({latent_vv : latent_vv_test , observed_vv : observed_vv_test }),
709709 expected_logp ,
710710 )
711+
712+ def test_explicit_rng (self ):
713+ def custom_dist (mu , size ):
714+ return Normal .dist (mu , size = size )
715+
716+ x = CustomDist .dist (0 , dist = custom_dist )
717+ assert len (x .owner .op .rng_params (x .owner )) == 1 # Rng created by default
718+
719+ explicit_rng = pt .random .type .random_generator_type ("rng" )
720+ x_explicit = CustomDist .dist (0 , dist = custom_dist , rng = explicit_rng )
721+ [used_rng ] = x_explicit .owner .op .rng_params (x_explicit .owner )
722+ assert used_rng is explicit_rng
723+
724+ # API for passing multiple explicit RNGs not supported
725+ def custom_dist_multi_rng (mu , size ):
726+ return Normal .dist (mu , size = size ) + Normal .dist (0 , size = size )
727+
728+ x = CustomDist .dist (0 , dist = custom_dist_multi_rng )
729+ assert len (x .owner .op .rng_params (x .owner )) == 2
730+
731+ with pytest .raises (
732+ ValueError ,
733+ match = "CustomDist received an explicit rng but it actually requires 2 rngs" ,
734+ ):
735+ CustomDist .dist (
736+ 0 ,
737+ dist = custom_dist_multi_rng ,
738+ rng = explicit_rng ,
739+ )
740+
741+ # But it can be done if the custom_dist uses only one RNG internally
742+ def custom_dist_multi_rng_fixed (mu , size ):
743+ next_rng , x = Normal .dist (mu , size = size ).owner .outputs
744+ return x + Normal .dist (0 , size = size , rng = next_rng )
745+
746+ x = CustomDist .dist (0 , dist = custom_dist_multi_rng_fixed )
747+ assert len (x .owner .op .rng_params (x .owner )) == 1
748+ x_explicit = CustomDist .dist (
749+ 0 ,
750+ dist = custom_dist_multi_rng_fixed ,
751+ rng = explicit_rng ,
752+ )
753+ [used_rng ] = x_explicit .owner .op .rng_params (x_explicit .owner )
754+ assert used_rng is explicit_rng
0 commit comments