@@ -11,6 +11,7 @@ def fit_lpq(y, x, d, z, quantile,
1111 learner_g , learner_m , all_smpls , treatment , dml_procedure , n_rep = 1 ,
1212 trimming_rule = 'truncate' ,
1313 trimming_threshold = 1e-2 ,
14+ kde = _default_kde ,
1415 normalize_ipw = True , m_z_params = None ,
1516 m_d_z0_params = None , m_d_z1_params = None ,
1617 g_du_z0_params = None , g_du_z1_params = None ):
@@ -37,10 +38,10 @@ def fit_lpq(y, x, d, z, quantile,
3738 g_du_z1_params = g_du_z1_params )
3839 if dml_procedure == 'dml1' :
3940 lpqs [i_rep ], ses [i_rep ] = lpq_dml1 (y , d , z , m_z_hat , g_du_z0_hat , g_du_z1_hat , comp_prob_hat ,
40- treatment , quantile , ipw_vec , coef_bounds , smpls )
41+ treatment , quantile , ipw_vec , coef_bounds , smpls , kde )
4142 else :
4243 lpqs [i_rep ], ses [i_rep ] = lpq_dml2 (y , d , z , m_z_hat , g_du_z0_hat , g_du_z1_hat , comp_prob_hat ,
43- treatment , quantile , ipw_vec , coef_bounds )
44+ treatment , quantile , ipw_vec , coef_bounds , kde )
4445
4546 lpq = np .median (lpqs )
4647 se = np .sqrt (np .median (np .power (ses , 2 ) * n_obs + np .power (lpqs - lpq , 2 )) / n_obs )
@@ -200,7 +201,7 @@ def ipw_score(theta):
200201 return m_z_hat , g_du_z0_hat , g_du_z1_hat , comp_prob_hat , ipw_vec , coef_bounds
201202
202203
203- def lpq_dml1 (y , d , z , m_z , g_du_z0 , g_du_z1 , comp_prob , treatment , quantile , ipw_vec , coef_bounds , smpls ):
204+ def lpq_dml1 (y , d , z , m_z , g_du_z0 , g_du_z1 , comp_prob , treatment , quantile , ipw_vec , coef_bounds , smpls , kde ):
204205 thetas = np .zeros (len (smpls ))
205206 n_obs = len (y )
206207 ipw_est = ipw_vec .mean ()
@@ -211,17 +212,17 @@ def lpq_dml1(y, d, z, m_z, g_du_z0, g_du_z1, comp_prob, treatment, quantile, ipw
211212
212213 theta_hat = np .mean (thetas )
213214
214- se = np .sqrt (lpq_var_est (theta_hat , m_z , g_du_z0 , g_du_z1 , comp_prob , d , y , z , treatment , quantile , n_obs ))
215+ se = np .sqrt (lpq_var_est (theta_hat , m_z , g_du_z0 , g_du_z1 , comp_prob , d , y , z , treatment , quantile , n_obs , kde ))
215216
216217 return theta_hat , se
217218
218219
219- def lpq_dml2 (y , d , z , m_z , g_du_z0 , g_du_z1 , comp_prob , treatment , quantile , ipw_vec , coef_bounds ):
220+ def lpq_dml2 (y , d , z , m_z , g_du_z0 , g_du_z1 , comp_prob , treatment , quantile , ipw_vec , coef_bounds , kde ):
220221 n_obs = len (y )
221222 ipw_est = ipw_vec .mean ()
222223 theta_hat = lpq_est (m_z , g_du_z0 , g_du_z1 , comp_prob , d , y , z , treatment , quantile , ipw_est , coef_bounds )
223224
224- se = np .sqrt (lpq_var_est (theta_hat , m_z , g_du_z0 , g_du_z1 , comp_prob , d , y , z , treatment , quantile , n_obs ))
225+ se = np .sqrt (lpq_var_est (theta_hat , m_z , g_du_z0 , g_du_z1 , comp_prob , d , y , z , treatment , quantile , n_obs , kde ))
225226
226227 return theta_hat , se
227228
0 commit comments