Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import pytensor.tensor as pt
import pytest

from numpyro.infer import MCMC
from pytensor.compile import SharedVariable
from pytensor.graph import graph_inputs

Expand All @@ -45,6 +44,8 @@
sample_numpyro_nuts,
)

MCMC = pytest.importorskip("numpyro.infer.MCMC")


def test_jax_PosDefMatrix():
x = pt.tensor(name="x", shape=(2, 2), dtype="float32")
Expand Down
5 changes: 5 additions & 0 deletions tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):


def test_step_args():
pytest.importorskip("numpyro")

with Model() as model:
a = Normal("a")
idata = sample(
Expand All @@ -91,6 +93,9 @@ def test_step_args():

@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
def test_sample_var_names(nuts_sampler):
if nuts_sampler != "pymc":
pytest.importorskip(nuts_sampler)

seed = 1234
kwargs = {
"chains": 1,
Expand Down