Skip to content

Commit bcace83

Browse files
committed
Fixes PyTensor compatibility for normalize_axis_index
Handles compatibility issues with PyTensor's `normalize_axis_index` function. PyTensor 2.35+ no longer exports `normalize_axis_index`, so the code now attempts to import it from `numpy._core.numeric` or `numpy.core.numeric` if the initial import fails.
1 parent 0fb6474 commit bcace83

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

pymc_marketing/mmm/transformers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,17 @@
2121
import pymc as pm
2222
import pytensor.tensor as pt
2323
from pymc.distributions.dist_math import check_parameters
24-
from pytensor.npy_2_compat import normalize_axis_index
24+
25+
# Import normalize_axis_index - handle compatibility
26+
try:
27+
from pytensor.npy_2_compat import normalize_axis_index
28+
except ImportError:
29+
# PyTensor 2.35+ no longer exports this, use numpy directly
30+
try:
31+
from numpy._core.numeric import normalize_axis_index
32+
except (ImportError, AttributeError):
33+
# Older numpy versions
34+
from numpy.core.numeric import normalize_axis_index
2535

2636

2737
class ConvMode(str, Enum):

0 commit comments

Comments
 (0)