@@ -710,8 +710,7 @@ def test_invalid_scenarios():
710710 # Giving a list, tuple, or Series when a matrix of data is expected should always raise
711711 with pytest .raises (
712712 ValueError ,
713- match = "Scenario data for variable 'a' has the wrong number of columns. "
714- "Expected 2, got 1" ,
713+ match = "Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1" ,
715714 ):
716715 for data_type in [list , tuple , pd .Series ]:
717716 ss_mod ._validate_scenario_data (data_type (np .zeros (10 )))
@@ -720,15 +719,14 @@ def test_invalid_scenarios():
720719 # Providing irrevelant data raises
721720 with pytest .raises (
722721 ValueError ,
723- match = "Scenario data provided for variable 'jk lol', which is not an exogenous " " variable" ,
722+ match = "Scenario data provided for variable 'jk lol', which is not an exogenous variable" ,
724723 ):
725724 ss_mod ._validate_scenario_data ({"jk lol" : np .zeros (10 )})
726725
727726 # Incorrect 2nd dimension of a non-dataframe
728727 with pytest .raises (
729728 ValueError ,
730- match = "Scenario data for variable 'a' has the wrong number of columns. Expected "
731- "2, got 1" ,
729+ match = "Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1" ,
732730 ):
733731 scenario = np .zeros (10 ).tolist ()
734732 ss_mod ._validate_scenario_data (scenario )
@@ -870,3 +868,13 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
870868 regression_effect_expected = (betas * scenario_xr ).sum (dim = ["state" ])
871869
872870 assert_allclose (regression_effect , regression_effect_expected )
871+
872+
873+ @pytest .mark .parametrize ("batch_size" , [(10 ,), (10 , 3 , 5 )])
874+ def test_insert_batched_rvs (ss_mod , batch_size ):
875+ with pm .Model ():
876+ rho = pm .Normal ("rho" , shape = batch_size )
877+ zeta = pm .Normal ("zeta" , shape = batch_size )
878+ ss_mod ._insert_random_variables ()
879+ matrices = ss_mod .unpack_statespace ()
880+ assert matrices [4 ].type .shape == (* batch_size , 2 , 2 )
0 commit comments