-
Notifications
You must be signed in to change notification settings - Fork 99
LPLR model #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
jer2ig
wants to merge
44
commits into
main
Choose a base branch
from
jh-logistic-model
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
LPLR model #365
Changes from all commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
df03887
Logistic regression implementation WIP
jer2ig f5521f1
First WIP of implementation
jer2ig bfa756c
Working implementation. Started on test set-up.
jer2ig d729d0a
Changed data type of arrays
jer2ig 8fe7ca6
Fix variable name
jer2ig 18bac23
Moved into plm folder, started testing setup
jer2ig c6e600d
Fixed bug in score computation
jer2ig 6f556e0
Reverted from ensure_all_finite to force_all_finite
jer2ig 3a332bf
Fixes to instrument score
jer2ig b41a773
Added option for exception on convergence failure
jer2ig c434667
Added unbalanced dataset option, bug fixes
jer2ig 443d82d
Added binary treatment dataset, fixed bug for model check
jer2ig 774c74d
Adjusted dataset balancing
jer2ig 9695820
Renamed Logistic to LPLR
jer2ig dbfea73
Clean-up of branch
jer2ig 29114ce
Ruff checks and formatting
jer2ig 5d2d1ed
Unit tests work and bug fix in lplr
jer2ig 2c626a0
Cleanup
jer2ig 9819436
Tests updated
jer2ig 5a7e279
Pre-commit checks
jer2ig fc03cc6
Pre-commit checks on all files
jer2ig 5dae651
Changed function signature, test
jer2ig 13fca2f
Argument fix
jer2ig ff4c75b
Updated tests for improved coverage
jer2ig 8a181cd
Unused var removed
jer2ig f2ecea7
Fixed resampling
jer2ig a9a2959
External predictions
jer2ig cd6055b
Bugfix and addtl text
jer2ig 4a8be08
Change to ext predictions
jer2ig 0472f1c
Change to targets data type
jer2ig 2fc1f53
DoubleResamplin integrated into mixin, small changes
jer2ig ecfe2c7
Added attribute to sample mixin
jer2ig a9c0deb
Smpls inner access adjusted
jer2ig 6abff49
Docstring, complexity reduction
jer2ig 0f08e37
Weights updated, seed corrected
jer2ig 430f4a6
Fix
jer2ig 5b92395
Renaming
jer2ig 042aa26
Doctest
jer2ig 3b6f3b7
Test updated and comments implemented
jer2ig 883aa77
Merge branch 'main' into jh-logistic-model
jer2ig 74b1caa
Sample splitting exceptions
jer2ig 46b575b
Merge remote-tracking branch 'origin/jh-logistic-model' into jh-logis…
jer2ig 72be054
Test coverage increase
jer2ig 5d9e0eb
Exception fixed
jer2ig File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| class DoubleML(SampleSplittingMixin, ABC): | ||
| """Double Machine Learning.""" | ||
|
|
||
| def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting): | ||
| def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting, double_sample_splitting=False): | ||
| # check and pick up obj_dml_data | ||
| if not isinstance(obj_dml_data, DoubleMLBaseData): | ||
| raise TypeError( | ||
|
|
@@ -34,18 +34,10 @@ def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting): | |
| if obj_dml_data.n_cluster_vars > 2: | ||
| raise NotImplementedError("Multi-way (n_ways > 2) clustering not yet implemented.") | ||
| self._is_cluster_data = True | ||
| self._is_panel_data = False | ||
| if isinstance(obj_dml_data, DoubleMLPanelData): | ||
| self._is_panel_data = True | ||
| self._is_did_data = False | ||
| if isinstance(obj_dml_data, DoubleMLDIDData): | ||
| self._is_did_data = True | ||
| self._is_ssm_data = False | ||
| if isinstance(obj_dml_data, DoubleMLSSMData): | ||
| self._is_ssm_data = True | ||
| self._is_rdd_data = False | ||
| if isinstance(obj_dml_data, DoubleMLRDDData): | ||
| self._is_rdd_data = True | ||
| self._is_panel_data = isinstance(obj_dml_data, DoubleMLPanelData) | ||
| self._is_did_data = isinstance(obj_dml_data, DoubleMLDIDData) | ||
| self._is_ssm_data = isinstance(obj_dml_data, DoubleMLSSMData) | ||
| self._is_rdd_data = isinstance(obj_dml_data, DoubleMLRDDData) | ||
|
|
||
| self._dml_data = obj_dml_data | ||
| self._n_obs = self._dml_data.n_obs | ||
|
|
@@ -108,6 +100,9 @@ def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting): | |
| self._smpls = None | ||
| self._smpls_cluster = None | ||
| self._n_obs_sample_splitting = self.n_obs | ||
| self._double_sample_splitting = double_sample_splitting | ||
| if self._double_sample_splitting: | ||
| self._smpls_inner = None | ||
| if draw_sample_splitting: | ||
| self.draw_sample_splitting() | ||
| self._score_dim = (self._dml_data.n_obs, self.n_rep, self._dml_data.n_coefs) | ||
|
|
@@ -263,6 +258,13 @@ def learner(self): | |
| """ | ||
| return self._learner | ||
|
|
||
| @property | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is a good idea to include a more obvious property for the name, but this should be identical to |
||
| def predictions_names(self): | ||
| """ | ||
| The names of predictions for the nuisance functions. | ||
| """ | ||
| return list(self.params_names) | ||
|
|
||
| @property | ||
| def learner_names(self): | ||
| """ | ||
|
|
@@ -359,6 +361,21 @@ def smpls(self): | |
| raise ValueError(err_msg) | ||
| return self._smpls | ||
|
|
||
| @property | ||
| def smpls_inner(self): | ||
jer2ig marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| The partition used for cross-fitting. | ||
| """ | ||
| if not self._double_sample_splitting: | ||
| raise ValueError("smpls_inner is only available for double sample splitting.") | ||
| if self._smpls_inner is None: | ||
| err_msg = ( | ||
| "Sample splitting not specified. Either draw samples via .draw_sample splitting() " | ||
| + "or set external samples via .set_sample_splitting()." | ||
| ) | ||
| raise ValueError(err_msg) | ||
| return self._smpls_inner | ||
|
|
||
| @property | ||
| def smpls_cluster(self): | ||
| """ | ||
|
|
@@ -507,6 +524,18 @@ def summary(self): | |
| def __smpls(self): | ||
| return self._smpls[self._i_rep] | ||
|
|
||
| @property | ||
| def __smpls__inner(self): | ||
| if not self._double_sample_splitting: | ||
| raise ValueError("smpls_inner is only available for double sample splitting.") | ||
| if self._smpls_inner is None: | ||
| err_msg = ( | ||
| "Sample splitting not specified. Either draw samples via .draw_sample splitting() " | ||
| + "or set external samples via .set_sample_splitting()." | ||
| ) | ||
| raise ValueError(err_msg) | ||
| return self._smpls_inner[self._i_rep] | ||
|
|
||
| @property | ||
| def __smpls_cluster(self): | ||
| return self._smpls_cluster[self._i_rep] | ||
|
|
@@ -1059,7 +1088,7 @@ def _check_fit(self, n_jobs_cv, store_predictions, external_predictions, store_m | |
| _check_external_predictions( | ||
| external_predictions=external_predictions, | ||
| valid_treatments=self._dml_data.d_cols, | ||
| valid_learners=self.params_names, | ||
| valid_learners=self.predictions_names, | ||
| n_obs=self.n_obs, | ||
| n_rep=self.n_rep, | ||
| ) | ||
|
|
@@ -1081,7 +1110,10 @@ def _initalize_fit(self, store_predictions, store_models): | |
|
|
||
| def _fit_nuisance_and_score_elements(self, n_jobs_cv, store_predictions, external_predictions, store_models): | ||
| ext_prediction_dict = _set_external_predictions( | ||
| external_predictions, learners=self.params_names, treatment=self._dml_data.d_cols[self._i_treat], i_rep=self._i_rep | ||
| external_predictions, | ||
| learners=self.predictions_names, | ||
| treatment=self._dml_data.d_cols[self._i_treat], | ||
| i_rep=self._i_rep, | ||
| ) | ||
|
|
||
| # ml estimation of nuisance models and computation of score elements | ||
|
|
@@ -1146,8 +1178,8 @@ def _initialize_arrays(self): | |
| self._all_se = np.full((n_thetas, n_rep), np.nan) | ||
|
|
||
| def _initialize_predictions_and_targets(self): | ||
| self._predictions = {learner: np.full(self._score_dim, np.nan) for learner in self.params_names} | ||
| self._nuisance_targets = {learner: np.full(self._score_dim, np.nan) for learner in self.params_names} | ||
| self._predictions = {learner: np.full(self._score_dim, np.nan) for learner in self.predictions_names} | ||
| self._nuisance_targets = {learner: np.full(self._score_dim, np.nan) for learner in self.predictions_names} | ||
|
|
||
| def _initialize_nuisance_loss(self): | ||
| self._nuisance_loss = {learner: np.full((self.n_rep, self._dml_data.n_coefs), np.nan) for learner in self.params_names} | ||
|
|
@@ -1158,7 +1190,7 @@ def _initialize_models(self): | |
| } | ||
|
|
||
| def _store_predictions_and_targets(self, preds, targets): | ||
| for learner in self.params_names: | ||
| for learner in self.predictions_names: | ||
| self._predictions[learner][:, self._i_rep, self._i_treat] = preds[learner] | ||
| self._nuisance_targets[learner][:, self._i_rep, self._i_treat] = targets[learner] | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.