77from sklearn .preprocessing import LabelEncoder
88from sklearn .model_selection import KFold , GridSearchCV , RandomizedSearchCV
99from sklearn .metrics import mean_squared_error
10- from sklearn .utils .multiclass import type_of_target
1110
1211from statsmodels .nonparametric .kde import KDEUnivariate
1312
1413from joblib import Parallel , delayed
1514
15+ from ._utils_checks import _check_is_partition
16+
1617
1718def _assure_2d_array (x ):
1819 if x .ndim == 1 :
@@ -40,63 +41,6 @@ def _get_cond_smpls_2d(smpls, bin_var1, bin_var2):
4041 return smpls_00 , smpls_01 , smpls_10 , smpls_11
4142
4243
43- def _check_is_partition (smpls , n_obs ):
44- test_indices = np .concatenate ([test_index for _ , test_index in smpls ])
45- if len (test_indices ) != n_obs :
46- return False
47- hit = np .zeros (n_obs , dtype = bool )
48- hit [test_indices ] = True
49- if not np .all (hit ):
50- return False
51- return True
52-
53-
54- def _check_all_smpls (all_smpls , n_obs , check_intersect = False ):
55- all_smpls_checked = list ()
56- for smpl in all_smpls :
57- all_smpls_checked .append (_check_smpl_split (smpl , n_obs , check_intersect ))
58- return all_smpls_checked
59-
60-
61- def _check_smpl_split (smpl , n_obs , check_intersect = False ):
62- smpl_checked = list ()
63- for tpl in smpl :
64- smpl_checked .append (_check_smpl_split_tpl (tpl , n_obs , check_intersect ))
65- return smpl_checked
66-
67-
68- def _check_smpl_split_tpl (tpl , n_obs , check_intersect = False ):
69- train_index = np .sort (np .array (tpl [0 ]))
70- test_index = np .sort (np .array (tpl [1 ]))
71-
72- if not issubclass (train_index .dtype .type , np .integer ):
73- raise TypeError ('Invalid sample split. Train indices must be of type integer.' )
74- if not issubclass (test_index .dtype .type , np .integer ):
75- raise TypeError ('Invalid sample split. Test indices must be of type integer.' )
76-
77- if check_intersect :
78- if set (train_index ) & set (test_index ):
79- raise ValueError ('Invalid sample split. Intersection of train and test indices is not empty.' )
80-
81- if len (np .unique (train_index )) != len (train_index ):
82- raise ValueError ('Invalid sample split. Train indices contain non-unique entries.' )
83- if len (np .unique (test_index )) != len (test_index ):
84- raise ValueError ('Invalid sample split. Test indices contain non-unique entries.' )
85-
86- # we sort the indices above
87- # if not np.all(np.diff(train_index) > 0):
88- # raise NotImplementedError('Invalid sample split. Only sorted train indices are supported.')
89- # if not np.all(np.diff(test_index) > 0):
90- # raise NotImplementedError('Invalid sample split. Only sorted test indices are supported.')
91-
92- if not set (train_index ).issubset (range (n_obs )):
93- raise ValueError ('Invalid sample split. Train indices must be in [0, n_obs).' )
94- if not set (test_index ).issubset (range (n_obs )):
95- raise ValueError ('Invalid sample split. Test indices must be in [0, n_obs).' )
96-
97- return train_index , test_index
98-
99-
10044def _fit (estimator , x , y , train_index , idx = None ):
10145 estimator .fit (x [train_index , :], y [train_index ])
10246 return estimator , idx
@@ -238,13 +182,6 @@ def _draw_weights(method, n_rep_boot, n_obs):
238182 return weights
239183
240184
241- def _check_finite_predictions (preds , learner , learner_name , smpls ):
242- test_indices = np .concatenate ([test_index for _ , test_index in smpls ])
243- if not np .all (np .isfinite (preds [test_indices ])):
244- raise ValueError (f'Predictions from learner { str (learner )} for { learner_name } are not finite.' )
245- return
246-
247-
248185def _trimm (preds , trimming_rule , trimming_threshold ):
249186 if trimming_rule == 'truncate' :
250187 preds [preds < trimming_threshold ] = trimming_threshold
@@ -261,14 +198,6 @@ def _normalize_ipw(propensity, treatment):
261198 return normalized_weights
262199
263200
264- def _check_is_propensity (preds , learner , learner_name , smpls , eps = 1e-12 ):
265- test_indices = np .concatenate ([test_index for _ , test_index in smpls ])
266- if any ((preds [test_indices ] < eps ) | (preds [test_indices ] > 1 - eps )):
267- warnings .warn (f'Propensity predictions from learner { str (learner )} for'
268- f' { learner_name } are close to zero or one (eps={ eps } ).' )
269- return
270-
271-
272201def _rmse (y_true , y_pred ):
273202 subset = np .logical_not (np .isnan (y_true ))
274203 rmse = mean_squared_error (y_true [subset ], y_pred [subset ], squared = False )
@@ -285,77 +214,6 @@ def _predict_zero_one_propensity(learner, X):
285214 return res
286215
287216
288- def _check_contains_iv (obj_dml_data ):
289- if obj_dml_data .z_cols is not None :
290- raise ValueError ('Incompatible data. ' +
291- ' and ' .join (obj_dml_data .z_cols ) +
292- ' have been set as instrumental variable(s). '
293- 'To fit an local model see the documentation.' )
294-
295-
296- def _check_zero_one_treatment (obj_dml ):
297- one_treat = (obj_dml ._dml_data .n_treat == 1 )
298- binary_treat = (type_of_target (obj_dml ._dml_data .d ) == 'binary' )
299- zero_one_treat = np .all ((np .power (obj_dml ._dml_data .d , 2 ) - obj_dml ._dml_data .d ) == 0 )
300- if not (one_treat & binary_treat & zero_one_treat ):
301- raise ValueError ('Incompatible data. '
302- f'To fit an { str (obj_dml .score )} model with DML '
303- 'exactly one binary variable with values 0 and 1 '
304- 'needs to be specified as treatment variable.' )
305-
306-
307- def _check_quantile (quantile ):
308- if not isinstance (quantile , float ):
309- raise TypeError ('Quantile has to be a float. ' +
310- f'Object of type { str (type (quantile ))} passed.' )
311-
312- if (quantile <= 0 ) | (quantile >= 1 ):
313- raise ValueError ('Quantile has be between 0 or 1. ' +
314- f'Quantile { str (quantile )} passed.' )
315- return
316-
317-
318- def _check_treatment (treatment ):
319- if not isinstance (treatment , int ):
320- raise TypeError ('Treatment indicator has to be an integer. ' +
321- f'Object of type { str (type (treatment ))} passed.' )
322-
323- if (treatment != 0 ) & (treatment != 1 ):
324- raise ValueError ('Treatment indicator has be either 0 or 1. ' +
325- f'Treatment indicator { str (treatment )} passed.' )
326- return
327-
328-
329- def _check_trimming (trimming_rule , trimming_threshold ):
330- valid_trimming_rule = ['truncate' ]
331- if trimming_rule not in valid_trimming_rule :
332- raise ValueError ('Invalid trimming_rule ' + str (trimming_rule ) + '. ' +
333- 'Valid trimming_rule ' + ' or ' .join (valid_trimming_rule ) + '.' )
334- if not isinstance (trimming_threshold , float ):
335- raise TypeError ('trimming_threshold has to be a float. ' +
336- f'Object of type { str (type (trimming_threshold ))} passed.' )
337- if (trimming_threshold <= 0 ) | (trimming_threshold >= 0.5 ):
338- raise ValueError ('Invalid trimming_threshold ' + str (trimming_threshold ) + '. ' +
339- 'trimming_threshold has to be between 0 and 0.5.' )
340- return
341-
342-
343- def _check_score (score , valid_score , allow_callable = True ):
344- if isinstance (score , str ):
345- if score not in valid_score :
346- raise ValueError ('Invalid score ' + score + '. ' +
347- 'Valid score ' + ' or ' .join (valid_score ) + '.' )
348- else :
349- if allow_callable :
350- if not callable (score ):
351- raise TypeError ('score should be either a string or a callable. '
352- '%r was passed.' % score )
353- else :
354- raise TypeError ('score should be a string. '
355- '%r was passed.' % score )
356- return
357-
358-
359217def _get_bracket_guess (score , coef_start , coef_bounds ):
360218 max_bracket_length = coef_bounds [1 ] - coef_bounds [0 ]
361219 b_guess = coef_bounds
@@ -388,3 +246,90 @@ def abs_ipw_score(theta):
388246 method = 'brent' )
389247 ipw_est = res .x
390248 return ipw_est
249+
250+
251+ def _aggregate_coefs_and_ses (all_coefs , all_ses , var_scaling_factor ):
252+ # aggregation is done over dimension 1, such that the coefs and ses have to be of shape (n_coefs, n_rep)
253+ n_rep = all_coefs .shape [1 ]
254+ coefs = np .median (all_coefs , 1 )
255+
256+ xx = np .tile (coefs .reshape (- 1 , 1 ), n_rep )
257+ ses = np .sqrt (np .divide (np .median (np .multiply (np .power (all_ses , 2 ), var_scaling_factor ) +
258+ np .power (all_coefs - xx , 2 ), 1 ), var_scaling_factor ))
259+
260+ return coefs , ses
261+
262+
263+ def _var_est (psi , psi_deriv , apply_cross_fitting , smpls , is_cluster_data ,
264+ cluster_vars = None , smpls_cluster = None , n_folds_per_cluster = None ):
265+
266+ if not is_cluster_data :
267+ # psi and psi_deriv should be of shape (n_obs, ...)
268+ if apply_cross_fitting :
269+ var_scaling_factor = psi .shape [0 ]
270+ else :
271+ # In case of no-cross-fitting, the score function was only evaluated on the test data set
272+ test_index = smpls [0 ][1 ]
273+ psi_deriv = psi_deriv [test_index ]
274+ psi = psi [test_index ]
275+ var_scaling_factor = len (test_index )
276+
277+ J = np .mean (psi_deriv )
278+ gamma_hat = np .mean (np .square (psi ))
279+
280+ else :
281+ assert cluster_vars is not None
282+ assert smpls_cluster is not None
283+ assert n_folds_per_cluster is not None
284+ n_folds = len (smpls )
285+
286+ # one cluster
287+ if cluster_vars .shape [1 ] == 1 :
288+ first_cluster_var = cluster_vars [:, 0 ]
289+ clusters = np .unique (first_cluster_var )
290+ gamma_hat = 0
291+ j_hat = 0
292+ for i_fold in range (n_folds ):
293+ test_inds = smpls [i_fold ][1 ]
294+ test_cluster_inds = smpls_cluster [i_fold ][1 ]
295+ I_k = test_cluster_inds [0 ]
296+ const = 1 / len (I_k )
297+ for cluster_value in I_k :
298+ ind_cluster = (first_cluster_var == cluster_value )
299+ gamma_hat += const * np .sum (np .outer (psi [ind_cluster ], psi [ind_cluster ]))
300+ j_hat += np .sum (psi_deriv [test_inds ]) / len (I_k )
301+
302+ var_scaling_factor = len (clusters )
303+ J = np .divide (j_hat , n_folds_per_cluster )
304+ gamma_hat = np .divide (gamma_hat , n_folds_per_cluster )
305+
306+ else :
307+ assert cluster_vars .shape [1 ] == 2
308+ first_cluster_var = cluster_vars [:, 0 ]
309+ second_cluster_var = cluster_vars [:, 1 ]
310+ gamma_hat = 0
311+ j_hat = 0
312+ for i_fold in range (n_folds ):
313+ test_inds = smpls [i_fold ][1 ]
314+ test_cluster_inds = smpls_cluster [i_fold ][1 ]
315+ I_k = test_cluster_inds [0 ]
316+ J_l = test_cluster_inds [1 ]
317+ const = np .divide (min (len (I_k ), len (J_l )), (np .square (len (I_k ) * len (J_l ))))
318+ for cluster_value in I_k :
319+ ind_cluster = (first_cluster_var == cluster_value ) & np .in1d (second_cluster_var , J_l )
320+ gamma_hat += const * np .sum (np .outer (psi [ind_cluster ], psi [ind_cluster ]))
321+ for cluster_value in J_l :
322+ ind_cluster = (second_cluster_var == cluster_value ) & np .in1d (first_cluster_var , I_k )
323+ gamma_hat += const * np .sum (np .outer (psi [ind_cluster ], psi [ind_cluster ]))
324+ j_hat += np .sum (psi_deriv [test_inds ]) / (len (I_k ) * len (J_l ))
325+
326+ n_first_clusters = len (np .unique (first_cluster_var ))
327+ n_second_clusters = len (np .unique (second_cluster_var ))
328+ var_scaling_factor = min (n_first_clusters , n_second_clusters )
329+ J = np .divide (j_hat , np .square (n_folds_per_cluster ))
330+ gamma_hat = np .divide (gamma_hat , np .square (n_folds_per_cluster ))
331+
332+ scaling = np .divide (1.0 , np .multiply (var_scaling_factor , np .square (J )))
333+ sigma2_hat = np .multiply (scaling , gamma_hat )
334+
335+ return sigma2_hat , var_scaling_factor
0 commit comments