Skip to content

Commit 83dd94f

Browse files
authored
Merge pull request #318 from esmucler/unifconfset
Methods for confidence sets for IV models that are robust to weak instruments
2 parents 2e38a03 + 776e644 commit 83dd94f

File tree

4 files changed

+250
-1
lines changed

4 files changed

+250
-1
lines changed

doubleml/irm/iivm.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from scipy.stats import norm
23
from sklearn.utils import check_X_y
34
from sklearn.utils.multiclass import type_of_target
45

@@ -12,7 +13,7 @@
1213
_check_score,
1314
_check_trimming,
1415
)
15-
from doubleml.utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls
16+
from doubleml.utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls, _solve_quadratic_inequality
1617
from doubleml.utils._propensity_score import _normalize_ipw, _trimm
1718

1819

@@ -196,6 +197,23 @@ def __init__(
196197
self.subgroups = subgroups
197198
self._external_predictions_implemented = True
198199

200+
def __str__(self):
201+
parent_str = super().__str__()
202+
203+
# add robust confset
204+
if self.framework is None:
205+
confset_str = ""
206+
else:
207+
confset = self.robust_confset()
208+
formatted_confset = ", ".join([f"[{lower:.4f}, {upper:.4f}]" for lower, upper in confset])
209+
confset_str = (
210+
"\n\n--------------- Additional Information ----------------\n"
211+
+ f"Robust Confidence Set: {formatted_confset}\n"
212+
)
213+
214+
res = parent_str + confset_str
215+
return res
216+
199217
@property
200218
def normalize_ipw(self):
201219
"""
@@ -550,3 +568,45 @@ def _nuisance_tuning(
550568

551569
def _sensitivity_element_est(self, preds):
552570
pass
571+
572+
def robust_confset(self, level=0.95):
573+
"""
574+
Confidence sets for non-parametric instrumental variable models that are uniformly valid under weak instruments.
575+
These are obtained by inverting a score-like test statistic based on estimated influence function.
576+
577+
Parameters
578+
----------
579+
level : float
580+
The confidence level.
581+
Default is ``0.95``.
582+
583+
Returns
584+
-------
585+
list_confset : List
586+
A list that contains tuples. Each tuple contains the lower and upper
587+
bounds of an interval. The union of this intervals forms the confidence set.
588+
"""
589+
590+
if self.framework is None:
591+
raise ValueError("Apply fit() before robust_confset().")
592+
if not isinstance(level, float):
593+
raise TypeError(f"The confidence level must be of float type. {str(level)} of type {str(type(level))} was passed.")
594+
if (level <= 0) | (level >= 1):
595+
raise ValueError(f"The confidence level must be in (0,1). {str(level)} was passed.")
596+
597+
# compute critical values
598+
alpha = 1 - level
599+
critical_value = norm.ppf(1.0 - alpha / 2)
600+
601+
# We need to find the thetas that solve the equation
602+
# n * np.mean(score(theta))/np.mean(score(theta)**2) <= critical_value**2.
603+
# This is equivalent to solving the equation
604+
# a theta^2 + b theta + c <= 0
605+
# for some a, b, c, which we calculate next, and then solve the equation.
606+
n = self.psi_elements["psi_a"].shape[0]
607+
a = n * np.mean(self.psi_elements["psi_a"]) ** 2 - critical_value**2 * np.mean(np.square(self.psi_elements["psi_a"]))
608+
b = 2 * n * np.mean(self.psi_elements["psi_a"]) * np.mean(
609+
self.psi_elements["psi_b"]
610+
) - 2 * critical_value**2 * np.mean(np.multiply(self.psi_elements["psi_a"], self.psi_elements["psi_b"]))
611+
c = n * np.mean(self.psi_elements["psi_b"]) ** 2 - critical_value**2 * np.mean(np.square(self.psi_elements["psi_b"]))
612+
return _solve_quadratic_inequality(a, b, c)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
import pytest
3+
from sklearn.ensemble import RandomForestClassifier
4+
from sklearn.linear_model import LinearRegression, LogisticRegression
5+
6+
import doubleml as dml
7+
8+
9+
def generate_weak_iv_data(n_samples, instrument_size, true_ATE):
10+
u = np.random.normal(0, 2, size=n_samples)
11+
X = np.random.normal(0, 1, size=n_samples)
12+
Z = np.random.binomial(1, 0.5, size=n_samples)
13+
A = instrument_size * Z + u
14+
A = np.array(A > 0, dtype=int)
15+
Y = true_ATE * A + np.sign(u)
16+
dml_data = dml.DoubleMLData.from_arrays(x=X, y=Y, d=A, z=Z)
17+
return dml_data
18+
19+
20+
@pytest.mark.ci
21+
def test_coverage_robust_confset():
22+
# Test parameters
23+
true_ATE = 0.5
24+
instrument_size = 0.005
25+
n_samples = 1000
26+
n_simulations = 100
27+
28+
np.random.seed(3141)
29+
coverage = []
30+
for _ in range(n_simulations):
31+
data = generate_weak_iv_data(n_samples, instrument_size, true_ATE)
32+
33+
# Set machine learning methods
34+
learner_g = LinearRegression()
35+
classifier_m = LogisticRegression()
36+
classifier_r = RandomForestClassifier(n_estimators=20, max_depth=5)
37+
38+
# Create and fit new model
39+
dml_iivm_obj = dml.DoubleMLIIVM(data, learner_g, classifier_m, classifier_r)
40+
dml_iivm_obj.fit()
41+
42+
# Get confidence set
43+
conf_set = dml_iivm_obj.robust_confset()
44+
45+
# check if conf_set is a list of tuples
46+
assert isinstance(conf_set, list)
47+
assert all(isinstance(x, tuple) and len(x) == 2 for x in conf_set)
48+
49+
# Check if true ATE is in confidence set
50+
ate_in_confset = any(x[0] < true_ATE < x[1] for x in conf_set)
51+
coverage.append(ate_in_confset)
52+
53+
# Calculate coverage rate
54+
coverage_rate = np.mean(coverage)
55+
assert coverage_rate >= 0.9, f"Coverage rate {coverage_rate} is below 0.9"
56+
57+
58+
@pytest.mark.ci
59+
def test_exceptions_robust_confset():
60+
# Test parameters
61+
true_ATE = 0.5
62+
instrument_size = 0.005
63+
n_samples = 1000
64+
65+
np.random.seed(3141)
66+
data = generate_weak_iv_data(n_samples, instrument_size, true_ATE)
67+
68+
# create new model
69+
learner_g = LinearRegression()
70+
classifier_m = LogisticRegression()
71+
classifier_r = RandomForestClassifier(n_estimators=20, max_depth=5)
72+
dml_iivm_obj = dml.DoubleMLIIVM(data, learner_g, classifier_m, classifier_r)
73+
74+
# Check if the robust_confset method raises an exception when called before fitting
75+
msg = r"Apply fit\(\) before robust_confset\(\)."
76+
with pytest.raises(ValueError, match=msg):
77+
dml_iivm_obj.robust_confset()
78+
79+
# Check if str representation of the object is working
80+
str_repr = str(dml_iivm_obj)
81+
assert isinstance(str_repr, str)
82+
assert "Robust" not in str_repr
83+
84+
# Fit the model
85+
dml_iivm_obj.fit()
86+
87+
# Check invalid inputs
88+
msg = "The confidence level must be of float type. 0.95 of type <class 'str'> was passed."
89+
with pytest.raises(TypeError, match=msg):
90+
dml_iivm_obj.robust_confset(level="0.95")
91+
msg = r"The confidence level must be in \(0,1\). 1.5 was passed."
92+
with pytest.raises(ValueError, match=msg):
93+
dml_iivm_obj.robust_confset(level=1.5)
94+
95+
# Check if str representation of the object is working
96+
str_repr = str(dml_iivm_obj)
97+
assert isinstance(str_repr, str)
98+
assert "Robust Confidence Set" in str_repr

doubleml/utils/_estimation.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,53 @@ def _set_external_predictions(external_predictions, learners, treatment, i_rep):
341341
else:
342342
ext_prediction_dict[learner] = None
343343
return ext_prediction_dict
344+
345+
346+
def _solve_quadratic_inequality(a: float, b: float, c: float):
347+
"""
348+
Solves the quadratic inequation a*x^2 + b*x + c <= 0 and returns the intervals.
349+
350+
Parameters
351+
----------
352+
a : float
353+
Coefficient of x^2.
354+
b : float
355+
Coefficient of x.
356+
c : float
357+
Constant term.
358+
359+
Returns
360+
-------
361+
List[Tuple[float, float]]
362+
A list of intervals where the inequation holds.
363+
"""
364+
365+
# Handle special cases
366+
if abs(a) < 1e-12: # a is effectively zero
367+
if abs(b) < 1e-12: # constant case
368+
return [(-np.inf, np.inf)] if c <= 0 else []
369+
# Linear case:
370+
else:
371+
root = -c / b
372+
return [(-np.inf, root)] if b > 0 else [(root, np.inf)]
373+
374+
# Standard case: quadratic equation
375+
roots = np.polynomial.polynomial.polyroots([c, b, a])
376+
real_roots = np.sort(roots[np.isreal(roots)].real)
377+
378+
if len(real_roots) == 0: # No real roots
379+
if a > 0: # parabola opens upwards, no real roots
380+
return []
381+
else: # parabola opens downwards, always <= 0
382+
return [(-np.inf, np.inf)]
383+
elif len(real_roots) == 1 or np.allclose(real_roots[0], real_roots[1]): # One real root
384+
if a > 0:
385+
return [(real_roots[0], real_roots[0])] # parabola touches x-axis at one point
386+
else:
387+
return [(-np.inf, np.inf)] # parabola is always <= 0
388+
else:
389+
assert len(real_roots) == 2
390+
if a > 0: # happy quadratic (parabola opens upwards)
391+
return [(real_roots[0], real_roots[1])]
392+
else: # sad quadratic (parabola opens downwards)
393+
return [(-np.inf, real_roots[0]), (real_roots[1], np.inf)]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
import pytest
3+
4+
from doubleml.utils._estimation import _solve_quadratic_inequality
5+
6+
7+
@pytest.mark.parametrize(
8+
"a, b, c, expected",
9+
[
10+
(1, 0, -4, [(-2.0, 2.0)]), # happy quadratic, determinant > 0
11+
(-1, 0, 4, [(-np.inf, -2), (2, np.inf)]), # sad quadratic, determinant > 0
12+
(1, 0, 4, []), # happy quadratic, determinant < 0
13+
(-1, 0, -4, [(-np.inf, np.inf)]), # sad quadratic, determinant < 0
14+
(1, 0, 0, [(0.0, 0.0)]), # happy quadratic, determinant = 0
15+
(-1, 0, 0, [(-np.inf, np.inf)]), # sad quadratic, determinant = 0
16+
(1, 3, -4, [(-4.0, 1.0)]), # happy quadratic, determinant > 0
17+
(-1, 3, 4, [(-np.inf, -1), (4, np.inf)]), # sad quadratic, determinant > 0
18+
(-1, -3, 4, [(-np.inf, -4), (1, np.inf)]), # sad quadratic, determinant > 0
19+
(1, 3, 4, []), # happy quadratic, determinant < 0
20+
(-1, 3, -4, [(-np.inf, np.inf)]), # sad quadratic, determinant < 0
21+
(1, 4, 4, [(-2.0, -2.0)]), # happy quadratic, determinant = 0
22+
(-1, 4, -4, [(-np.inf, np.inf)]), # sad quadratic, determinant = 0
23+
(0, 0, 0, [(-np.inf, np.inf)]), # constant and equal to zero
24+
(0, 0, 1, []), # constant and larger than zero
25+
(0, 1, 0, [(-np.inf, 0.0)]), # increasing linear function
26+
(0, -1, -1, [(-1.0, np.inf)]), # decreasing linear function
27+
],
28+
)
29+
def test_solve_quadratic_inequation(a, b, c, expected):
30+
result = _solve_quadratic_inequality(a, b, c)
31+
32+
assert len(result) == len(expected), f"Expected {len(expected)} intervals but got {len(result)}"
33+
34+
for i, tpl in enumerate(result):
35+
if tpl[0] == -np.inf:
36+
assert np.isinf(tpl[0])
37+
if tpl[1] == np.inf:
38+
assert np.isinf(tpl[1])
39+
else:
40+
assert np.isclose(tpl[0], expected[i][0]), f"Expected {expected[i][0]} but got {tpl[0]}"
41+
assert np.isclose(tpl[1], expected[i][1]), f"Expected {expected[i][1]} but got {tpl[1]}"

0 commit comments

Comments
 (0)