Skip to content

Commit 5e31584

Browse files
authored
a couple of remaining small fix (#419)
* copy the attribute for inference results * pass through the raw feature name when parse const_marginal_effect * fix the broadcasting issue in drlearner score function * fix final model cate prediction shape when multitask is True for DRLearner and add test
1 parent 05fb595 commit 5e31584

File tree

8 files changed

+139
-50
lines changed

8 files changed

+139
-50
lines changed

econml/_shap.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def _shap_explain_cme(cme_model, X, d_t, d_y,
2929
cme_models: function
3030
const_marginal_effect function.
3131
X: (m, d_x) matrix
32-
Features for each sample. Should be in the same shape of fitted X in final stage.
32+
Features for each sample. Should be in the same shape of X during fit.
3333
d_t: tuple of int
3434
Tuple of number of treatment (exclude control in discrete treatment scenario).
3535
d_y: tuple of int
3636
Tuple of number of outcome.
3737
feature_names: optional None or list of strings of length X.shape[1] (Default=None)
38-
The names of input features.
38+
The names of raw input features.
3939
treatment_names: optional None or list (Default=None)
4040
The name of treatment. In discrete treatment scenario, the name should not include the name of
4141
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
@@ -54,8 +54,9 @@ def _shap_explain_cme(cme_model, X, d_t, d_y,
5454
and the shap_values explanation object as value.
5555
5656
"""
57-
(dt, dy, treatment_names, output_names, feature_names) = _define_names(d_t, d_y, treatment_names, output_names,
58-
feature_names, input_names)
57+
(dt, dy, treatment_names, output_names, feature_names, _) = _define_names(d_t, d_y, treatment_names,
58+
output_names, feature_names,
59+
input_names, None)
5960
# define masker by using entire dataset, otherwise Explainer will only sample 100 obs by default.
6061
bg_samples = X.shape[0] if background_samples is None else min(background_samples, X.shape[0])
6162
background = shap.maskers.Independent(X, max_samples=bg_samples)
@@ -108,7 +109,7 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe
108109
featurizer: optional None or instance of featurizer
109110
Fitted Featurizer of feature X.
110111
feature_names: optional None or list of strings of length X.shape[1] (Default=None)
111-
The names of input features.
112+
The names of raw input features.
112113
treatment_names: optional None or list (Default=None)
113114
The name of treatment. In discrete treatment scenario, the name should not include the name of
114115
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
@@ -129,8 +130,12 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe
129130
d_t_, d_y_ = d_t, d_y
130131
feature_names_, treatment_names_ = feature_names, treatment_names,
131132
output_names_, input_names_ = output_names, input_names
132-
(dt, dy, treatment_names, output_names, feature_names) = _define_names(d_t, d_y, treatment_names, output_names,
133-
feature_names, input_names)
133+
(dt, dy, treatment_names, output_names, feature_names, transformed_feature_names) = _define_names(d_t, d_y,
134+
treatment_names,
135+
output_names,
136+
feature_names,
137+
input_names,
138+
featurizer)
134139
if featurizer is not None:
135140
F = featurizer.transform(X)
136141
else:
@@ -146,11 +151,11 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe
146151
for i in range(dt):
147152
try:
148153
explainer = shap.Explainer(models[i], background,
149-
feature_names=feature_names)
154+
feature_names=transformed_feature_names)
150155
except Exception as e:
151156
print("Final model can't be parsed, explain const_marginal_effect() instead!", repr(e))
152157
return _shap_explain_cme(cme_model, X, d_t_, d_y_,
153-
feature_names=None,
158+
feature_names=feature_names_,
154159
treatment_names=treatment_names_,
155160
output_names=output_names_,
156161
input_names=input_names_,
@@ -183,16 +188,17 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
183188
model_final: a single estimator
184189
the model's final stage model.
185190
X: matrix
186-
Featurized X
191+
Featurized X.
187192
d_t: tuple of int
188193
Tuple of number of treatment (exclude control in discrete treatment scenario).
189194
d_y: tuple of int
190195
Tuple of number of outcome.
191196
fit_cate_intercept: bool
192197
Whether the first entry of the coefficient of the joint linear model associated with
193198
each treatment, is an intercept.
194-
feature_names: optional None or list of strings of length X.shape[1] (Default=None)
195-
The names of input features.
199+
feature_names: optional None or list of strings of length X.shape[1] or X.shape[1]-1 (Default=None)
200+
The name of featurized X (exclude intercept). Length is X.shape[1] if fit_cate_intercpet=False, otherwise
201+
length is X.shape[1]-1.
196202
treatment_names: optional None or list (Default=None)
197203
The name of treatment. In discrete treatment scenario, the name should not include the name of
198204
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
@@ -210,8 +216,11 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
210216
each treatment name (e.g. "T0" when `treatment_names=None`) as key
211217
and the shap_values explanation object as value.
212218
"""
213-
(d_t, d_y, treatment_names, output_names, feature_names) = _define_names(d_t, d_y, treatment_names, output_names,
214-
feature_names, input_names)
219+
# input feature name is already updated by cate_feature_names.
220+
(d_t, d_y, treatment_names, output_names, _, _) = _define_names(d_t, d_y, treatment_names,
221+
output_names,
222+
feature_names,
223+
input_names, None)
215224
X, T = broadcast_unit_treatments(X, d_t)
216225
X = cross_product(X, T)
217226
d_x = X.shape[1]
@@ -226,7 +235,7 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
226235
# define masker by using entire dataset, otherwise Explainer will only sample 100 obs by default.
227236
bg_samples = X_sub.shape[0] if background_samples is None else min(background_samples, X_sub.shape[0])
228237
background = shap.maskers.Independent(X_sub, max_samples=bg_samples)
229-
explainer = shap.Explainer(model_final, background)
238+
explainer = shap.Explainer(model_final, background, feature_names=feature_names)
230239
shap_out = explainer(X_sub)
231240

232241
data = shap_out.data[:, ind_x[i]]
@@ -236,14 +245,14 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
236245
main_effects = None if shap_out.main_effects is None else shap_out.main_effects[..., ind_x[i], j]
237246
values = shap_out.values[..., ind_x[i], j]
238247
shap_out_new = shap.Explanation(values, base_values=base_values, data=data, main_effects=main_effects,
239-
feature_names=feature_names)
248+
feature_names=shap_out.feature_names)
240249
shap_outs[output_names[j]][treatment_names[i]] = shap_out_new
241250
else:
242251
values = shap_out.values[..., ind_x[i]]
243252
main_effects = shap_out.main_effects[..., ind_x[i], 0]
244253
shap_out_new = shap.Explanation(values, base_values=shap_out.base_values, data=data,
245254
main_effects=main_effects,
246-
feature_names=feature_names)
255+
feature_names=shap_out.feature_names)
247256
shap_outs[output_names[0]][treatment_names[i]] = shap_out_new
248257

249258
return shap_outs
@@ -274,7 +283,7 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
274283
featurizer: optional None or instance of featurizer
275284
Fitted Featurizer of feature X.
276285
feature_names: optional None or list of strings of length X.shape[1] (Default=None)
277-
The names of input features.
286+
The names of raw input features.
278287
treatment_names: optional None or list (Default=None)
279288
The name of treatment. In discrete treatment scenario, the name should not include the name of
280289
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
@@ -295,8 +304,12 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
295304
d_t_, d_y_ = d_t, d_y
296305
feature_names_, treatment_names_ = feature_names, treatment_names,
297306
output_names_, input_names_ = output_names, input_names
298-
(dt, dy, treatment_names, output_names, feature_names) = _define_names(d_t, d_y, treatment_names, output_names,
299-
feature_names, input_names)
307+
(dt, dy, treatment_names, output_names, feature_names, transformed_feature_names) = _define_names(d_t, d_y,
308+
treatment_names,
309+
output_names,
310+
feature_names,
311+
input_names,
312+
featurizer)
300313
if featurizer is not None:
301314
F = featurizer.transform(X)
302315
else:
@@ -311,11 +324,11 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
311324
for j in range(dy):
312325
try:
313326
explainer = shap.Explainer(multitask_model_cate[j], background,
314-
feature_names=feature_names)
327+
feature_names=transformed_feature_names)
315328
except Exception as e:
316329
print("Final model can't be parsed, explain const_marginal_effect() instead!", repr(e))
317330
return _shap_explain_cme(cme_model, X, d_t_, d_y_,
318-
feature_names=None,
331+
feature_names=feature_names_,
319332
treatment_names=treatment_names_,
320333
output_names=output_names_,
321334
input_names=input_names_,
@@ -336,7 +349,7 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
336349
return shap_outs
337350

338351

339-
def _define_names(d_t, d_y, treatment_names, output_names, feature_names, input_names):
352+
def _define_names(d_t, d_y, treatment_names, output_names, feature_names, input_names, featurizer):
340353
"""
341354
Helper function to get treatment and output names
342355
@@ -355,28 +368,33 @@ def _define_names(d_t, d_y, treatment_names, output_names, feature_names, input_
355368
The user provided names of the features
356369
input_names: dicitionary
357370
The names of the features, outputs and treatments parsed from the fit input at fit time.
371+
featurizer: optional None or instance of featurizer
372+
Fitted Featurizer of feature X.
358373
359374
Returns
360375
-------
361376
d_t: int
362377
d_y: int
363378
treament_names: List
364379
output_names: List
365-
feature_names: List or None
380+
feature_names: List
381+
transformed_feature_names: List or None
366382
"""
367383

368384
d_t = d_t[0] if d_t else 1
369385
d_y = d_y[0] if d_y else 1
386+
370387
if treatment_names is None:
371-
if (input_names is None) or (input_names['treatment_names'] is None):
372-
treatment_names = [f"T{i}" for i in range(d_t)]
373-
else:
374-
treatment_names = input_names['treatment_names']
388+
treatment_names = input_names['treatment_names']
375389
if output_names is None:
376-
if (input_names is None) or (input_names['output_names'] is None):
377-
output_names = [f"Y{i}" for i in range(d_y)]
378-
else:
379-
output_names = input_names['output_names']
380-
if (feature_names is None) and (input_names is not None):
390+
output_names = input_names['output_names']
391+
if feature_names is None:
381392
feature_names = input_names['feature_names']
382-
return (d_t, d_y, treatment_names, output_names, feature_names)
393+
if featurizer is None:
394+
transformed_feature_names = feature_names
395+
elif featurizer is not None and hasattr(featurizer, 'get_feature_names'):
396+
transformed_feature_names = featurizer.get_feature_names(feature_names)
397+
else:
398+
transformed_feature_names = None
399+
400+
return (d_t, d_y, treatment_names, output_names, feature_names, transformed_feature_names)

econml/dml/causal_forest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,6 @@ def feature_importances(self, max_depth=4, depth_decay_exponent=2.0):
555555
return imps.reshape(self._d_y + (-1,))
556556

557557
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
558-
feature_names = self.cate_feature_names(feature_names)
559-
560558
return _shap_explain_multitask_model_cate(self.const_marginal_effect, self.model_cate.estimators_, X,
561559
self._d_t, self._d_y, featurizer=self.featurizer_,
562560
feature_names=feature_names,

econml/dml/dml.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,8 +1177,6 @@ def refit_final(self, *, inference='auto'):
11771177
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__
11781178

11791179
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
1180-
feature_names = self.cate_feature_names(feature_names)
1181-
11821180
return _shap_explain_model_cate(self.const_marginal_effect, self.model_cate, X, self._d_t, self._d_y,
11831181
featurizer=self.featurizer_,
11841182
feature_names=feature_names,

econml/dr/_drlearner.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(self, model_final, featurizer, multitask_model_final):
127127
def fit(self, Y, T, X=None, W=None, *, nuisances, sample_weight=None, sample_var=None):
128128
Y_pred, = nuisances
129129
self.d_y = Y_pred.shape[1:-1] # track whether there's a Y dimension (must be a singleton)
130+
self.d_t = Y_pred.shape[-1] - 1 # track # of treatment (exclude baseline treatment)
130131
if (X is not None) and (self._featurizer is not None):
131132
X = self._featurizer.fit_transform(X)
132133
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight, sample_var=sample_var)
@@ -145,7 +146,7 @@ def predict(self, X=None):
145146
if (X is not None) and (self._featurizer is not None):
146147
X = self._featurizer.transform(X)
147148
if self._multitask_model_final:
148-
pred = self.model_cate.predict(X)
149+
pred = self.model_cate.predict(X).reshape((-1, self.d_t))
149150
if self.d_y: # need to reintroduce singleton Y dimension
150151
return pred[:, np.newaxis, :]
151152
return pred
@@ -158,13 +159,21 @@ def score(self, Y, T, X=None, W=None, *, nuisances, sample_weight=None, sample_v
158159
X = self._featurizer.transform(X)
159160
Y_pred, = nuisances
160161
if self._multitask_model_final:
161-
return np.mean(np.average((Y_pred[..., 1:] - Y_pred[..., [0]] - self.model_cate.predict(X))**2,
162-
weights=sample_weight, axis=0))
162+
Y_pred_diff = Y_pred[..., 1:] - Y_pred[..., [0]]
163+
cate_pred = self.model_cate.predict(X).reshape((-1, self.d_t))
164+
if self.d_y:
165+
cate_pred = cate_pred[:, np.newaxis, :]
166+
return np.mean(np.average((Y_pred_diff - cate_pred)**2, weights=sample_weight, axis=0))
167+
163168
else:
164-
return np.mean([np.average((Y_pred[..., t] - Y_pred[..., 0] -
165-
self.models_cate[t - 1].predict(X))**2,
166-
weights=sample_weight, axis=0)
167-
for t in np.arange(1, Y_pred.shape[-1])])
169+
scores = []
170+
for t in np.arange(1, Y_pred.shape[-1]):
171+
# since we only allow single dimensional y, we could flatten the prediction
172+
Y_pred_diff = (Y_pred[..., t] - Y_pred[..., 0]).flatten()
173+
cate_pred = self.models_cate[t - 1].predict(X).flatten()
174+
score = np.average((Y_pred_diff - cate_pred)**2, weights=sample_weight, axis=0)
175+
scores.append(score)
176+
return np.mean(scores)
168177

169178

170179
class DRLearner(_OrthoLearner):
@@ -637,8 +646,6 @@ def fitted_models_final(self):
637646
return self.ortho_learner_model_final_.models_cate
638647

639648
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
640-
feature_names = self.cate_feature_names(feature_names)
641-
642649
if self.ortho_learner_model_final_._multitask_model_final:
643650
return _shap_explain_multitask_model_cate(self.const_marginal_effect, self.multitask_model_cate, X,
644651
self._d_t, self._d_y,

econml/inference/_inference.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def __init__(self, d_t, d_y, pred, inf_type, fname_transformer=None,
559559
# For effect summaries, d_t is None, but the result arrays behave as if d_t=1
560560
self._d_t = d_t or 1
561561
self.d_y = d_y
562-
self.pred = pred
562+
self.pred = np.copy(pred) if pred is not None and not np.isscalar(pred) else pred
563563
self.inf_type = inf_type
564564
self.fname_transformer = fname_transformer
565565
self.feature_names = feature_names
@@ -848,7 +848,8 @@ class NormalInferenceResults(InferenceResults):
848848

849849
def __init__(self, d_t, d_y, pred, pred_stderr, inf_type, fname_transformer=None,
850850
feature_names=None, output_names=None, treatment_names=None):
851-
self.pred_stderr = pred_stderr
851+
self.pred_stderr = np.copy(pred_stderr) if pred_stderr is not None and not np.isscalar(
852+
pred_stderr) else pred_stderr
852853
super().__init__(d_t, d_y, pred, inf_type, fname_transformer, feature_names, output_names, treatment_names)
853854

854855
@property
@@ -948,7 +949,7 @@ class EmpiricalInferenceResults(InferenceResults):
948949

949950
def __init__(self, d_t, d_y, pred, pred_dist, inf_type, fname_transformer=None,
950951
feature_names=None, output_names=None, treatment_names=None):
951-
self.pred_dist = pred_dist
952+
self.pred_dist = np.copy(pred_dist) if pred_dist is not None and not np.isscalar(pred_dist) else pred_dist
952953
super().__init__(d_t, d_y, pred, inf_type, fname_transformer, feature_names, output_names, treatment_names)
953954

954955
@property

0 commit comments

Comments
 (0)