4040Please head there instead!
4141
4242If you are only interested in performant attention score modifications, please
43- head to the `FlexAttention blog <https://flexattention.com /blog/>`_ that
43+ head to the `FlexAttention blog <https://pytorch.org /blog/flexattention />`_ that
4444contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_ .
4545
4646If you are wondering about what building blocks the ``torch`` library provides
6363*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
6464
6565``scaled_dot_product_attention`` is a primitive for
66- $ \t ext{softmax}(\f rac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused
66+ :math:` \t ext{softmax}(\f rac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
6767implementations of the operator or a fallback implementation. It works out of
6868the box in eager mode (i.e. the default mode of using PyTorch where operations
6969are executed on the fly as they are encountered) and also integrates seamlessly
118118# The improvements are threefold:
119119#
120120# * User Experience
121- # Recall that `nn.MultiheadAttention` requires ``query```, ``key`` and
121+ # Recall that `` nn.MultiheadAttention` ` requires ``query```, ``key`` and
122122# ``value`` to be dense ``torch.Tensor``s. It also provides a
123123# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
124124# that arise due to different sequence lengths within a batch. Since there is
@@ -202,10 +202,10 @@ def forward(self,
202202 4. Apply output projection
203203
204204 Args:
205- query (torch.Tensor): query of shape (N, L_q, E_qk)
206- key (torch.Tensor): key of shape (N, L_kv, E_qk)
207- value (torch.Tensor): value of shape (N, L_kv, E_v)
208- attn_mask (torch.Tensor, optional): attention mask of shape (N, L_q, L_kv) to pass to sdpa . Default: None
205+ query (torch.Tensor): query of shape (``N``, `` L_q``, `` E_qk`` )
206+ key (torch.Tensor): key of shape (``N``, `` L_kv``, `` E_qk`` )
207+ value (torch.Tensor): value of shape (``N``, `` L_kv``, `` E_v`` )
208+ attn_mask (torch.Tensor, optional): attention mask of shape (``N``, `` L_q``, `` L_kv`` ) to pass to SDPA . Default: None
209209 is_causal (bool, optional): Whether to apply causal mask. Default: False
210210
211211 Returns:
@@ -251,11 +251,10 @@ def forward(self,
251251
252252 return attn_output
253253
254- # TODO: Check whether there is a way to collapse this section by default
255- # sphinx.collapse?
254+
256255###############################################################################
257256# Utilities
258- # =========
257+ # ========================
259258# In this section, we include a utility to generate semi-realistic data using
260259# Zipf distribution for sentence lengths. This is used to generate the nested
261260# query, key and value tensors. We also include a benchmark utility.
@@ -343,7 +342,7 @@ def benchmark(func, *args, **kwargs):
343342torch .manual_seed (6 )
344343vanilla_mha_layer = nn .MultiheadAttention (E_q , nheads , dropout = dropout , batch_first = True , bias = bias , device = 'cuda' )
345344
346- # nn.MultiheadAttention uses a non conventional init for layers, so do this for exact parity :(
345+ # nn.MultiheadAttention uses a non conventional initialization for layers, so do this for exact parity :(
347346mha_layer .out_proj .weight = nn .Parameter (vanilla_mha_layer .out_proj .weight .clone ().detach ())
348347mha_layer .packed_proj .weight = nn .Parameter (vanilla_mha_layer .in_proj_weight .clone ().detach ())
349348mha_layer .out_proj .bias = nn .Parameter (vanilla_mha_layer .out_proj .bias .clone ().detach ())
@@ -357,8 +356,8 @@ def benchmark(func, *args, **kwargs):
357356nested_result , nested_time , nested_peak_memory = benchmark (new_mha_layer , query , query , query , is_causal = True )
358357padded_nested_result = nested_result .to_padded_tensor (0.0 )
359358
360- # For the vanilla nn.MHA , we need to construct the key_padding_mask
361- # Further, nn.MultiheadAttention forces one to materialize the attn_mask even if using is_causal
359+ # For the vanilla `` nn.MultiheadAttention`` , we need to construct the `` key_padding_mask``
360+ # Further, `` nn.MultiheadAttention`` forces one to materialize the `` attn_mask`` even if using `` is_causal``
362361src_key_padding_mask = torch .where (padded_query == 0.0 , - math .inf , 0 )[:, :, 0 ]
363362attn_mask = torch .empty ((N , S , S ), device = device ).fill_ (float ('-inf' ))
364363for i , s in enumerate (sentence_lengths ):
@@ -431,14 +430,14 @@ def benchmark(func, *args, **kwargs):
431430# classification of modifications to the transformer architecture, recall that we
432431# classified the modifications into layer type, layer ordering, and modifications
433432# to the attention score. We trust that changing layer type and layer ordering
434- # (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward.
433+ # (e.g. swapping`` LayerNorm`` for `` RMSNorm`` ) is fairly straightforward.
435434#
436435# In this section, we will discuss various functionalities using the
437436# aforementioned building blocks. In particular,
438437#
439438# * Packed Projection
440439# * Cross Attention
441- # * Fully masked rows no longer cause NaNs
440+ # * Fully masked rows no longer cause ``NaN``s
442441# * [TODO] Modifying attention score: Relative Positional Embedding with NJT
443442# * [TODO] KV-Caching with NJT
444443# * [TODO] Grouped Query Attention with NJT
@@ -448,13 +447,13 @@ def benchmark(func, *args, **kwargs):
448447# -----------------
449448#
450449# Packed projection is a technique that makes use of the fact that when the input
451- # for projection (matmul ) are the same (self-attention), we can pack the projection
450+ # for projection (matrix multiplications ) are the same (self-attention), we can pack the projection
452451# weights and biases into single tensors. It is especially useful when the individual
453- # projections (matmuls) are memory bound rather than compute bound. There are
452+ # projections are memory bound rather than compute bound. There are
454453# two examples that we will demonstrate here:
455454#
456455# * Input projection for MultiheadAttention
457- # * SwiGLU activation in FFN of Transformer Layer
456+ # * SwiGLU activation in feed-forward network of Transformer Layer
458457#
459458# Input projection for MultiheadAttention
460459# ----------------------------------------
@@ -505,7 +504,7 @@ def forward(self, query):
505504# SwiGLU feed forward network of Transformer Layer
506505# ------------------------------------------------
507506# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
508- # network of the transformer layer (e.g. Llama). A FFN with SwiGLU activation is defined as
507+ # network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as
509508
510509class SwiGLUFFN (nn .Module ):
511510 def __init__ (self , dim , hidden_dim , multiple_of , ffn_dim_multiplier = None , device = None , dtype = None ):
@@ -601,45 +600,47 @@ def forward(self, x):
601600
602601
603602################################################################################
604- # [PENDING] KV-Caching with NJT
605- # ----------------------------
606- # During decoding in inference, the query comprises of the current token. However,
607- # the key and value comprises of all the previous keys and values in addition to
608- # the current token.
609- #
610- # When we do batched inference, each batch item will be at a different stage of
611- # decoding, so we expect the keys and values to have different sequence lengths.
612- # The query is a dense tensor of shape ``[B, 1, E_qk]`` and the keys and values
613- # will be of shapes ``[B, *, E_qk]`` and ``[B, *, E_v]`` where ``B`` represents
614- # batch size, ``*`` represents varying sequence lengths and ``E_qk`` and ``E_v``
615- # are embedding dimensions for query/key and value respectively.
616-
617- # Directly related to the above point is the idea of KV-Caching. This is a technique
618- # that is used in inference to reduce the latency of decoding. The idea is to cache
619- # the key and value tensors for the previous tokens and use them for the current
620- # token. This is especially useful when the sequence length is long.
621-
622- # FIXME: Pending https://github.com/pytorch/pytorch/pull/135722
623-
624-
625- ################################################################################
626- # [PENDING] Relative Positional Embedding with NJT (FlexAttention + NJT)
603+ # ALiBi with NJT (FlexAttention + NJT)
627604# ---------------------------------------------------------------------
628- #
629- # FIXME: Pending https://github.com/pytorch/pytorch/pull/136792
605+ # NJT also composes with the ``FlexAttention`` module. This is a generalization
606+ # of the ``MultiheadAttention`` layer that allows for arbitrary modifications
607+ # to the attention score. The example below takes the ``alibi_mod`` from
608+ # attention gym and uses it with nested input tensors.
630609
610+ from torch .nn .attention .flex_attention import flex_attention
611+
612+ def generate_alibi_bias (H : int ):
613+ """Returns an alibi bias score_mod given the number of heads H
614+ Args:
615+ H: number of heads
616+ Returns:
617+ alibi_bias: alibi bias score_mod
618+ """
619+ def alibi_mod (score , b , h , q_idx , kv_idx ):
620+ scale = torch .exp2 (- ((h + 1 ) * 8.0 / H ))
621+ bias = (q_idx - kv_idx ) * scale
622+ return score + bias
623+ return alibi_mod
624+
625+ query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
626+ n_heads , D = 8 , E_q // 8
627+ alibi_score_mod = generate_alibi_bias (n_heads )
628+ query = (
629+ query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
630+ )
631+ key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
632+ value = (
633+ value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
634+ )
635+ out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
631636
632637################################################################################
633- # [PENDING] Grouped Query Attention with NJT
634- # ------------------------------------------
635- #
636- # Grouped Query Attention refers to using a number of key/value heads that is
637- # less than the number of query heads. Compared to MultiheadAttention, this
638- # decreases the size of the kv-cache during inference.
639- #
640- # We can implement this using nested tensors as follows
638+ # And more
639+ # --------
641640#
642- # FIXME: Pending FlexAttention/testing for NJT with grouped query attention
641+ # We intend to update this tutorial to demonstrate more examples of how to use
642+ # the various performant building blocks such as KV-Caching, Grouped Query Attention
643+ # etc.
643644
644645
645646################################################################################
@@ -649,7 +650,7 @@ def forward(self, x):
649650# There are several good examples of using various performant building blocks to
650651# implement various transformer architectures. Some examples include
651652#
652- # * `gpt_fast <https://github.com/pytorch-labs/gpt-fast>`_
653- # * `sam_fast <https://github.com/pytorch-labs/sam -fast>`_
654- # * `lucidrains implementation of ViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/main /vit_pytorch/nested_tensor .py>`_
653+ # * `gpt-fast <https://github.com/pytorch-labs/gpt-fast>`_
654+ # * `segment-anything-fast <https://github.com/pytorch-labs/segment-anything -fast>`_
655+ # * `lucidrains implementation of NaViT with nested tensors <https://github.com/lucidrains/vit-pytorch/blob/73199ab486e0fad9eced2e3350a11681db08b61b /vit_pytorch/na_vit_nested_tensor .py>`_
655656# * `torchtune's implementation of VisionTransformer <https://github.com/pytorch/torchtune/blob/a8a64ec6a99a6ea2be4fdaf0cd5797b03a2567cf/torchtune/modules/vision_transformer.py#L16>`_
0 commit comments