-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add Unified Sequence Parallel attention #12693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
4b0c647
81494b8
83fc606
fcb06e5
4b71777
e0ed41e
3a407d8
9ebcff5
61ddcdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1025,6 +1025,68 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: | |
| x = _wait_tensor(x) | ||
| return x | ||
|
|
||
| def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we also add some basic docstrings to this function because it will help readability. Some commentaries on what the function is doing will also be helpful. |
||
| group_world_size = torch.distributed.get_world_size(group) | ||
|
|
||
| if scatter_idx == 2 and gather_idx == 1: | ||
| B, S_LOCAL, H, D = x.shape | ||
| S = S_LOCAL * group_world_size | ||
| H_LOCAL = H // group_world_size | ||
|
|
||
|
Comment on lines
+1047
to
+1050
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit): prefer using fully qualified variable names. |
||
| # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D | ||
| x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous() | ||
|
|
||
|
|
||
| if group_world_size >1: | ||
| #maybe here need to use the _all_to_all_single helper to avoid contiguity issues | ||
| out = _all_to_all_single(x_temp, group=group) | ||
| #out = _wait_tensor(out) | ||
| else: | ||
| out = x_temp | ||
| # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D | ||
| out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() | ||
| out = out.reshape(B, S, H_LOCAL, D) | ||
| return out | ||
| elif scatter_idx == 1 and gather_idx == 2: | ||
| B, S, H_LOCAL, D = x.shape | ||
| H = H_LOCAL * group_world_size | ||
| S_LOCAL = S // group_world_size | ||
|
|
||
| #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D | ||
| x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D) | ||
|
|
||
| if group_world_size >1: | ||
| #maybe here need to use the _all_to_all_single helper to avoid contiguity issues | ||
| output = _all_to_all_single(x_temp, group) | ||
| #output = _wait_tensor(output) | ||
| else: | ||
| output = x_temp | ||
| output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() | ||
| output = output.reshape(B, S_LOCAL, H, D) | ||
| return output | ||
| else: | ||
| raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") | ||
|
|
||
|
|
||
| class SeqAllToAllDim(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, group, input, scatter_id=2, gather_id=1): | ||
| ctx.group = group | ||
| ctx.scatter_id = scatter_id | ||
| ctx.gather_id = gather_id | ||
| return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_outputs): | ||
| grad_input = SeqAllToAllDim.apply( | ||
| ctx.group, | ||
| grad_outputs, | ||
| ctx.gather_id, # reversed | ||
| ctx.scatter_id, # reversed | ||
| ) | ||
| return (None, grad_input, None, None) | ||
|
|
||
|
|
||
|
|
||
| class TemplatedRingAttention(torch.autograd.Function): | ||
| @staticmethod | ||
|
|
@@ -1147,7 +1209,7 @@ def backward( | |
|
|
||
| grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) | ||
|
|
||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | ||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the change here? |
||
|
|
||
|
|
||
| class TemplatedUlyssesAttention(torch.autograd.Function): | ||
|
|
@@ -1244,6 +1306,64 @@ def backward( | |
|
|
||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | ||
|
|
||
| def TemplatedUnifiedAttention( | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: Optional[torch.Tensor], | ||
| dropout_p: float, | ||
| is_causal: bool, | ||
| scale: Optional[float], | ||
| enable_gqa: bool, | ||
| return_lse: bool, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config: Optional["ParallelConfig"] = None, | ||
| ): | ||
| ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh | ||
| ulysses_group = ulysses_mesh.get_group() | ||
| ring_mesh = _parallel_config.context_parallel_config._ring_mesh | ||
| ring_group = ring_mesh.get_group() | ||
| #hardcoded for now | ||
| scatter_idx = 2 | ||
| gather_idx = 1 | ||
|
Comment on lines
+1343
to
+1344
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make this configurable. |
||
|
|
||
| query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) | ||
| key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) | ||
| value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) | ||
|
Comment on lines
+1346
to
+1348
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think it's better to have |
||
| out = TemplatedRingAttention.apply( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| return_lse, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config, | ||
| ) | ||
| if return_lse: | ||
| context_layer, lse, *_ = out | ||
| else: | ||
| context_layer = out | ||
| # Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D) | ||
| output = SeqAllToAllDim.apply( | ||
| ulysses_group, | ||
| context_layer, | ||
| gather_idx, | ||
| scatter_idx, | ||
| ) | ||
| if return_lse: | ||
| # not sure if this is correct: Assuming (based on forward ops in ringAttention) | ||
| # the lse is of shape (B, S, H_LOCAL) | ||
| lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) | ||
| lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1) | ||
| lse = lse.squeeze(-1) | ||
| return (output, lse) | ||
| return output | ||
|
|
||
| def _templated_context_parallel_attention( | ||
| query: torch.Tensor, | ||
|
|
@@ -1268,7 +1388,22 @@ def _templated_context_parallel_attention( | |
| raise ValueError("GQA is not yet supported for templated attention.") | ||
|
|
||
| # TODO: add support for unified attention with ring/ulysses degree both being > 1 | ||
| if _parallel_config.context_parallel_config.ring_degree > 1: | ||
| if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1: | ||
| return TemplatedUnifiedAttention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| return_lse, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config, | ||
| ) | ||
| elif _parallel_config.context_parallel_config.ring_degree > 1: | ||
| return TemplatedRingAttention.apply( | ||
| query, | ||
| key, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| import math | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.multiprocessing as mp | ||
| from diffusers.models.attention_dispatch import TemplatedUnifiedAttention | ||
| import os | ||
|
|
||
| def run(rank, world_size): | ||
| dist.init_process_group( | ||
| backend="gloo", | ||
| rank=rank, | ||
| world_size=world_size | ||
| ) | ||
|
|
||
| torch.manual_seed(0) | ||
|
|
||
| B, S, H, D = 2, 8, 4, 16 # small toy | ||
| q = torch.randn(B, S, H, D) | ||
| k = torch.randn(B, S, H, D) | ||
| v = torch.randn(B, S, H, D) | ||
|
|
||
| q.requires_grad_(True) | ||
|
|
||
| from diffusers.models._modeling_parallel import ( | ||
| ParallelConfig, | ||
| ContextParallelConfig | ||
| ) | ||
|
|
||
| pc = ParallelConfig( | ||
| context_parallel_config=ContextParallelConfig( | ||
| ring_degree=2, | ||
| ulysses_degree=2, | ||
| ) | ||
| ) | ||
|
|
||
| pc.context_parallel_config.setup( | ||
| rank=rank, | ||
| world_size=world_size, | ||
| device=torch.device("cpu"), | ||
| mesh=dist.device_mesh.init_device_mesh("cpu", | ||
| (2,2), | ||
| mesh_dim_names=["ring", "ulysses"], | ||
| ) | ||
| ) | ||
|
|
||
| def dummy_forward_op( | ||
| ctx, | ||
| q, | ||
| k, | ||
| v, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| return_lse, | ||
| *, | ||
| _save_ctx=True, | ||
| _parallel_config=None, | ||
| ): | ||
| head_scale = math.sqrt(D) | ||
| attn = (q @ k.transpose(-1, -2)) / head_scale | ||
| out = attn @ v | ||
| lse = torch.logsumexp(attn, dim=-1) | ||
|
|
||
| if _save_ctx: | ||
| ctx.save_for_backward(q, k, v) | ||
| ctx._cached_qkv = [] | ||
| ctx._cached_iter = 0 | ||
|
|
||
| if not hasattr(ctx, "_cached_qkv"): | ||
| ctx._cached_qkv = [] | ||
|
|
||
| ctx._cached_qkv.append((q.detach(), k.detach(), v.detach())) | ||
|
|
||
| return (out, lse) if return_lse else out | ||
|
|
||
| def dummy_backward_op(ctx, grad_out, *args, **kwargs): | ||
| if not hasattr(ctx, "_cached_qkv"): | ||
| raise RuntimeError("No cached tensors for backward.") | ||
|
|
||
| if not hasattr(ctx, "_cached_iter"): | ||
| ctx._cached_iter = 0 | ||
|
|
||
| if ctx._cached_iter >= len(ctx._cached_qkv): | ||
| raise RuntimeError("Backward called more times than cached forwards.") | ||
|
|
||
| q, k, v = ctx._cached_qkv[ctx._cached_iter] | ||
| ctx._cached_iter += 1 | ||
|
|
||
| head_scale = math.sqrt(D) | ||
| attn = (q @ k.transpose(-1, -2)) / head_scale | ||
|
|
||
| grad_v = attn.transpose(-1, -2) @ grad_out | ||
| grad_attn = grad_out @ v.transpose(-1, -2) | ||
| grad_q = (grad_attn @ k) / head_scale | ||
| grad_k = (grad_attn.transpose(-1, -2) @ q) / head_scale | ||
|
|
||
| return ( | ||
| grad_q, | ||
| grad_k, | ||
| grad_v, | ||
| ) | ||
|
|
||
|
|
||
| out = TemplatedUnifiedAttention( | ||
| q, k, v, None, | ||
| dropout_p=0.0, | ||
| is_causal=False, | ||
| scale=None, | ||
| enable_gqa=False, | ||
| return_lse=False, | ||
| forward_op=dummy_forward_op, | ||
| backward_op=dummy_backward_op, | ||
| _parallel_config=pc, | ||
| ) | ||
|
|
||
| print(f"[RANK {rank}] output:", out.shape) | ||
|
|
||
| out.sum().backward() | ||
| print(f"[RANK {rank}] grad:", q.grad.shape) | ||
|
|
||
| dist.destroy_process_group() | ||
|
|
||
| if __name__ == "__main__": | ||
| world_size = 4 | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = "12355" | ||
| mp.spawn(run, args=(world_size,), nprocs=world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this needs to be removed?