Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions causalpy/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,22 @@ def _get_data_home() -> pathlib.Path:


def load_data(dataset: str | None = None) -> pd.DataFrame:
"""Loads the requested dataset and returns a pandas DataFrame.
"""Load the requested dataset and return a pandas DataFrame.

:param dataset: The desired dataset to load
Parameters
----------
dataset : str, optional
The desired dataset to load. If None, raises ValueError.

Returns
-------
pd.DataFrame
The loaded dataset as a pandas DataFrame.

Raises
------
ValueError
If the requested dataset is not found.
"""

if dataset in DATASETS:
Expand Down
35 changes: 17 additions & 18 deletions causalpy/data/simulate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
Functions that generate data sets used in examples
"""

from typing import Any

import numpy as np
import pandas as pd
from scipy.stats import dirichlet, gamma, norm, uniform
Expand All @@ -31,7 +29,7 @@ def _smoothed_gaussian_random_walk(
gaussian_random_walk_mu: float,
gaussian_random_walk_sigma: float,
N: int,
lowess_kwargs: dict[str, Any],
lowess_kwargs: dict,
) -> tuple[np.ndarray, np.ndarray]:
"""
Generates Gaussian random walk data and applies LOWESS.
Expand All @@ -57,7 +55,7 @@ def generate_synthetic_control_data(
treatment_time: int = 70,
grw_mu: float = 0.25,
grw_sigma: float = 1,
lowess_kwargs: dict[str, Any] | None = None,
lowess_kwargs: dict = default_lowess_kwargs,
) -> tuple[pd.DataFrame, np.ndarray]:
"""
Generates data for synthetic control example.
Expand All @@ -78,9 +76,6 @@ def generate_synthetic_control_data(
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
"""
if lowess_kwargs is None:
lowess_kwargs = default_lowess_kwargs

# 1. Generate non-treated variables
df = pd.DataFrame(
{
Expand Down Expand Up @@ -166,7 +161,9 @@ def generate_time_series_data(
return df


def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
def generate_time_series_data_seasonal(
treatment_time: pd.Timestamp,
) -> pd.DataFrame:
"""
Generates 10 years of monthly data with seasonality
"""
Expand All @@ -184,7 +181,9 @@ def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataF

N = df.shape[0]
idx = np.arange(N)[df.index > treatment_time]
df["causal effect"] = 100 * gamma(10).pdf(np.arange(0, N, 1) - np.min(idx))
df["causal effect"] = 100 * gamma(10).pdf(
np.array(np.arange(0, N, 1)) - int(np.min(idx))
)

df["y"] += df["causal effect"]
df["y"] += norm(0, 2).rvs(N)
Expand Down Expand Up @@ -310,8 +309,8 @@ def impact(x: np.ndarray) -> np.ndarray:
def generate_ancova_data(
N: int = 200,
pre_treatment_means: np.ndarray = np.array([10, 12]),
treatment_effect: float = 2,
sigma: float = 1,
treatment_effect: int = 2,
sigma: int = 1,
) -> pd.DataFrame:
"""
Generate ANCOVA example data
Expand Down Expand Up @@ -445,7 +444,7 @@ def generate_multicell_geolift_data() -> pd.DataFrame:


def generate_seasonality(
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
n: int = 12, amplitude: int = 1, length_scale: float = 0.5
) -> np.ndarray:
"""Generate monthly seasonality by sampling from a Gaussian process with a
Gaussian kernel, using numpy code"""
Expand All @@ -463,9 +462,9 @@ def generate_seasonality(
def periodic_kernel(
x1: np.ndarray,
x2: np.ndarray,
period: float = 1,
length_scale: float = 1,
amplitude: float = 1,
period: int = 1,
length_scale: float = 1.0,
amplitude: int = 1,
) -> np.ndarray:
"""Generate a periodic kernel for gaussian process"""
return amplitude**2 * np.exp(
Expand All @@ -475,10 +474,10 @@ def periodic_kernel(

def create_series(
n: int = 52,
amplitude: float = 1,
length_scale: float = 2,
amplitude: int = 1,
length_scale: int = 2,
n_years: int = 4,
intercept: float = 3,
intercept: int = 3,
) -> np.ndarray:
"""
Returns numpy tile with generated seasonality data repeated over
Expand Down
33 changes: 23 additions & 10 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

from abc import abstractmethod
from typing import Any, Union

import arviz as az
import matplotlib.pyplot as plt
Expand All @@ -29,10 +30,12 @@
class BaseExperiment:
"""Base class for quasi experimental designs."""

labels: list[str]

supports_bayes: bool
supports_ols: bool

def __init__(self, model=None):
def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
# Ensure we've made any provided Scikit Learn model (as identified as being type
# RegressorMixin) compatible with CausalPy by appending our custom methods.
if isinstance(model, RegressorMixin):
Expand All @@ -50,16 +53,26 @@ def __init__(self, model=None):
if self.model is None:
raise ValueError("model not set or passed.")

def fit(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("fit method not implemented")

@property
def idata(self):
def idata(self) -> az.InferenceData:
"""Return the InferenceData object of the model. Only relevant for PyMC models."""
return self.model.idata

def print_coefficients(self, round_to=None):
"""Ask the model to print its coefficients."""
def print_coefficients(self, round_to: int | None = None) -> None:
"""Ask the model to print its coefficients.

Parameters
----------
round_to : int, optional
Number of significant figures to round to. Defaults to None,
in which case 2 significant figures are used.
"""
self.model.print_coefficients(self.labels, round_to)

def plot(self, *args, **kwargs) -> tuple:
def plot(self, *args: Any, **kwargs: Any) -> tuple:
"""Plot the model.

Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
Expand All @@ -75,16 +88,16 @@ def plot(self, *args, **kwargs) -> tuple:
raise ValueError("Unsupported model type")

@abstractmethod
def _bayesian_plot(self, *args, **kwargs):
def _bayesian_plot(self, *args: Any, **kwargs: Any) -> tuple:
"""Abstract method for plotting the model."""
raise NotImplementedError("_bayesian_plot method not yet implemented")

@abstractmethod
def _ols_plot(self, *args, **kwargs):
def _ols_plot(self, *args: Any, **kwargs: Any) -> tuple:
"""Abstract method for plotting the model."""
raise NotImplementedError("_ols_plot method not yet implemented")

def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
def get_plot_data(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
"""Recover the data of an experiment along with the prediction and causal impact information.

Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols`
Expand All @@ -98,11 +111,11 @@ def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
raise ValueError("Unsupported model type")

@abstractmethod
def get_plot_data_bayesian(self, *args, **kwargs):
def get_plot_data_bayesian(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")

@abstractmethod
def get_plot_data_ols(self, *args, **kwargs):
def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_ols method not yet implemented")
63 changes: 38 additions & 25 deletions causalpy/experiments/diff_in_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Difference in differences
"""

from typing import Union

import arviz as az
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -47,20 +49,24 @@ class DifferenceInDifferences(BaseExperiment):

.. note::

There is no pre/post intervention data distinction for DiD, we fit all the
data available.
:param data:
A pandas dataframe
:param formula:
A statistical model formula
:param time_variable_name:
Name of the data column for the time variable
:param group_variable_name:
Name of the data column for the group variable
:param post_treatment_variable_name:
Name of the data column indicating post-treatment period (default: "post_treatment")
:param model:
A PyMC model for difference in differences
There is no pre/post intervention data distinction for DiD, we fit
all the data available.

Parameters
----------
data : pd.DataFrame
A pandas dataframe.
formula : str
A statistical model formula.
time_variable_name : str
Name of the data column for the time variable.
group_variable_name : str
Name of the data column for the group variable.
post_treatment_variable_name : str, optional
Name of the data column indicating post-treatment period.
Defaults to "post_treatment".
model : PyMCModel or RegressorMixin, optional
A PyMC model for difference in differences. Defaults to None.

Example
--------
Expand Down Expand Up @@ -92,8 +98,8 @@ def __init__(
time_variable_name: str,
group_variable_name: str,
post_treatment_variable_name: str = "post_treatment",
model=None,
**kwargs,
model: Union[PyMCModel, RegressorMixin] | None = None,
**kwargs: dict,
) -> None:
super().__init__(model=model)
self.causal_impact: xr.DataArray | float | None
Expand Down Expand Up @@ -234,14 +240,14 @@ def __init__(
f"{self.group_variable_name}:{self.post_treatment_variable_name}"
)
matched_key = next((k for k in coef_map if interaction_term in k), None)
att = coef_map.get(matched_key)
att = coef_map.get(matched_key) if matched_key is not None else None
self.causal_impact = att
else:
raise ValueError("Model type not recognized")

return

def input_validation(self):
def input_validation(self) -> None:
# Validate formula structure and interaction interaction terms
self._validate_formula_interaction_terms()

Expand Down Expand Up @@ -269,7 +275,7 @@ def input_validation(self):
coded. Consisting of 0's and 1's only."""
)

def _validate_formula_interaction_terms(self):
def _validate_formula_interaction_terms(self) -> None:
"""
Validate that the formula contains at most one interaction term and no three-way or higher-order interactions.
Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables.
Expand Down Expand Up @@ -299,7 +305,7 @@ def _validate_formula_interaction_terms(self):
"Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect."
)

def summary(self, round_to=None) -> None:
def summary(self, round_to: int | None = 2) -> None:
"""Print summary of main results and model coefficients.

:param round_to:
Expand All @@ -311,11 +317,13 @@ def summary(self, round_to=None) -> None:
print(self._causal_impact_summary_stat(round_to))
self.print_coefficients(round_to)

def _causal_impact_summary_stat(self, round_to=None) -> str:
def _causal_impact_summary_stat(self, round_to: int | None = None) -> str:
"""Computes the mean and 94% credible interval bounds for the causal impact."""
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"

def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def _bayesian_plot(
self, round_to: int | None = None, **kwargs: dict
) -> tuple[plt.Figure, plt.Axes]:
"""
Plot the results

Expand Down Expand Up @@ -463,9 +471,10 @@ def _plot_causal_impact_arrow(results, ax):
)
return fig, ax

def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def _ols_plot(
self, round_to: int | None = 2, **kwargs: dict
) -> tuple[plt.Figure, plt.Axes]:
Copy link

Copilot AI Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation **kwargs: dict is incorrect. The correct type annotation for **kwargs should be **kwargs: Any (after importing Any from typing). The current annotation would imply each keyword argument should be of type dict, which is not the intended behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +474 to +476
Copy link

Copilot AI Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent default value for round_to: The _ols_plot method has a default of 2, but the _bayesian_plot method on line 229 has a default of None. For consistency across plotting methods, both should use the same default value (either 2 or None).

Copilot uses AI. Check for mistakes.
"""Generate plot for difference-in-differences"""
round_to = kwargs.get("round_to")
fig, ax = plt.subplots()

# Plot raw data
Expand Down Expand Up @@ -528,11 +537,15 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
va="center",
)
# formatting
# In OLS context, causal_impact should be a float, but mypy doesn't know this
causal_impact_value = (
float(self.causal_impact) if self.causal_impact is not None else 0.0
)
ax.set(
xlim=[-0.05, 1.1],
xticks=[0, 1],
xticklabels=["pre", "post"],
title=f"Causal impact = {round_num(self.causal_impact, round_to)}",
title=f"Causal impact = {round_num(causal_impact_value, round_to)}",
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
return fig, ax
Loading