@@ -254,13 +254,13 @@ def identity(x):
254254 )
255255
256256 new_x = fake_X [:, var ]
257- p_d = np .array (y_pred )
257+ p_d = func ( np .array (y_pred ) )
258258
259259 for s_i in range (shape ):
260260 if centered :
261- p_di = func ( p_d [:, :, s_i ]) - func ( p_d [:, :, s_i ][:, 0 ][:, None ])
261+ p_di = p_d [:, :, s_i ] - p_d [:, :, s_i ][:, 0 ][:, None ]
262262 else :
263- p_di = func ( p_d [:, :, s_i ])
263+ p_di = p_d [:, :, s_i ]
264264 if var in var_discrete :
265265 axes [count ].plot (new_x , p_di .mean (0 ), "o" , color = color_mean )
266266 axes [count ].plot (new_x , p_di .T , "." , color = color , alpha = alpha )
@@ -393,14 +393,17 @@ def identity(x):
393393 for var in range (len (var_idx )):
394394 excluded = indices [:]
395395 excluded .remove (var )
396- p_d = _sample_posterior (
397- all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
396+ p_d = func (
397+ _sample_posterior (
398+ all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
399+ )
398400 )
401+
399402 with warnings .catch_warnings ():
400403 warnings .filterwarnings ("ignore" , message = "hdi currently interprets 2d data" )
401404 new_x = fake_X [:, var ]
402405 for s_i in range (shape ):
403- p_di = func ( p_d [:, :, s_i ])
406+ p_di = p_d [:, :, s_i ]
404407 null_pd .append (p_di .mean ())
405408 if var in var_discrete :
406409 _ , idx_uni = np .unique (new_x , return_index = True )
@@ -1125,8 +1128,11 @@ def plot_scatter_submodels(
11251128 plot_kwargs : dict
11261129 Additional keyword arguments for the plot. Defaults to None.
11271130 Valid keys are:
1128- - color_ref : matplotlib valid color for the 45 degree line
1131+ - marker_scatter : matplotlib valid marker for the scatter plot
11291132 - color_scatter: matplotlib valid color for the scatter plot
1133+ - alpha_scatter: matplotlib valid alpha for the scatter plot
1134+ - color_ref: matplotlib valid color for the 45 degree line
1135+ - ls_ref: matplotlib valid linestyle for the reference line
11301136 axes : axes
11311137 Matplotlib axes.
11321138
@@ -1140,41 +1146,69 @@ def plot_scatter_submodels(
11401146 submodels = np .sort (submodels )
11411147
11421148 indices = vi_results ["indices" ][submodels ]
1143- preds = vi_results ["preds" ][submodels ]
1149+ preds_sub = vi_results ["preds" ][submodels ]
11441150 preds_all = vi_results ["preds_all" ]
11451151
1152+ if labels is None :
1153+ labels = vi_results ["labels" ][submodels ]
1154+
1155+ # handle categorical regression case:
1156+ n_cats = None
1157+ if preds_all .ndim > 2 :
1158+ n_cats = preds_all .shape [- 1 ]
1159+ indices = np .tile (indices , n_cats )
1160+
11461161 if ax is None :
11471162 _ , ax = _get_axes (grid , len (indices ), True , True , figsize )
11481163
11491164 if plot_kwargs is None :
11501165 plot_kwargs = {}
11511166
1152- if labels is None :
1153- labels = vi_results ["labels" ][submodels ]
1154-
11551167 if func is not None :
1156- preds = func (preds )
1168+ preds_sub = func (preds_sub )
11571169 preds_all = func (preds_all )
11581170
1159- min_ = min (np .min (preds ), np .min (preds_all ))
1160- max_ = max (np .max (preds ), np .max (preds_all ))
1161-
1162- for pred , x_label , axi in zip (preds , labels , ax .ravel ()):
1163- axi .plot (
1164- pred ,
1165- preds_all ,
1166- marker = plot_kwargs .get ("marker_scatter" , "." ),
1167- ls = "" ,
1168- color = plot_kwargs .get ("color_scatter" , "C0" ),
1169- alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1170- )
1171- axi .set_xlabel (x_label )
1172- axi .axline (
1173- [min_ , min_ ],
1174- [max_ , max_ ],
1175- color = plot_kwargs .get ("color_ref" , "0.5" ),
1176- ls = plot_kwargs .get ("ls_ref" , "--" ),
1177- )
1171+ min_ = min (np .min (preds_sub ), np .min (preds_all ))
1172+ max_ = max (np .max (preds_sub ), np .max (preds_all ))
1173+
1174+ # handle categorical regression case:
1175+ if n_cats is not None :
1176+ i = 0
1177+ for cat in range (n_cats ):
1178+ for pred_sub , x_label in zip (preds_sub , labels ):
1179+ ax [i ].plot (
1180+ pred_sub [..., cat ],
1181+ preds_all [..., cat ],
1182+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1183+ ls = "" ,
1184+ color = plot_kwargs .get ("color_scatter" , f"C{ cat } " ),
1185+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1186+ )
1187+ ax [i ].set (xlabel = x_label , ylabel = "ref model" , title = f"Category { cat } " )
1188+ ax [i ].axline (
1189+ [min_ , min_ ],
1190+ [max_ , max_ ],
1191+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1192+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1193+ )
1194+ i += 1
1195+ else :
1196+ for pred_sub , x_label , axi in zip (preds_sub , labels , ax .ravel ()):
1197+ axi .plot (
1198+ pred_sub ,
1199+ preds_all ,
1200+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1201+ ls = "" ,
1202+ color = plot_kwargs .get ("color_scatter" , "C0" ),
1203+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1204+ )
1205+ axi .set (xlabel = x_label , ylabel = "ref model" )
1206+ axi .axline (
1207+ [min_ , min_ ],
1208+ [max_ , max_ ],
1209+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1210+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1211+ )
11781212 return ax
11791213
11801214
0 commit comments