2626
2727
2828def _smoothed_gaussian_random_walk (
29- gaussian_random_walk_mu , gaussian_random_walk_sigma , N , lowess_kwargs
30- ):
29+ gaussian_random_walk_mu : float ,
30+ gaussian_random_walk_sigma : float ,
31+ N : int ,
32+ lowess_kwargs : dict ,
33+ ) -> tuple [np .ndarray , np .ndarray ]:
3134 """
3235 Generates Gaussian random walk data and applies LOWESS
3336
@@ -48,12 +51,12 @@ def _smoothed_gaussian_random_walk(
4851
4952
5053def generate_synthetic_control_data (
51- N = 100 ,
52- treatment_time = 70 ,
53- grw_mu = 0.25 ,
54- grw_sigma = 1 ,
55- lowess_kwargs = default_lowess_kwargs ,
56- ):
54+ N : int = 100 ,
55+ treatment_time : int = 70 ,
56+ grw_mu : float = 0.25 ,
57+ grw_sigma : float = 1 ,
58+ lowess_kwargs : dict = default_lowess_kwargs ,
59+ ) -> tuple [ pd . DataFrame , np . ndarray ] :
5760 """
5861 Generates data for synthetic control example.
5962
@@ -108,8 +111,12 @@ def generate_synthetic_control_data(
108111
109112
110113def generate_time_series_data (
111- N = 100 , treatment_time = 70 , beta_temp = - 1 , beta_linear = 0.5 , beta_intercept = 3
112- ):
114+ N : int = 100 ,
115+ treatment_time : int = 70 ,
116+ beta_temp : float = - 1 ,
117+ beta_linear : float = 0.5 ,
118+ beta_intercept : float = 3 ,
119+ ) -> pd .DataFrame :
113120 """
114121 Generates interrupted time series example data
115122
@@ -155,7 +162,9 @@ def generate_time_series_data(
155162 return df
156163
157164
158- def generate_time_series_data_seasonal (treatment_time ):
165+ def generate_time_series_data_seasonal (
166+ treatment_time : pd .Timestamp ,
167+ ) -> pd .DataFrame :
159168 """
160169 Generates 10 years of monthly data with seasonality
161170 """
@@ -169,11 +178,13 @@ def generate_time_series_data_seasonal(treatment_time):
169178 t = df .index ,
170179 ).set_index ("date" , drop = True )
171180 month_effect = np .array ([11 , 13 , 12 , 15 , 19 , 23 , 21 , 28 , 20 , 17 , 15 , 12 ])
172- df ["y" ] = 0.2 * df ["t" ] + 2 * month_effect [df .month .values - 1 ]
181+ df ["y" ] = 0.2 * df ["t" ] + 2 * month_effect [np . asarray ( df .month .values ) - 1 ]
173182
174183 N = df .shape [0 ]
175184 idx = np .arange (N )[df .index > treatment_time ]
176- df ["causal effect" ] = 100 * gamma (10 ).pdf (np .arange (0 , N , 1 ) - np .min (idx ))
185+ df ["causal effect" ] = 100 * gamma (10 ).pdf (
186+ np .array (np .arange (0 , N , 1 )) - int (np .min (idx ))
187+ )
177188
178189 df ["y" ] += df ["causal effect" ]
179190 df ["y" ] += norm (0 , 2 ).rvs (N )
@@ -183,7 +194,9 @@ def generate_time_series_data_seasonal(treatment_time):
183194 return df
184195
185196
186- def generate_time_series_data_simple (treatment_time , slope = 0.0 ):
197+ def generate_time_series_data_simple (
198+ treatment_time : pd .Timestamp , slope : float = 0.0
199+ ) -> pd .DataFrame :
187200 """Generate simple interrupted time series data, with no seasonality or temporal
188201 structure.
189202 """
@@ -205,7 +218,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
205218 return df
206219
207220
208- def generate_did ():
221+ def generate_did () -> pd . DataFrame :
209222 """
210223 Generate Difference in Differences data
211224
@@ -257,8 +270,8 @@ def outcome(
257270
258271
259272def generate_regression_discontinuity_data (
260- N = 100 , true_causal_impact = 0.5 , true_treatment_threshold = 0.0
261- ):
273+ N : int = 100 , true_causal_impact : float = 0.5 , true_treatment_threshold : float = 0.0
274+ ) -> pd . DataFrame :
262275 """
263276 Generate regression discontinuity example data
264277
@@ -289,8 +302,11 @@ def impact(x):
289302
290303
291304def generate_ancova_data (
292- N = 200 , pre_treatment_means = np .array ([10 , 12 ]), treatment_effect = 2 , sigma = 1
293- ):
305+ N : int = 200 ,
306+ pre_treatment_means : np .ndarray = np .array ([10 , 12 ]),
307+ treatment_effect : int = 2 ,
308+ sigma : int = 1 ,
309+ ) -> pd .DataFrame :
294310 """
295311 Generate ANCOVA example data
296312
@@ -310,7 +326,7 @@ def generate_ancova_data(
310326 return df
311327
312328
313- def generate_geolift_data ():
329+ def generate_geolift_data () -> pd . DataFrame :
314330 """Generate synthetic data for a geolift example. This will consists of 6 untreated
315331 countries. The treated unit `Denmark` is a weighted combination of the untreated
316332 units. We additionally specify a treatment effect which takes effect after the
@@ -360,7 +376,7 @@ def generate_geolift_data():
360376 return df
361377
362378
363- def generate_multicell_geolift_data ():
379+ def generate_multicell_geolift_data () -> pd . DataFrame :
364380 """Generate synthetic data for a geolift example. This will consists of 6 untreated
365381 countries. The treated unit `Denmark` is a weighted combination of the untreated
366382 units. We additionally specify a treatment effect which takes effect after the
@@ -422,7 +438,9 @@ def generate_multicell_geolift_data():
422438# -----------------
423439
424440
425- def generate_seasonality (n = 12 , amplitude = 1 , length_scale = 0.5 ):
441+ def generate_seasonality (
442+ n : int = 12 , amplitude : int = 1 , length_scale : float = 0.5
443+ ) -> np .ndarray :
426444 """Generate monthly seasonality by sampling from a Gaussian process with a
427445 Gaussian kernel, using numpy code"""
428446 # Generate the covariance matrix
@@ -436,14 +454,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
436454 return seasonality
437455
438456
439- def periodic_kernel (x1 , x2 , period = 1 , length_scale = 1 , amplitude = 1 ):
457+ def periodic_kernel (
458+ x1 : np .ndarray ,
459+ x2 : np .ndarray ,
460+ period : int = 1 ,
461+ length_scale : float = 1.0 ,
462+ amplitude : int = 1 ,
463+ ) -> np .ndarray :
440464 """Generate a periodic kernel for gaussian process"""
441465 return amplitude ** 2 * np .exp (
442466 - 2 * np .sin (np .pi * np .abs (x1 - x2 ) / period ) ** 2 / length_scale ** 2
443467 )
444468
445469
446- def create_series (n = 52 , amplitude = 1 , length_scale = 2 , n_years = 4 , intercept = 3 ):
470+ def create_series (
471+ n : int = 52 ,
472+ amplitude : int = 1 ,
473+ length_scale : int = 2 ,
474+ n_years : int = 4 ,
475+ intercept : int = 3 ,
476+ ) -> np .ndarray :
447477 """
448478 Returns numpy tile with generated seasonality data repeated over
449479 multiple years
0 commit comments