Skip to content

Commit f6fdbc2

Browse files
committed
update prop score for experimental setting
1 parent 1697e90 commit f6fdbc2

File tree

2 files changed

+100
-64
lines changed

2 files changed

+100
-64
lines changed

doubleml/double_ml_did.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from sklearn.utils import check_X_y
33
from sklearn.utils.multiclass import type_of_target
4+
import warnings
45

56
from .double_ml import DoubleML
67
from .double_ml_data import DoubleMLData
@@ -27,7 +28,8 @@ class DoubleMLDID(LinearScoreMixin, DoubleML):
2728
2829
ml_m : classifier implementing ``fit()`` and ``predict_proba()``
2930
A machine learner implementing ``fit()`` and ``predict_proba()`` methods (e.g.
30-
:py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D|X]`.
31+
:py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D=1|X]`.
32+
Only relevant for ``score='observational'``.
3133
3234
n_folds : int
3335
Number of folds.
@@ -86,7 +88,7 @@ class DoubleMLDID(LinearScoreMixin, DoubleML):
8688
def __init__(self,
8789
obj_dml_data,
8890
ml_g,
89-
ml_m,
91+
ml_m=None,
9092
n_folds=5,
9193
n_rep=1,
9294
score='observational',
@@ -116,18 +118,29 @@ def __init__(self,
116118
# set stratication for resampling
117119
self._strata = self._dml_data.d
118120

121+
# check learners
119122
ml_g_is_classifier = self._check_learner(ml_g, 'ml_g', regressor=True, classifier=True)
120-
_ = self._check_learner(ml_m, 'ml_m', regressor=False, classifier=True)
121-
self._learner = {'ml_g': ml_g, 'ml_m': ml_m}
123+
if self.score == 'observational':
124+
_ = self._check_learner(ml_m, 'ml_m', regressor=False, classifier=True)
125+
self._learner = {'ml_g': ml_g, 'ml_m': ml_m}
126+
else:
127+
assert self.score == 'experimental'
128+
if ml_m is not None:
129+
warnings.warn(('A learner ml_m has been provided for score = "experimental" but will be ignored. "'
130+
'A learner ml_m is not required for estimation.'))
131+
self._learner = {'ml_g': ml_g}
122132

123133
if ml_g_is_classifier:
124-
if obj_dml_data.binary_outcome:
125-
self._predict_method = {'ml_g': 'predict_proba', 'ml_m': 'predict_proba'}
134+
if obj_dml_data.binary_outcome:
135+
self._predict_method = {'ml_g': 'predict_proba'}
126136
else:
127137
raise ValueError(f'The ml_g learner {str(ml_g)} was identified as classifier '
128138
'but the outcome variable is not binary with values 0 and 1.')
129139
else:
130-
self._predict_method = {'ml_g': 'predict', 'ml_m': 'predict_proba'}
140+
self._predict_method = {'ml_g': 'predict'}
141+
142+
if 'ml_m' in self._learner:
143+
self._predict_method['ml_m'] = 'predict_proba'
131144
self._initialize_ml_nuisance_params()
132145

133146
self._trimming_rule = trimming_rule
@@ -197,8 +210,18 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
197210
g_hat0['targets'] = g_hat0['targets'].astype(float)
198211
g_hat0['targets'][d == 1] = np.nan
199212

200-
# only relevant for experimental setting
213+
# only relevant for observational or experimental setting
214+
m_hat = {'preds': None, 'targets': None, 'models': None}
201215
g_hat1 = {'preds': None, 'targets': None, 'models': None}
216+
if self.score == 'observational':
217+
# nuisance m
218+
m_hat = _dml_cv_predict(self._learner['ml_m'], x, d, smpls=smpls, n_jobs=n_jobs_cv,
219+
est_params=self._get_params('ml_m'), method=self._predict_method['ml_m'],
220+
return_models=return_models)
221+
_check_finite_predictions(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls)
222+
_check_is_propensity(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls, eps=1e-12)
223+
m_hat['preds'] = _trimm(m_hat['preds'], self.trimming_rule, self.trimming_threshold)
224+
202225
if self.score == 'experimental':
203226
g_hat1 = _dml_cv_predict(self._learner['ml_g'], x, y, smpls=smpls_d1, n_jobs=n_jobs_cv,
204227
est_params=self._get_params('ml_g1'), method=self._predict_method['ml_g'],
@@ -209,13 +232,6 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
209232
g_hat1['targets'] = g_hat1['targets'].astype(float)
210233
g_hat1['targets'][d == 0] = np.nan
211234

212-
# nuisance m
213-
m_hat = _dml_cv_predict(self._learner['ml_m'], x, d, smpls=smpls, n_jobs=n_jobs_cv,
214-
est_params=self._get_params('ml_m'), method=self._predict_method['ml_m'],
215-
return_models=return_models)
216-
_check_finite_predictions(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls)
217-
_check_is_propensity(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls, eps=1e-12)
218-
219235
# nuisance estimates of the uncond. treatment prob.
220236
p_hat = np.full_like(d, np.nan, dtype='float64')
221237
for train_index, test_index in smpls:
@@ -240,8 +256,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
240256
return psi_elements, preds
241257

242258
def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, p_hat):
243-
# trimm propensities and calc residuals
244-
m_hat = _trimm(m_hat, self.trimming_rule, self.trimming_threshold)
259+
# calc residuals
245260
resid_d0 = y - g_hat0
246261

247262
if self.score == 'observational':
@@ -261,13 +276,12 @@ def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, p_hat):
261276
weight_psi_a = np.ones_like(y)
262277
weight_g0 = np.divide(d, np.mean(d)) - 1.0
263278
weight_g1 = 1.0 - np.divide(d, np.mean(d))
264-
propensity_weight = np.multiply(1.0-d, np.divide(m_hat, 1.0-m_hat))
265-
weight_resid_d0 = np.divide(d, np.mean(d)) - np.divide(propensity_weight, np.mean(propensity_weight))
279+
weight_resid_d0 = np.divide(d, np.mean(d)) - np.divide(1.0-d, np.mean(1.0-d))
266280
else:
267281
weight_psi_a = np.ones_like(y)
268282
weight_g0 = np.divide(d, p_hat) - 1.0
269283
weight_g1 = 1.0 - np.divide(d, p_hat)
270-
weight_resid_d0 = np.divide(d-m_hat, np.multiply(p_hat, 1.0-m_hat))
284+
weight_resid_d0 = np.divide(d-p_hat, np.multiply(p_hat, 1.0-p_hat))
271285

272286
psi_b_1 = np.multiply(weight_g0, g_hat0) + np.multiply(weight_g1, g_hat1)
273287

@@ -296,32 +310,30 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
296310
g0_tune_res = _dml_tune(y, x, train_inds_d0,
297311
self._learner['ml_g'], param_grids['ml_g'], scoring_methods['ml_g'],
298312
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
299-
313+
m_tune_res = list()
314+
if self.score == 'observational':
315+
m_tune_res = _dml_tune(d, x, train_inds,
316+
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
317+
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
300318
g1_tune_res = list()
301319
if self.score == 'experimental':
302320
g1_tune_res = _dml_tune(y, x, train_inds_d1,
303321
self._learner['ml_g'], param_grids['ml_g'], scoring_methods['ml_g'],
304322
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
305323

306-
m_tune_res = _dml_tune(d, x, train_inds,
307-
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
308-
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
309-
310324
g0_best_params = [xx.best_params_ for xx in g0_tune_res]
311-
m_best_params = [xx.best_params_ for xx in m_tune_res]
312325
if self.score == 'observational':
326+
m_best_params = [xx.best_params_ for xx in m_tune_res]
313327
params = {'ml_g0': g0_best_params,
314328
'ml_m': m_best_params}
315329
tune_res = {'g0_tune': g0_tune_res,
316330
'm_tune': m_tune_res}
317331
else:
318332
g1_best_params = [xx.best_params_ for xx in g1_tune_res]
319333
params = {'ml_g0': g0_best_params,
320-
'ml_g1': g1_best_params,
321-
'ml_m': m_best_params}
334+
'ml_g1': g1_best_params}
322335
tune_res = {'g0_tune': g0_tune_res,
323-
'g1_tune': g1_tune_res,
324-
'm_tune': m_tune_res}
336+
'g1_tune': g1_tune_res}
325337

326338
res = {'params': params,
327339
'tune_res': tune_res}

doubleml/double_ml_did_cs.py

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from sklearn.utils import check_X_y
33
from sklearn.utils.multiclass import type_of_target
4+
import warnings
45

56
from .double_ml import DoubleML
67
from .double_ml_data import DoubleMLData
@@ -27,7 +28,8 @@ class DoubleMLDIDCS(LinearScoreMixin, DoubleML):
2728
2829
ml_m : classifier implementing ``fit()`` and ``predict_proba()``
2930
A machine learner implementing ``fit()`` and ``predict_proba()`` methods (e.g.
30-
:py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D|X]`.
31+
:py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D=1|X]`.
32+
Only relevant for ``score='observational'``.
3133
3234
n_folds : int
3335
Number of folds.
@@ -86,7 +88,7 @@ class DoubleMLDIDCS(LinearScoreMixin, DoubleML):
8688
def __init__(self,
8789
obj_dml_data,
8890
ml_g,
89-
ml_m,
91+
ml_m=None,
9092
n_folds=5,
9193
n_rep=1,
9294
score='observational',
@@ -114,23 +116,31 @@ def __init__(self,
114116
f'Object of type {str(type(self.in_sample_normalization))} passed.')
115117

116118
# set stratication for resampling
117-
self._strata = self._dml_data.d.reshape(-1, 1) + \
118-
2 * self._dml_data.t.reshape(-1, 1)
119+
self._strata = self._dml_data.d.reshape(-1, 1) + 2 * self._dml_data.t.reshape(-1, 1)
119120

120-
ml_g_is_classifier = self._check_learner(
121-
ml_g, 'ml_g', regressor=True, classifier=True)
122-
_ = self._check_learner(ml_m, 'ml_m', regressor=False, classifier=True)
123-
self._learner = {'ml_g': ml_g, 'ml_m': ml_m}
121+
# check learners
122+
ml_g_is_classifier = self._check_learner(ml_g, 'ml_g', regressor=True, classifier=True)
123+
if self.score == 'observational':
124+
_ = self._check_learner(ml_m, 'ml_m', regressor=False, classifier=True)
125+
self._learner = {'ml_g': ml_g, 'ml_m': ml_m}
126+
else:
127+
assert self.score == 'experimental'
128+
if ml_m is not None:
129+
warnings.warn(('A learner ml_m has been provided for score = "experimental" but will be ignored. "'
130+
'A learner ml_m is not required for estimation.'))
131+
self._learner = {'ml_g': ml_g}
124132

125133
if ml_g_is_classifier:
126134
if obj_dml_data.binary_outcome:
127-
self._predict_method = {
128-
'ml_g': 'predict_proba', 'ml_m': 'predict_proba'}
135+
self._predict_method = {'ml_g': 'predict_proba'}
129136
else:
130137
raise ValueError(f'The ml_g learner {str(ml_g)} was identified as classifier '
131138
'but the outcome variable is not binary with values 0 and 1.')
132139
else:
133-
self._predict_method = {'ml_g': 'predict', 'ml_m': 'predict_proba'}
140+
self._predict_method = {'ml_g': 'predict'}
141+
142+
if 'ml_m' in self._learner:
143+
self._predict_method['ml_m'] = 'predict_proba'
134144
self._initialize_ml_nuisance_params()
135145

136146
self._trimming_rule = trimming_rule
@@ -241,12 +251,16 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
241251
g_hat_d1_t1['targets'] = g_hat_d1_t1['targets'].astype(float)
242252
g_hat_d1_t1['targets'][np.invert((d == 1) & (t == 1))] = np.nan
243253

244-
# nuisance m
245-
m_hat = _dml_cv_predict(self._learner['ml_m'], x, d, smpls=smpls, n_jobs=n_jobs_cv,
246-
est_params=self._get_params('ml_m'), method=self._predict_method['ml_m'],
247-
return_models=return_models)
248-
_check_finite_predictions(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls)
249-
_check_is_propensity(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls, eps=1e-12)
254+
# only relevant for observational or experimental setting
255+
m_hat = {'preds': None, 'targets': None, 'models': None}
256+
if self.score == 'observational':
257+
# nuisance m
258+
m_hat = _dml_cv_predict(self._learner['ml_m'], x, d, smpls=smpls, n_jobs=n_jobs_cv,
259+
est_params=self._get_params('ml_m'), method=self._predict_method['ml_m'],
260+
return_models=return_models)
261+
_check_finite_predictions(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls)
262+
_check_is_propensity(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls, eps=1e-12)
263+
m_hat['preds'] = _trimm(m_hat['preds'], self.trimming_rule, self.trimming_threshold)
250264

251265
psi_a, psi_b = self._score_elements(y, d, t,
252266
g_hat_d0_t0['preds'], g_hat_d0_t1['preds'],
@@ -279,8 +293,6 @@ def _score_elements(self, y, d, t,
279293
g_hat_d1_t0, g_hat_d1_t1,
280294
m_hat, p_hat, lambda_hat):
281295

282-
# trimm propensities
283-
m_hat = _trimm(m_hat, self.trimming_rule, self.trimming_threshold)
284296
# calculate residuals
285297
resid_d0_t0 = y - g_hat_d0_t0
286298
resid_d0_t1 = y - g_hat_d0_t1
@@ -412,26 +424,38 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
412424
self._learner['ml_g'], param_grids['ml_g'], scoring_methods['ml_g'],
413425
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
414426

415-
m_tune_res = _dml_tune(d, x, train_inds,
416-
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
417-
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
427+
m_tune_res = list()
428+
if self.score == 'observational':
429+
m_tune_res = _dml_tune(d, x, train_inds,
430+
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
431+
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
418432

419433
g_d0_t0_best_params = [xx.best_params_ for xx in g_d0_t0_tune_res]
420434
g_d0_t1_best_params = [xx.best_params_ for xx in g_d0_t1_tune_res]
421435
g_d1_t0_best_params = [xx.best_params_ for xx in g_d1_t0_tune_res]
422436
g_d1_t1_best_params = [xx.best_params_ for xx in g_d1_t1_tune_res]
423-
m_best_params = [xx.best_params_ for xx in m_tune_res]
424-
425-
params = {'ml_g_d0_t0': g_d0_t0_best_params,
426-
'ml_g_d0_t1': g_d0_t1_best_params,
427-
'ml_g_d1_t0': g_d1_t0_best_params,
428-
'ml_g_d1_t1': g_d1_t1_best_params,
429-
'ml_m': m_best_params}
430-
tune_res = {'g_d0_t0_tune': g_d0_t0_tune_res,
431-
'g_d0_t1_tune': g_d0_t1_tune_res,
432-
'g_d1_t0_tune': g_d1_t0_tune_res,
433-
'g_d1_t1_tune': g_d1_t1_tune_res,
434-
'm_tune': m_tune_res}
437+
438+
if self.score == 'observational':
439+
m_best_params = [xx.best_params_ for xx in m_tune_res]
440+
params = {'ml_g_d0_t0': g_d0_t0_best_params,
441+
'ml_g_d0_t1': g_d0_t1_best_params,
442+
'ml_g_d1_t0': g_d1_t0_best_params,
443+
'ml_g_d1_t1': g_d1_t1_best_params,
444+
'ml_m': m_best_params}
445+
tune_res = {'g_d0_t0_tune': g_d0_t0_tune_res,
446+
'g_d0_t1_tune': g_d0_t1_tune_res,
447+
'g_d1_t0_tune': g_d1_t0_tune_res,
448+
'g_d1_t1_tune': g_d1_t1_tune_res,
449+
'm_tune': m_tune_res}
450+
else:
451+
params = {'ml_g_d0_t0': g_d0_t0_best_params,
452+
'ml_g_d0_t1': g_d0_t1_best_params,
453+
'ml_g_d1_t0': g_d1_t0_best_params,
454+
'ml_g_d1_t1': g_d1_t1_best_params}
455+
tune_res = {'g_d0_t0_tune': g_d0_t0_tune_res,
456+
'g_d0_t1_tune': g_d0_t1_tune_res,
457+
'g_d1_t0_tune': g_d1_t0_tune_res,
458+
'g_d1_t1_tune': g_d1_t1_tune_res}
435459

436460
res = {'params': params,
437461
'tune_res': tune_res}

0 commit comments

Comments
 (0)