Skip to content

Commit fa009ae

Browse files
authored
Check memory when building plan (#836)
1 parent ed4d82e commit fa009ae

File tree

4 files changed

+41
-26
lines changed

4 files changed

+41
-26
lines changed

cubed/core/plan.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,26 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
271271

272272
return dag
273273

274+
def _check_projected_mem(self, dag) -> None:
275+
op_name = None
276+
max_projected_mem_op = None
277+
for n, d in dag.nodes(data=True):
278+
if "primitive_op" in d:
279+
op = d["primitive_op"]
280+
if (
281+
max_projected_mem_op is None
282+
or op.projected_mem > max_projected_mem_op.projected_mem
283+
):
284+
op_name = n
285+
max_projected_mem_op = op
286+
if max_projected_mem_op is not None:
287+
op = max_projected_mem_op
288+
if op.projected_mem > op.allowed_mem:
289+
raise ValueError(
290+
f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), "
291+
f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}"
292+
)
293+
274294
@lru_cache # noqa: B019
275295
def _finalize(
276296
self,
@@ -284,6 +304,7 @@ def _finalize(
284304
if callable(compile_function):
285305
dag = self._compile_blockwise(dag, compile_function)
286306
dag = self._create_lazy_zarr_arrays(dag)
307+
self._check_projected_mem(dag)
287308
return FinalizedPlan(nx.freeze(dag), self.array_names, optimize_graph)
288309

289310

cubed/primitive/blockwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,6 @@ def general_blockwise(
413413
buffer_copies=buffer_copies,
414414
)
415415

416-
if projected_mem > allowed_mem:
417-
raise ValueError(
418-
f"Projected blockwise memory ({projected_mem}) exceeds allowed_mem ({allowed_mem}), including reserved_mem ({reserved_mem})"
419-
)
420-
421416
# this must be an iterator of lists, not of tuples, otherwise lithops breaks
422417
if output_blocks is None:
423418
output_blocks = map(

cubed/tests/primitive/test_blockwise.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,25 +149,23 @@ def test_blockwise_allowed_mem_exceeded(tmp_path, reserved_mem):
149149
allowed_mem = 100
150150
target_store = tmp_path / "target.zarr"
151151

152-
with pytest.raises(
153-
ValueError,
154-
match=r"Projected blockwise memory \(\d+\) exceeds allowed_mem \(100\), including reserved_mem \(\d+\)",
155-
):
156-
blockwise(
157-
nxp.linalg.outer,
158-
"ij",
159-
source1,
160-
"i",
161-
source2,
162-
"j",
163-
allowed_mem=allowed_mem,
164-
reserved_mem=reserved_mem,
165-
target_store=target_store,
166-
target_name="target",
167-
shape=(3, 3),
168-
dtype=np.int64,
169-
chunks=(2, 2),
170-
)
152+
op = blockwise(
153+
nxp.linalg.outer,
154+
"ij",
155+
source1,
156+
"i",
157+
source2,
158+
"j",
159+
allowed_mem=allowed_mem,
160+
reserved_mem=reserved_mem,
161+
target_store=target_store,
162+
target_name="target",
163+
shape=(3, 3),
164+
dtype=np.int64,
165+
chunks=(2, 2),
166+
)
167+
168+
assert op.projected_mem > op.allowed_mem
171169

172170

173171
def test_general_blockwise(tmp_path, executor):

cubed/tests/test_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,11 +491,12 @@ def test_default_spec(executor):
491491
def test_default_spec_allowed_mem_exceeded():
492492
# default spec fails for large computations
493493
a = xp.ones((20000, 10000), chunks=(10000, 10000))
494+
b = xp.negative(a)
494495
with pytest.raises(
495496
ValueError,
496-
match=r"Projected blockwise memory \(\d+\) exceeds allowed_mem \(\d+\), including reserved_mem \(\d+\)",
497+
match=r"Projected blockwise memory \(.+\) exceeds allowed_mem \(.+\), including reserved_mem \(.+\) for op-\d+",
497498
):
498-
xp.negative(a)
499+
b.plan()
499500

500501

501502
def test_default_spec_config_override():

0 commit comments

Comments
 (0)