11"""
2- Dismantling down the ``nn.Transformer`` modules for gains and profits
3- ======================================================================
2+ Dismantling the ``nn.Transformer`` modules for gains and profits
3+ =================================================================
44**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
55
66The ``torch.nn`` module currently provides various ``Transformer``-related layers.
16161. People want to add slight customizations to their transformer layers
17172. Writing these layers and customizations is not hard
1818
19+
1920Supporting all transformer variants via a small number of out of the box layers would
2021yield too many keyword arguments. This tutorial will describe how to build your
2122own performant transformer layers following our recommended best practices.
4142
4243If you are only interested in performant attention score modifications, please
4344head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
44- contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_ .
45+ contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
4546
4647If you are wondering about what building blocks the ``torch`` library provides
4748for writing your own transformer layers and best practices, you are in the
6061sequence lengths. They eliminate the need for the bug-prone practices of explicit
6162padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
6263
63- *`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
64+ * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
6465
6566``scaled_dot_product_attention`` is a primitive for
6667:math:`\t ext{softmax}(\f rac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
1011022. Layer ordering (where to apply norms, where to apply positional encoding etc.)
1021033. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
103104
105+
104106In a pre-compiler world, one might write their custom transformer and observe
105107that it works but is slow. Then, one might write a custom fused kernel for
106108the specific series of ops. In a compiler world, one can do the former, compile
118120# The improvements are threefold:
119121#
120122# * User Experience
121- # Recall that ``nn.MultiheadAttention`` requires ``query` ``, ``key`` and
122- # ``value`` to be dense ``torch.Tensor``s . It also provides a
123- # ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
124- # that arise due to different sequence lengths within a batch. Since there is
125- # no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
126- # the outputs appropriately to account for query sequence lengths. Nested tensor
127- # cleanly removes the need for this sort of error-prone padding masks.
123+ # Recall that ``nn.MultiheadAttention`` requires ``query``, ``key`` and
124+ # ``value`` to be dense ``torch.Tensors`` . It also provides a
125+ # ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
126+ # that arise due to different sequence lengths within a batch. Since there is
127+ # no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
128+ # the outputs appropriately to account for query sequence lengths. Nested tensor
129+ # cleanly removes the need for this sort of error-prone padding masks.
128130#
129131# * Memory
130- # Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
131- # padding mask (where ``B`` is batch size, ``S`` is max sequence length in the
132- # batch and ``D`` is embedding size), nested tensors allow you to cleanly
133- # represent the batch of varying sequence lengths. As a result, the input and
134- # intermediate activations will use less memory.
132+ # Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
133+ # padding mask (where ``B`` is batch size, ``S`` is max sequence length in the
134+ # batch and ``D`` is embedding size), nested tensors allow you to cleanly
135+ # represent the batch of varying sequence lengths. As a result, the input and
136+ # intermediate activations will use less memory.
135137#
136138# * Performance
137- # Since padding is not materialized and unnecessary computation on padding is
138- # skipped, performance and memory usage improve.
139+ # Since padding is not materialized and unnecessary computation on padding is
140+ # skipped, performance and memory usage improve.
139141#
140142# We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
141143# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
@@ -319,8 +321,8 @@ def benchmark(func, *args, **kwargs):
319321
320322##############################################################################
321323# We will now demonstrate the performance improvements of using nested tensors
322- # in the ``MultiheadAttention`` layer for self attention. We compare this against
323- # the traditional ``nn.MultiheadAttention`` with padding and masking.
324+ # in the ``MultiheadAttention`` layer + compile for self attention. We compare this against
325+ # the traditional ``nn.MultiheadAttention`` + compile with padding and masking.
324326
325327N , E_q , E_k , E_v , E_total = 512 , 512 , 512 , 512 , 512
326328E_out = E_q
@@ -392,8 +394,9 @@ def benchmark(func, *args, **kwargs):
392394
393395######################################################################################
394396# We can also see the same for backward pass
395- # padding-specific step: remove output projection bias from padded entries for fair comparison
397+
396398for i , entry_length in enumerate (sentence_lengths ):
399+ # padding-specific step: remove output projection bias from padded entries for fair comparison
397400 padded_result [i , entry_length :, :] = 0.0
398401
399402_ , padded_bw_time , padded_bw_peak_mem = benchmark (lambda : padded_result .sum ().backward ())
@@ -417,7 +420,7 @@ def benchmark(func, *args, **kwargs):
417420# this is fairly straightforward using the ``MultiheadAttention`` layer above and
418421# gives equivalent results to an ``nn.TransformerEncoderLayer`` with
419422# ``is_causal=True``.
420-
423+ #
421424# We demonstrate examples of implementing the rest of the nn layers
422425# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
423426# tutorial for brevity.
@@ -438,9 +441,7 @@ def benchmark(func, *args, **kwargs):
438441# * Packed Projection
439442# * Cross Attention
440443# * Fully masked rows no longer cause ``NaN``s
441- # * [TODO] Modifying attention score: Relative Positional Embedding with NJT
442- # * [TODO] KV-Caching with NJT
443- # * [TODO] Grouped Query Attention with NJT
444+ # * Modifying attention score: ALiBi with FlexAttention and NJT
444445
445446###############################################################################
446447# Packed Projection
@@ -566,7 +567,7 @@ def forward(self, x):
566567# ---------------
567568# Cross attention is a form of attention where the query and key/value tensors
568569# are from different sequences.
569-
570+ #
570571# One example of this is in ``nn.TransformerDecoderLayer`` where the query comes
571572# from the decoder and the key/value come from the encoder.
572573#
0 commit comments