Skip to content

Commit 9a88b01

Browse files
authored
Don't change matplotlib style on import of causalpy + misc dev fixes (#538)
* fix environment issue * update makefile to avoid local dev errors * hide arviz style change in BaseExperiment.plot * move imports to top
1 parent 2d06be5 commit 9a88b01

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ check_lint:
1313
interrogate .
1414

1515
doctest:
16-
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
16+
python -m pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
1717

1818
test:
1919
python -m pytest

causalpy/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import arviz as az
1514

1615
import causalpy.pymc_models as pymc_models
1716
import causalpy.skl_models as skl_models
@@ -28,8 +27,6 @@
2827
from .experiments.regression_kink import RegressionKink
2928
from .experiments.synthetic_control import SyntheticControl
3029

31-
az.style.use("arviz-darkgrid")
32-
3330
__all__ = [
3431
"__version__",
3532
"DifferenceInDifferences",

causalpy/experiments/base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from abc import abstractmethod
1919

20+
import arviz as az
21+
import matplotlib.pyplot as plt
2022
import pandas as pd
2123
from sklearn.base import RegressorMixin
2224

@@ -63,12 +65,14 @@ def plot(self, *args, **kwargs) -> tuple:
6365
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
6466
depending on the model type.
6567
"""
66-
if isinstance(self.model, PyMCModel):
67-
return self._bayesian_plot(*args, **kwargs)
68-
elif isinstance(self.model, RegressorMixin):
69-
return self._ols_plot(*args, **kwargs)
70-
else:
71-
raise ValueError("Unsupported model type")
68+
# Apply arviz-darkgrid style only during plotting, then revert
69+
with plt.style.context(az.style.library["arviz-darkgrid"]):
70+
if isinstance(self.model, PyMCModel):
71+
return self._bayesian_plot(*args, **kwargs)
72+
elif isinstance(self.model, RegressorMixin):
73+
return self._ols_plot(*args, **kwargs)
74+
else:
75+
raise ValueError("Unsupported model type")
7276

7377
@abstractmethod
7478
def _bayesian_plot(self, *args, **kwargs):

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ dependencies:
1616
- statsmodels
1717
- xarray>=v2022.11.0
1818
- pymc-extras>=0.3.0
19+
- python>=3.11

0 commit comments

Comments
 (0)