Skip to content

Commit 8fee905

Browse files
CopilotwilliambdeanjuanitorduzCopilot
authored
Remove Prior classes and functions from pymc_marketing.prior module (#2060)
* Initial plan * Remove deprecated functions and classes from prior module - Removed deprecated class wrappers: Prior, Censored, Scaled - Removed deprecated function wrappers: sample_prior, create_dim_handler, handle_dims, register_tensor_transform - Removed deprecation warning decorators - Removed unused imports (functools, warnings) - Kept alternative serialization functionality (is_alternative_prior, deserialize_alternative_prior) - Updated imports in test_fourier.py and Fourier_Seasonality.py to use pymc_extras.prior - Removed deprecation tests from test_prior.py Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> * Fix import ordering in Fourier_Seasonality.py Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> * Update nested_logit notebook to use pymc_extras.prior Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> * Fix end-of-file in nested_logit.ipynb for pre-commit lint Co-authored-by: juanitorduz <22996444+juanitorduz@users.noreply.github.com> * Fix import ordering in nested_logit.ipynb for ruff pre-commit Co-authored-by: juanitorduz <22996444+juanitorduz@users.noreply.github.com> * Fix test_prior.py by restoring import of pymc_marketing.prior The import is needed to trigger registration of the alternative prior deserializer. Without it, the deserialize() function doesn't know how to handle dict-based prior specifications. Co-authored-by: juanitorduz <22996444+juanitorduz@users.noreply.github.com> * Update tests/test_prior.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix import formatting in test_prior.py for ruff compliance Ruff requires imports to be properly grouped and formatted. The local import was reformatted into a multi-line format for consistency. Co-authored-by: juanitorduz <22996444+juanitorduz@users.noreply.github.com> * Ensure deserializer registration on pymc_marketing import Added import of pymc_marketing.prior in the main __init__.py to ensure the alternative prior deserializer is registered when users import pymc_marketing or pymc_marketing.mmm. This makes the deserializer available for YAML-based configurations without requiring explicit import of the prior module. Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> * Revert "Ensure deserializer registration on pymc_marketing import" This reverts commit 0adb6c4. The import of pymc_marketing.prior in __init__.py is not needed since the deserializer is already registered via pymc_marketing/mmm/builders/factories.py when using YAML-based configurations. Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> Co-authored-by: juanitorduz <22996444+juanitorduz@users.noreply.github.com> Co-authored-by: Juan Orduz <juanitorduz@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 7baa55e commit 8fee905

File tree

5 files changed

+7
-88
lines changed

5 files changed

+7
-88
lines changed

docs/source/notebooks/customer_choice/nested_logit.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
"import arviz as az\n",
2929
"import matplotlib.pyplot as plt\n",
3030
"import pandas as pd\n",
31+
"from pymc_extras.prior import Prior\n",
3132
"\n",
3233
"from pymc_marketing.customer_choice.nested_logit import NestedLogit\n",
33-
"from pymc_marketing.paths import data_dir\n",
34-
"from pymc_marketing.prior import Prior"
34+
"from pymc_marketing.paths import data_dir"
3535
]
3636
},
3737
{

pymc_marketing/prior.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ def custom_transform(x):
103103
from __future__ import annotations
104104

105105
import copy
106-
import functools
107-
import warnings
108106
from typing import Any
109107

110108
from pymc_extras import prior
@@ -162,70 +160,3 @@ def deserialize_alternative_prior(data: dict[str, Any]) -> prior.Prior:
162160

163161
# Register the alternative prior deserializer for more complex nested cases
164162
register_deserialization(is_alternative_prior, deserialize_alternative_prior)
165-
166-
167-
def warn_class_deprecation(func):
168-
"""Warn about the deprecation of this module."""
169-
170-
@functools.wraps(func)
171-
def wrapper(self, *args, **kwargs):
172-
name = self.__class__.__name__
173-
warnings.warn(
174-
f"The {name} class has moved to pymc_extras.prior module and will be removed in a future release. "
175-
f"Import it from `from pymc_extras.prior import {name}`. ",
176-
DeprecationWarning,
177-
stacklevel=2,
178-
)
179-
return func(self, *args, **kwargs)
180-
181-
return wrapper
182-
183-
184-
def warn_function_deprecation(func):
185-
"""Warn about the deprecation of this function."""
186-
187-
@functools.wraps(func)
188-
def wrapper(*args, **kwargs):
189-
name = func.__name__
190-
warnings.warn(
191-
f"The {name} function has moved to pymc_extras.prior module and will be removed in a future release. "
192-
f"Import it from `from pymc_extras.prior import {name}`.",
193-
DeprecationWarning,
194-
stacklevel=2,
195-
)
196-
return func(*args, **kwargs)
197-
198-
return wrapper
199-
200-
201-
class Prior(prior.Prior):
202-
"""Backwards-compatible wrapper for the Prior class."""
203-
204-
@warn_class_deprecation
205-
def __init__(self, *args, **kwargs):
206-
"""Initialize the Prior class with the given arguments."""
207-
super().__init__(*args, **kwargs)
208-
209-
210-
class Censored(prior.Censored):
211-
"""Backwards-compatible wrapper for the CensoredPrior class."""
212-
213-
@warn_class_deprecation
214-
def __init__(self, *args, **kwargs):
215-
"""Initialize the CensoredPrior class with the given arguments."""
216-
super().__init__(*args, **kwargs)
217-
218-
219-
class Scaled(prior.Scaled):
220-
"""Backwards-compatible wrapper for the ScaledPrior class."""
221-
222-
@warn_class_deprecation
223-
def __init__(self, *args, **kwargs):
224-
"""Initialize the ScaledPrior class with the given arguments."""
225-
super().__init__(*args, **kwargs)
226-
227-
228-
sample_prior = warn_function_deprecation(prior.sample_prior)
229-
create_dim_handler = warn_function_deprecation(prior.create_dim_handler)
230-
handle_dims = warn_function_deprecation(prior.handle_dims)
231-
register_tensor_transform = warn_function_deprecation(prior.register_tensor_transform)

streamlit/mmm-explainer/pages/Fourier_Seasonality.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
"""Streamlit page for fourier modes."""
1515

1616
import plotly.graph_objects as go
17+
from pymc_extras.prior import Prior
1718

1819
import streamlit as st
1920
from pymc_marketing.mmm import MonthlyFourier, YearlyFourier
20-
from pymc_marketing.prior import Prior
2121

2222
# Constants
2323
PLOT_HEIGHT = 500

tests/mmm/test_fourier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
deserialize,
2424
register_deserialization,
2525
)
26+
from pymc_extras.prior import Prior
2627

2728
from pymc_marketing.mmm.fourier import (
2829
FourierBase,
@@ -31,7 +32,6 @@
3132
YearlyFourier,
3233
generate_fourier_modes,
3334
)
34-
from pymc_marketing.prior import Prior
3535

3636

3737
@pytest.mark.parametrize(

tests/test_prior.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from pymc_extras.deserialize import deserialize
1717
from pymc_extras.prior import Prior
1818

19-
from pymc_marketing import prior
19+
from pymc_marketing import (
20+
prior, # noqa: F401 - import needed to register custom deserializers
21+
)
2022

2123

2224
@pytest.mark.parametrize(
@@ -51,17 +53,3 @@
5153
)
5254
def test_alternative_prior_deserialize(data, expected) -> None:
5355
assert deserialize(data) == expected
54-
55-
56-
@pytest.mark.parametrize(
57-
"obj, args, kwargs, match",
58-
[
59-
(prior.Prior, ["Normal"], {}, "The Prior class"),
60-
(prior.Censored, [Prior("Normal")], dict(lower=0), "The Censored"),
61-
(prior.Scaled, [Prior("Normal")], dict(factor=2), "The Scaled"),
62-
(prior.create_dim_handler, [("date", "channel")], {}, "The create_dim_handler"),
63-
],
64-
)
65-
def test_deprecation_warnings(obj, args, kwargs, match):
66-
with pytest.warns(DeprecationWarning, match=match):
67-
obj(*args, **kwargs)

0 commit comments

Comments
 (0)