Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ repos:
additional_dependencies:
# Support pyproject.toml configuration
- tomli
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
hooks:
- id: mypy
args: [--ignore-missing-imports]
files: ^causalpy/
additional_dependencies: [numpy>=1.20, pandas-stubs]
10 changes: 10 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,13 @@
- **Formulas**: Use patsy for formula parsing (via `dmatrices()`)
- **Custom exceptions**: Use project-specific exceptions from `causalpy.custom_exceptions`: `FormulaException`, `DataException`, `BadIndexException`
- **File organization**: Experiments in `causalpy/experiments/`, PyMC models in `causalpy/pymc_models.py`, scikit-learn models in `causalpy/skl_models.py`

## Type Checking

- **Tool**: MyPy
- **Configuration**: Integrated as a pre-commit hook.
- **Scope**: Checks Python files within the `causalpy/` directory.
- **Settings**:
- `ignore-missing-imports`: Enabled to allow for gradual adoption of type hints without requiring all third-party libraries to have stubs.
- `additional_dependencies`: Includes `numpy` and `pandas-stubs` to provide type information for these libraries.
- **Execution**: Run automatically via `pre-commit run --all-files` or on commit.
4 changes: 2 additions & 2 deletions causalpy/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@
}


def _get_data_home() -> pathlib.PosixPath:
def _get_data_home() -> pathlib.Path:
"""Return the path of the data directory"""
return pathlib.Path(cp.__file__).parents[1] / "causalpy" / "data"


def load_data(dataset: str = None) -> pd.DataFrame:
def load_data(dataset: str | None = None) -> pd.DataFrame:
"""Loads the requested dataset and returns a pandas DataFrame.

:param dataset: The desired dataset to load
Expand Down
2 changes: 2 additions & 0 deletions causalpy/experiments/diff_in_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(model=model)
self.causal_impact: xr.DataArray | float | None
# rename the index to "obs_ind"
data.index.name = "obs_ind"
self.data = data
Expand Down Expand Up @@ -213,6 +214,7 @@ def __init__(

# calculate causal impact
if isinstance(self.model, PyMCModel):
assert self.model.idata is not None
# This is the coefficient on the interaction term
coeff_names = self.model.idata.posterior.coords["coeffs"].data
for i, label in enumerate(coeff_names):
Expand Down
2 changes: 2 additions & 0 deletions causalpy/experiments/interrupted_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(
**kwargs,
) -> None:
super().__init__(model=model)
self.pre_y: xr.DataArray
self.post_y: xr.DataArray
# rename the index to "obs_ind"
data.index.name = "obs_ind"
self.input_validation(data, treatment_time)
Expand Down
5 changes: 5 additions & 0 deletions causalpy/experiments/prepostnegd.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __init__(
**kwargs,
):
super().__init__(model=model)
self.causal_impact: xr.DataArray
self.pred_xi: np.ndarray
self.pred_untreated: az.InferenceData
self.pred_treated: az.InferenceData
self.data = data
self.expt_type = "Pretest/posttest Nonequivalent Group Design"
self.formula = formula
Expand Down Expand Up @@ -140,6 +144,7 @@ def __init__(
else:
raise ValueError("Model type not recognized")

assert self.model.idata is not None
# Calculate the posterior predictive for the treatment and control for an
# interpolated set of pretest values
# get the model predictions of the observed data
Expand Down
3 changes: 2 additions & 1 deletion causalpy/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
import xarray as xr
from matplotlib.collections import PolyCollection
from matplotlib.lines import Line2D
from pandas.api.extensions import ExtensionArray


def plot_xY(
x: Union[pd.DatetimeIndex, np.array],
x: Union[pd.DatetimeIndex, np.ndarray, pd.Index, pd.Series, ExtensionArray],
Y: xr.DataArray,
ax: plt.Axes,
plot_hdi_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
7 changes: 6 additions & 1 deletion causalpy/pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class PyMCModel(pm.Model):
Inference data...
"""

default_priors = {}
default_priors: Dict[str, Prior] = {}

def priors_from_data(self, X, y) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -236,6 +236,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
self.build_model(X, y, coords)
with self:
self.idata = pm.sample(**self.sample_kwargs)
if self.idata is None:
raise RuntimeError("pm.sample() returned None")
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
self.idata.extend(
pm.sample_posterior_predictive(
Expand Down Expand Up @@ -349,6 +351,9 @@ def calculate_cumulative_impact(self, impact):
return impact.cumsum(dim="obs_ind")

def print_coefficients(self, labels, round_to=None) -> None:
if self.idata is None:
raise RuntimeError("Model has not been fit")

def print_row(
max_label_length: int, name: str, coeff_samples: xr.DataArray, round_to: int
) -> None:
Expand Down
36 changes: 30 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ dependencies = [
#
# Similar to `dependencies` above, these must be valid existing projects.
[project.optional-dependencies]
dev = ["pathlib", "pre-commit", "twine", "interrogate", "codespell", "nbformat", "nbconvert"]
dev = [
"pathlib",
"pre-commit",
"twine",
"interrogate",
"codespell",
"nbformat",
"nbconvert",
]
docs = [
"ipykernel",
"daft-pgm",
Expand All @@ -71,7 +79,7 @@ docs = [
"sphinx-design",
"sphinx-togglebutton",
]
lint = ["interrogate", "pre-commit", "ruff"]
lint = ["interrogate", "pre-commit", "ruff", "mypy"]
test = ["pytest", "pytest-cov", "codespell", "nbformat", "nbconvert"]

[project.urls]
Expand Down Expand Up @@ -129,10 +137,7 @@ ignore-words = "./docs/source/.codespell/codespell-whitelist.txt"
skip = "*.ipynb,*.csv,pyproject.toml,docs/source/.codespell/codespell-whitelist.txt"

[tool.coverage.run]
omit = [
"*/conftest.py",
"*/tests/conftest.py",
]
omit = ["*/conftest.py", "*/tests/conftest.py"]

[tool.coverage.report]
exclude_lines = [
Expand All @@ -147,3 +152,22 @@ exclude_lines = [
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod",
]

[tool.mypy]
files = "causalpy/*.py"
exclude = "build|dist|docs|notebooks|tests|setup.py"

[tool.mypy-matplotlib]
ignore_missing_imports = true

[tool.mypy-pymc]
ignore_missing_imports = true

[tool.mypy-seaborn]
ignore_missing_imports = true

[tool.mypy-sklearn]
ignore_missing_imports = true

[tool.mypy-scipy]
ignore_missing_imports = true