diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 6eda28059..6c816c594 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -268,6 +268,26 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap return dag + def _check_projected_mem(self, dag) -> None: + op_name = None + max_projected_mem_op = None + for n, d in dag.nodes(data=True): + if "primitive_op" in d: + op = d["primitive_op"] + if ( + max_projected_mem_op is None + or op.projected_mem > max_projected_mem_op.projected_mem + ): + op_name = n + max_projected_mem_op = op + if max_projected_mem_op is not None: + op = max_projected_mem_op + if op.projected_mem > op.allowed_mem: + raise ValueError( + f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), " + f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}" + ) + @lru_cache # noqa: B019 def _finalize( self, @@ -281,6 +301,7 @@ def _finalize( if callable(compile_function): dag = self._compile_blockwise(dag, compile_function) dag = self._create_lazy_zarr_arrays(dag) + self._check_projected_mem(dag) return FinalizedPlan(nx.freeze(dag), self.array_names, optimize_graph) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 905e32e03..ed2f2dcfe 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -413,11 +413,6 @@ def general_blockwise( buffer_copies=buffer_copies, ) - if projected_mem > allowed_mem: - raise ValueError( - f"Projected blockwise memory ({projected_mem}) exceeds allowed_mem ({allowed_mem}), including reserved_mem ({reserved_mem})" - ) - # this must be an iterator of lists, not of tuples, otherwise lithops breaks if output_blocks is None: output_blocks = map( diff --git a/cubed/tests/primitive/test_blockwise.py b/cubed/tests/primitive/test_blockwise.py index b4e9a8837..337295c61 100644 --- a/cubed/tests/primitive/test_blockwise.py +++ b/cubed/tests/primitive/test_blockwise.py @@ -149,25 +149,23 @@ def test_blockwise_allowed_mem_exceeded(tmp_path, reserved_mem): allowed_mem = 100 target_store = tmp_path / "target.zarr" - with pytest.raises( - ValueError, - match=r"Projected blockwise memory \(\d+\) exceeds allowed_mem \(100\), including reserved_mem \(\d+\)", - ): - blockwise( - nxp.linalg.outer, - "ij", - source1, - "i", - source2, - "j", - allowed_mem=allowed_mem, - reserved_mem=reserved_mem, - target_store=target_store, - target_name="target", - shape=(3, 3), - dtype=np.int64, - chunks=(2, 2), - ) + op = blockwise( + nxp.linalg.outer, + "ij", + source1, + "i", + source2, + "j", + allowed_mem=allowed_mem, + reserved_mem=reserved_mem, + target_store=target_store, + target_name="target", + shape=(3, 3), + dtype=np.int64, + chunks=(2, 2), + ) + + assert op.projected_mem > op.allowed_mem def test_general_blockwise(tmp_path, executor): diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 2ae72ab8f..89ab0a6d3 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -491,11 +491,12 @@ def test_default_spec(executor): def test_default_spec_allowed_mem_exceeded(): # default spec fails for large computations a = xp.ones((20000, 10000), chunks=(10000, 10000)) + b = xp.negative(a) with pytest.raises( ValueError, - match=r"Projected blockwise memory \(\d+\) exceeds allowed_mem \(\d+\), including reserved_mem \(\d+\)", + match=r"Projected blockwise memory \(.+\) exceeds allowed_mem \(.+\), including reserved_mem \(.+\) for op-\d+", ): - xp.negative(a) + b.plan() def test_default_spec_config_override():