Skip to content

Commit 6034cb3

Browse files
committed
init
1 parent f918e84 commit 6034cb3

15 files changed

+332
-184
lines changed

causalpy/data/simulate_data.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626

2727

2828
def _smoothed_gaussian_random_walk(
29-
gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs
30-
):
29+
gaussian_random_walk_mu: float,
30+
gaussian_random_walk_sigma: float,
31+
N: int,
32+
lowess_kwargs: dict,
33+
) -> tuple[np.ndarray, np.ndarray]:
3134
"""
3235
Generates Gaussian random walk data and applies LOWESS
3336
@@ -48,12 +51,12 @@ def _smoothed_gaussian_random_walk(
4851

4952

5053
def generate_synthetic_control_data(
51-
N=100,
52-
treatment_time=70,
53-
grw_mu=0.25,
54-
grw_sigma=1,
55-
lowess_kwargs=default_lowess_kwargs,
56-
):
54+
N: int = 100,
55+
treatment_time: int = 70,
56+
grw_mu: float = 0.25,
57+
grw_sigma: float = 1,
58+
lowess_kwargs: dict = default_lowess_kwargs,
59+
) -> tuple[pd.DataFrame, np.ndarray]:
5760
"""
5861
Generates data for synthetic control example.
5962
@@ -108,8 +111,12 @@ def generate_synthetic_control_data(
108111

109112

110113
def generate_time_series_data(
111-
N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3
112-
):
114+
N: int = 100,
115+
treatment_time: int = 70,
116+
beta_temp: float = -1,
117+
beta_linear: float = 0.5,
118+
beta_intercept: float = 3,
119+
) -> pd.DataFrame:
113120
"""
114121
Generates interrupted time series example data
115122
@@ -155,7 +162,9 @@ def generate_time_series_data(
155162
return df
156163

157164

158-
def generate_time_series_data_seasonal(treatment_time):
165+
def generate_time_series_data_seasonal(
166+
treatment_time: pd.Timestamp,
167+
) -> pd.DataFrame:
159168
"""
160169
Generates 10 years of monthly data with seasonality
161170
"""
@@ -169,11 +178,13 @@ def generate_time_series_data_seasonal(treatment_time):
169178
t=df.index,
170179
).set_index("date", drop=True)
171180
month_effect = np.array([11, 13, 12, 15, 19, 23, 21, 28, 20, 17, 15, 12])
172-
df["y"] = 0.2 * df["t"] + 2 * month_effect[df.month.values - 1]
181+
df["y"] = 0.2 * df["t"] + 2 * month_effect[np.asarray(df.month.values) - 1]
173182

174183
N = df.shape[0]
175184
idx = np.arange(N)[df.index > treatment_time]
176-
df["causal effect"] = 100 * gamma(10).pdf(np.arange(0, N, 1) - np.min(idx))
185+
df["causal effect"] = 100 * gamma(10).pdf(
186+
np.array(np.arange(0, N, 1)) - int(np.min(idx))
187+
)
177188

178189
df["y"] += df["causal effect"]
179190
df["y"] += norm(0, 2).rvs(N)
@@ -183,7 +194,9 @@ def generate_time_series_data_seasonal(treatment_time):
183194
return df
184195

185196

186-
def generate_time_series_data_simple(treatment_time, slope=0.0):
197+
def generate_time_series_data_simple(
198+
treatment_time: pd.Timestamp, slope: float = 0.0
199+
) -> pd.DataFrame:
187200
"""Generate simple interrupted time series data, with no seasonality or temporal
188201
structure.
189202
"""
@@ -205,7 +218,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
205218
return df
206219

207220

208-
def generate_did():
221+
def generate_did() -> pd.DataFrame:
209222
"""
210223
Generate Difference in Differences data
211224
@@ -257,8 +270,8 @@ def outcome(
257270

258271

259272
def generate_regression_discontinuity_data(
260-
N=100, true_causal_impact=0.5, true_treatment_threshold=0.0
261-
):
273+
N: int = 100, true_causal_impact: float = 0.5, true_treatment_threshold: float = 0.0
274+
) -> pd.DataFrame:
262275
"""
263276
Generate regression discontinuity example data
264277
@@ -289,8 +302,11 @@ def impact(x):
289302

290303

291304
def generate_ancova_data(
292-
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
293-
):
305+
N: int = 200,
306+
pre_treatment_means: np.ndarray = np.array([10, 12]),
307+
treatment_effect: int = 2,
308+
sigma: int = 1,
309+
) -> pd.DataFrame:
294310
"""
295311
Generate ANCOVA example data
296312
@@ -310,7 +326,7 @@ def generate_ancova_data(
310326
return df
311327

312328

313-
def generate_geolift_data():
329+
def generate_geolift_data() -> pd.DataFrame:
314330
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
315331
countries. The treated unit `Denmark` is a weighted combination of the untreated
316332
units. We additionally specify a treatment effect which takes effect after the
@@ -360,7 +376,7 @@ def generate_geolift_data():
360376
return df
361377

362378

363-
def generate_multicell_geolift_data():
379+
def generate_multicell_geolift_data() -> pd.DataFrame:
364380
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
365381
countries. The treated unit `Denmark` is a weighted combination of the untreated
366382
units. We additionally specify a treatment effect which takes effect after the
@@ -422,7 +438,9 @@ def generate_multicell_geolift_data():
422438
# -----------------
423439

424440

425-
def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
441+
def generate_seasonality(
442+
n: int = 12, amplitude: int = 1, length_scale: float = 0.5
443+
) -> np.ndarray:
426444
"""Generate monthly seasonality by sampling from a Gaussian process with a
427445
Gaussian kernel, using numpy code"""
428446
# Generate the covariance matrix
@@ -436,14 +454,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
436454
return seasonality
437455

438456

439-
def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1):
457+
def periodic_kernel(
458+
x1: np.ndarray,
459+
x2: np.ndarray,
460+
period: int = 1,
461+
length_scale: float = 1.0,
462+
amplitude: int = 1,
463+
) -> np.ndarray:
440464
"""Generate a periodic kernel for gaussian process"""
441465
return amplitude**2 * np.exp(
442466
-2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2
443467
)
444468

445469

446-
def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3):
470+
def create_series(
471+
n: int = 52,
472+
amplitude: int = 1,
473+
length_scale: int = 2,
474+
n_years: int = 4,
475+
intercept: int = 3,
476+
) -> np.ndarray:
447477
"""
448478
Returns numpy tile with generated seasonality data repeated over
449479
multiple years

causalpy/experiments/base.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717

1818
from abc import abstractmethod
19+
from typing import Any, Union
1920

2021
import arviz as az
2122
import matplotlib.pyplot as plt
@@ -29,10 +30,12 @@
2930
class BaseExperiment:
3031
"""Base class for quasi experimental designs."""
3132

33+
labels: list[str]
34+
3235
supports_bayes: bool
3336
supports_ols: bool
3437

35-
def __init__(self, model=None):
38+
def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
3639
# Ensure we've made any provided Scikit Learn model (as identified as being type
3740
# RegressorMixin) compatible with CausalPy by appending our custom methods.
3841
if isinstance(model, RegressorMixin):
@@ -50,16 +53,19 @@ def __init__(self, model=None):
5053
if self.model is None:
5154
raise ValueError("model not set or passed.")
5255

56+
def fit(self, *args: Any, **kwargs: Any) -> None:
57+
raise NotImplementedError("fit method not implemented")
58+
5359
@property
54-
def idata(self):
60+
def idata(self) -> az.InferenceData:
5561
"""Return the InferenceData object of the model. Only relevant for PyMC models."""
5662
return self.model.idata
5763

58-
def print_coefficients(self, round_to=None):
64+
def print_coefficients(self, round_to: int | None = None) -> None:
5965
"""Ask the model to print its coefficients."""
6066
self.model.print_coefficients(self.labels, round_to)
6167

62-
def plot(self, *args, **kwargs) -> tuple:
68+
def plot(self, *args: Any, **kwargs: Any) -> tuple:
6369
"""Plot the model.
6470
6571
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
@@ -75,16 +81,16 @@ def plot(self, *args, **kwargs) -> tuple:
7581
raise ValueError("Unsupported model type")
7682

7783
@abstractmethod
78-
def _bayesian_plot(self, *args, **kwargs):
84+
def _bayesian_plot(self, *args: Any, **kwargs: Any) -> tuple:
7985
"""Abstract method for plotting the model."""
8086
raise NotImplementedError("_bayesian_plot method not yet implemented")
8187

8288
@abstractmethod
83-
def _ols_plot(self, *args, **kwargs):
89+
def _ols_plot(self, *args: Any, **kwargs: Any) -> tuple:
8490
"""Abstract method for plotting the model."""
8591
raise NotImplementedError("_ols_plot method not yet implemented")
8692

87-
def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
93+
def get_plot_data(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
8894
"""Recover the data of an experiment along with the prediction and causal impact information.
8995
9096
Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols`
@@ -98,11 +104,11 @@ def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
98104
raise ValueError("Unsupported model type")
99105

100106
@abstractmethod
101-
def get_plot_data_bayesian(self, *args, **kwargs):
107+
def get_plot_data_bayesian(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
102108
"""Abstract method for recovering plot data."""
103109
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")
104110

105111
@abstractmethod
106-
def get_plot_data_ols(self, *args, **kwargs):
112+
def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
107113
"""Abstract method for recovering plot data."""
108114
raise NotImplementedError("get_plot_data_ols method not yet implemented")

causalpy/experiments/diff_in_diff.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Difference in differences
1616
"""
1717

18+
from typing import Union
19+
1820
import arviz as az
1921
import numpy as np
2022
import pandas as pd
@@ -92,8 +94,8 @@ def __init__(
9294
time_variable_name: str,
9395
group_variable_name: str,
9496
post_treatment_variable_name: str = "post_treatment",
95-
model=None,
96-
**kwargs,
97+
model: Union[PyMCModel, RegressorMixin] | None = None,
98+
**kwargs: dict,
9799
) -> None:
98100
super().__init__(model=model)
99101
self.causal_impact: xr.DataArray | float | None
@@ -234,14 +236,14 @@ def __init__(
234236
f"{self.group_variable_name}:{self.post_treatment_variable_name}"
235237
)
236238
matched_key = next((k for k in coef_map if interaction_term in k), None)
237-
att = coef_map.get(matched_key)
239+
att = coef_map.get(matched_key) if matched_key is not None else None
238240
self.causal_impact = att
239241
else:
240242
raise ValueError("Model type not recognized")
241243

242244
return
243245

244-
def input_validation(self):
246+
def input_validation(self) -> None:
245247
# Validate formula structure and interaction interaction terms
246248
self._validate_formula_interaction_terms()
247249

@@ -269,7 +271,7 @@ def input_validation(self):
269271
coded. Consisting of 0's and 1's only."""
270272
)
271273

272-
def _validate_formula_interaction_terms(self):
274+
def _validate_formula_interaction_terms(self) -> None:
273275
"""
274276
Validate that the formula contains at most one interaction term and no three-way or higher-order interactions.
275277
Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables.
@@ -299,7 +301,7 @@ def _validate_formula_interaction_terms(self):
299301
"Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect."
300302
)
301303

302-
def summary(self, round_to=None) -> None:
304+
def summary(self, round_to: int | None = 2) -> None:
303305
"""Print summary of main results and model coefficients.
304306
305307
:param round_to:
@@ -311,11 +313,13 @@ def summary(self, round_to=None) -> None:
311313
print(self._causal_impact_summary_stat(round_to))
312314
self.print_coefficients(round_to)
313315

314-
def _causal_impact_summary_stat(self, round_to=None) -> str:
316+
def _causal_impact_summary_stat(self, round_to: int | None = None) -> str:
315317
"""Computes the mean and 94% credible interval bounds for the causal impact."""
316318
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"
317319

318-
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
320+
def _bayesian_plot(
321+
self, round_to: int | None = None, **kwargs: dict
322+
) -> tuple[plt.Figure, plt.Axes]:
319323
"""
320324
Plot the results
321325
@@ -463,9 +467,10 @@ def _plot_causal_impact_arrow(results, ax):
463467
)
464468
return fig, ax
465469

466-
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
470+
def _ols_plot(
471+
self, round_to: int | None = 2, **kwargs: dict
472+
) -> tuple[plt.Figure, plt.Axes]:
467473
"""Generate plot for difference-in-differences"""
468-
round_to = kwargs.get("round_to")
469474
fig, ax = plt.subplots()
470475

471476
# Plot raw data
@@ -528,11 +533,15 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
528533
va="center",
529534
)
530535
# formatting
536+
# In OLS context, causal_impact should be a float, but mypy doesn't know this
537+
causal_impact_value = (
538+
float(self.causal_impact) if self.causal_impact is not None else 0.0
539+
)
531540
ax.set(
532541
xlim=[-0.05, 1.1],
533542
xticks=[0, 1],
534543
xticklabels=["pre", "post"],
535-
title=f"Causal impact = {round_num(self.causal_impact, round_to)}",
544+
title=f"Causal impact = {round_num(causal_impact_value, round_to)}",
536545
)
537546
ax.legend(fontsize=LEGEND_FONT_SIZE)
538547
return fig, ax

0 commit comments

Comments
 (0)