44from ._utils_boot import boot_manual , draw_weights
55from ._utils import fit_predict , fit_predict_proba , tune_grid_search
66
7- from .._utils import _check_is_propensity
8-
97
108def fit_did (y , x , d ,
119 learner_g , learner_m , all_smpls , dml_procedure , score , in_sample_normalization ,
@@ -75,6 +73,10 @@ def fit_nuisance_did(y, x, d, learner_g, learner_m, smpls, score,
7573 train_cond1 = np .where (d == 1 )[0 ]
7674 g_hat1_list = fit_predict (y , x , ml_g1 , g1_params , smpls ,
7775 train_cond = train_cond1 )
76+ m_hat_list = list ()
77+ for idx , _ in enumerate (smpls ):
78+ # fill it up, but its not further used
79+ m_hat_list .append (np .zeros_like (g_hat0_list [idx ], dtype = 'float64' ))
7880
7981 else :
8082 assert score == 'observational'
@@ -83,9 +85,9 @@ def fit_nuisance_did(y, x, d, learner_g, learner_m, smpls, score,
8385 # fill it up, but its not further used
8486 g_hat1_list .append (np .zeros_like (g_hat0_list [idx ], dtype = 'float64' ))
8587
86- ml_m = clone (learner_m )
87- m_hat_list = fit_predict_proba (d , x , ml_m , m_params , smpls ,
88- trimming_threshold = trimming_threshold )
88+ ml_m = clone (learner_m )
89+ m_hat_list = fit_predict_proba (d , x , ml_m , m_params , smpls ,
90+ trimming_threshold = trimming_threshold )
8991
9092 p_hat_list = []
9193 for (train_index , _ ) in smpls :
@@ -107,7 +109,6 @@ def compute_did_residuals(y, g_hat0_list, g_hat1_list, m_hat_list, p_hat_list, s
107109 m_hat [test_index ] = m_hat_list [idx ]
108110 p_hat [test_index ] = p_hat_list [idx ]
109111
110- _check_is_propensity (m_hat , 'learner_m' , 'ml_m' , smpls , eps = 1e-12 )
111112 return resid_d0 , g_hat0 , g_hat1 , m_hat , p_hat
112113
113114
@@ -157,13 +158,12 @@ def did_score_elements(g_hat0, g_hat1, m_hat, p_hat, resid_d0, d, score, in_samp
157158 weight_psi_a = np .ones_like (d )
158159 weight_g0 = np .divide (d , np .mean (d )) - 1.0
159160 weight_g1 = 1.0 - np .divide (d , np .mean (d ))
160- propensity_weight = np .multiply (1.0 - d , np .divide (m_hat , 1.0 - m_hat ))
161- weight_resid_d0 = np .divide (d , np .mean (d )) - np .divide (propensity_weight , np .mean (propensity_weight ))
161+ weight_resid_d0 = np .divide (d , np .mean (d )) - np .divide (1.0 - d , np .mean (1.0 - d ))
162162 else :
163163 weight_psi_a = np .ones_like (d )
164164 weight_g0 = np .divide (d , p_hat ) - 1.0
165165 weight_g1 = 1.0 - np .divide (d , p_hat )
166- weight_resid_d0 = np .divide (d - m_hat , np .multiply (p_hat , 1.0 - m_hat ))
166+ weight_resid_d0 = np .divide (d - p_hat , np .multiply (p_hat , 1.0 - p_hat ))
167167
168168 psi_b_1 = np .multiply (weight_g0 , g_hat0 ) + np .multiply (weight_g1 , g_hat1 )
169169
@@ -223,18 +223,18 @@ def tune_nuisance_did(y, x, d, ml_g, ml_m, smpls, score, n_folds_tune,
223223 train_cond0 = np .where (d == 0 )[0 ]
224224 g0_tune_res = tune_grid_search (y , x , ml_g , smpls , param_grid_g , n_folds_tune ,
225225 train_cond = train_cond0 )
226-
226+ g0_best_params = [ xx . best_params_ for xx in g0_tune_res ]
227227 if score == 'experimental' :
228228 train_cond1 = np .where (d == 1 )[0 ]
229229 g1_tune_res = tune_grid_search (y , x , ml_g , smpls , param_grid_g , n_folds_tune ,
230230 train_cond = train_cond1 )
231231 g1_best_params = [xx .best_params_ for xx in g1_tune_res ]
232+ m_best_params = None
232233 else :
234+ assert score == 'observational'
233235 g1_best_params = None
234236
235- m_tune_res = tune_grid_search (d , x , ml_m , smpls , param_grid_m , n_folds_tune )
236-
237- g0_best_params = [xx .best_params_ for xx in g0_tune_res ]
238- m_best_params = [xx .best_params_ for xx in m_tune_res ]
237+ m_tune_res = tune_grid_search (d , x , ml_m , smpls , param_grid_m , n_folds_tune )
238+ m_best_params = [xx .best_params_ for xx in m_tune_res ]
239239
240240 return g0_best_params , g1_best_params , m_best_params
0 commit comments