Skip to content

Commit 277237c

Browse files
authored
attention: use flag based OOM fallback (#11038)
Exception ref all local variables for the lifetime of exception context. Just set a flag and then if to dump the exception before falling back.
1 parent daaceac commit 277237c

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

comfy/ldm/modules/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
517517

518518
@wrap_attn
519519
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
520+
exception_fallback = False
520521
if skip_reshape:
521522
b, _, _, dim_head = q.shape
522523
tensor_layout = "HND"
@@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
541542
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
542543
except Exception as e:
543544
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
545+
exception_fallback = True
546+
if exception_fallback:
544547
if tensor_layout == "NHD":
545548
q, k, v = map(
546549
lambda t: t.transpose(1, 2),

comfy/ldm/modules/diffusionmodules/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def pytorch_attention(q, k, v):
279279
orig_shape = q.shape
280280
B = orig_shape[0]
281281
C = orig_shape[1]
282+
oom_fallback = False
282283
q, k, v = map(
283284
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
284285
(q, k, v),
@@ -289,6 +290,8 @@ def pytorch_attention(q, k, v):
289290
out = out.transpose(2, 3).reshape(orig_shape)
290291
except model_management.OOM_EXCEPTION:
291292
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
293+
oom_fallback = True
294+
if oom_fallback:
292295
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
293296
return out
294297

0 commit comments

Comments
 (0)