Skip to content

Commit 20a4d6d

Browse files
committed
add exception tests
1 parent 80ce23f commit 20a4d6d

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

doubleml/irm/iivm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ def robust_confset(self, level=0.95):
570570
bounds of an interval. The union of this intervals forms the confidence set.
571571
"""
572572

573+
if self.framework is None:
574+
raise ValueError("Apply fit() before robust_confset().")
573575
if not isinstance(level, float):
574576
raise TypeError(f"The confidence level must be of float type. {str(level)} of type {str(type(level))} was passed.")
575577
if (level <= 0) | (level >= 1):

doubleml/irm/tests/test_iivm_unif_confset.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,36 @@ def test_coverage_robust_confset():
5353
# Calculate coverage rate
5454
coverage_rate = np.mean(coverage)
5555
assert coverage_rate >= 0.9, f"Coverage rate {coverage_rate} is below 0.9"
56+
57+
58+
@pytest.mark.ci
59+
def test_exceptions_robust_confset():
60+
# Test parameters
61+
true_ATE = 0.5
62+
instrument_size = 0.005
63+
n_samples = 1000
64+
65+
np.random.seed(3141)
66+
data = generate_weak_iv_data(n_samples, instrument_size, true_ATE)
67+
68+
# create new model
69+
learner_g = LinearRegression()
70+
classifier_m = LogisticRegression()
71+
classifier_r = RandomForestClassifier(n_estimators=20, max_depth=5)
72+
dml_iivm_obj = dml.DoubleMLIIVM(data, learner_g, classifier_m, classifier_r)
73+
74+
# Check if the robust_confset method raises an exception when called before fitting
75+
msg = r"Apply fit\(\) before robust_confset\(\)."
76+
with pytest.raises(ValueError, match=msg):
77+
dml_iivm_obj.robust_confset()
78+
79+
# Fit the model
80+
dml_iivm_obj.fit()
81+
82+
# Check invalid inputs
83+
msg = "The confidence level must be of float type. 0.95 of type <class 'str'> was passed."
84+
with pytest.raises(TypeError, match=msg):
85+
dml_iivm_obj.robust_confset(level="0.95")
86+
msg = r"The confidence level must be in \(0,1\). 1.5 was passed."
87+
with pytest.raises(ValueError, match=msg):
88+
dml_iivm_obj.robust_confset(level=1.5)

0 commit comments

Comments
 (0)