33=================================================================
44**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
55
6+ .. note::
7+ This tutorial should be run with the latest nightly, or, when available, 2.6.
8+
69The ``torch.nn`` module currently provides various ``Transformer``-related layers.
710In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
811``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
@@ -253,73 +256,72 @@ def forward(self,
253256
254257 return attn_output
255258
256- # .. dropdown::
257-
258- ###############################################################################
259- # Utilities
260- # =========
261- # In this section, we include a utility to generate semi-realistic data using
262- # Zipf distribution for sentence lengths. This is used to generate the nested
263- # query, key and value tensors. We also include a benchmark utility.
264-
265259
266- import numpy as np
267-
268- def zipf_sentence_lengths (alpha : float , batch_size : int ) -> torch .Tensor :
269- # generate fake corpus by unigram Zipf distribution
270- # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
271- sentence_lengths = np .empty (batch_size , dtype = int )
272- for ibatch in range (batch_size ):
273- sentence_lengths [ibatch ] = 1
260+ ###############################################################################
261+ # Utilities
262+ # =========
263+ # In this section, we include a utility to generate semi-realistic data using
264+ # Zipf distribution for sentence lengths. This is used to generate the nested
265+ # query, key and value tensors. We also include a benchmark utility.
266+
267+
268+ import numpy as np
269+
270+ def zipf_sentence_lengths (alpha : float , batch_size : int ) -> torch .Tensor :
271+ # generate fake corpus by unigram Zipf distribution
272+ # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
273+ sentence_lengths = np .empty (batch_size , dtype = int )
274+ for ibatch in range (batch_size ):
275+ sentence_lengths [ibatch ] = 1
276+ word = np .random .zipf (alpha )
277+ while word != 3 and word != 386 and word != 858 :
278+ sentence_lengths [ibatch ] += 1
274279 word = np .random .zipf (alpha )
275- while word != 3 and word != 386 and word != 858 :
276- sentence_lengths [ibatch ] += 1
277- word = np .random .zipf (alpha )
278- return torch .tensor (sentence_lengths )
279-
280- # Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
281- # in the form of nested tensors with the jagged layout.
282- def gen_batch (N , E_q , E_k , E_v , device , dtype = torch .float32 , query_seq_len_1 = False ):
283- # generate semi-realistic data using Zipf distribution for sentence lengths
284- sentence_lengths = zipf_sentence_lengths (alpha = 1.2 , batch_size = N )
285-
286- # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
287- # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
288- # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
289- if query_seq_len_1 :
280+ return torch .tensor (sentence_lengths )
281+
282+ # Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
283+ # in the form of nested tensors with the jagged layout.
284+ def gen_batch (N , E_q , E_k , E_v , device , dtype = torch .float32 , query_seq_len_1 = False ):
285+ # generate semi-realistic data using Zipf distribution for sentence lengths
286+ sentence_lengths = zipf_sentence_lengths (alpha = 1.2 , batch_size = N )
287+
288+ # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
289+ # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
290+ # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
291+ if query_seq_len_1 :
290292 query = torch .nested .nested_tensor ([
291293 torch .randn (1 , E_q , dtype = dtype , device = device )
292294 for l in sentence_lengths
293295 ], layout = torch .jagged )
294- else :
296+ else :
295297 query = torch .nested .nested_tensor ([
296298 torch .randn (l .item (), E_q , dtype = dtype , device = device )
297299 for l in sentence_lengths
298300 ], layout = torch .jagged )
299301
300- key = torch .nested .nested_tensor ([
301- torch .randn (s .item (), E_k , dtype = dtype , device = device )
302- for s in sentence_lengths
303- ], layout = torch .jagged )
302+ key = torch .nested .nested_tensor ([
303+ torch .randn (s .item (), E_k , dtype = dtype , device = device )
304+ for s in sentence_lengths
305+ ], layout = torch .jagged )
304306
305- value = torch .nested .nested_tensor ([
306- torch .randn (s .item (), E_v , dtype = dtype , device = device )
307- for s in sentence_lengths
308- ], layout = torch .jagged )
307+ value = torch .nested .nested_tensor ([
308+ torch .randn (s .item (), E_v , dtype = dtype , device = device )
309+ for s in sentence_lengths
310+ ], layout = torch .jagged )
309311
310- return query , key , value , sentence_lengths
312+ return query , key , value , sentence_lengths
311313
312- import timeit
313- import math
314+ import timeit
315+ import math
314316
315- def benchmark (func , * args , ** kwargs ):
316- torch .cuda .synchronize ()
317- torch .cuda .reset_peak_memory_stats ()
318- begin = timeit .default_timer ()
319- output = func (* args , ** kwargs )
320- torch .cuda .synchronize ()
321- end = timeit .default_timer ()
322- return output , (end - begin ), torch .cuda .max_memory_allocated ()
317+ def benchmark (func , * args , ** kwargs ):
318+ torch .cuda .synchronize ()
319+ torch .cuda .reset_peak_memory_stats ()
320+ begin = timeit .default_timer ()
321+ output = func (* args , ** kwargs )
322+ torch .cuda .synchronize ()
323+ end = timeit .default_timer ()
324+ return output , (end - begin ), torch .cuda .max_memory_allocated ()
323325
324326##############################################################################
325327# We will now demonstrate the performance improvements of using nested tensors
@@ -395,6 +397,16 @@ def benchmark(func, *args, **kwargs):
395397print (f"Nested peak memory reduction { ((padded_peak_memory - nested_peak_memory )/ 1e9 ):.2f} GB" )
396398
397399######################################################################################
400+ # For reference some sample outputs on A100:
401+ #
402+ # ```
403+ # padded_time=0.03454, padded_peak_memory=4.14 GB
404+ # nested_time=0.00612, nested_peak_memory=0.76 GB
405+ # Difference between vanilla and nested result 0.0
406+ # Nested speedup: 5.65
407+ # Nested peak memory reduction 3.39 GB
408+ # ````
409+ #
398410# We can also see the same for backward pass
399411
400412for i , entry_length in enumerate (sentence_lengths ):
@@ -414,6 +426,20 @@ def benchmark(func, *args, **kwargs):
414426print ("Difference in out_proj.bias.grad" , (mha_layer .out_proj .bias .grad - vanilla_mha_layer .out_proj .bias .grad ).abs ().max ().item ())
415427print ("Difference in packed_proj.bias.grad" , (mha_layer .packed_proj .bias .grad - vanilla_mha_layer .in_proj_bias .grad ).abs ().max ().item ())
416428
429+ ##################################################################################
430+ # Sample outputs on A100:
431+ #
432+ # ```
433+ # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
434+ # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
435+ # Nested backward speedup: 144.13
436+ # Nested backward peak memory reduction 1.86 GB
437+ # Difference in out_proj.weight.grad 0.000244140625
438+ # Difference in packed_proj.weight.grad 0.001556396484375
439+ # Difference in out_proj.bias.grad 0.0
440+ # Difference in packed_proj.bias.grad 0.001953125
441+ # ```
442+
417443##################################################################################
418444# GPT-style layer
419445# ---------------
@@ -424,8 +450,9 @@ def benchmark(func, *args, **kwargs):
424450# ``is_causal=True``.
425451#
426452# We demonstrate examples of implementing the rest of the ``nn`` layers
427- # `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
428- # tutorial for brevity.
453+ # `here <https://github.com/mikaylagawarecki/transformer_tutorial_accompaniment>`_
454+ # but omit that from this tutorial for brevity.
455+
429456
430457###############################################################################
431458# Going one step further
@@ -440,10 +467,85 @@ def benchmark(func, *args, **kwargs):
440467# In this section, we will discuss various functionalities using the
441468# aforementioned building blocks. In particular,
442469#
443- # * Packed Projection
444470# * Cross Attention
445471# * Fully masked rows no longer cause ``NaN``s
446472# * Modifying attention score: ALiBi with FlexAttention and NJT
473+ # * Packed Projection
474+
475+ ###############################################################################
476+ # Cross Attention
477+ # ---------------
478+ # Cross attention is a form of attention where the query and key/value tensors
479+ # are from different sequences.
480+ #
481+ # One example of this is in ``nn.TransformerDecoderLayer`` where the query comes
482+ # from the decoder and the key/value come from the encoder.
483+ #
484+ # The above MultiheadAttention layer nicely generalizes to this case with nested
485+ # tensors for both query and key/value.
486+
487+ query , _ , _ , q_len = gen_batch (N , E_q , E_k , E_v , device )
488+ _ , key , value , kv_len = gen_batch (N , E_q , E_k , E_v , device )
489+
490+ print (f"Total sequence length in nested query { q_len .sum ().item ()} , max sequence length { q_len .max ().item ()} " )
491+ print (f"Total sequence length in nested key/value { kv_len .sum ().item ()} , max sequence length { kv_len .max ().item ()} " )
492+ out = new_mha_layer (query , key , value , is_causal = False )
493+
494+
495+ ################################################################################
496+ # Fully masked rows no longer cause NaNs
497+ # --------------------------------------
498+ #
499+ # There has been a long standing issue with ``nn.MultiheadAttention`` and
500+ # ``scaled_dot_product_attention`` where if a row was fully masked out, the output
501+ # of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
502+ # This is because the softmax over an empty set is undefined.
503+ #
504+ # Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
505+ # this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``.
506+ # For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply.
507+ #
508+ # Using a custom MHA layer with NJTs is strongly recommended over the
509+ # existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness
510+ # appropriately makes it possible to properly express empty sequences.
511+
512+
513+ ################################################################################
514+ # FlexAttention + NJT
515+ # ---------------------------------------------------------------------
516+ # NJT also composes with the ``FlexAttention`` module. This is a generalization
517+ # of the ``MultiheadAttention`` layer that allows for arbitrary modifications
518+ # to the attention score. The example below takes the ``alibi_mod``
519+ # that implements `ALiBi <https://arxiv.org/abs/2108.12409>`_ from
520+ # `attention gym <https://github.com/pytorch-labs/attention-gym>`_ and uses it
521+ # with nested input tensors.
522+
523+ from torch .nn .attention .flex_attention import flex_attention
524+
525+ def generate_alibi_bias (H : int ):
526+ """Returns an alibi bias score_mod given the number of heads H
527+ Args:
528+ H: number of heads
529+ Returns:
530+ alibi_bias: alibi bias score_mod
531+ """
532+ def alibi_mod (score , b , h , q_idx , kv_idx ):
533+ scale = torch .exp2 (- ((h + 1 ) * 8.0 / H ))
534+ bias = (q_idx - kv_idx ) * scale
535+ return score + bias
536+ return alibi_mod
537+
538+ query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
539+ n_heads , D = 8 , E_q // 8
540+ alibi_score_mod = generate_alibi_bias (n_heads )
541+ query = (
542+ query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
543+ )
544+ key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
545+ value = (
546+ value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
547+ )
548+ out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
447549
448550###############################################################################
449551# Packed Projection
@@ -567,80 +669,6 @@ def forward(self, x):
567669_ , time_packed , _ = benchmark (packed_swigluffn , q )
568670print (f"SwiGLUFFN: { time } s, PackedSwiGLUFFN: { time_packed } s, speedup: { time / time_packed :.2f} x" )
569671
570- ###############################################################################
571- # Cross Attention
572- # ---------------
573- # Cross attention is a form of attention where the query and key/value tensors
574- # are from different sequences.
575- #
576- # One example of this is in ``nn.TransformerDecoderLayer`` where the query comes
577- # from the decoder and the key/value come from the encoder.
578- #
579- # The above MultiheadAttention layer nicely generalizes to this case with nested
580- # tensors for both query and key/value.
581-
582- query , _ , _ , q_len = gen_batch (N , E_q , E_k , E_v , device )
583- _ , key , value , kv_len = gen_batch (N , E_q , E_k , E_v , device )
584-
585- print (f"Total sequence length in nested query { q_len .sum ().item ()} , max sequence length { q_len .max ().item ()} " )
586- print (f"Total sequence length in nested key/value { kv_len .sum ().item ()} , max sequence length { kv_len .max ().item ()} " )
587- out = new_mha_layer (query , key , value , is_causal = False )
588-
589-
590- ################################################################################
591- # Fully masked rows no longer cause NaNs
592- # --------------------------------------
593- #
594- # There has been a long standing issue with ``nn.MultiheadAttention`` and
595- # ``scaled_dot_product_attention`` where if a row was fully masked out, the output
596- # of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
597- # This is because the softmax over an empty set is undefined.
598- #
599- # Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
600- # this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``.
601- # For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply.
602- #
603- # Using a custom MHA layer with NJTs is strongly recommended over the
604- # existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness
605- # appropriately makes it possible to distinguish when there is an empty sequence.
606-
607-
608- ################################################################################
609- # ALiBi with NJT (FlexAttention + NJT)
610- # ---------------------------------------------------------------------
611- # NJT also composes with the ``FlexAttention`` module. This is a generalization
612- # of the ``MultiheadAttention`` layer that allows for arbitrary modifications
613- # to the attention score. The example below takes the ``alibi_mod`` from
614- # attention gym and uses it with nested input tensors.
615-
616- from torch .nn .attention .flex_attention import flex_attention
617-
618- def generate_alibi_bias (H : int ):
619- """Returns an alibi bias score_mod given the number of heads H
620- Args:
621- H: number of heads
622- Returns:
623- alibi_bias: alibi bias score_mod
624- """
625- def alibi_mod (score , b , h , q_idx , kv_idx ):
626- scale = torch .exp2 (- ((h + 1 ) * 8.0 / H ))
627- bias = (q_idx - kv_idx ) * scale
628- return score + bias
629- return alibi_mod
630-
631- query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
632- n_heads , D = 8 , E_q // 8
633- alibi_score_mod = generate_alibi_bias (n_heads )
634- query = (
635- query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
636- )
637- key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
638- value = (
639- value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
640- )
641- out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
642-
643-
644672################################################################################
645673# Extended examples
646674# -----------------
0 commit comments