Skip to content
31 changes: 29 additions & 2 deletions pymc_marketing/mmm/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,24 @@

Plot the prior fourier seasonality trend.

.. code-block:: python
.. plot::
:context: close-figs

import pandas as pd
import pymc as pm
import matplotlib.pyplot as plt

from pymc_marketing.mmm import YearlyFourier

yearly = YearlyFourier(n_order=3)

dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")

dayofyear = dates.dayofyear.to_numpy()

with pm.Model() as model:
fourier_trend = yearly.apply(dayofyear)

prior = yearly.sample_prior()
curve = yearly.sample_curve(prior)
yearly.plot_curve(curve)
Expand Down Expand Up @@ -107,10 +121,23 @@

All the plotting will still work! Just pass any coords.

.. code-block:: python
.. plot::
:context: close-figs

import matplotlib.pyplot as plt

from pymc_marketing.mmm import YearlyFourier
from pymc_extras.prior import Prior

# "fourier" is the default prefix!
prior = Prior(
"Laplace",
mu=Prior("Normal", dims="fourier"),
b=Prior("HalfNormal", sigma=0.1, dims="fourier"),
dims=("fourier", "hierarchy"),
)
yearly = YearlyFourier(n_order=3, prior=prior)

coords = {"hierarchy": ["A", "B", "C"]}
prior = yearly.sample_prior(coords=coords)
curve = yearly.sample_curve(prior)
Expand Down
Loading