From 0e95705f9f3edabb31b7e2cf0193ccd1d580ed61 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 27 Oct 2025 20:21:20 +0200 Subject: [PATCH 1/8] Add MLX backend support for Nutpie compilation Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX. --- pyproject.toml | 3 + python/nutpie/compile_pymc.py | 137 +++++++++++++++++++++++++++++++--- tests/test_pymc.py | 22 +++++- 3 files changed, 150 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84e5e33..f52b4a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ Repository = "https://github.com/pymc-devs/nutpie" stan = ["bridgestan >= 2.7.0", "stanio >= 0.5.1"] pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] +pymc-mlx = ["pymc >= 5.20.1", "mlx >= 0.20.0"] nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] dev = [ "bridgestan >= 2.7.0", @@ -41,6 +42,7 @@ dev = [ "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", + "mlx >= 0.20.0", "flowjax >= 17.0.2", "pytest", "pytest-timeout", @@ -52,6 +54,7 @@ all = [ "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", + "mlx >= 0.20.0", "flowjax >= 17.1.0", "equinox >= 0.11.12", ] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 08bbc94..1c866db 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -470,11 +470,116 @@ def expand(_x, **shared): ) +def _compile_pymc_model_mlx( + model, + *, + gradient_backend=None, + pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], + var_names: Iterable[str] | None = None, + **kwargs, +): + if find_spec("mlx") is None: + raise ImportError( + "MLX is not installed in the current environment. " + "Please install it with something like " + "'pip install mlx' " + "and restart your kernel in case you are in an interactive session." + ) + import mlx.core as mx + + if gradient_backend is None: + gradient_backend = "pytensor" + elif gradient_backend not in ["mlx", "pytensor"]: + raise ValueError(f"Unknown gradient backend: {gradient_backend}") + + ( + n_dim, + _, + logp_fn_pt, + expand_fn_pt, + initial_point_fn, + shape_info, + ) = _make_functions( + model, + mode="MLX", + compute_grad=gradient_backend == "pytensor", + join_expanded=False, + pymc_initial_point_fn=pymc_initial_point_fn, + var_names=var_names, + ) + + logp_fn = logp_fn_pt.vm.jit_fn + expand_fn = expand_fn_pt.vm.jit_fn + + logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] + expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] + + if gradient_backend == "mlx": + orig_logp_fn = logp_fn + + def logp_fn_mlx_grad(x, *shared): + return mx.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x) + + logp_fn = logp_fn_mlx_grad + else: + orig_logp_fn = None + + shared_data = {} + shared_vars = {} + seen = set() + for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: + if val.name in shared_data and val not in seen: + raise ValueError(f"Shared variables must have unique names: {val.name}") + shared_data[val.name] = mx.array(val.get_value()) + shared_vars[val.name] = val + seen.add(val) + + def make_logp_func(): + def logp(_x, **shared): + logp, grad = logp_fn(_x, *[shared[name] for name in logp_shared_names]) + return float(logp), np.asarray(grad, dtype="float64", order="C") + + return logp + + names, slices, shapes = shape_info + # TODO do not cast to float64 + dtypes = [np.dtype("float64")] * len(names) + + def make_expand_func(seed1, seed2, chain): + # TODO handle seeds + def expand(_x, **shared): + values = expand_fn(_x, *[shared[name] for name in expand_shared_names]) + return { + name: np.asarray(val, order="C", dtype=dtype).reshape(shape) + for name, val, dtype, shape in zip( + names, values, dtypes, shapes, strict=True + ) + } + + return expand + + dims, coords = _prepare_dims_and_coords(model, shape_info) + + return from_pyfunc( + ndim=n_dim, + make_logp_fn=make_logp_func, + make_expand_fn=make_expand_func, + make_initial_point_fn=initial_point_fn, + expanded_dtypes=dtypes, + expanded_shapes=shapes, + expanded_names=names, + shared_data=shared_data, + dims=dims, + coords=coords, + raw_logp_fn=orig_logp_fn, + ) + + def compile_pymc_model( model: "pm.Model", *, - backend: Literal["numba", "jax"] = "numba", - gradient_backend: Literal["pytensor", "jax"] = "pytensor", + backend: Literal["numba", "jax", "mlx"] = "numba", + gradient_backend: Literal["pytensor", "jax", "mlx"] = "pytensor", initial_points: dict[Union["Variable", str], np.ndarray | float | int] | None = None, jitter_rvs: set["TensorVariable"] | None = None, @@ -491,11 +596,11 @@ def compile_pymc_model( ---------- model : pymc.Model The model to compile. - backend : ["jax", "numba"] + backend : ["jax", "numba", "mlx"] The pytensor backend that is used to compile the logp function. - gradient_backend: ["pytensor", "jax"] + gradient_backend: ["pytensor", "jax", "mlx"] Which library is used to compute the gradients. This can only be changed - to "jax" if the jax backend is used. + to "jax" if the jax backend is used, or "mlx" if the mlx backend is used. jitter_rvs : set The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for @@ -534,7 +639,7 @@ def compile_pymc_model( from pymc.initial_point import make_initial_point_fn if freeze_model is None: - freeze_model = backend == "jax" + freeze_model = backend in ["jax", "mlx"] if freeze_model: model = freeze_dims_and_data(model) @@ -553,8 +658,10 @@ def compile_pymc_model( initial_point_fn = _wrap_with_lock(initial_point_fn) if backend.lower() == "numba": - if gradient_backend == "jax": - raise ValueError("Gradient backend cannot be jax when using numba backend") + if gradient_backend in ["jax", "mlx"]: + raise ValueError( + f"Gradient backend cannot be {gradient_backend} when using numba backend" + ) return _compile_pymc_model_numba( model=model, pymc_initial_point_fn=initial_point_fn, @@ -569,8 +676,16 @@ def compile_pymc_model( var_names=var_names, **kwargs, ) + elif backend.lower() == "mlx": + return _compile_pymc_model_mlx( + model=model, + gradient_backend=gradient_backend, + pymc_initial_point_fn=initial_point_fn, + var_names=var_names, + **kwargs, + ) else: - raise ValueError(f"Backend must be one of numba and jax. Got {backend}") + raise ValueError(f"Backend must be one of numba, jax, and mlx. Got {backend}") def _wrap_with_lock(func: Callable) -> Callable: @@ -616,7 +731,7 @@ def _compute_shapes(model) -> dict[str, tuple[int, ...]]: def _make_functions( model: "pm.Model", *, - mode: Literal["JAX", "NUMBA"], + mode: Literal["JAX", "NUMBA", "MLX"], compute_grad: bool, join_expanded: bool, pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], @@ -637,7 +752,7 @@ def _make_functions( model: pymc.Model The model to compile mode: str - Pytensor compile mode. One of "NUMBA" or "JAX" + Pytensor compile mode. One of "NUMBA", "JAX", or "MLX" compute_grad: bool Whether to compute gradients using pytensor. Must be True if mode is "NUMBA", otherwise False implies Jax will be used to compute gradients diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 677c63c..2adfdc9 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -14,7 +14,13 @@ parameterize_backends = pytest.mark.parametrize( "backend, gradient_backend", - [("numba", None), ("jax", "pytensor"), ("jax", "jax")], + [ + ("numba", None), + ("jax", "pytensor"), + ("jax", "jax"), + ("mlx", "pytensor"), + ("mlx", "mlx"), + ], ) @@ -465,6 +471,20 @@ def test_deterministic_sampling_jax(): return trace.posterior.a.values.ravel() +@pytest.mark.pymc +@pytest.mark.array_compare(atol=1e-4, rtol=1e-4) +def test_deterministic_sampling_mlx(): + if find_spec("mlx") is None: + pytest.skip("MLX not installed") + + with pm.Model() as model: + pm.HalfNormal("a") + + compiled = nutpie.compile_pymc_model(model, backend="mlx", gradient_backend="mlx") + trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100) + return trace.posterior.a.values.ravel() + + @pytest.mark.pymc def test_zarr_store(tmp_path): with pm.Model() as model: From d57f9303f8517c3f09b169e116add179e23f1d3c Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 27 Oct 2025 20:46:12 +0200 Subject: [PATCH 2/8] Update MLX dependency and JIT compile logp function Bump MLX version requirement to >=0.29.0 in pyproject.toml for all relevant extras. In compile_pymc.py, JIT compile the logp function using mx.compile for improved performance, aligning with JAX backend behavior. --- pyproject.toml | 6 +++--- python/nutpie/compile_pymc.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f52b4a3..adad2a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ Repository = "https://github.com/pymc-devs/nutpie" stan = ["bridgestan >= 2.7.0", "stanio >= 0.5.1"] pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] -pymc-mlx = ["pymc >= 5.20.1", "mlx >= 0.20.0"] +pymc-mlx = ["pymc >= 5.20.1", "mlx >= 0.29.0"] nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] dev = [ "bridgestan >= 2.7.0", @@ -42,7 +42,7 @@ dev = [ "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", - "mlx >= 0.20.0", + "mlx >= 0.29.0", "flowjax >= 17.0.2", "pytest", "pytest-timeout", @@ -54,7 +54,7 @@ all = [ "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", - "mlx >= 0.20.0", + "mlx >= 0.29.0", "flowjax >= 17.1.0", "equinox >= 0.11.12", ] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 1c866db..9aa5277 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -520,6 +520,9 @@ def _compile_pymc_model_mlx( def logp_fn_mlx_grad(x, *shared): return mx.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x) + # JIT compile for performance, similar to jax.jit() in JAX backend + logp_fn_mlx_grad = mx.compile(logp_fn_mlx_grad) + logp_fn = logp_fn_mlx_grad else: orig_logp_fn = None From f73cb55d192a94a6327bf89c32a9bed05a7559c2 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 27 Oct 2025 21:28:10 +0200 Subject: [PATCH 3/8] issue with tests --- .github/workflows/ci.yml | 4 ++++ tests/test_pymc.py | 29 ++++++++++++++++++++--------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index db96669..7f841b3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -236,6 +236,10 @@ jobs: pytest -m "stan and not flow" --arraydiff uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall uv pip install jax + # Install MLX only on Apple Silicon (aarch64) + if [ "${{ matrix.platform.target }}" = "aarch64" ]; then + uv pip install mlx + fi pytest -m "pymc and not flow" --arraydiff uv pip install 'nutpie[all]' --find-links dist --force-reinstall pytest -m flow --arraydiff diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 2adfdc9..d3fc9a7 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -12,15 +12,26 @@ import nutpie import nutpie.compile_pymc -parameterize_backends = pytest.mark.parametrize( - "backend, gradient_backend", - [ - ("numba", None), - ("jax", "pytensor"), - ("jax", "jax"), +# Check if MLX is available (macOS only, optional dependency) +MLX_AVAILABLE = find_spec("mlx") is not None + +# Build backend list dynamically based on availability +backend_params = [ + ("numba", None), + ("jax", "pytensor"), + ("jax", "jax"), +] + +# Only add MLX backends if MLX is available +if MLX_AVAILABLE: + backend_params.extend([ ("mlx", "pytensor"), ("mlx", "mlx"), - ], + ]) + +parameterize_backends = pytest.mark.parametrize( + "backend, gradient_backend", + backend_params, ) @@ -474,9 +485,9 @@ def test_deterministic_sampling_jax(): @pytest.mark.pymc @pytest.mark.array_compare(atol=1e-4, rtol=1e-4) def test_deterministic_sampling_mlx(): - if find_spec("mlx") is None: + if not MLX_AVAILABLE: pytest.skip("MLX not installed") - + with pm.Model() as model: pm.HalfNormal("a") From 13b9a432795f464acf0000ec5ce9a581a91ebdc4 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Mon, 27 Oct 2025 21:29:23 +0200 Subject: [PATCH 4/8] Update test_pymc.py --- tests/test_pymc.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_pymc.py b/tests/test_pymc.py index d3fc9a7..7f22b13 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -24,10 +24,12 @@ # Only add MLX backends if MLX is available if MLX_AVAILABLE: - backend_params.extend([ - ("mlx", "pytensor"), - ("mlx", "mlx"), - ]) + backend_params.extend( + [ + ("mlx", "pytensor"), + ("mlx", "mlx"), + ] + ) parameterize_backends = pytest.mark.parametrize( "backend, gradient_backend", @@ -487,7 +489,7 @@ def test_deterministic_sampling_jax(): def test_deterministic_sampling_mlx(): if not MLX_AVAILABLE: pytest.skip("MLX not installed") - + with pm.Model() as model: pm.HalfNormal("a") From f97d00235a468f0e9ec24ca0e4e150acffb7220f Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Wed, 29 Oct 2025 01:22:01 +0200 Subject: [PATCH 5/8] Changes --- python/nutpie/compile_pymc.py | 4 +++- python/nutpie/compiled_pyfunc.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 9aa5277..fc06b48 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -539,7 +539,8 @@ def logp_fn_mlx_grad(x, *shared): def make_logp_func(): def logp(_x, **shared): - logp, grad = logp_fn(_x, *[shared[name] for name in logp_shared_names]) + _x_mlx = mx.array(_x) + logp, grad = logp_fn(_x_mlx, *[shared[name] for name in logp_shared_names]) return float(logp), np.asarray(grad, dtype="float64", order="C") return logp @@ -575,6 +576,7 @@ def expand(_x, **shared): dims=dims, coords=coords, raw_logp_fn=orig_logp_fn, + force_single_core=(gradient_backend == "mlx") ) diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index db58c28..57f1756 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -23,6 +23,7 @@ class PyFuncModel(CompiledModel): _coords: dict[str, Any] _raw_logp_fn: Callable | None _transform_adapt_args: dict | None = None + _force_single_core: bool = False @property def shapes(self) -> dict[str, tuple[int, ...]]: @@ -42,13 +43,24 @@ def with_data(self, **updates): raise ValueError(f"Unknown data variable: {name}") updated = self._shared_data.copy() - updated.update(**updates) + + # Convert to MLX arrays if using MLX backend (indicated by force_single_core) + if self._force_single_core: + import mlx.core as mx + for name, value in updates.items(): + updated[name] = mx.array(value) + else: + updated.update(**updates) + return dataclasses.replace(self, _shared_data=updated) def with_transform_adapt(self, **kwargs): return dataclasses.replace(self, _transform_adapt_args=kwargs) def _make_sampler(self, settings, init_mean, cores, progress_type, store): + # Force single-core execution if required (e.g., for MLX backend) + if self._force_single_core: + cores = 1 model = self._make_model(init_mean) return _lib.PySampler.from_pyfunc( settings, @@ -108,6 +120,7 @@ def from_pyfunc( make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None, make_transform_adapter=None, raw_logp_fn=None, + force_single_core: bool = False, ): if coords is None: coords = {} @@ -139,4 +152,5 @@ def from_pyfunc( _variables=variables, _shared_data=shared_data, _raw_logp_fn=raw_logp_fn, + _force_single_core=force_single_core, ) From 7543d765c57c0124d1ee829e0f495de4dfced33c Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Wed, 29 Oct 2025 01:22:39 +0200 Subject: [PATCH 6/8] precommit --- python/nutpie/compile_pymc.py | 2 +- python/nutpie/compiled_pyfunc.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index fc06b48..4492976 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -576,7 +576,7 @@ def expand(_x, **shared): dims=dims, coords=coords, raw_logp_fn=orig_logp_fn, - force_single_core=(gradient_backend == "mlx") + force_single_core=(gradient_backend == "mlx"), ) diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 57f1756..1219299 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -43,15 +43,16 @@ def with_data(self, **updates): raise ValueError(f"Unknown data variable: {name}") updated = self._shared_data.copy() - + # Convert to MLX arrays if using MLX backend (indicated by force_single_core) if self._force_single_core: import mlx.core as mx + for name, value in updates.items(): updated[name] = mx.array(value) else: updated.update(**updates) - + return dataclasses.replace(self, _shared_data=updated) def with_transform_adapt(self, **kwargs): From 905f61fdd861a06d325034dbe6ff2cc8f33d1e70 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Wed, 29 Oct 2025 02:03:24 +0200 Subject: [PATCH 7/8] Create test_deterministic_sampling_mlx.txt --- .../test_deterministic_sampling_mlx.txt | 200 ++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 tests/reference/test_deterministic_sampling_mlx.txt diff --git a/tests/reference/test_deterministic_sampling_mlx.txt b/tests/reference/test_deterministic_sampling_mlx.txt new file mode 100644 index 0000000..7fbb8e7 --- /dev/null +++ b/tests/reference/test_deterministic_sampling_mlx.txt @@ -0,0 +1,200 @@ +2.38754 +0.138255 +0.0696831 +0.0204736 +0.0308912 +0.54282 +0.393642 +0.703939 +0.482711 +2.22404 +0.510405 +1.74111 +1.74111 +0.364023 +2.5632 +0.684026 +0.414764 +0.414764 +0.26368 +0.283684 +0.203853 +0.542914 +0.646859 +0.861701 +0.861701 +0.861701 +0.728978 +1.0528 +1.0528 +1.0528 +1.0528 +1.0528 +1.19056 +0.149186 +0.199848 +0.795704 +0.795704 +0.734627 +0.84465 +0.915367 +0.704601 +0.704601 +0.704601 +0.414619 +0.227516 +0.22453 +0.889432 +0.889432 +0.735724 +0.719196 +0.688977 +0.0724534 +0.0457166 +0.400116 +0.974773 +0.974773 +0.870367 +0.870367 +0.870367 +0.870367 +0.870367 +0.870367 +0.840601 +0.787697 +0.787697 +0.411171 +0.506744 +0.335273 +0.309355 +0.495757 +0.632767 +0.632767 +0.632767 +0.516386 +0.550656 +1.42213 +2.26821 +2.26821 +0.618058 +0.736866 +0.736866 +0.0860633 +0.0539637 +0.0327673 +0.0810014 +0.0956654 +0.134825 +0.224479 +0.309928 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.16806 +1.16806 +1.31512 +1.31512 +1.03726 +0.25523 +2.44138 +1.07632 +0.655629 +0.601123 +0.542795 +0.319316 +1.10565 +0.597321 +0.33601 +0.359843 +0.761331 +0.425976 +0.825186 +0.340091 +0.0631592 +0.0734237 +0.0862233 +0.879708 +0.3444 +0.634993 +0.634993 +0.13069 +0.135324 +0.114996 +0.346445 +0.181392 +0.155654 +0.178757 +1.06726 +0.640434 +2.9661 +0.619286 +0.686098 +1.03444 +1.08534 +1.44254 +0.932663 +0.547982 +1.63447 +1.23931 +2.36775 +1.46697 +1.28671 +1.41629 +0.918363 +1.11078 +0.695998 +0.755846 +0.761555 +1.09839 +1.32175 +0.3281 +0.272174 +0.301852 +0.347011 +0.359253 +0.573898 +0.573569 +0.154932 +0.329482 +0.315104 +0.144827 +0.223074 +0.118261 +0.496447 +0.499072 +0.312409 +0.0513948 +0.264955 +0.31759 +0.151413 +0.0926338 +0.0650902 +0.166325 +0.238831 +0.0268235 +0.0269092 +0.0349006 +0.043314 +0.0135705 +0.0363067 +0.0207646 +0.0672809 +0.224202 +0.450867 +0.970648 +1.06552 +0.865336 +0.747244 +0.706772 +0.815929 +0.814764 +0.578404 +0.316513 +0.584274 +0.862195 From ad3e2d6ce0aaf602736d0fd60457f0428a3d74cb Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Wed, 29 Oct 2025 23:41:45 +0200 Subject: [PATCH 8/8] Update test_pymc.py --- tests/test_pymc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 7f22b13..b4a2bd6 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -485,7 +485,7 @@ def test_deterministic_sampling_jax(): @pytest.mark.pymc -@pytest.mark.array_compare(atol=1e-4, rtol=1e-4) +@pytest.mark.array_compare(atol=1e-6, rtol=1e-6) def test_deterministic_sampling_mlx(): if not MLX_AVAILABLE: pytest.skip("MLX not installed")