@@ -344,7 +344,7 @@ def benchmark(func, *args, **kwargs):
344344torch .manual_seed (6 )
345345vanilla_mha_layer = nn .MultiheadAttention (E_q , nheads , dropout = dropout , batch_first = True , bias = bias , device = 'cuda' )
346346
347- # nn.MultiheadAttention uses a non conventional initialization for layers, so do this for exact parity :(
347+ # `` nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
348348mha_layer .out_proj .weight = nn .Parameter (vanilla_mha_layer .out_proj .weight .clone ().detach ())
349349mha_layer .packed_proj .weight = nn .Parameter (vanilla_mha_layer .in_proj_weight .clone ().detach ())
350350mha_layer .out_proj .bias = nn .Parameter (vanilla_mha_layer .out_proj .bias .clone ().detach ())
@@ -421,7 +421,7 @@ def benchmark(func, *args, **kwargs):
421421# gives equivalent results to an ``nn.TransformerEncoderLayer`` with
422422# ``is_causal=True``.
423423#
424- # We demonstrate examples of implementing the rest of the nn layers
424+ # We demonstrate examples of implementing the rest of the ``nn`` layers
425425# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
426426# tutorial for brevity.
427427
@@ -457,7 +457,7 @@ def benchmark(func, *args, **kwargs):
457457# * SwiGLU activation in feed-forward network of Transformer Layer
458458#
459459# Input projection for MultiheadAttention
460- # ----------------------------------------
460+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
461461# Recall that when doing self-attention, the ``query``, ``key`` and ``value``
462462# are the same tensor. Each of these tensors is projected with a
463463# ``Linear(E_q, E_total)`` layer. Instead, we can pack this into one layer,
@@ -502,8 +502,9 @@ def forward(self, query):
502502(q_out , k_out , v_out ), time_packed , _ = benchmark (packed_in_proj , q )
503503print (f"InputProjection: { time :5f} s, PackedInputProjection: { time_packed :5f} s, speedup: { time / time_packed :.2f} x" )
504504
505+ ##################################################
505506# SwiGLU feed forward network of Transformer Layer
506- # ------------------------------------------------
507+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
507508# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
508509# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as
509510
@@ -524,6 +525,7 @@ def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device
524525 def forward (self , x ):
525526 return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
526527
528+ ########################################################################
527529# An alternative way of implementing this that uses packed projection is
528530
529531class PackedSwiGLUFFN (nn .Module ):
@@ -543,6 +545,7 @@ def forward(self, x):
543545 x1 , x3 = torch .chunk (self .w13 (x ), 2 , dim = - 1 )
544546 return self .w2 (F .silu (x1 ) * x3 )
545547
548+ ################################################################################
546549# We can compare the performance of the two implementations as follows
547550# Depending on your hardware, you might see different results. On an A100 I see
548551# 1.12x speedup for D=128.
@@ -635,20 +638,14 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
635638)
636639out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
637640
638- ################################################################################
639- # And more
640- # --------
641- #
642- # We intend to update this tutorial to demonstrate more examples of how to use
643- # the various performant building blocks such as KV-Caching, Grouped Query Attention
644- # etc.
645-
646641
647642################################################################################
648643# Extended examples
649644# -----------------
650645#
651- # There are several good examples of using various performant building blocks to
646+ # We intend to update this tutorial to demonstrate more examples of how to use
647+ # the various performant building blocks such as KV-Caching, Grouped Query Attention
648+ # etc. Further, there are several good examples of using various performant building blocks to
652649# implement various transformer architectures. Some examples include
653650#
654651# * `gpt-fast <https://github.com/pytorch-labs/gpt-fast>`_
0 commit comments