Skip to content

Commit 891a455

Browse files
committed
update unit tests
1 parent f6fdbc2 commit 891a455

File tree

8 files changed

+128
-42
lines changed

8 files changed

+128
-42
lines changed

doubleml/double_ml_did.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class DoubleMLDID(LinearScoreMixin, DoubleML):
2121
2222
ml_g : estimator implementing ``fit()`` and ``predict()``
2323
A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
24-
:py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`g_0(d,X) = E[\Delta Y|D=d, X]`.
24+
:py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`g_0(d,X) = E[Y_1-Y_0|D=d, X]`.
2525
For a binary outcome variable :math:`Y` (with values 0 and 1), a classifier implementing ``fit()`` and
2626
``predict_proba()`` can also be specified. If :py:func:`sklearn.base.is_classifier` returns ``True``,
2727
``predict_proba()`` is used otherwise ``predict()``.
@@ -126,19 +126,19 @@ def __init__(self,
126126
else:
127127
assert self.score == 'experimental'
128128
if ml_m is not None:
129-
warnings.warn(('A learner ml_m has been provided for score = "experimental" but will be ignored. "'
129+
warnings.warn(('A learner ml_m has been provided for score = "experimental" but will be ignored. '
130130
'A learner ml_m is not required for estimation.'))
131131
self._learner = {'ml_g': ml_g}
132132

133133
if ml_g_is_classifier:
134-
if obj_dml_data.binary_outcome:
134+
if obj_dml_data.binary_outcome:
135135
self._predict_method = {'ml_g': 'predict_proba'}
136136
else:
137137
raise ValueError(f'The ml_g learner {str(ml_g)} was identified as classifier '
138138
'but the outcome variable is not binary with values 0 and 1.')
139139
else:
140140
self._predict_method = {'ml_g': 'predict'}
141-
141+
142142
if 'ml_m' in self._learner:
143143
self._predict_method['ml_m'] = 'predict_proba'
144144
self._initialize_ml_nuisance_params()
@@ -313,8 +313,8 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
313313
m_tune_res = list()
314314
if self.score == 'observational':
315315
m_tune_res = _dml_tune(d, x, train_inds,
316-
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
317-
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
316+
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
317+
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
318318
g1_tune_res = list()
319319
if self.score == 'experimental':
320320
g1_tune_res = _dml_tune(y, x, train_inds_d1,

doubleml/double_ml_did_cs.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(self,
126126
else:
127127
assert self.score == 'experimental'
128128
if ml_m is not None:
129-
warnings.warn(('A learner ml_m has been provided for score = "experimental" but will be ignored. "'
129+
warnings.warn(('A learner ml_m has been provided for score = "experimental" but will be ignored. '
130130
'A learner ml_m is not required for estimation.'))
131131
self._learner = {'ml_g': ml_g}
132132

@@ -427,8 +427,8 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
427427
m_tune_res = list()
428428
if self.score == 'observational':
429429
m_tune_res = _dml_tune(d, x, train_inds,
430-
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
431-
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
430+
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
431+
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
432432

433433
g_d0_t0_best_params = [xx.best_params_ for xx in g_d0_t0_tune_res]
434434
g_d0_t1_best_params = [xx.best_params_ for xx in g_d0_t1_tune_res]
@@ -438,20 +438,20 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
438438
if self.score == 'observational':
439439
m_best_params = [xx.best_params_ for xx in m_tune_res]
440440
params = {'ml_g_d0_t0': g_d0_t0_best_params,
441-
'ml_g_d0_t1': g_d0_t1_best_params,
442-
'ml_g_d1_t0': g_d1_t0_best_params,
443-
'ml_g_d1_t1': g_d1_t1_best_params,
444-
'ml_m': m_best_params}
441+
'ml_g_d0_t1': g_d0_t1_best_params,
442+
'ml_g_d1_t0': g_d1_t0_best_params,
443+
'ml_g_d1_t1': g_d1_t1_best_params,
444+
'ml_m': m_best_params}
445445
tune_res = {'g_d0_t0_tune': g_d0_t0_tune_res,
446446
'g_d0_t1_tune': g_d0_t1_tune_res,
447447
'g_d1_t0_tune': g_d1_t0_tune_res,
448448
'g_d1_t1_tune': g_d1_t1_tune_res,
449449
'm_tune': m_tune_res}
450450
else:
451451
params = {'ml_g_d0_t0': g_d0_t0_best_params,
452-
'ml_g_d0_t1': g_d0_t1_best_params,
453-
'ml_g_d1_t0': g_d1_t0_best_params,
454-
'ml_g_d1_t1': g_d1_t1_best_params}
452+
'ml_g_d0_t1': g_d0_t1_best_params,
453+
'ml_g_d1_t0': g_d1_t0_best_params,
454+
'ml_g_d1_t1': g_d1_t1_best_params}
455455
tune_res = {'g_d0_t0_tune': g_d0_t0_tune_res,
456456
'g_d0_t1_tune': g_d0_t1_tune_res,
457457
'g_d1_t0_tune': g_d1_t0_tune_res,

doubleml/tests/_utils_did_cs_manual.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from ._utils import fit_predict, fit_predict_proba, tune_grid_search
55
from ._utils_did_manual import did_dml1, did_dml2
66

7-
from .._utils import _check_is_propensity
8-
97

108
def fit_did_cs(y, x, d, t,
119
learner_g, learner_m, all_smpls, dml_procedure, score, in_sample_normalization,
@@ -105,10 +103,16 @@ def fit_nuisance_did_cs(y, x, d, t,
105103
train_cond_d1_t1 = np.intersect1d(np.where(d == 1)[0], np.where(t == 1)[0])
106104
g_hat_d1_t1_list = fit_predict(y, x, ml_g_d1_t1, g_d1_t1_params, smpls,
107105
train_cond=train_cond_d1_t1)
108-
109-
ml_m = clone(learner_m)
110-
m_hat_list = fit_predict_proba(d, x, ml_m, m_params, smpls,
111-
trimming_threshold=trimming_threshold)
106+
if score == 'observational':
107+
ml_m = clone(learner_m)
108+
m_hat_list = fit_predict_proba(d, x, ml_m, m_params, smpls,
109+
trimming_threshold=trimming_threshold)
110+
else:
111+
assert score == 'experimental'
112+
m_hat_list = list()
113+
for idx, _ in enumerate(smpls):
114+
# fill it up, but its not further used
115+
m_hat_list.append(np.zeros_like(g_hat_d1_t1_list[idx], dtype='float64'))
112116

113117
p_hat_list = []
114118
for (train_index, _) in smpls:
@@ -145,7 +149,6 @@ def compute_did_cs_residuals(y, g_hat_d0_t0_list, g_hat_d0_t1_list,
145149
resid_d0_t1 = y - g_hat_d0_t1
146150
resid_d1_t0 = y - g_hat_d1_t0
147151
resid_d1_t1 = y - g_hat_d1_t1
148-
_check_is_propensity(m_hat, 'learner_m', 'ml_m', smpls, eps=1e-12)
149152
return resid_d0_t0, resid_d0_t1, resid_d1_t0, resid_d1_t1, \
150153
g_hat_d0_t0, g_hat_d0_t1, g_hat_d1_t0, g_hat_d1_t1, \
151154
m_hat, p_hat, lambda_hat
@@ -259,13 +262,17 @@ def tune_nuisance_did_cs(y, x, d, t, ml_g, ml_m, smpls, score, n_folds_tune,
259262
g_d1_t1_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune,
260263
train_cond=smpls_d1_t1)
261264

262-
m_tune_res = tune_grid_search(d, x, ml_m, smpls, param_grid_m, n_folds_tune)
263-
264265
g_d0_t0_best_params = [xx.best_params_ for xx in g_d0_t0_tune_res]
265266
g_d0_t1_best_params = [xx.best_params_ for xx in g_d0_t1_tune_res]
266267
g_d1_t0_best_params = [xx.best_params_ for xx in g_d1_t0_tune_res]
267268
g_d1_t1_best_params = [xx.best_params_ for xx in g_d1_t1_tune_res]
268-
m_best_params = [xx.best_params_ for xx in m_tune_res]
269+
270+
if score == 'observational':
271+
m_tune_res = tune_grid_search(d, x, ml_m, smpls, param_grid_m, n_folds_tune)
272+
m_best_params = [xx.best_params_ for xx in m_tune_res]
273+
else:
274+
assert score == 'experimental'
275+
m_best_params = None
269276

270277
return g_d0_t0_best_params, g_d0_t1_best_params, \
271278
g_d1_t0_best_params, g_d1_t1_best_params, m_best_params

doubleml/tests/_utils_did_manual.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from ._utils_boot import boot_manual, draw_weights
55
from ._utils import fit_predict, fit_predict_proba, tune_grid_search
66

7-
from .._utils import _check_is_propensity
8-
97

108
def fit_did(y, x, d,
119
learner_g, learner_m, all_smpls, dml_procedure, score, in_sample_normalization,
@@ -75,6 +73,10 @@ def fit_nuisance_did(y, x, d, learner_g, learner_m, smpls, score,
7573
train_cond1 = np.where(d == 1)[0]
7674
g_hat1_list = fit_predict(y, x, ml_g1, g1_params, smpls,
7775
train_cond=train_cond1)
76+
m_hat_list = list()
77+
for idx, _ in enumerate(smpls):
78+
# fill it up, but its not further used
79+
m_hat_list.append(np.zeros_like(g_hat0_list[idx], dtype='float64'))
7880

7981
else:
8082
assert score == 'observational'
@@ -83,9 +85,9 @@ def fit_nuisance_did(y, x, d, learner_g, learner_m, smpls, score,
8385
# fill it up, but its not further used
8486
g_hat1_list.append(np.zeros_like(g_hat0_list[idx], dtype='float64'))
8587

86-
ml_m = clone(learner_m)
87-
m_hat_list = fit_predict_proba(d, x, ml_m, m_params, smpls,
88-
trimming_threshold=trimming_threshold)
88+
ml_m = clone(learner_m)
89+
m_hat_list = fit_predict_proba(d, x, ml_m, m_params, smpls,
90+
trimming_threshold=trimming_threshold)
8991

9092
p_hat_list = []
9193
for (train_index, _) in smpls:
@@ -107,7 +109,6 @@ def compute_did_residuals(y, g_hat0_list, g_hat1_list, m_hat_list, p_hat_list, s
107109
m_hat[test_index] = m_hat_list[idx]
108110
p_hat[test_index] = p_hat_list[idx]
109111

110-
_check_is_propensity(m_hat, 'learner_m', 'ml_m', smpls, eps=1e-12)
111112
return resid_d0, g_hat0, g_hat1, m_hat, p_hat
112113

113114

@@ -157,13 +158,12 @@ def did_score_elements(g_hat0, g_hat1, m_hat, p_hat, resid_d0, d, score, in_samp
157158
weight_psi_a = np.ones_like(d)
158159
weight_g0 = np.divide(d, np.mean(d)) - 1.0
159160
weight_g1 = 1.0 - np.divide(d, np.mean(d))
160-
propensity_weight = np.multiply(1.0-d, np.divide(m_hat, 1.0-m_hat))
161-
weight_resid_d0 = np.divide(d, np.mean(d)) - np.divide(propensity_weight, np.mean(propensity_weight))
161+
weight_resid_d0 = np.divide(d, np.mean(d)) - np.divide(1.0-d, np.mean(1.0-d))
162162
else:
163163
weight_psi_a = np.ones_like(d)
164164
weight_g0 = np.divide(d, p_hat) - 1.0
165165
weight_g1 = 1.0 - np.divide(d, p_hat)
166-
weight_resid_d0 = np.divide(d-m_hat, np.multiply(p_hat, 1.0-m_hat))
166+
weight_resid_d0 = np.divide(d-p_hat, np.multiply(p_hat, 1.0-p_hat))
167167

168168
psi_b_1 = np.multiply(weight_g0, g_hat0) + np.multiply(weight_g1, g_hat1)
169169

@@ -223,18 +223,18 @@ def tune_nuisance_did(y, x, d, ml_g, ml_m, smpls, score, n_folds_tune,
223223
train_cond0 = np.where(d == 0)[0]
224224
g0_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune,
225225
train_cond=train_cond0)
226-
226+
g0_best_params = [xx.best_params_ for xx in g0_tune_res]
227227
if score == 'experimental':
228228
train_cond1 = np.where(d == 1)[0]
229229
g1_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune,
230230
train_cond=train_cond1)
231231
g1_best_params = [xx.best_params_ for xx in g1_tune_res]
232+
m_best_params = None
232233
else:
234+
assert score == 'observational'
233235
g1_best_params = None
234236

235-
m_tune_res = tune_grid_search(d, x, ml_m, smpls, param_grid_m, n_folds_tune)
236-
237-
g0_best_params = [xx.best_params_ for xx in g0_tune_res]
238-
m_best_params = [xx.best_params_ for xx in m_tune_res]
237+
m_tune_res = tune_grid_search(d, x, ml_m, smpls, param_grid_m, n_folds_tune)
238+
m_best_params = [xx.best_params_ for xx in m_tune_res]
239239

240240
return g0_best_params, g1_best_params, m_best_params

doubleml/tests/test_did.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,40 @@ def test_dml_did_boot(dml_did_fixture):
130130
assert np.allclose(dml_did_fixture['boot_t_stat' + bootstrap],
131131
dml_did_fixture['boot_t_stat' + bootstrap + '_manual'],
132132
rtol=1e-9, atol=1e-4)
133+
134+
135+
@pytest.mark.ci
136+
def test_dml_did_experimental(generate_data_did, in_sample_normalization, learner):
137+
# collect data
138+
(x, y, d) = generate_data_did
139+
140+
# Set machine learning methods for m & g
141+
ml_g = clone(learner[0])
142+
ml_m = clone(learner[1])
143+
144+
np.random.seed(3141)
145+
obj_dml_data = dml.DoubleMLData.from_arrays(x, y, d)
146+
147+
np.random.seed(3141)
148+
dml_did_obj_without_ml_m = dml.DoubleMLDID(obj_dml_data,
149+
ml_g,
150+
score='experimental',
151+
in_sample_normalization=in_sample_normalization)
152+
dml_did_obj_without_ml_m.fit()
153+
154+
np.random.seed(3141)
155+
dml_did_obj_with_ml_m = dml.DoubleMLDID(obj_dml_data,
156+
ml_g, ml_m,
157+
score='experimental',
158+
in_sample_normalization=in_sample_normalization)
159+
dml_did_obj_with_ml_m.fit()
160+
assert math.isclose(dml_did_obj_with_ml_m.coef,
161+
dml_did_obj_without_ml_m.coef,
162+
rel_tol=1e-9, abs_tol=1e-4)
163+
164+
msg = ('A learner ml_m has been provided for score = "experimental" but will be ignored. '
165+
'A learner ml_m is not required for estimation.')
166+
with pytest.warns(UserWarning, match=msg):
167+
dml.DoubleMLDID(obj_dml_data, ml_g, ml_m,
168+
score='experimental',
169+
in_sample_normalization=in_sample_normalization)

doubleml/tests/test_did_cs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,40 @@ def test_dml_did_cs_boot(dml_did_cs_fixture):
131131
assert np.allclose(dml_did_cs_fixture['boot_t_stat' + bootstrap],
132132
dml_did_cs_fixture['boot_t_stat' + bootstrap + '_manual'],
133133
rtol=1e-9, atol=1e-4)
134+
135+
136+
@pytest.mark.ci
137+
def test_dml_did_cs_experimental(generate_data_did_cs, in_sample_normalization, learner):
138+
# collect data
139+
(x, y, d, t) = generate_data_did_cs
140+
141+
# Set machine learning methods for m & g
142+
ml_g = clone(learner[0])
143+
ml_m = clone(learner[1])
144+
145+
np.random.seed(3141)
146+
obj_dml_data = dml.DoubleMLData.from_arrays(x, y, d, t=t)
147+
148+
np.random.seed(3141)
149+
dml_did_obj_without_ml_m = dml.DoubleMLDIDCS(obj_dml_data,
150+
ml_g,
151+
score='experimental',
152+
in_sample_normalization=in_sample_normalization)
153+
dml_did_obj_without_ml_m.fit()
154+
155+
np.random.seed(3141)
156+
dml_did_obj_with_ml_m = dml.DoubleMLDIDCS(obj_dml_data,
157+
ml_g, ml_m,
158+
score='experimental',
159+
in_sample_normalization=in_sample_normalization)
160+
dml_did_obj_with_ml_m.fit()
161+
assert math.isclose(dml_did_obj_with_ml_m.coef,
162+
dml_did_obj_without_ml_m.coef,
163+
rel_tol=1e-9, abs_tol=1e-4)
164+
165+
msg = ('A learner ml_m has been provided for score = "experimental" but will be ignored. '
166+
'A learner ml_m is not required for estimation.')
167+
with pytest.warns(UserWarning, match=msg):
168+
dml.DoubleMLDIDCS(obj_dml_data, ml_g, ml_m,
169+
score='experimental',
170+
in_sample_normalization=in_sample_normalization)

doubleml/tests/test_did_cs_tune.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ def dml_did_cs_fixture(generate_data_did_cs, learner_g, learner_m, score, in_sam
116116
g_d0_t1_params = g_d0_t1_params * n_folds
117117
g_d1_t0_params = g_d1_t0_params * n_folds
118118
g_d1_t1_params = g_d1_t1_params * n_folds
119-
m_params = m_params * n_folds
119+
if score == 'observational':
120+
m_params = m_params * n_folds
121+
else:
122+
assert score == 'experimental'
123+
m_params = None
120124

121125
res_manual = fit_did_cs(y, x, d, t, clone(learner_g), clone(learner_m),
122126
all_smpls, dml_procedure, score, in_sample_normalization,

doubleml/tests/test_did_tune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,13 @@ def dml_did_fixture(generate_data_did, learner_g, learner_m, score, in_sample_no
109109
n_folds_tune,
110110
par_grid['ml_g'], par_grid['ml_m'])
111111
g0_params = g0_params * n_folds
112-
m_params = m_params * n_folds
113112
if score == 'experimental':
114113
g1_params = g1_params * n_folds
114+
m_params = None
115115
else:
116116
assert score == 'observational'
117117
g1_params = None
118+
m_params = m_params * n_folds
118119

119120
res_manual = fit_did(y, x, d, clone(learner_g), clone(learner_m),
120121
all_smpls, dml_procedure, score, in_sample_normalization,

0 commit comments

Comments
 (0)