@@ -29,13 +29,13 @@ def _shap_explain_cme(cme_model, X, d_t, d_y,
2929 cme_models: function
3030 const_marginal_effect function.
3131 X: (m, d_x) matrix
32- Features for each sample. Should be in the same shape of fitted X in final stage .
32+ Features for each sample. Should be in the same shape of X during fit .
3333 d_t: tuple of int
3434 Tuple of number of treatment (exclude control in discrete treatment scenario).
3535 d_y: tuple of int
3636 Tuple of number of outcome.
3737 feature_names: optional None or list of strings of length X.shape[1] (Default=None)
38- The names of input features.
38+ The names of raw input features.
3939 treatment_names: optional None or list (Default=None)
4040 The name of treatment. In discrete treatment scenario, the name should not include the name of
4141 the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
@@ -54,8 +54,9 @@ def _shap_explain_cme(cme_model, X, d_t, d_y,
5454 and the shap_values explanation object as value.
5555
5656 """
57- (dt , dy , treatment_names , output_names , feature_names ) = _define_names (d_t , d_y , treatment_names , output_names ,
58- feature_names , input_names )
57+ (dt , dy , treatment_names , output_names , feature_names , _ ) = _define_names (d_t , d_y , treatment_names ,
58+ output_names , feature_names ,
59+ input_names , None )
5960 # define masker by using entire dataset, otherwise Explainer will only sample 100 obs by default.
6061 bg_samples = X .shape [0 ] if background_samples is None else min (background_samples , X .shape [0 ])
6162 background = shap .maskers .Independent (X , max_samples = bg_samples )
@@ -108,7 +109,7 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe
108109 featurizer: optional None or instance of featurizer
109110 Fitted Featurizer of feature X.
110111 feature_names: optional None or list of strings of length X.shape[1] (Default=None)
111- The names of input features.
112+ The names of raw input features.
112113 treatment_names: optional None or list (Default=None)
113114 The name of treatment. In discrete treatment scenario, the name should not include the name of
114115 the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
@@ -129,8 +130,12 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe
129130 d_t_ , d_y_ = d_t , d_y
130131 feature_names_ , treatment_names_ = feature_names , treatment_names ,
131132 output_names_ , input_names_ = output_names , input_names
132- (dt , dy , treatment_names , output_names , feature_names ) = _define_names (d_t , d_y , treatment_names , output_names ,
133- feature_names , input_names )
133+ (dt , dy , treatment_names , output_names , feature_names , transformed_feature_names ) = _define_names (d_t , d_y ,
134+ treatment_names ,
135+ output_names ,
136+ feature_names ,
137+ input_names ,
138+ featurizer )
134139 if featurizer is not None :
135140 F = featurizer .transform (X )
136141 else :
@@ -146,11 +151,11 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe
146151 for i in range (dt ):
147152 try :
148153 explainer = shap .Explainer (models [i ], background ,
149- feature_names = feature_names )
154+ feature_names = transformed_feature_names )
150155 except Exception as e :
151156 print ("Final model can't be parsed, explain const_marginal_effect() instead!" , repr (e ))
152157 return _shap_explain_cme (cme_model , X , d_t_ , d_y_ ,
153- feature_names = None ,
158+ feature_names = feature_names_ ,
154159 treatment_names = treatment_names_ ,
155160 output_names = output_names_ ,
156161 input_names = input_names_ ,
@@ -183,16 +188,17 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
183188 model_final: a single estimator
184189 the model's final stage model.
185190 X: matrix
186- Featurized X
191+ Featurized X.
187192 d_t: tuple of int
188193 Tuple of number of treatment (exclude control in discrete treatment scenario).
189194 d_y: tuple of int
190195 Tuple of number of outcome.
191196 fit_cate_intercept: bool
192197 Whether the first entry of the coefficient of the joint linear model associated with
193198 each treatment, is an intercept.
194- feature_names: optional None or list of strings of length X.shape[1] (Default=None)
195- The names of input features.
199+ feature_names: optional None or list of strings of length X.shape[1] or X.shape[1]-1 (Default=None)
200+ The name of featurized X (exclude intercept). Length is X.shape[1] if fit_cate_intercpet=False, otherwise
201+ length is X.shape[1]-1.
196202 treatment_names: optional None or list (Default=None)
197203 The name of treatment. In discrete treatment scenario, the name should not include the name of
198204 the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
@@ -210,8 +216,11 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
210216 each treatment name (e.g. "T0" when `treatment_names=None`) as key
211217 and the shap_values explanation object as value.
212218 """
213- (d_t , d_y , treatment_names , output_names , feature_names ) = _define_names (d_t , d_y , treatment_names , output_names ,
214- feature_names , input_names )
219+ # input feature name is already updated by cate_feature_names.
220+ (d_t , d_y , treatment_names , output_names , _ , _ ) = _define_names (d_t , d_y , treatment_names ,
221+ output_names ,
222+ feature_names ,
223+ input_names , None )
215224 X , T = broadcast_unit_treatments (X , d_t )
216225 X = cross_product (X , T )
217226 d_x = X .shape [1 ]
@@ -226,7 +235,7 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
226235 # define masker by using entire dataset, otherwise Explainer will only sample 100 obs by default.
227236 bg_samples = X_sub .shape [0 ] if background_samples is None else min (background_samples , X_sub .shape [0 ])
228237 background = shap .maskers .Independent (X_sub , max_samples = bg_samples )
229- explainer = shap .Explainer (model_final , background )
238+ explainer = shap .Explainer (model_final , background , feature_names = feature_names )
230239 shap_out = explainer (X_sub )
231240
232241 data = shap_out .data [:, ind_x [i ]]
@@ -236,14 +245,14 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
236245 main_effects = None if shap_out .main_effects is None else shap_out .main_effects [..., ind_x [i ], j ]
237246 values = shap_out .values [..., ind_x [i ], j ]
238247 shap_out_new = shap .Explanation (values , base_values = base_values , data = data , main_effects = main_effects ,
239- feature_names = feature_names )
248+ feature_names = shap_out . feature_names )
240249 shap_outs [output_names [j ]][treatment_names [i ]] = shap_out_new
241250 else :
242251 values = shap_out .values [..., ind_x [i ]]
243252 main_effects = shap_out .main_effects [..., ind_x [i ], 0 ]
244253 shap_out_new = shap .Explanation (values , base_values = shap_out .base_values , data = data ,
245254 main_effects = main_effects ,
246- feature_names = feature_names )
255+ feature_names = shap_out . feature_names )
247256 shap_outs [output_names [0 ]][treatment_names [i ]] = shap_out_new
248257
249258 return shap_outs
@@ -274,7 +283,7 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
274283 featurizer: optional None or instance of featurizer
275284 Fitted Featurizer of feature X.
276285 feature_names: optional None or list of strings of length X.shape[1] (Default=None)
277- The names of input features.
286+ The names of raw input features.
278287 treatment_names: optional None or list (Default=None)
279288 The name of treatment. In discrete treatment scenario, the name should not include the name of
280289 the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
@@ -295,8 +304,12 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
295304 d_t_ , d_y_ = d_t , d_y
296305 feature_names_ , treatment_names_ = feature_names , treatment_names ,
297306 output_names_ , input_names_ = output_names , input_names
298- (dt , dy , treatment_names , output_names , feature_names ) = _define_names (d_t , d_y , treatment_names , output_names ,
299- feature_names , input_names )
307+ (dt , dy , treatment_names , output_names , feature_names , transformed_feature_names ) = _define_names (d_t , d_y ,
308+ treatment_names ,
309+ output_names ,
310+ feature_names ,
311+ input_names ,
312+ featurizer )
300313 if featurizer is not None :
301314 F = featurizer .transform (X )
302315 else :
@@ -311,11 +324,11 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
311324 for j in range (dy ):
312325 try :
313326 explainer = shap .Explainer (multitask_model_cate [j ], background ,
314- feature_names = feature_names )
327+ feature_names = transformed_feature_names )
315328 except Exception as e :
316329 print ("Final model can't be parsed, explain const_marginal_effect() instead!" , repr (e ))
317330 return _shap_explain_cme (cme_model , X , d_t_ , d_y_ ,
318- feature_names = None ,
331+ feature_names = feature_names_ ,
319332 treatment_names = treatment_names_ ,
320333 output_names = output_names_ ,
321334 input_names = input_names_ ,
@@ -336,7 +349,7 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
336349 return shap_outs
337350
338351
339- def _define_names (d_t , d_y , treatment_names , output_names , feature_names , input_names ):
352+ def _define_names (d_t , d_y , treatment_names , output_names , feature_names , input_names , featurizer ):
340353 """
341354 Helper function to get treatment and output names
342355
@@ -355,28 +368,33 @@ def _define_names(d_t, d_y, treatment_names, output_names, feature_names, input_
355368 The user provided names of the features
356369 input_names: dicitionary
357370 The names of the features, outputs and treatments parsed from the fit input at fit time.
371+ featurizer: optional None or instance of featurizer
372+ Fitted Featurizer of feature X.
358373
359374 Returns
360375 -------
361376 d_t: int
362377 d_y: int
363378 treament_names: List
364379 output_names: List
365- feature_names: List or None
380+ feature_names: List
381+ transformed_feature_names: List or None
366382 """
367383
368384 d_t = d_t [0 ] if d_t else 1
369385 d_y = d_y [0 ] if d_y else 1
386+
370387 if treatment_names is None :
371- if (input_names is None ) or (input_names ['treatment_names' ] is None ):
372- treatment_names = [f"T{ i } " for i in range (d_t )]
373- else :
374- treatment_names = input_names ['treatment_names' ]
388+ treatment_names = input_names ['treatment_names' ]
375389 if output_names is None :
376- if (input_names is None ) or (input_names ['output_names' ] is None ):
377- output_names = [f"Y{ i } " for i in range (d_y )]
378- else :
379- output_names = input_names ['output_names' ]
380- if (feature_names is None ) and (input_names is not None ):
390+ output_names = input_names ['output_names' ]
391+ if feature_names is None :
381392 feature_names = input_names ['feature_names' ]
382- return (d_t , d_y , treatment_names , output_names , feature_names )
393+ if featurizer is None :
394+ transformed_feature_names = feature_names
395+ elif featurizer is not None and hasattr (featurizer , 'get_feature_names' ):
396+ transformed_feature_names = featurizer .get_feature_names (feature_names )
397+ else :
398+ transformed_feature_names = None
399+
400+ return (d_t , d_y , treatment_names , output_names , feature_names , transformed_feature_names )
0 commit comments