77from sklearn .base import clone
88from sklearn .linear_model import LogisticRegression
99from sklearn .ensemble import RandomForestClassifier
10+ from statsmodels .nonparametric .kde import KDEUnivariate
1011
1112from ._utils import draw_smpls
1213from ._utils_lpq_manual import fit_lpq
14+ from .._utils import _default_kde
15+
16+
17+ def custom_kde (u , weights ):
18+ dens = KDEUnivariate (u )
19+ dens .fit (kernel = 'epa' , bw = 'silverman' , weights = weights , fft = False )
20+
21+ return dens .evaluate (0 )
1322
1423
1524@pytest .fixture (scope = 'module' ,
@@ -19,14 +28,13 @@ def treatment(request):
1928
2029
2130@pytest .fixture (scope = 'module' ,
22- params = [0.25 , 0.5 , 0. 75 ])
31+ params = [0.25 , 0.75 ])
2332def quantile (request ):
2433 return request .param
2534
2635
2736@pytest .fixture (scope = 'module' ,
28- params = [RandomForestClassifier (max_depth = 2 , n_estimators = 5 , random_state = 42 ),
29- LogisticRegression ()])
37+ params = [LogisticRegression ()])
3038def learner (request ):
3139 return request .param
3240
@@ -44,14 +52,20 @@ def normalize_ipw(request):
4452
4553
4654@pytest .fixture (scope = 'module' ,
47- params = [0.01 , 0. 05 ])
55+ params = [0.05 ])
4856def trimming_threshold (request ):
4957 return request .param
5058
5159
60+ @pytest .fixture (scope = 'module' ,
61+ params = ['default' , custom_kde ])
62+ def kde (request ):
63+ return request .param
64+
65+
5266@pytest .fixture (scope = "module" )
5367def dml_lpq_fixture (generate_data_local_quantiles , treatment , quantile , learner ,
54- dml_procedure , normalize_ipw , trimming_threshold ):
68+ dml_procedure , normalize_ipw , trimming_threshold , kde ):
5569 n_folds = 3
5670
5771 # collect data
@@ -63,26 +77,48 @@ def dml_lpq_fixture(generate_data_local_quantiles, treatment, quantile, learner,
6377 all_smpls = draw_smpls (n_obs , n_folds , n_rep = 1 , groups = strata )
6478
6579 np .random .seed (42 )
66- dml_lpq_obj = dml .DoubleMLLPQ (obj_dml_data ,
67- clone (learner ), clone (learner ),
68- treatment = treatment ,
69- quantile = quantile ,
70- n_folds = n_folds ,
71- n_rep = 1 ,
72- dml_procedure = dml_procedure ,
73- normalize_ipw = normalize_ipw ,
74- trimming_threshold = trimming_threshold ,
75- draw_sample_splitting = False )
76-
77- # synchronize the sample splitting
78- dml_lpq_obj .set_sample_splitting (all_smpls = all_smpls )
79- dml_lpq_obj .fit ()
80-
81- np .random .seed (42 )
82- res_manual = fit_lpq (y , x , d , z , quantile , clone (learner ), clone (learner ),
83- all_smpls , treatment , dml_procedure ,
84- normalize_ipw = normalize_ipw ,
85- n_rep = 1 , trimming_threshold = trimming_threshold )
80+ if kde == 'default' :
81+ dml_lpq_obj = dml .DoubleMLLPQ (obj_dml_data ,
82+ clone (learner ), clone (learner ),
83+ treatment = treatment ,
84+ quantile = quantile ,
85+ n_folds = n_folds ,
86+ n_rep = 1 ,
87+ dml_procedure = dml_procedure ,
88+ normalize_ipw = normalize_ipw ,
89+ trimming_threshold = trimming_threshold ,
90+ draw_sample_splitting = False )
91+ # synchronize the sample splitting
92+ dml_lpq_obj .set_sample_splitting (all_smpls = all_smpls )
93+ dml_lpq_obj .fit ()
94+
95+ np .random .seed (42 )
96+ res_manual = fit_lpq (y , x , d , z , quantile , clone (learner ), clone (learner ),
97+ all_smpls , treatment , dml_procedure ,
98+ normalize_ipw = normalize_ipw , kde = _default_kde ,
99+ n_rep = 1 , trimming_threshold = trimming_threshold )
100+ else :
101+ dml_lpq_obj = dml .DoubleMLLPQ (obj_dml_data ,
102+ clone (learner ), clone (learner ),
103+ treatment = treatment ,
104+ quantile = quantile ,
105+ n_folds = n_folds ,
106+ n_rep = 1 ,
107+ dml_procedure = dml_procedure ,
108+ normalize_ipw = normalize_ipw ,
109+ kde = kde ,
110+ trimming_threshold = trimming_threshold ,
111+ draw_sample_splitting = False )
112+
113+ # synchronize the sample splitting
114+ dml_lpq_obj .set_sample_splitting (all_smpls = all_smpls )
115+ dml_lpq_obj .fit ()
116+
117+ np .random .seed (42 )
118+ res_manual = fit_lpq (y , x , d , z , quantile , clone (learner ), clone (learner ),
119+ all_smpls , treatment , dml_procedure ,
120+ normalize_ipw = normalize_ipw , kde = kde ,
121+ n_rep = 1 , trimming_threshold = trimming_threshold )
86122
87123 res_dict = {'coef' : dml_lpq_obj .coef ,
88124 'coef_manual' : res_manual ['lpq' ],
0 commit comments