Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
from pymc_marketing.clv.models.basic import CLVModel
from pymc_marketing.mmm import MMM
from pymc_marketing.mmm.evaluation import compute_summary_metrics
from pymc_marketing.mmm.multidimensional import MMM as MultiDimensionalMMM
from pymc_marketing.version import __version__

FLAVOR_NAME = "pymc"
Expand Down Expand Up @@ -1257,6 +1258,7 @@ def new_fit(self, *args, **kwargs):

if log_mmm:
MMM.fit = patch_mmm_fit(MMM.fit)
MultiDimensionalMMM.fit = patch_mmm_fit(MultiDimensionalMMM.fit)

def patch_clv_fit(fit):
@wraps(fit)
Expand Down
81 changes: 81 additions & 0 deletions tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
log_sample_diagnostics,
)
from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM as MultiDimensionalMMM
from pymc_marketing.version import __version__

seed = sum(map(ord, "mlflow-with-pymc"))
Expand All @@ -61,6 +62,8 @@ def setup_module():
pm.sample = pm.sample.__wrapped__
while hasattr(MMM.fit, "__wrapped__"):
MMM.fit = MMM.fit.__wrapped__
while hasattr(MultiDimensionalMMM.fit, "__wrapped__"):
MultiDimensionalMMM.fit = MultiDimensionalMMM.fit.__wrapped__


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -539,6 +542,84 @@ def test_autolog_mmm(mmm, toy_X, toy_y) -> None:
}


@pytest.fixture(scope="module")
def multidimensional_mmm() -> MultiDimensionalMMM:
return MultiDimensionalMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
target_column="y",
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)


@pytest.fixture(scope="module")
def toy_multidim_X() -> pd.DataFrame:
# Simple data for multidimensional MMM test
n_obs = 20
date_data = pd.DataFrame(
{
"date": pd.date_range(start="2020-01-01", periods=n_obs, freq="W-MON"),
"channel_1": rng.integers(low=0, high=100, size=n_obs),
"channel_2": rng.integers(low=0, high=100, size=n_obs),
}
)
return date_data


@pytest.fixture(scope="module")
def toy_multidim_y(toy_multidim_X: pd.DataFrame) -> pd.Series:
return pd.Series(
data=rng.integers(low=0, high=100, size=toy_multidim_X.shape[0]), name="y"
)


def test_autolog_multidimensional_mmm(
multidimensional_mmm, toy_multidim_X, toy_multidim_y
) -> None:
mlflow.set_experiment("pymc-marketing-test-suite-multidimensional-mmm")
with mlflow.start_run() as run:
draws = 10
tune = 5
chains = 1
multidimensional_mmm.fit(
toy_multidim_X,
toy_multidim_y,
draws=draws,
chains=chains,
tune=tune,
)

assert mlflow.active_run() is None

run_id = run.info.run_id
inputs, params, metrics, tags, artifacts = get_run_data(run_id)

param_checks(
params=params,
draws=draws,
chains=chains,
tune=tune,
nuts_sampler="pymc",
)

assert params["adstock_name"] == "geometric"
assert params["saturation_name"] == "logistic"

metric_checks(metrics, "pymc")

assert set(artifacts) == {
"coords.json",
"idata.nc",
"model_graph.pdf",
"model_repr.txt",
"summary.html",
}
assert tags == {}

assert len(inputs) == 1


@pytest.fixture(scope="function")
def mock_idata() -> az.InferenceData:
chains = 4
Expand Down