-
Notifications
You must be signed in to change notification settings - Fork 87
Add type hints to all code base #557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7962b50
b226a9d
aa0ca3b
1aaabf1
82ed06d
92c791e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,8 @@ | |
| Difference in differences | ||
| """ | ||
|
|
||
| from typing import Union | ||
|
|
||
| import arviz as az | ||
| import numpy as np | ||
| import pandas as pd | ||
|
|
@@ -47,20 +49,24 @@ class DifferenceInDifferences(BaseExperiment): | |
|
|
||
| .. note:: | ||
|
|
||
| There is no pre/post intervention data distinction for DiD, we fit all the | ||
| data available. | ||
| :param data: | ||
| A pandas dataframe | ||
| :param formula: | ||
| A statistical model formula | ||
| :param time_variable_name: | ||
| Name of the data column for the time variable | ||
| :param group_variable_name: | ||
| Name of the data column for the group variable | ||
| :param post_treatment_variable_name: | ||
| Name of the data column indicating post-treatment period (default: "post_treatment") | ||
| :param model: | ||
| A PyMC model for difference in differences | ||
| There is no pre/post intervention data distinction for DiD, we fit | ||
| all the data available. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : pd.DataFrame | ||
| A pandas dataframe. | ||
| formula : str | ||
| A statistical model formula. | ||
| time_variable_name : str | ||
| Name of the data column for the time variable. | ||
| group_variable_name : str | ||
| Name of the data column for the group variable. | ||
| post_treatment_variable_name : str, optional | ||
| Name of the data column indicating post-treatment period. | ||
| Defaults to "post_treatment". | ||
| model : PyMCModel or RegressorMixin, optional | ||
| A PyMC model for difference in differences. Defaults to None. | ||
|
|
||
| Example | ||
| -------- | ||
|
|
@@ -92,8 +98,8 @@ def __init__( | |
| time_variable_name: str, | ||
| group_variable_name: str, | ||
| post_treatment_variable_name: str = "post_treatment", | ||
| model=None, | ||
| **kwargs, | ||
| model: Union[PyMCModel, RegressorMixin] | None = None, | ||
| **kwargs: dict, | ||
| ) -> None: | ||
| super().__init__(model=model) | ||
| self.causal_impact: xr.DataArray | float | None | ||
|
|
@@ -234,14 +240,14 @@ def __init__( | |
| f"{self.group_variable_name}:{self.post_treatment_variable_name}" | ||
| ) | ||
| matched_key = next((k for k in coef_map if interaction_term in k), None) | ||
| att = coef_map.get(matched_key) | ||
| att = coef_map.get(matched_key) if matched_key is not None else None | ||
| self.causal_impact = att | ||
| else: | ||
| raise ValueError("Model type not recognized") | ||
|
|
||
| return | ||
|
|
||
| def input_validation(self): | ||
| def input_validation(self) -> None: | ||
| # Validate formula structure and interaction interaction terms | ||
| self._validate_formula_interaction_terms() | ||
|
|
||
|
|
@@ -269,7 +275,7 @@ def input_validation(self): | |
| coded. Consisting of 0's and 1's only.""" | ||
| ) | ||
|
|
||
| def _validate_formula_interaction_terms(self): | ||
| def _validate_formula_interaction_terms(self) -> None: | ||
| """ | ||
| Validate that the formula contains at most one interaction term and no three-way or higher-order interactions. | ||
| Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables. | ||
|
|
@@ -299,7 +305,7 @@ def _validate_formula_interaction_terms(self): | |
| "Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect." | ||
| ) | ||
|
|
||
| def summary(self, round_to=None) -> None: | ||
| def summary(self, round_to: int | None = 2) -> None: | ||
| """Print summary of main results and model coefficients. | ||
|
|
||
| :param round_to: | ||
|
|
@@ -311,11 +317,13 @@ def summary(self, round_to=None) -> None: | |
| print(self._causal_impact_summary_stat(round_to)) | ||
| self.print_coefficients(round_to) | ||
|
|
||
| def _causal_impact_summary_stat(self, round_to=None) -> str: | ||
| def _causal_impact_summary_stat(self, round_to: int | None = None) -> str: | ||
| """Computes the mean and 94% credible interval bounds for the causal impact.""" | ||
| return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}" | ||
|
|
||
| def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: | ||
| def _bayesian_plot( | ||
| self, round_to: int | None = None, **kwargs: dict | ||
| ) -> tuple[plt.Figure, plt.Axes]: | ||
juanitorduz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Plot the results | ||
|
|
||
|
|
@@ -463,9 +471,10 @@ def _plot_causal_impact_arrow(results, ax): | |
| ) | ||
| return fig, ax | ||
|
|
||
| def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: | ||
| def _ols_plot( | ||
| self, round_to: int | None = 2, **kwargs: dict | ||
| ) -> tuple[plt.Figure, plt.Axes]: | ||
|
||
| """Generate plot for difference-in-differences""" | ||
| round_to = kwargs.get("round_to") | ||
| fig, ax = plt.subplots() | ||
|
|
||
| # Plot raw data | ||
|
|
@@ -528,11 +537,15 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: | |
| va="center", | ||
| ) | ||
| # formatting | ||
| # In OLS context, causal_impact should be a float, but mypy doesn't know this | ||
| causal_impact_value = ( | ||
| float(self.causal_impact) if self.causal_impact is not None else 0.0 | ||
| ) | ||
| ax.set( | ||
| xlim=[-0.05, 1.1], | ||
| xticks=[0, 1], | ||
| xticklabels=["pre", "post"], | ||
| title=f"Causal impact = {round_num(self.causal_impact, round_to)}", | ||
| title=f"Causal impact = {round_num(causal_impact_value, round_to)}", | ||
| ) | ||
| ax.legend(fontsize=LEGEND_FONT_SIZE) | ||
| return fig, ax | ||
Uh oh!
There was an error while loading. Please reload this page.