11import numpy as np
22from sklearn .utils import check_X_y
33from sklearn .utils .multiclass import type_of_target
4+ import warnings
45
56from .double_ml import DoubleML
67from .double_ml_data import DoubleMLData
@@ -27,7 +28,8 @@ class DoubleMLDID(LinearScoreMixin, DoubleML):
2728
2829 ml_m : classifier implementing ``fit()`` and ``predict_proba()``
2930 A machine learner implementing ``fit()`` and ``predict_proba()`` methods (e.g.
30- :py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D|X]`.
31+ :py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D=1|X]`.
32+ Only relevant for ``score='observational'``.
3133
3234 n_folds : int
3335 Number of folds.
@@ -86,7 +88,7 @@ class DoubleMLDID(LinearScoreMixin, DoubleML):
8688 def __init__ (self ,
8789 obj_dml_data ,
8890 ml_g ,
89- ml_m ,
91+ ml_m = None ,
9092 n_folds = 5 ,
9193 n_rep = 1 ,
9294 score = 'observational' ,
@@ -116,18 +118,29 @@ def __init__(self,
116118 # set stratication for resampling
117119 self ._strata = self ._dml_data .d
118120
121+ # check learners
119122 ml_g_is_classifier = self ._check_learner (ml_g , 'ml_g' , regressor = True , classifier = True )
120- _ = self ._check_learner (ml_m , 'ml_m' , regressor = False , classifier = True )
121- self ._learner = {'ml_g' : ml_g , 'ml_m' : ml_m }
123+ if self .score == 'observational' :
124+ _ = self ._check_learner (ml_m , 'ml_m' , regressor = False , classifier = True )
125+ self ._learner = {'ml_g' : ml_g , 'ml_m' : ml_m }
126+ else :
127+ assert self .score == 'experimental'
128+ if ml_m is not None :
129+ warnings .warn (('A learner ml_m has been provided for score = "experimental" but will be ignored. "'
130+ 'A learner ml_m is not required for estimation.' ))
131+ self ._learner = {'ml_g' : ml_g }
122132
123133 if ml_g_is_classifier :
124- if obj_dml_data .binary_outcome :
125- self ._predict_method = {'ml_g' : 'predict_proba' , 'ml_m' : 'predict_proba' }
134+ if obj_dml_data .binary_outcome :
135+ self ._predict_method = {'ml_g' : 'predict_proba' }
126136 else :
127137 raise ValueError (f'The ml_g learner { str (ml_g )} was identified as classifier '
128138 'but the outcome variable is not binary with values 0 and 1.' )
129139 else :
130- self ._predict_method = {'ml_g' : 'predict' , 'ml_m' : 'predict_proba' }
140+ self ._predict_method = {'ml_g' : 'predict' }
141+
142+ if 'ml_m' in self ._learner :
143+ self ._predict_method ['ml_m' ] = 'predict_proba'
131144 self ._initialize_ml_nuisance_params ()
132145
133146 self ._trimming_rule = trimming_rule
@@ -197,8 +210,18 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
197210 g_hat0 ['targets' ] = g_hat0 ['targets' ].astype (float )
198211 g_hat0 ['targets' ][d == 1 ] = np .nan
199212
200- # only relevant for experimental setting
213+ # only relevant for observational or experimental setting
214+ m_hat = {'preds' : None , 'targets' : None , 'models' : None }
201215 g_hat1 = {'preds' : None , 'targets' : None , 'models' : None }
216+ if self .score == 'observational' :
217+ # nuisance m
218+ m_hat = _dml_cv_predict (self ._learner ['ml_m' ], x , d , smpls = smpls , n_jobs = n_jobs_cv ,
219+ est_params = self ._get_params ('ml_m' ), method = self ._predict_method ['ml_m' ],
220+ return_models = return_models )
221+ _check_finite_predictions (m_hat ['preds' ], self ._learner ['ml_m' ], 'ml_m' , smpls )
222+ _check_is_propensity (m_hat ['preds' ], self ._learner ['ml_m' ], 'ml_m' , smpls , eps = 1e-12 )
223+ m_hat ['preds' ] = _trimm (m_hat ['preds' ], self .trimming_rule , self .trimming_threshold )
224+
202225 if self .score == 'experimental' :
203226 g_hat1 = _dml_cv_predict (self ._learner ['ml_g' ], x , y , smpls = smpls_d1 , n_jobs = n_jobs_cv ,
204227 est_params = self ._get_params ('ml_g1' ), method = self ._predict_method ['ml_g' ],
@@ -209,13 +232,6 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
209232 g_hat1 ['targets' ] = g_hat1 ['targets' ].astype (float )
210233 g_hat1 ['targets' ][d == 0 ] = np .nan
211234
212- # nuisance m
213- m_hat = _dml_cv_predict (self ._learner ['ml_m' ], x , d , smpls = smpls , n_jobs = n_jobs_cv ,
214- est_params = self ._get_params ('ml_m' ), method = self ._predict_method ['ml_m' ],
215- return_models = return_models )
216- _check_finite_predictions (m_hat ['preds' ], self ._learner ['ml_m' ], 'ml_m' , smpls )
217- _check_is_propensity (m_hat ['preds' ], self ._learner ['ml_m' ], 'ml_m' , smpls , eps = 1e-12 )
218-
219235 # nuisance estimates of the uncond. treatment prob.
220236 p_hat = np .full_like (d , np .nan , dtype = 'float64' )
221237 for train_index , test_index in smpls :
@@ -240,8 +256,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
240256 return psi_elements , preds
241257
242258 def _score_elements (self , y , d , g_hat0 , g_hat1 , m_hat , p_hat ):
243- # trimm propensities and calc residuals
244- m_hat = _trimm (m_hat , self .trimming_rule , self .trimming_threshold )
259+ # calc residuals
245260 resid_d0 = y - g_hat0
246261
247262 if self .score == 'observational' :
@@ -261,13 +276,12 @@ def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, p_hat):
261276 weight_psi_a = np .ones_like (y )
262277 weight_g0 = np .divide (d , np .mean (d )) - 1.0
263278 weight_g1 = 1.0 - np .divide (d , np .mean (d ))
264- propensity_weight = np .multiply (1.0 - d , np .divide (m_hat , 1.0 - m_hat ))
265- weight_resid_d0 = np .divide (d , np .mean (d )) - np .divide (propensity_weight , np .mean (propensity_weight ))
279+ weight_resid_d0 = np .divide (d , np .mean (d )) - np .divide (1.0 - d , np .mean (1.0 - d ))
266280 else :
267281 weight_psi_a = np .ones_like (y )
268282 weight_g0 = np .divide (d , p_hat ) - 1.0
269283 weight_g1 = 1.0 - np .divide (d , p_hat )
270- weight_resid_d0 = np .divide (d - m_hat , np .multiply (p_hat , 1.0 - m_hat ))
284+ weight_resid_d0 = np .divide (d - p_hat , np .multiply (p_hat , 1.0 - p_hat ))
271285
272286 psi_b_1 = np .multiply (weight_g0 , g_hat0 ) + np .multiply (weight_g1 , g_hat1 )
273287
@@ -296,32 +310,30 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
296310 g0_tune_res = _dml_tune (y , x , train_inds_d0 ,
297311 self ._learner ['ml_g' ], param_grids ['ml_g' ], scoring_methods ['ml_g' ],
298312 n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
299-
313+ m_tune_res = list ()
314+ if self .score == 'observational' :
315+ m_tune_res = _dml_tune (d , x , train_inds ,
316+ self ._learner ['ml_m' ], param_grids ['ml_m' ], scoring_methods ['ml_m' ],
317+ n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
300318 g1_tune_res = list ()
301319 if self .score == 'experimental' :
302320 g1_tune_res = _dml_tune (y , x , train_inds_d1 ,
303321 self ._learner ['ml_g' ], param_grids ['ml_g' ], scoring_methods ['ml_g' ],
304322 n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
305323
306- m_tune_res = _dml_tune (d , x , train_inds ,
307- self ._learner ['ml_m' ], param_grids ['ml_m' ], scoring_methods ['ml_m' ],
308- n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
309-
310324 g0_best_params = [xx .best_params_ for xx in g0_tune_res ]
311- m_best_params = [xx .best_params_ for xx in m_tune_res ]
312325 if self .score == 'observational' :
326+ m_best_params = [xx .best_params_ for xx in m_tune_res ]
313327 params = {'ml_g0' : g0_best_params ,
314328 'ml_m' : m_best_params }
315329 tune_res = {'g0_tune' : g0_tune_res ,
316330 'm_tune' : m_tune_res }
317331 else :
318332 g1_best_params = [xx .best_params_ for xx in g1_tune_res ]
319333 params = {'ml_g0' : g0_best_params ,
320- 'ml_g1' : g1_best_params ,
321- 'ml_m' : m_best_params }
334+ 'ml_g1' : g1_best_params }
322335 tune_res = {'g0_tune' : g0_tune_res ,
323- 'g1_tune' : g1_tune_res ,
324- 'm_tune' : m_tune_res }
336+ 'g1_tune' : g1_tune_res }
325337
326338 res = {'params' : params ,
327339 'tune_res' : tune_res }
0 commit comments