Skip to content

Commit 80ce23f

Browse files
committed
update cov test
1 parent a2566a5 commit 80ce23f

File tree

1 file changed

+32
-47
lines changed

1 file changed

+32
-47
lines changed

doubleml/irm/tests/test_iivm_unif_confset.py

Lines changed: 32 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,66 +5,51 @@
55

66
import doubleml as dml
77

8-
np.random.seed(3141)
98

10-
11-
@pytest.fixture(scope="module")
12-
def true_ATE():
13-
return 0.5
14-
15-
16-
@pytest.fixture(scope="module")
17-
def instrument_size():
18-
return 0.005
19-
20-
21-
@pytest.fixture(scope="module")
22-
def n_samples():
23-
return 1000
24-
25-
26-
@pytest.fixture(scope="module")
27-
def n_simulations():
28-
return 100
29-
30-
31-
@pytest.fixture(scope="module")
32-
def weakiv_data(n_samples, instrument_size, true_ATE):
33-
# Generate data
9+
def generate_weak_iv_data(n_samples, instrument_size, true_ATE):
3410
u = np.random.normal(0, 2, size=n_samples)
3511
X = np.random.normal(0, 1, size=n_samples)
3612
Z = np.random.binomial(1, 0.5, size=n_samples)
37-
A = instrument_size * Z + u # Continuous treatment A
13+
A = instrument_size * Z + u
3814
A = np.array(A > 0, dtype=int)
39-
Y = true_ATE * A + np.sign(u) # Outcome Y
40-
return dml.DoubleMLData.from_arrays(x=X, y=Y, d=A, z=Z)
41-
15+
Y = true_ATE * A + np.sign(u)
16+
dml_data = dml.DoubleMLData.from_arrays(x=X, y=Y, d=A, z=Z)
17+
return dml_data
4218

43-
@pytest.fixture(scope="module")
44-
def iivm_obj(weakiv_data):
45-
# Set machine learning methods for m, r & g
46-
learner_g = LinearRegression()
47-
classifier_m = LogisticRegression()
48-
classifier_r = RandomForestClassifier(n_estimators=20, max_depth=5)
4919

50-
# Create DoubleMLIIVM object
51-
obj_dml_data = weakiv_data
52-
dml_iivm_obj = dml.DoubleMLIIVM(obj_dml_data, learner_g, classifier_m, classifier_r)
53-
return dml_iivm_obj
20+
@pytest.mark.ci
21+
def test_coverage_robust_confset():
22+
# Test parameters
23+
true_ATE = 0.5
24+
instrument_size = 0.005
25+
n_samples = 1000
26+
n_simulations = 100
5427

55-
56-
def test_coverage(iivm_obj, true_ATE, n_simulations):
28+
np.random.seed(3141)
5729
coverage = []
5830
for _ in range(n_simulations):
59-
# Fit the model
60-
iivm_obj.fit()
31+
data = generate_weak_iv_data(n_samples, instrument_size, true_ATE)
32+
33+
# Set machine learning methods
34+
learner_g = LinearRegression()
35+
classifier_m = LogisticRegression()
36+
classifier_r = RandomForestClassifier(n_estimators=20, max_depth=5)
6137

62-
# Get the confidence set
63-
conf_set = iivm_obj.robust_confset()
38+
# Create and fit new model
39+
dml_iivm_obj = dml.DoubleMLIIVM(data, learner_g, classifier_m, classifier_r)
40+
dml_iivm_obj.fit()
6441

65-
# Check if the true ATE is in the confidence set
42+
# Get confidence set
43+
conf_set = dml_iivm_obj.robust_confset()
44+
45+
# check if conf_set is a list of tuples
46+
assert isinstance(conf_set, list)
47+
assert all(isinstance(x, tuple) and len(x) == 2 for x in conf_set)
48+
49+
# Check if true ATE is in confidence set
6650
ate_in_confset = any(x[0] < true_ATE < x[1] for x in conf_set)
6751
coverage.append(ate_in_confset)
68-
# Calculate the coverage rate
52+
53+
# Calculate coverage rate
6954
coverage_rate = np.mean(coverage)
7055
assert coverage_rate >= 0.9, f"Coverage rate {coverage_rate} is below 0.9"

0 commit comments

Comments
 (0)