Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,26 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap

return dag

def _check_projected_mem(self, dag) -> None:
op_name = None
max_projected_mem_op = None
for n, d in dag.nodes(data=True):
if "primitive_op" in d:
op = d["primitive_op"]
if (
max_projected_mem_op is None
or op.projected_mem > max_projected_mem_op.projected_mem
):
op_name = n
max_projected_mem_op = op
if max_projected_mem_op is not None:
op = max_projected_mem_op
if op.projected_mem > op.allowed_mem:
raise ValueError(
f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), "
f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}"
)

@lru_cache # noqa: B019
def _finalize(
self,
Expand All @@ -281,6 +301,7 @@ def _finalize(
if callable(compile_function):
dag = self._compile_blockwise(dag, compile_function)
dag = self._create_lazy_zarr_arrays(dag)
self._check_projected_mem(dag)
return FinalizedPlan(nx.freeze(dag), self.array_names, optimize_graph)


Expand Down
5 changes: 0 additions & 5 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,6 @@ def general_blockwise(
buffer_copies=buffer_copies,
)

if projected_mem > allowed_mem:
raise ValueError(
f"Projected blockwise memory ({projected_mem}) exceeds allowed_mem ({allowed_mem}), including reserved_mem ({reserved_mem})"
)

# this must be an iterator of lists, not of tuples, otherwise lithops breaks
if output_blocks is None:
output_blocks = map(
Expand Down
36 changes: 17 additions & 19 deletions cubed/tests/primitive/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,25 +149,23 @@ def test_blockwise_allowed_mem_exceeded(tmp_path, reserved_mem):
allowed_mem = 100
target_store = tmp_path / "target.zarr"

with pytest.raises(
ValueError,
match=r"Projected blockwise memory \(\d+\) exceeds allowed_mem \(100\), including reserved_mem \(\d+\)",
):
blockwise(
nxp.linalg.outer,
"ij",
source1,
"i",
source2,
"j",
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
target_store=target_store,
target_name="target",
shape=(3, 3),
dtype=np.int64,
chunks=(2, 2),
)
op = blockwise(
nxp.linalg.outer,
"ij",
source1,
"i",
source2,
"j",
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
target_store=target_store,
target_name="target",
shape=(3, 3),
dtype=np.int64,
chunks=(2, 2),
)

assert op.projected_mem > op.allowed_mem


def test_general_blockwise(tmp_path, executor):
Expand Down
5 changes: 3 additions & 2 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,11 +491,12 @@ def test_default_spec(executor):
def test_default_spec_allowed_mem_exceeded():
# default spec fails for large computations
a = xp.ones((20000, 10000), chunks=(10000, 10000))
b = xp.negative(a)
with pytest.raises(
ValueError,
match=r"Projected blockwise memory \(\d+\) exceeds allowed_mem \(\d+\), including reserved_mem \(\d+\)",
match=r"Projected blockwise memory \(.+\) exceeds allowed_mem \(.+\), including reserved_mem \(.+\) for op-\d+",
):
xp.negative(a)
b.plan()


def test_default_spec_config_override():
Expand Down