diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index db96669..7f841b3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -236,6 +236,10 @@ jobs: pytest -m "stan and not flow" --arraydiff uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall uv pip install jax + # Install MLX only on Apple Silicon (aarch64) + if [ "${{ matrix.platform.target }}" = "aarch64" ]; then + uv pip install mlx + fi pytest -m "pymc and not flow" --arraydiff uv pip install 'nutpie[all]' --find-links dist --force-reinstall pytest -m flow --arraydiff diff --git a/pyproject.toml b/pyproject.toml index 84e5e33..adad2a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ Repository = "https://github.com/pymc-devs/nutpie" stan = ["bridgestan >= 2.7.0", "stanio >= 0.5.1"] pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] +pymc-mlx = ["pymc >= 5.20.1", "mlx >= 0.29.0"] nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] dev = [ "bridgestan >= 2.7.0", @@ -41,6 +42,7 @@ dev = [ "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", + "mlx >= 0.29.0", "flowjax >= 17.0.2", "pytest", "pytest-timeout", @@ -52,6 +54,7 @@ all = [ "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", + "mlx >= 0.29.0", "flowjax >= 17.1.0", "equinox >= 0.11.12", ] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 08bbc94..4492976 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -470,11 +470,121 @@ def expand(_x, **shared): ) +def _compile_pymc_model_mlx( + model, + *, + gradient_backend=None, + pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], + var_names: Iterable[str] | None = None, + **kwargs, +): + if find_spec("mlx") is None: + raise ImportError( + "MLX is not installed in the current environment. " + "Please install it with something like " + "'pip install mlx' " + "and restart your kernel in case you are in an interactive session." + ) + import mlx.core as mx + + if gradient_backend is None: + gradient_backend = "pytensor" + elif gradient_backend not in ["mlx", "pytensor"]: + raise ValueError(f"Unknown gradient backend: {gradient_backend}") + + ( + n_dim, + _, + logp_fn_pt, + expand_fn_pt, + initial_point_fn, + shape_info, + ) = _make_functions( + model, + mode="MLX", + compute_grad=gradient_backend == "pytensor", + join_expanded=False, + pymc_initial_point_fn=pymc_initial_point_fn, + var_names=var_names, + ) + + logp_fn = logp_fn_pt.vm.jit_fn + expand_fn = expand_fn_pt.vm.jit_fn + + logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] + expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] + + if gradient_backend == "mlx": + orig_logp_fn = logp_fn + + def logp_fn_mlx_grad(x, *shared): + return mx.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x) + + # JIT compile for performance, similar to jax.jit() in JAX backend + logp_fn_mlx_grad = mx.compile(logp_fn_mlx_grad) + + logp_fn = logp_fn_mlx_grad + else: + orig_logp_fn = None + + shared_data = {} + shared_vars = {} + seen = set() + for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: + if val.name in shared_data and val not in seen: + raise ValueError(f"Shared variables must have unique names: {val.name}") + shared_data[val.name] = mx.array(val.get_value()) + shared_vars[val.name] = val + seen.add(val) + + def make_logp_func(): + def logp(_x, **shared): + _x_mlx = mx.array(_x) + logp, grad = logp_fn(_x_mlx, *[shared[name] for name in logp_shared_names]) + return float(logp), np.asarray(grad, dtype="float64", order="C") + + return logp + + names, slices, shapes = shape_info + # TODO do not cast to float64 + dtypes = [np.dtype("float64")] * len(names) + + def make_expand_func(seed1, seed2, chain): + # TODO handle seeds + def expand(_x, **shared): + values = expand_fn(_x, *[shared[name] for name in expand_shared_names]) + return { + name: np.asarray(val, order="C", dtype=dtype).reshape(shape) + for name, val, dtype, shape in zip( + names, values, dtypes, shapes, strict=True + ) + } + + return expand + + dims, coords = _prepare_dims_and_coords(model, shape_info) + + return from_pyfunc( + ndim=n_dim, + make_logp_fn=make_logp_func, + make_expand_fn=make_expand_func, + make_initial_point_fn=initial_point_fn, + expanded_dtypes=dtypes, + expanded_shapes=shapes, + expanded_names=names, + shared_data=shared_data, + dims=dims, + coords=coords, + raw_logp_fn=orig_logp_fn, + force_single_core=(gradient_backend == "mlx"), + ) + + def compile_pymc_model( model: "pm.Model", *, - backend: Literal["numba", "jax"] = "numba", - gradient_backend: Literal["pytensor", "jax"] = "pytensor", + backend: Literal["numba", "jax", "mlx"] = "numba", + gradient_backend: Literal["pytensor", "jax", "mlx"] = "pytensor", initial_points: dict[Union["Variable", str], np.ndarray | float | int] | None = None, jitter_rvs: set["TensorVariable"] | None = None, @@ -491,11 +601,11 @@ def compile_pymc_model( ---------- model : pymc.Model The model to compile. - backend : ["jax", "numba"] + backend : ["jax", "numba", "mlx"] The pytensor backend that is used to compile the logp function. - gradient_backend: ["pytensor", "jax"] + gradient_backend: ["pytensor", "jax", "mlx"] Which library is used to compute the gradients. This can only be changed - to "jax" if the jax backend is used. + to "jax" if the jax backend is used, or "mlx" if the mlx backend is used. jitter_rvs : set The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for @@ -534,7 +644,7 @@ def compile_pymc_model( from pymc.initial_point import make_initial_point_fn if freeze_model is None: - freeze_model = backend == "jax" + freeze_model = backend in ["jax", "mlx"] if freeze_model: model = freeze_dims_and_data(model) @@ -553,8 +663,10 @@ def compile_pymc_model( initial_point_fn = _wrap_with_lock(initial_point_fn) if backend.lower() == "numba": - if gradient_backend == "jax": - raise ValueError("Gradient backend cannot be jax when using numba backend") + if gradient_backend in ["jax", "mlx"]: + raise ValueError( + f"Gradient backend cannot be {gradient_backend} when using numba backend" + ) return _compile_pymc_model_numba( model=model, pymc_initial_point_fn=initial_point_fn, @@ -569,8 +681,16 @@ def compile_pymc_model( var_names=var_names, **kwargs, ) + elif backend.lower() == "mlx": + return _compile_pymc_model_mlx( + model=model, + gradient_backend=gradient_backend, + pymc_initial_point_fn=initial_point_fn, + var_names=var_names, + **kwargs, + ) else: - raise ValueError(f"Backend must be one of numba and jax. Got {backend}") + raise ValueError(f"Backend must be one of numba, jax, and mlx. Got {backend}") def _wrap_with_lock(func: Callable) -> Callable: @@ -616,7 +736,7 @@ def _compute_shapes(model) -> dict[str, tuple[int, ...]]: def _make_functions( model: "pm.Model", *, - mode: Literal["JAX", "NUMBA"], + mode: Literal["JAX", "NUMBA", "MLX"], compute_grad: bool, join_expanded: bool, pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], @@ -637,7 +757,7 @@ def _make_functions( model: pymc.Model The model to compile mode: str - Pytensor compile mode. One of "NUMBA" or "JAX" + Pytensor compile mode. One of "NUMBA", "JAX", or "MLX" compute_grad: bool Whether to compute gradients using pytensor. Must be True if mode is "NUMBA", otherwise False implies Jax will be used to compute gradients diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index db58c28..1219299 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -23,6 +23,7 @@ class PyFuncModel(CompiledModel): _coords: dict[str, Any] _raw_logp_fn: Callable | None _transform_adapt_args: dict | None = None + _force_single_core: bool = False @property def shapes(self) -> dict[str, tuple[int, ...]]: @@ -42,13 +43,25 @@ def with_data(self, **updates): raise ValueError(f"Unknown data variable: {name}") updated = self._shared_data.copy() - updated.update(**updates) + + # Convert to MLX arrays if using MLX backend (indicated by force_single_core) + if self._force_single_core: + import mlx.core as mx + + for name, value in updates.items(): + updated[name] = mx.array(value) + else: + updated.update(**updates) + return dataclasses.replace(self, _shared_data=updated) def with_transform_adapt(self, **kwargs): return dataclasses.replace(self, _transform_adapt_args=kwargs) def _make_sampler(self, settings, init_mean, cores, progress_type, store): + # Force single-core execution if required (e.g., for MLX backend) + if self._force_single_core: + cores = 1 model = self._make_model(init_mean) return _lib.PySampler.from_pyfunc( settings, @@ -108,6 +121,7 @@ def from_pyfunc( make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None, make_transform_adapter=None, raw_logp_fn=None, + force_single_core: bool = False, ): if coords is None: coords = {} @@ -139,4 +153,5 @@ def from_pyfunc( _variables=variables, _shared_data=shared_data, _raw_logp_fn=raw_logp_fn, + _force_single_core=force_single_core, ) diff --git a/tests/reference/test_deterministic_sampling_mlx.txt b/tests/reference/test_deterministic_sampling_mlx.txt new file mode 100644 index 0000000..7fbb8e7 --- /dev/null +++ b/tests/reference/test_deterministic_sampling_mlx.txt @@ -0,0 +1,200 @@ +2.38754 +0.138255 +0.0696831 +0.0204736 +0.0308912 +0.54282 +0.393642 +0.703939 +0.482711 +2.22404 +0.510405 +1.74111 +1.74111 +0.364023 +2.5632 +0.684026 +0.414764 +0.414764 +0.26368 +0.283684 +0.203853 +0.542914 +0.646859 +0.861701 +0.861701 +0.861701 +0.728978 +1.0528 +1.0528 +1.0528 +1.0528 +1.0528 +1.19056 +0.149186 +0.199848 +0.795704 +0.795704 +0.734627 +0.84465 +0.915367 +0.704601 +0.704601 +0.704601 +0.414619 +0.227516 +0.22453 +0.889432 +0.889432 +0.735724 +0.719196 +0.688977 +0.0724534 +0.0457166 +0.400116 +0.974773 +0.974773 +0.870367 +0.870367 +0.870367 +0.870367 +0.870367 +0.870367 +0.840601 +0.787697 +0.787697 +0.411171 +0.506744 +0.335273 +0.309355 +0.495757 +0.632767 +0.632767 +0.632767 +0.516386 +0.550656 +1.42213 +2.26821 +2.26821 +0.618058 +0.736866 +0.736866 +0.0860633 +0.0539637 +0.0327673 +0.0810014 +0.0956654 +0.134825 +0.224479 +0.309928 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.03209 +1.16806 +1.16806 +1.31512 +1.31512 +1.03726 +0.25523 +2.44138 +1.07632 +0.655629 +0.601123 +0.542795 +0.319316 +1.10565 +0.597321 +0.33601 +0.359843 +0.761331 +0.425976 +0.825186 +0.340091 +0.0631592 +0.0734237 +0.0862233 +0.879708 +0.3444 +0.634993 +0.634993 +0.13069 +0.135324 +0.114996 +0.346445 +0.181392 +0.155654 +0.178757 +1.06726 +0.640434 +2.9661 +0.619286 +0.686098 +1.03444 +1.08534 +1.44254 +0.932663 +0.547982 +1.63447 +1.23931 +2.36775 +1.46697 +1.28671 +1.41629 +0.918363 +1.11078 +0.695998 +0.755846 +0.761555 +1.09839 +1.32175 +0.3281 +0.272174 +0.301852 +0.347011 +0.359253 +0.573898 +0.573569 +0.154932 +0.329482 +0.315104 +0.144827 +0.223074 +0.118261 +0.496447 +0.499072 +0.312409 +0.0513948 +0.264955 +0.31759 +0.151413 +0.0926338 +0.0650902 +0.166325 +0.238831 +0.0268235 +0.0269092 +0.0349006 +0.043314 +0.0135705 +0.0363067 +0.0207646 +0.0672809 +0.224202 +0.450867 +0.970648 +1.06552 +0.865336 +0.747244 +0.706772 +0.815929 +0.814764 +0.578404 +0.316513 +0.584274 +0.862195 diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 677c63c..b4a2bd6 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -12,9 +12,28 @@ import nutpie import nutpie.compile_pymc +# Check if MLX is available (macOS only, optional dependency) +MLX_AVAILABLE = find_spec("mlx") is not None + +# Build backend list dynamically based on availability +backend_params = [ + ("numba", None), + ("jax", "pytensor"), + ("jax", "jax"), +] + +# Only add MLX backends if MLX is available +if MLX_AVAILABLE: + backend_params.extend( + [ + ("mlx", "pytensor"), + ("mlx", "mlx"), + ] + ) + parameterize_backends = pytest.mark.parametrize( "backend, gradient_backend", - [("numba", None), ("jax", "pytensor"), ("jax", "jax")], + backend_params, ) @@ -465,6 +484,20 @@ def test_deterministic_sampling_jax(): return trace.posterior.a.values.ravel() +@pytest.mark.pymc +@pytest.mark.array_compare(atol=1e-6, rtol=1e-6) +def test_deterministic_sampling_mlx(): + if not MLX_AVAILABLE: + pytest.skip("MLX not installed") + + with pm.Model() as model: + pm.HalfNormal("a") + + compiled = nutpie.compile_pymc_model(model, backend="mlx", gradient_backend="mlx") + trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100) + return trace.posterior.a.values.ravel() + + @pytest.mark.pymc def test_zarr_store(tmp_path): with pm.Model() as model: