-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add Unified Sequence Parallel attention #12693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
It would be nice to get a testing script so that we can quickly check things. |
|
I added a basic test script with a simple forward and backward op. Is it better to have a test script with flash_attention_backward and forward?? |
a244006 to
9dee8f8
Compare
bug fixes, lse calculation - switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues bug fix bug fix bug fix
9dee8f8 to
9ebcff5
Compare
|
Let us know if this is ready for a review! |
|
Yep, ready for review! I tested it with a 4-process setup (2×2 mesh, on cpu) and everything checks out, shapes look good and gradients flow correctly. Looking forward for feedback and happy to address any issues. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting started on this!
| # if self.ring_degree > 1 and self.ulysses_degree > 1: | ||
| # raise ValueError( | ||
| # "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." | ||
| # ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this needs to be removed?
| x = _wait_tensor(x) | ||
| return x | ||
|
|
||
| def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also add some basic docstrings to this function because it will help readability. Some commentaries on what the function is doing will also be helpful.
| B, S_LOCAL, H, D = x.shape | ||
| S = S_LOCAL * group_world_size | ||
| H_LOCAL = H // group_world_size | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): prefer using fully qualified variable names.
| grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) | ||
|
|
||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | ||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change here?
| query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) | ||
| key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) | ||
| value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's better to have SeqAllToAllDim accept QKV tensors as inputs rather having them like this?
| scatter_idx = 2 | ||
| gather_idx = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this configurable.
|
I am trying with the following code: import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
device = torch.device(f"cuda:{dist.get_rank()}")
torch.cuda.set_device(device)
return device
device = setup_distributed()
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]
if dist.get_rank() == 0:
image.save("output_ua.png")
if dist.is_initialized():
dist.destroy_process_group()Run the above with And it leads to: |
What does this PR do?
This is a draft implementation of the Unified SP attention approach.
_all_to_all_dim_exchangewith custom scatter and gather indicesTemplatedUnifiedAttentionCore implementation complete, needs: