Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion zerogpu-aoti.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ In this post, we’ll show how to wire up Ahead-of-Time (AoT) compilation in Zer
- [Dynamic shapes](#dynamic-shapes)
- [Multi-compile / shared weights](#multi-compile--shared-weights)
- [FlashAttention-3](#flashattention-3)
- [Regional compilation](#regional-compilation)
- [AoT compiled ZeroGPU Spaces demos](#aot-compiled-zerogpu-spaces-demos)
- [Conclusion](#conclusion)
- [Resources](#resources)
Expand Down Expand Up @@ -340,6 +341,75 @@ It tries to load a kernel from the [`kernels-community/vllm-flash-attn3`](https:

Here is a [fully working example of an FA3 attention processor](https://gist.github.com/sayakpaul/ff715f979793d4d44beb68e5e08ee067#file-fa3_qwen-py) for the Qwen-Image model.

### Regional compilation

> [!TIP]
> We suggest using regional compilation as opposed to using full model compilation, especially when the speed benefits are similar.

So far, we have been compiling the full model. Depending on the model, full model compilation can lead to significant cold start times. Long cold start times make the
development experience unpleasant.

We can also choose to compile _regions_ within a model, significantly reducing the cold start times, while retaining almost all the benefits of full model compilation. Regional
compilation becomes promising when a model has repeated blocks of computation. A standard
language model, for example, has a number of identically structured Transformer blocks.

In our example, we can compile the repeated blocks of the Flux transformer ahead of time like so. The [Flux Transformer](https://github.com/huggingface/diffusers/blob/c2e5ece08bf22d249c62e964f91bc326cf9e3759/src/diffusers/models/transformers/transformer_flux.py) has two kinds of repeated blocks: `FluxTransformerBlock` and `FluxSingleTransformerBlock`. We start by capturing inputs for just these blocks:

```py
# Capturing inputs just for the first block is enough as the input structure remains
# the same for others.

with spaces.aoti_capture(pipe.transformer.transformer_blocks[0]) as call_double_blocks:
pipe("arbitrary example prompt")

with spaces.aoti_capture(
pipe.transformer.single_transformer_blocks[0]
) as call_single_blocks:
pipe("arbitrary example prompt")
```

We then perform compilation after exporting a `torch.export.ExportedProgram`:

```py
exported_double = torch.export.export(
mod=pipe.transformer.transformer_blocks[0],
args=call_double_blocks.args,
kwargs=call_double_blocks.kwargs,
)
exported_single = torch.export.export(
mod=pipe.transformer.single_transformer_blocks[0],
args=call_single_blocks.args,
kwargs=call_single_blocks.kwargs,
)

compiled_double = spaces.aoti_compile(exported_double)
compiled_single = spaces.aoti_compile(exported_single)
```

Note that we are only compiling the first blocks within `transformer.transformer_blocks` and
`transformer.single_transformer_blocks` as they each block within these share the same
configuration. So, the compilation-optimized graph can be reused. During loading, we
make use of this optimized graph while reusing the parameters for each block as shown below:

```py
from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights

for block in pipe.transformer.transformer_blocks:
weights = ZeroGPUWeights(block.state_dict())
compiled_block = ZeroGPUCompiledModel(compiled_double.archive_file, weights)
block.forward = compiled_block

for block in pipe.transformer.single_transformer_blocks:
weights = ZeroGPUWeights(block.state_dict())
compiled_block = ZeroGPUCompiledModel(compiled_single.archive_file, weights)
block.forward = compiled_block
```

And we should be ready to go 🚀 You can check out [this Space](https://huggingface.co/spaces/zerogpu-aoti/Qwen-Image-Edit-AoT-Regional) for a more complete example.

> [!TIP]
> 💡 For Flux.1-Dev, compiling regions ahead of time like this reduces the cold start timing from _103 seconds to 23 seconds_, while delivering almost identical speedups.

## AoT compiled ZeroGPU Spaces demos

### Speedup comparison
Expand All @@ -350,7 +420,11 @@ Here is a [fully working example of an FA3 attention processor](https://gist.git
- [FLUX.1 Kontext](https://huggingface.co/spaces/zerogpu-aoti/FLUX.1-Kontext-Dev)
- [QwenImage Edit](https://huggingface.co/spaces/multimodalart/Qwen-Image-Edit-Fast)
- [Wan 2.2](https://huggingface.co/spaces/zerogpu-aoti/wan2-2-fp8da-aoti-faster)
- [LTX Video](https://huggingface.co/spaces/zerogpu-aoti/ltx-dev-fast)

### Regional compilation
- [Regional compilation recipe](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially thought that it was your recent tutorial on regional AoT. Still nice to include this one though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's about to be merged: pytorch/tutorials#3543

- [Native integration in Diffusers](https://huggingface.co/docs/diffusers/main/en/optimization/fp16)
- [More performance numbers](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)

## Conclusion

Expand Down