Skip to content

Commit 42bfcda

Browse files
juanitorduzCopilot
andauthored
Add MyPy Checks (#556)
* init * gemini * rm unwanted change * do not use assert and raise instead * add info to AGENTG.md * fix mypy * Add type hints to all code base (#557) * init * rm file * update badge * Update causalpy/experiments/regression_kink.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update docstrings * fix: apply ruff formatting after rebase --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent bc2fbae commit 42bfcda

19 files changed

+603
-299
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,10 @@ repos:
4848
additional_dependencies:
4949
# Support pyproject.toml configuration
5050
- tomli
51+
- repo: https://github.com/pre-commit/mirrors-mypy
52+
rev: v1.18.2
53+
hooks:
54+
- id: mypy
55+
args: [--ignore-missing-imports]
56+
files: ^causalpy/
57+
additional_dependencies: [numpy>=1.20, pandas-stubs]

AGENTS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,13 @@
3737
- **Formulas**: Use patsy for formula parsing (via `dmatrices()`)
3838
- **Custom exceptions**: Use project-specific exceptions from `causalpy.custom_exceptions`: `FormulaException`, `DataException`, `BadIndexException`
3939
- **File organization**: Experiments in `causalpy/experiments/`, PyMC models in `causalpy/pymc_models.py`, scikit-learn models in `causalpy/skl_models.py`
40+
41+
## Type Checking
42+
43+
- **Tool**: MyPy
44+
- **Configuration**: Integrated as a pre-commit hook.
45+
- **Scope**: Checks Python files within the `causalpy/` directory.
46+
- **Settings**:
47+
- `ignore-missing-imports`: Enabled to allow for gradual adoption of type hints without requiring all third-party libraries to have stubs.
48+
- `additional_dependencies`: Includes `numpy` and `pandas-stubs` to provide type information for these libraries.
49+
- **Execution**: Run automatically via `pre-commit run --all-files` or on commit.

causalpy/data/datasets.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,28 @@
4343
}
4444

4545

46-
def _get_data_home() -> pathlib.PosixPath:
46+
def _get_data_home() -> pathlib.Path:
4747
"""Return the path of the data directory"""
4848
return pathlib.Path(cp.__file__).parents[1] / "causalpy" / "data"
4949

5050

51-
def load_data(dataset: str = None) -> pd.DataFrame:
52-
"""Loads the requested dataset and returns a pandas DataFrame.
51+
def load_data(dataset: str | None = None) -> pd.DataFrame:
52+
"""Load the requested dataset and return a pandas DataFrame.
5353
54-
:param dataset: The desired dataset to load
54+
Parameters
55+
----------
56+
dataset : str, optional
57+
The desired dataset to load. If None, raises ValueError.
58+
59+
Returns
60+
-------
61+
pd.DataFrame
62+
The loaded dataset as a pandas DataFrame.
63+
64+
Raises
65+
------
66+
ValueError
67+
If the requested dataset is not found.
5568
"""
5669

5770
if dataset in DATASETS:

causalpy/data/simulate_data.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
Functions that generate data sets used in examples
1616
"""
1717

18-
from typing import Any
19-
2018
import numpy as np
2119
import pandas as pd
2220
from scipy.stats import dirichlet, gamma, norm, uniform
@@ -31,7 +29,7 @@ def _smoothed_gaussian_random_walk(
3129
gaussian_random_walk_mu: float,
3230
gaussian_random_walk_sigma: float,
3331
N: int,
34-
lowess_kwargs: dict[str, Any],
32+
lowess_kwargs: dict,
3533
) -> tuple[np.ndarray, np.ndarray]:
3634
"""
3735
Generates Gaussian random walk data and applies LOWESS.
@@ -57,7 +55,7 @@ def generate_synthetic_control_data(
5755
treatment_time: int = 70,
5856
grw_mu: float = 0.25,
5957
grw_sigma: float = 1,
60-
lowess_kwargs: dict[str, Any] | None = None,
58+
lowess_kwargs: dict = default_lowess_kwargs,
6159
) -> tuple[pd.DataFrame, np.ndarray]:
6260
"""
6361
Generates data for synthetic control example.
@@ -78,9 +76,6 @@ def generate_synthetic_control_data(
7876
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
7977
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
8078
"""
81-
if lowess_kwargs is None:
82-
lowess_kwargs = default_lowess_kwargs
83-
8479
# 1. Generate non-treated variables
8580
df = pd.DataFrame(
8681
{
@@ -166,7 +161,9 @@ def generate_time_series_data(
166161
return df
167162

168163

169-
def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
164+
def generate_time_series_data_seasonal(
165+
treatment_time: pd.Timestamp,
166+
) -> pd.DataFrame:
170167
"""
171168
Generates 10 years of monthly data with seasonality
172169
"""
@@ -180,11 +177,13 @@ def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataF
180177
t=df.index,
181178
).set_index("date", drop=True)
182179
month_effect = np.array([11, 13, 12, 15, 19, 23, 21, 28, 20, 17, 15, 12])
183-
df["y"] = 0.2 * df["t"] + 2 * month_effect[df.month.values - 1]
180+
df["y"] = 0.2 * df["t"] + 2 * month_effect[np.asarray(df.month.values) - 1]
184181

185182
N = df.shape[0]
186183
idx = np.arange(N)[df.index > treatment_time]
187-
df["causal effect"] = 100 * gamma(10).pdf(np.arange(0, N, 1) - np.min(idx))
184+
df["causal effect"] = 100 * gamma(10).pdf(
185+
np.array(np.arange(0, N, 1)) - int(np.min(idx))
186+
)
188187

189188
df["y"] += df["causal effect"]
190189
df["y"] += norm(0, 2).rvs(N)
@@ -263,13 +262,13 @@ def outcome(
263262
df["post_treatment"] = df["t"] > intervention_time
264263

265264
df["y"] = outcome(
266-
df["t"],
265+
np.asarray(df["t"]),
267266
control_intercept,
268267
treat_intercept_delta,
269268
trend,
270269
Δ,
271-
df["group"],
272-
df["post_treatment"],
270+
np.asarray(df["group"]),
271+
np.asarray(df["post_treatment"]),
273272
)
274273
df["y"] += rng.normal(0, 0.1, df.shape[0])
275274
return df
@@ -310,8 +309,8 @@ def impact(x: np.ndarray) -> np.ndarray:
310309
def generate_ancova_data(
311310
N: int = 200,
312311
pre_treatment_means: np.ndarray = np.array([10, 12]),
313-
treatment_effect: float = 2,
314-
sigma: float = 1,
312+
treatment_effect: int = 2,
313+
sigma: int = 1,
315314
) -> pd.DataFrame:
316315
"""
317316
Generate ANCOVA example data
@@ -445,7 +444,7 @@ def generate_multicell_geolift_data() -> pd.DataFrame:
445444

446445

447446
def generate_seasonality(
448-
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
447+
n: int = 12, amplitude: int = 1, length_scale: float = 0.5
449448
) -> np.ndarray:
450449
"""Generate monthly seasonality by sampling from a Gaussian process with a
451450
Gaussian kernel, using numpy code"""
@@ -463,9 +462,9 @@ def generate_seasonality(
463462
def periodic_kernel(
464463
x1: np.ndarray,
465464
x2: np.ndarray,
466-
period: float = 1,
467-
length_scale: float = 1,
468-
amplitude: float = 1,
465+
period: int = 1,
466+
length_scale: float = 1.0,
467+
amplitude: int = 1,
469468
) -> np.ndarray:
470469
"""Generate a periodic kernel for gaussian process"""
471470
return amplitude**2 * np.exp(
@@ -475,10 +474,10 @@ def periodic_kernel(
475474

476475
def create_series(
477476
n: int = 52,
478-
amplitude: float = 1,
479-
length_scale: float = 2,
477+
amplitude: int = 1,
478+
length_scale: int = 2,
480479
n_years: int = 4,
481-
intercept: float = 3,
480+
intercept: int = 3,
482481
) -> np.ndarray:
483482
"""
484483
Returns numpy tile with generated seasonality data repeated over

causalpy/experiments/base.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717

1818
from abc import abstractmethod
19+
from typing import Any, Union
1920

2021
import arviz as az
2122
import matplotlib.pyplot as plt
@@ -29,10 +30,12 @@
2930
class BaseExperiment:
3031
"""Base class for quasi experimental designs."""
3132

33+
labels: list[str]
34+
3235
supports_bayes: bool
3336
supports_ols: bool
3437

35-
def __init__(self, model=None):
38+
def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
3639
# Ensure we've made any provided Scikit Learn model (as identified as being type
3740
# RegressorMixin) compatible with CausalPy by appending our custom methods.
3841
if isinstance(model, RegressorMixin):
@@ -50,16 +53,26 @@ def __init__(self, model=None):
5053
if self.model is None:
5154
raise ValueError("model not set or passed.")
5255

56+
def fit(self, *args: Any, **kwargs: Any) -> None:
57+
raise NotImplementedError("fit method not implemented")
58+
5359
@property
54-
def idata(self):
60+
def idata(self) -> az.InferenceData:
5561
"""Return the InferenceData object of the model. Only relevant for PyMC models."""
5662
return self.model.idata
5763

58-
def print_coefficients(self, round_to=None):
59-
"""Ask the model to print its coefficients."""
64+
def print_coefficients(self, round_to: int | None = None) -> None:
65+
"""Ask the model to print its coefficients.
66+
67+
Parameters
68+
----------
69+
round_to : int, optional
70+
Number of significant figures to round to. Defaults to None,
71+
in which case 2 significant figures are used.
72+
"""
6073
self.model.print_coefficients(self.labels, round_to)
6174

62-
def plot(self, *args, **kwargs) -> tuple:
75+
def plot(self, *args: Any, **kwargs: Any) -> tuple:
6376
"""Plot the model.
6477
6578
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
@@ -75,16 +88,16 @@ def plot(self, *args, **kwargs) -> tuple:
7588
raise ValueError("Unsupported model type")
7689

7790
@abstractmethod
78-
def _bayesian_plot(self, *args, **kwargs):
91+
def _bayesian_plot(self, *args: Any, **kwargs: Any) -> tuple:
7992
"""Abstract method for plotting the model."""
8093
raise NotImplementedError("_bayesian_plot method not yet implemented")
8194

8295
@abstractmethod
83-
def _ols_plot(self, *args, **kwargs):
96+
def _ols_plot(self, *args: Any, **kwargs: Any) -> tuple:
8497
"""Abstract method for plotting the model."""
8598
raise NotImplementedError("_ols_plot method not yet implemented")
8699

87-
def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
100+
def get_plot_data(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
88101
"""Recover the data of an experiment along with the prediction and causal impact information.
89102
90103
Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols`
@@ -98,11 +111,11 @@ def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
98111
raise ValueError("Unsupported model type")
99112

100113
@abstractmethod
101-
def get_plot_data_bayesian(self, *args, **kwargs):
114+
def get_plot_data_bayesian(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
102115
"""Abstract method for recovering plot data."""
103116
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")
104117

105118
@abstractmethod
106-
def get_plot_data_ols(self, *args, **kwargs):
119+
def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
107120
"""Abstract method for recovering plot data."""
108121
raise NotImplementedError("get_plot_data_ols method not yet implemented")

0 commit comments

Comments
 (0)