Skip to content

Commit 5abe4ff

Browse files
committed
updated dim handling in model_to_laplace_approx to not force dims on variables that did not have them originally
1 parent 1e41758 commit 5abe4ff

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,15 @@ def model_to_laplace_approx(
224224
elif name in model.named_vars_to_dims:
225225
dims = (*batch_dims, *model.named_vars_to_dims[name])
226226
else:
227-
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
228227
initval = initial_point.get(name, None)
229-
dim_shapes = initval.shape if initval is not None else batched_rv.shape.eval()[2:]
230-
laplace_model.add_coords(
231-
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
232-
)
228+
dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
229+
if dim_shapes[0] is not None:
230+
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
231+
laplace_model.add_coords(
232+
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
233+
)
234+
else:
235+
dims = None
233236

234237
pm.Deterministic(name, batched_rv, dims=dims)
235238

0 commit comments

Comments
 (0)