Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
df03887
Logistic regression implementation WIP
jer2ig Jan 13, 2025
f5521f1
First WIP of implementation
jer2ig Jan 27, 2025
bfa756c
Working implementation. Started on test set-up.
jer2ig Feb 21, 2025
d729d0a
Changed data type of arrays
jer2ig Feb 27, 2025
8fe7ca6
Fix variable name
jer2ig Feb 27, 2025
18bac23
Moved into plm folder, started testing setup
jer2ig Aug 27, 2025
c6e600d
Fixed bug in score computation
jer2ig Aug 27, 2025
6f556e0
Reverted from ensure_all_finite to force_all_finite
jer2ig Aug 27, 2025
3a332bf
Fixes to instrument score
jer2ig Aug 28, 2025
b41a773
Added option for exception on convergence failure
jer2ig Sep 3, 2025
c434667
Added unbalanced dataset option, bug fixes
jer2ig Sep 29, 2025
443d82d
Added binary treatment dataset, fixed bug for model check
jer2ig Oct 7, 2025
774c74d
Adjusted dataset balancing
jer2ig Oct 7, 2025
9695820
Renamed Logistic to LPLR
jer2ig Oct 27, 2025
dbfea73
Clean-up of branch
jer2ig Oct 27, 2025
29114ce
Ruff checks and formatting
jer2ig Oct 27, 2025
5d2d1ed
Unit tests work and bug fix in lplr
jer2ig Oct 28, 2025
2c626a0
Cleanup
jer2ig Oct 28, 2025
9819436
Tests updated
jer2ig Nov 6, 2025
5a7e279
Pre-commit checks
jer2ig Nov 6, 2025
fc03cc6
Pre-commit checks on all files
jer2ig Nov 6, 2025
5dae651
Changed function signature, test
jer2ig Nov 7, 2025
13fca2f
Argument fix
jer2ig Nov 7, 2025
ff4c75b
Updated tests for improved coverage
jer2ig Nov 7, 2025
8a181cd
Unused var removed
jer2ig Nov 7, 2025
f2ecea7
Fixed resampling
jer2ig Nov 7, 2025
a9a2959
External predictions
jer2ig Nov 8, 2025
cd6055b
Bugfix and addtl text
jer2ig Nov 8, 2025
4a8be08
Change to ext predictions
jer2ig Nov 10, 2025
0472f1c
Change to targets data type
jer2ig Nov 10, 2025
2fc1f53
DoubleResamplin integrated into mixin, small changes
jer2ig Nov 10, 2025
ecfe2c7
Added attribute to sample mixin
jer2ig Nov 10, 2025
a9c0deb
Smpls inner access adjusted
jer2ig Nov 10, 2025
6abff49
Docstring, complexity reduction
jer2ig Nov 11, 2025
0f08e37
Weights updated, seed corrected
jer2ig Nov 11, 2025
430f4a6
Fix
jer2ig Nov 11, 2025
5b92395
Renaming
jer2ig Nov 11, 2025
042aa26
Doctest
jer2ig Nov 11, 2025
3b6f3b7
Test updated and comments implemented
jer2ig Nov 12, 2025
883aa77
Merge branch 'main' into jh-logistic-model
jer2ig Nov 12, 2025
74b1caa
Sample splitting exceptions
jer2ig Nov 12, 2025
46b575b
Merge remote-tracking branch 'origin/jh-logistic-model' into jh-logis…
jer2ig Nov 12, 2025
72be054
Test coverage increase
jer2ig Nov 12, 2025
5d9e0eb
Exception fixed
jer2ig Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doubleml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .irm.pq import DoubleMLPQ
from .irm.qte import DoubleMLQTE
from .irm.ssm import DoubleMLSSM
from .plm.lplr import DoubleMLLPLR
from .plm.pliv import DoubleMLPLIV
from .plm.plr import DoubleMLPLR
from .utils.blp import DoubleMLBLP
Expand Down Expand Up @@ -42,6 +43,7 @@
"DoubleMLBLP",
"DoubleMLPolicyTree",
"DoubleMLSSM",
"DoubleMLLPLR",
]

__version__ = importlib.metadata.version("doubleml")
68 changes: 50 additions & 18 deletions doubleml/double_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class DoubleML(SampleSplittingMixin, ABC):
"""Double Machine Learning."""

def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting):
def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting, double_sample_splitting=False):
# check and pick up obj_dml_data
if not isinstance(obj_dml_data, DoubleMLBaseData):
raise TypeError(
Expand All @@ -34,18 +34,10 @@ def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting):
if obj_dml_data.n_cluster_vars > 2:
raise NotImplementedError("Multi-way (n_ways > 2) clustering not yet implemented.")
self._is_cluster_data = True
self._is_panel_data = False
if isinstance(obj_dml_data, DoubleMLPanelData):
self._is_panel_data = True
self._is_did_data = False
if isinstance(obj_dml_data, DoubleMLDIDData):
self._is_did_data = True
self._is_ssm_data = False
if isinstance(obj_dml_data, DoubleMLSSMData):
self._is_ssm_data = True
self._is_rdd_data = False
if isinstance(obj_dml_data, DoubleMLRDDData):
self._is_rdd_data = True
self._is_panel_data = isinstance(obj_dml_data, DoubleMLPanelData)
self._is_did_data = isinstance(obj_dml_data, DoubleMLDIDData)
self._is_ssm_data = isinstance(obj_dml_data, DoubleMLSSMData)
self._is_rdd_data = isinstance(obj_dml_data, DoubleMLRDDData)

self._dml_data = obj_dml_data
self._n_obs = self._dml_data.n_obs
Expand Down Expand Up @@ -108,6 +100,9 @@ def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting):
self._smpls = None
self._smpls_cluster = None
self._n_obs_sample_splitting = self.n_obs
self._double_sample_splitting = double_sample_splitting
if self._double_sample_splitting:
self._smpls_inner = None
if draw_sample_splitting:
self.draw_sample_splitting()
self._score_dim = (self._dml_data.n_obs, self.n_rep, self._dml_data.n_coefs)
Expand Down Expand Up @@ -263,6 +258,13 @@ def learner(self):
"""
return self._learner

@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good idea to include a more obvious property for the name, but this should be identical to params_names right?
Especially how it is handled in evaluate_learners(). So do we need a extra list?

def predictions_names(self):
"""
The names of predictions for the nuisance functions.
"""
return list(self.params_names)

@property
def learner_names(self):
"""
Expand Down Expand Up @@ -359,6 +361,21 @@ def smpls(self):
raise ValueError(err_msg)
return self._smpls

@property
def smpls_inner(self):
"""
The partition used for cross-fitting.
"""
if not self._double_sample_splitting:
raise ValueError("smpls_inner is only available for double sample splitting.")
if self._smpls_inner is None:
err_msg = (
"Sample splitting not specified. Either draw samples via .draw_sample splitting() "
+ "or set external samples via .set_sample_splitting()."
)
raise ValueError(err_msg)
return self._smpls_inner

@property
def smpls_cluster(self):
"""
Expand Down Expand Up @@ -507,6 +524,18 @@ def summary(self):
def __smpls(self):
return self._smpls[self._i_rep]

@property
def __smpls__inner(self):
if not self._double_sample_splitting:
raise ValueError("smpls_inner is only available for double sample splitting.")
if self._smpls_inner is None:
err_msg = (
"Sample splitting not specified. Either draw samples via .draw_sample splitting() "
+ "or set external samples via .set_sample_splitting()."
)
raise ValueError(err_msg)
return self._smpls_inner[self._i_rep]

@property
def __smpls_cluster(self):
return self._smpls_cluster[self._i_rep]
Expand Down Expand Up @@ -1059,7 +1088,7 @@ def _check_fit(self, n_jobs_cv, store_predictions, external_predictions, store_m
_check_external_predictions(
external_predictions=external_predictions,
valid_treatments=self._dml_data.d_cols,
valid_learners=self.params_names,
valid_learners=self.predictions_names,
n_obs=self.n_obs,
n_rep=self.n_rep,
)
Expand All @@ -1081,7 +1110,10 @@ def _initalize_fit(self, store_predictions, store_models):

def _fit_nuisance_and_score_elements(self, n_jobs_cv, store_predictions, external_predictions, store_models):
ext_prediction_dict = _set_external_predictions(
external_predictions, learners=self.params_names, treatment=self._dml_data.d_cols[self._i_treat], i_rep=self._i_rep
external_predictions,
learners=self.predictions_names,
treatment=self._dml_data.d_cols[self._i_treat],
i_rep=self._i_rep,
)

# ml estimation of nuisance models and computation of score elements
Expand Down Expand Up @@ -1146,8 +1178,8 @@ def _initialize_arrays(self):
self._all_se = np.full((n_thetas, n_rep), np.nan)

def _initialize_predictions_and_targets(self):
self._predictions = {learner: np.full(self._score_dim, np.nan) for learner in self.params_names}
self._nuisance_targets = {learner: np.full(self._score_dim, np.nan) for learner in self.params_names}
self._predictions = {learner: np.full(self._score_dim, np.nan) for learner in self.predictions_names}
self._nuisance_targets = {learner: np.full(self._score_dim, np.nan) for learner in self.predictions_names}

def _initialize_nuisance_loss(self):
self._nuisance_loss = {learner: np.full((self.n_rep, self._dml_data.n_coefs), np.nan) for learner in self.params_names}
Expand All @@ -1158,7 +1190,7 @@ def _initialize_models(self):
}

def _store_predictions_and_targets(self, preds, targets):
for learner in self.params_names:
for learner in self.predictions_names:
self._predictions[learner][:, self._i_rep, self._i_treat] = preds[learner]
self._nuisance_targets[learner][:, self._i_rep, self._i_treat] = targets[learner]

Expand Down
27 changes: 22 additions & 5 deletions doubleml/double_ml_sampling_mixins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod

from doubleml.utils._checks import _check_sample_splitting
from doubleml.utils.resampling import DoubleMLClusterResampling, DoubleMLResampling
from doubleml.utils.resampling import DoubleMLClusterResampling, DoubleMLDoubleResampling, DoubleMLResampling


class SampleSplittingMixin:
Expand All @@ -17,6 +17,8 @@ class SampleSplittingMixin:
`sample splitting <https://docs.doubleml.org/stable/guide/resampling.html>`_ in the DoubleML user guide.
"""

_double_sample_splitting = False

def draw_sample_splitting(self):
"""
Draw sample splitting for DoubleML models.
Expand All @@ -29,6 +31,8 @@ def draw_sample_splitting(self):
self : object
"""
if self._is_cluster_data:
if self._double_sample_splitting:
raise ValueError("Cluster data not supported for double sample splitting.")
obj_dml_resampling = DoubleMLClusterResampling(
n_folds=self._n_folds_per_cluster,
n_rep=self.n_rep,
Expand All @@ -38,10 +42,20 @@ def draw_sample_splitting(self):
)
self._smpls, self._smpls_cluster = obj_dml_resampling.split_samples()
else:
obj_dml_resampling = DoubleMLResampling(
n_folds=self.n_folds, n_rep=self.n_rep, n_obs=self._n_obs_sample_splitting, stratify=self._strata
)
self._smpls = obj_dml_resampling.split_samples()
if self._double_sample_splitting:
obj_dml_resampling = DoubleMLDoubleResampling(
n_folds=self.n_folds,
n_folds_inner=self.n_folds_inner,
n_rep=self.n_rep,
n_obs=self._dml_data.n_obs,
stratify=self._strata,
)
self._smpls, self._smpls_inner = obj_dml_resampling.split_samples()
else:
obj_dml_resampling = DoubleMLResampling(
n_folds=self.n_folds, n_rep=self.n_rep, n_obs=self._n_obs_sample_splitting, stratify=self._strata
)
self._smpls = obj_dml_resampling.split_samples()

return self

Expand Down Expand Up @@ -104,6 +118,9 @@ def set_sample_splitting(self, all_smpls, all_smpls_cluster=None):
>>> dml_plr_obj.set_sample_splitting(smpls) # doctest: +ELLIPSIS
<doubleml.plm.plr.DoubleMLPLR object at 0x...>
"""
if self._double_sample_splitting:
raise ValueError("set_sample_splitting not supported for double sample splitting.")

self._smpls, self._smpls_cluster, self._n_rep, self._n_folds = _check_sample_splitting(
all_smpls, all_smpls_cluster, self._dml_data, self._is_cluster_data, n_obs=self._n_obs_sample_splitting
)
Expand Down
19 changes: 16 additions & 3 deletions doubleml/double_ml_score_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class NonLinearScoreMixin:
_score_type = "nonlinear"
_coef_start_val = np.nan
_coef_bounds = None
_error_on_convergence_failure = False

@property
@abstractmethod
Expand Down Expand Up @@ -149,12 +150,16 @@ def score_deriv(theta):
theta_hat = root_res.root
if not root_res.converged:
score_val = score(theta_hat)
warnings.warn(
msg = (
"Could not find a root of the score function.\n "
f"Flag: {root_res.flag}.\n"
f"Score value found is {score_val} "
f"for parameter theta equal to {theta_hat}."
)
if self._error_on_convergence_failure:
raise ValueError(msg)
else:
warnings.warn(msg)
else:
signs_different, bracket_guess = _get_bracket_guess(score, self._coef_start_val, self._coef_bounds)

Expand Down Expand Up @@ -186,12 +191,16 @@ def score_squared(theta):
score, self._coef_start_val, approx_grad=True, bounds=[self._coef_bounds]
)
theta_hat = theta_hat_array.item()
warnings.warn(
msg = (
"Could not find a root of the score function.\n "
f"Minimum score value found is {score_val} "
f"for parameter theta equal to {theta_hat}.\n "
"No theta found such that the score function evaluates to a negative value."
)
if self._error_on_convergence_failure:
raise ValueError(msg)
else:
warnings.warn(msg)
else:

def neg_score(theta):
Expand All @@ -202,11 +211,15 @@ def neg_score(theta):
neg_score, self._coef_start_val, approx_grad=True, bounds=[self._coef_bounds]
)
theta_hat = theta_hat_array.item()
warnings.warn(
msg = (
"Could not find a root of the score function. "
f"Maximum score value found is {-1 * neg_score_val} "
f"for parameter theta equal to {theta_hat}. "
"No theta found such that the score function evaluates to a positive value."
)
if self._error_on_convergence_failure:
raise ValueError(msg)
else:
warnings.warn(msg)

return theta_hat
Loading