Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ jobs:
pytest -m "stan and not flow" --arraydiff
uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall
uv pip install jax
# Install MLX only on Apple Silicon (aarch64)
if [ "${{ matrix.platform.target }}" = "aarch64" ]; then
uv pip install mlx
fi
pytest -m "pymc and not flow" --arraydiff
uv pip install 'nutpie[all]' --find-links dist --force-reinstall
pytest -m flow --arraydiff
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ Repository = "https://github.com/pymc-devs/nutpie"
stan = ["bridgestan >= 2.7.0", "stanio >= 0.5.1"]
pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"]
pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"]
pymc-mlx = ["pymc >= 5.20.1", "mlx >= 0.29.0"]
nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"]
dev = [
"bridgestan >= 2.7.0",
"stanio >= 0.5.1",
"pymc >= 5.20.1",
"numba >= 0.60.0",
"jax >= 0.4.27",
"mlx >= 0.29.0",
"flowjax >= 17.0.2",
"pytest",
"pytest-timeout",
Expand All @@ -52,6 +54,7 @@ all = [
"pymc >= 5.20.1",
"numba >= 0.60.0",
"jax >= 0.4.27",
"mlx >= 0.29.0",
"flowjax >= 17.1.0",
"equinox >= 0.11.12",
]
Expand Down
142 changes: 131 additions & 11 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,121 @@ def expand(_x, **shared):
)


def _compile_pymc_model_mlx(
model,
*,
gradient_backend=None,
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
var_names: Iterable[str] | None = None,
**kwargs,
):
if find_spec("mlx") is None:
raise ImportError(
"MLX is not installed in the current environment. "
"Please install it with something like "
"'pip install mlx' "
"and restart your kernel in case you are in an interactive session."
)
import mlx.core as mx

if gradient_backend is None:
gradient_backend = "pytensor"
elif gradient_backend not in ["mlx", "pytensor"]:
raise ValueError(f"Unknown gradient backend: {gradient_backend}")

(
n_dim,
_,
logp_fn_pt,
expand_fn_pt,
initial_point_fn,
shape_info,
) = _make_functions(
model,
mode="MLX",
compute_grad=gradient_backend == "pytensor",
join_expanded=False,
pymc_initial_point_fn=pymc_initial_point_fn,
var_names=var_names,
)

logp_fn = logp_fn_pt.vm.jit_fn
expand_fn = expand_fn_pt.vm.jit_fn

logp_shared_names = [var.name for var in logp_fn_pt.get_shared()]
expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]

if gradient_backend == "mlx":
orig_logp_fn = logp_fn

def logp_fn_mlx_grad(x, *shared):
return mx.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x)

# JIT compile for performance, similar to jax.jit() in JAX backend
logp_fn_mlx_grad = mx.compile(logp_fn_mlx_grad)

logp_fn = logp_fn_mlx_grad
else:
orig_logp_fn = None

shared_data = {}
shared_vars = {}
seen = set()
for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]:
if val.name in shared_data and val not in seen:
raise ValueError(f"Shared variables must have unique names: {val.name}")
shared_data[val.name] = mx.array(val.get_value())
shared_vars[val.name] = val
seen.add(val)

def make_logp_func():
def logp(_x, **shared):
_x_mlx = mx.array(_x)
logp, grad = logp_fn(_x_mlx, *[shared[name] for name in logp_shared_names])
return float(logp), np.asarray(grad, dtype="float64", order="C")

return logp

names, slices, shapes = shape_info
# TODO do not cast to float64
dtypes = [np.dtype("float64")] * len(names)

def make_expand_func(seed1, seed2, chain):
# TODO handle seeds
def expand(_x, **shared):
values = expand_fn(_x, *[shared[name] for name in expand_shared_names])
return {
name: np.asarray(val, order="C", dtype=dtype).reshape(shape)
for name, val, dtype, shape in zip(
names, values, dtypes, shapes, strict=True
)
}

return expand

dims, coords = _prepare_dims_and_coords(model, shape_info)

return from_pyfunc(
ndim=n_dim,
make_logp_fn=make_logp_func,
make_expand_fn=make_expand_func,
make_initial_point_fn=initial_point_fn,
expanded_dtypes=dtypes,
expanded_shapes=shapes,
expanded_names=names,
shared_data=shared_data,
dims=dims,
coords=coords,
raw_logp_fn=orig_logp_fn,
force_single_core=(gradient_backend == "mlx"),
)


def compile_pymc_model(
model: "pm.Model",
*,
backend: Literal["numba", "jax"] = "numba",
gradient_backend: Literal["pytensor", "jax"] = "pytensor",
backend: Literal["numba", "jax", "mlx"] = "numba",
gradient_backend: Literal["pytensor", "jax", "mlx"] = "pytensor",
initial_points: dict[Union["Variable", str], np.ndarray | float | int]
| None = None,
jitter_rvs: set["TensorVariable"] | None = None,
Expand All @@ -491,11 +601,11 @@ def compile_pymc_model(
----------
model : pymc.Model
The model to compile.
backend : ["jax", "numba"]
backend : ["jax", "numba", "mlx"]
The pytensor backend that is used to compile the logp function.
gradient_backend: ["pytensor", "jax"]
gradient_backend: ["pytensor", "jax", "mlx"]
Which library is used to compute the gradients. This can only be changed
to "jax" if the jax backend is used.
to "jax" if the jax backend is used, or "mlx" if the mlx backend is used.
jitter_rvs : set
The set (or list or tuple) of random variables for which a U(-1, +1)
jitter should be added to the initial value. Only available for
Expand Down Expand Up @@ -534,7 +644,7 @@ def compile_pymc_model(
from pymc.initial_point import make_initial_point_fn

if freeze_model is None:
freeze_model = backend == "jax"
freeze_model = backend in ["jax", "mlx"]

if freeze_model:
model = freeze_dims_and_data(model)
Expand All @@ -553,8 +663,10 @@ def compile_pymc_model(
initial_point_fn = _wrap_with_lock(initial_point_fn)

if backend.lower() == "numba":
if gradient_backend == "jax":
raise ValueError("Gradient backend cannot be jax when using numba backend")
if gradient_backend in ["jax", "mlx"]:
raise ValueError(
f"Gradient backend cannot be {gradient_backend} when using numba backend"
)
return _compile_pymc_model_numba(
model=model,
pymc_initial_point_fn=initial_point_fn,
Expand All @@ -569,8 +681,16 @@ def compile_pymc_model(
var_names=var_names,
**kwargs,
)
elif backend.lower() == "mlx":
return _compile_pymc_model_mlx(
model=model,
gradient_backend=gradient_backend,
pymc_initial_point_fn=initial_point_fn,
var_names=var_names,
**kwargs,
)
else:
raise ValueError(f"Backend must be one of numba and jax. Got {backend}")
raise ValueError(f"Backend must be one of numba, jax, and mlx. Got {backend}")


def _wrap_with_lock(func: Callable) -> Callable:
Expand Down Expand Up @@ -616,7 +736,7 @@ def _compute_shapes(model) -> dict[str, tuple[int, ...]]:
def _make_functions(
model: "pm.Model",
*,
mode: Literal["JAX", "NUMBA"],
mode: Literal["JAX", "NUMBA", "MLX"],
compute_grad: bool,
join_expanded: bool,
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
Expand All @@ -637,7 +757,7 @@ def _make_functions(
model: pymc.Model
The model to compile
mode: str
Pytensor compile mode. One of "NUMBA" or "JAX"
Pytensor compile mode. One of "NUMBA", "JAX", or "MLX"
compute_grad: bool
Whether to compute gradients using pytensor. Must be True if mode is
"NUMBA", otherwise False implies Jax will be used to compute gradients
Expand Down
17 changes: 16 additions & 1 deletion python/nutpie/compiled_pyfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class PyFuncModel(CompiledModel):
_coords: dict[str, Any]
_raw_logp_fn: Callable | None
_transform_adapt_args: dict | None = None
_force_single_core: bool = False

@property
def shapes(self) -> dict[str, tuple[int, ...]]:
Expand All @@ -42,13 +43,25 @@ def with_data(self, **updates):
raise ValueError(f"Unknown data variable: {name}")

updated = self._shared_data.copy()
updated.update(**updates)

# Convert to MLX arrays if using MLX backend (indicated by force_single_core)
if self._force_single_core:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use that argument to detect mlx.
How about we add an attribute _convert_data_item or so to the dataclass, that contains a function that transforms data arrays? We could then also use that for jax.

import mlx.core as mx

for name, value in updates.items():
updated[name] = mx.array(value)
else:
updated.update(**updates)

return dataclasses.replace(self, _shared_data=updated)

def with_transform_adapt(self, **kwargs):
return dataclasses.replace(self, _transform_adapt_args=kwargs)

def _make_sampler(self, settings, init_mean, cores, progress_type, store):
# Force single-core execution if required (e.g., for MLX backend)
if self._force_single_core:
cores = 1
model = self._make_model(init_mean)
return _lib.PySampler.from_pyfunc(
settings,
Expand Down Expand Up @@ -108,6 +121,7 @@ def from_pyfunc(
make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None,
make_transform_adapter=None,
raw_logp_fn=None,
force_single_core: bool = False,
):
if coords is None:
coords = {}
Expand Down Expand Up @@ -139,4 +153,5 @@ def from_pyfunc(
_variables=variables,
_shared_data=shared_data,
_raw_logp_fn=raw_logp_fn,
_force_single_core=force_single_core,
)
Loading
Loading