4646If you are only interested in performant attention score modifications, please
4747head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
4848contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
49-
5049If you are wondering about what building blocks the ``torch`` library provides
5150for writing your own transformer layers and best practices, you are in the
5251right place, please keep reading!
5352
5453
55- Introducing the Building Blocks
56- ===============================
57- First, we will briefly introduce the 4 technologies mentioned in the introduction
54+ """
55+
56+ ################################################################################
57+ # Introducing the Building Blocks
58+ # ===============================
59+ # First, we will briefly introduce the 4 technologies mentioned in the introduction
5860
59- * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
61+ # * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
6062
61- Nested tensors generalize the shape of regular dense tensors, allowing for
62- representation of ragged-sized data with the same tensor UX. In the context of
63- transformers, we can think of nested tensors as a tool for representing variable
64- sequence lengths. They eliminate the need for the bug-prone practices of explicit
65- padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
63+ # Nested tensors generalize the shape of regular dense tensors, allowing for
64+ # representation of ragged-sized data with the same tensor UX. In the context of
65+ # transformers, we can think of nested tensors as a tool for representing variable
66+ # sequence lengths. They eliminate the need for the bug-prone practices of explicit
67+ # padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
6668
67- * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
69+ # * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
6870
69- ``scaled_dot_product_attention`` is a primitive for
70- :math:`\t ext{softmax}(\f rac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
71- implementations of the operator or a fallback implementation. It works out of
72- the box in eager mode (i.e. the default mode of using PyTorch where operations
73- are executed on the fly as they are encountered) and also integrates seamlessly
74- with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
75- natively.
71+ # ``scaled_dot_product_attention`` is a primitive for
72+ # :math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
73+ # implementations of the operator or a fallback implementation. It works out of
74+ # the box in eager mode (i.e. the default mode of using PyTorch where operations
75+ # are executed on the fly as they are encountered) and also integrates seamlessly
76+ # with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
77+ # natively.
7678
77- * `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
79+ # * `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
7880
79- ``torch.compile()`` is a compiler introduced in version 2.0 that is able to
80- capture a graph of PyTorch code and perform various optimizations on it, such as
81- fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
82- and ``scaled_dot_product_attention`` work seamlessly with compile. In the
83- context of transformers, the value add of using compile with nested tensor
84- and SDPA is that compile can remove framework overhead ones sees in eager mode
85- and fuse sequences of ops in transformers together (e.g. projection and
86- activation).
81+ # ``torch.compile()`` is a compiler introduced in version 2.0 that is able to
82+ # capture a graph of PyTorch code and perform various optimizations on it, such as
83+ # fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
84+ # and ``scaled_dot_product_attention`` work seamlessly with compile. In the
85+ # context of transformers, the value add of using compile with nested tensor
86+ # and SDPA is that compile can remove framework overhead ones sees in eager mode
87+ # and fuse sequences of ops in transformers together (e.g. projection and
88+ # activation).
8789
88- * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
90+ # * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
8991
90- ``FlexAttention`` is a primitive that allows users to modify attention scores
91- prior to the softmax operation. It generalizes the additive ``B`` term above
92- for `scaled_dot_product_attention`, allowing for arbitrary calculation. It
93- requires compile to achieve good performance.
92+ # ``FlexAttention`` is a primitive that allows users to modify attention scores
93+ # prior to the softmax operation. It generalizes the additive ``B`` term above
94+ # for `` scaled_dot_product_attention` `, allowing for arbitrary calculation. It
95+ # requires compile to achieve good performance.
9496
95- The above building blocks are "All You Need" (as of October 2024)
96- ==================================================================
97+ # The above building blocks are "All You Need" (as of October 2024)
98+ # ==================================================================
9799
98- The main premise in this section is that most transformer variations are
99- GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
100- Blocks and Feed Forward networks. If we were to try to classify the differences
101- in this space, we might land on something like:
100+ # The main premise in this section is that most transformer variations are
101+ # GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
102+ # Blocks and Feed Forward networks. If we were to try to classify the differences
103+ # in this space, we might land on something like:
102104
103- 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
104- e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
105- 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
106- 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
105+ # 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
106+ # e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
107+ # 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
108+ # 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
107109
108110
109- In a pre-compiler world, one might write their custom transformer and observe
110- that it works but is slow. Then, one might write a custom fused kernel for
111- the specific series of ops. In a compiler world, one can do the former, compile
112- and profit.
111+ # In a pre-compiler world, one might write their custom transformer and observe
112+ # that it works but is slow. Then, one might write a custom fused kernel for
113+ # the specific series of ops. In a compiler world, one can do the former, compile
114+ # and profit.
113115
114- """
115116
116117###############################################################################
117118# MultiheadAttention
@@ -399,13 +400,12 @@ def benchmark(func, *args, **kwargs):
399400######################################################################################
400401# For reference some sample outputs on A100:
401402#
402- # ```
403- # padded_time=0.03454, padded_peak_memory=4.14 GB
404- # nested_time=0.00612, nested_peak_memory=0.76 GB
405- # Difference between vanilla and nested result 0.0
406- # Nested speedup: 5.65
407- # Nested peak memory reduction 3.39 GB
408- # ````
403+ # ..code::
404+ # padded_time=0.03454, padded_peak_memory=4.14 GB
405+ # nested_time=0.00612, nested_peak_memory=0.76 GB
406+ # Difference between vanilla and nested result 0.0
407+ # Nested speedup: 5.65
408+ # Nested peak memory reduction 3.39 GB
409409#
410410# We can also see the same for backward pass
411411
@@ -429,16 +429,16 @@ def benchmark(func, *args, **kwargs):
429429##################################################################################
430430# Sample outputs on A100:
431431#
432- # ```
433- # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
434- # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
435- # Nested backward speedup: 144.13
436- # Nested backward peak memory reduction 1.86 GB
437- # Difference in out_proj.weight.grad 0.000244140625
438- # Difference in packed_proj.weight.grad 0.001556396484375
439- # Difference in out_proj.bias.grad 0.0
440- # Difference in packed_proj.bias.grad 0.001953125
441- # ```
432+ # ..code::
433+ # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
434+ # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
435+ # Nested backward speedup: 144.13
436+ # Nested backward peak memory reduction 1.86 GB
437+ # Difference in out_proj.weight.grad 0.000244140625
438+ # Difference in packed_proj.weight.grad 0.001556396484375
439+ # Difference in out_proj.bias.grad 0.0
440+ # Difference in packed_proj.bias.grad 0.001953125
441+ #
442442
443443##################################################################################
444444# GPT-style layer
@@ -462,13 +462,13 @@ def benchmark(func, *args, **kwargs):
462462# classification of modifications to the transformer architecture, recall that we
463463# classified the modifications into layer type, layer ordering, and modifications
464464# to the attention score. We trust that changing layer type and layer ordering
465- # (e.g. swapping``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
465+ # (e.g. swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward.
466466#
467467# In this section, we will discuss various functionalities using the
468468# aforementioned building blocks. In particular,
469469#
470470# * Cross Attention
471- # * Fully masked rows no longer cause ``NaN``s
471+ # * Fully masked rows no longer cause NaNs
472472# * Modifying attention score: ALiBi with FlexAttention and NJT
473473# * Packed Projection
474474
0 commit comments