Skip to content

Commit 9cc82f4

Browse files
Copilotwilliambdeancetagostinijuanitorduz
authored
Integrate multidimensional MMM class with MLflow autologging (#2072)
* Initial plan * Integrate new multidimensional MMM class with MLflow autologging Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> * Add test for multidimensional MMM MLflow integration Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> * Fix unused variable in test_autolog_multidimensional_mmm Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> Co-authored-by: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
1 parent f59d7e3 commit 9cc82f4

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

pymc_marketing/mlflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
from pymc_marketing.clv.models.basic import CLVModel
172172
from pymc_marketing.mmm import MMM
173173
from pymc_marketing.mmm.evaluation import compute_summary_metrics
174+
from pymc_marketing.mmm.multidimensional import MMM as MultiDimensionalMMM
174175
from pymc_marketing.version import __version__
175176

176177
FLAVOR_NAME = "pymc"
@@ -1257,6 +1258,7 @@ def new_fit(self, *args, **kwargs):
12571258

12581259
if log_mmm:
12591260
MMM.fit = patch_mmm_fit(MMM.fit)
1261+
MultiDimensionalMMM.fit = patch_mmm_fit(MultiDimensionalMMM.fit)
12601262

12611263
def patch_clv_fit(fit):
12621264
@wraps(fit)

tests/test_mlflow.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
log_sample_diagnostics,
4343
)
4444
from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
45+
from pymc_marketing.mmm.multidimensional import MMM as MultiDimensionalMMM
4546
from pymc_marketing.version import __version__
4647

4748
seed = sum(map(ord, "mlflow-with-pymc"))
@@ -61,6 +62,8 @@ def setup_module():
6162
pm.sample = pm.sample.__wrapped__
6263
while hasattr(MMM.fit, "__wrapped__"):
6364
MMM.fit = MMM.fit.__wrapped__
65+
while hasattr(MultiDimensionalMMM.fit, "__wrapped__"):
66+
MultiDimensionalMMM.fit = MultiDimensionalMMM.fit.__wrapped__
6467

6568

6669
@pytest.fixture(scope="module")
@@ -539,6 +542,84 @@ def test_autolog_mmm(mmm, toy_X, toy_y) -> None:
539542
}
540543

541544

545+
@pytest.fixture(scope="module")
546+
def multidimensional_mmm() -> MultiDimensionalMMM:
547+
return MultiDimensionalMMM(
548+
date_column="date",
549+
channel_columns=["channel_1", "channel_2"],
550+
target_column="y",
551+
adstock=GeometricAdstock(l_max=4),
552+
saturation=LogisticSaturation(),
553+
)
554+
555+
556+
@pytest.fixture(scope="module")
557+
def toy_multidim_X() -> pd.DataFrame:
558+
# Simple data for multidimensional MMM test
559+
n_obs = 20
560+
date_data = pd.DataFrame(
561+
{
562+
"date": pd.date_range(start="2020-01-01", periods=n_obs, freq="W-MON"),
563+
"channel_1": rng.integers(low=0, high=100, size=n_obs),
564+
"channel_2": rng.integers(low=0, high=100, size=n_obs),
565+
}
566+
)
567+
return date_data
568+
569+
570+
@pytest.fixture(scope="module")
571+
def toy_multidim_y(toy_multidim_X: pd.DataFrame) -> pd.Series:
572+
return pd.Series(
573+
data=rng.integers(low=0, high=100, size=toy_multidim_X.shape[0]), name="y"
574+
)
575+
576+
577+
def test_autolog_multidimensional_mmm(
578+
multidimensional_mmm, toy_multidim_X, toy_multidim_y
579+
) -> None:
580+
mlflow.set_experiment("pymc-marketing-test-suite-multidimensional-mmm")
581+
with mlflow.start_run() as run:
582+
draws = 10
583+
tune = 5
584+
chains = 1
585+
multidimensional_mmm.fit(
586+
toy_multidim_X,
587+
toy_multidim_y,
588+
draws=draws,
589+
chains=chains,
590+
tune=tune,
591+
)
592+
593+
assert mlflow.active_run() is None
594+
595+
run_id = run.info.run_id
596+
inputs, params, metrics, tags, artifacts = get_run_data(run_id)
597+
598+
param_checks(
599+
params=params,
600+
draws=draws,
601+
chains=chains,
602+
tune=tune,
603+
nuts_sampler="pymc",
604+
)
605+
606+
assert params["adstock_name"] == "geometric"
607+
assert params["saturation_name"] == "logistic"
608+
609+
metric_checks(metrics, "pymc")
610+
611+
assert set(artifacts) == {
612+
"coords.json",
613+
"idata.nc",
614+
"model_graph.pdf",
615+
"model_repr.txt",
616+
"summary.html",
617+
}
618+
assert tags == {}
619+
620+
assert len(inputs) == 1
621+
622+
542623
@pytest.fixture(scope="function")
543624
def mock_idata() -> az.InferenceData:
544625
chains = 4

0 commit comments

Comments
 (0)