@@ -255,69 +255,70 @@ def forward(self,
255255
256256
257257###############################################################################
258- # .. dropdown:: Utilities
259- # ========================
258+ # Utilities
259+ # =========
260260# In this section, we include a utility to generate semi-realistic data using
261261# Zipf distribution for sentence lengths. This is used to generate the nested
262262# query, key and value tensors. We also include a benchmark utility.
263263
264- import numpy as np
265-
266- def zipf_sentence_lengths (alpha : float , batch_size : int ) -> torch .Tensor :
267- # generate fake corpus by unigram Zipf distribution
268- # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
269- sentence_lengths = np .empty (batch_size , dtype = int )
270- for ibatch in range (batch_size ):
271- sentence_lengths [ibatch ] = 1
272- word = np .random .zipf (alpha )
273- while word != 3 and word != 386 and word != 858 :
274- sentence_lengths [ibatch ] += 1
264+ # .. dropdown::
265+ import numpy as np
266+
267+ def zipf_sentence_lengths (alpha : float , batch_size : int ) -> torch .Tensor :
268+ # generate fake corpus by unigram Zipf distribution
269+ # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
270+ sentence_lengths = np .empty (batch_size , dtype = int )
271+ for ibatch in range (batch_size ):
272+ sentence_lengths [ibatch ] = 1
275273 word = np .random .zipf (alpha )
276- return torch .tensor (sentence_lengths )
277-
278- # Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
279- # in the form of nested tensors with the jagged layout.
280- def gen_batch (N , E_q , E_k , E_v , device , dtype = torch .float32 , query_seq_len_1 = False ):
281- # generate semi-realistic data using Zipf distribution for sentence lengths
282- sentence_lengths = zipf_sentence_lengths (alpha = 1.2 , batch_size = N )
283-
284- # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
285- # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
286- # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
287- if query_seq_len_1 :
288- query = torch .nested .nested_tensor ([
289- torch .randn (1 , E_q , dtype = dtype , device = device )
290- for l in sentence_lengths
291- ], layout = torch .jagged )
292- else :
293- query = torch .nested .nested_tensor ([
294- torch .randn (l .item (), E_q , dtype = dtype , device = device )
295- for l in sentence_lengths
296- ], layout = torch .jagged )
297-
298- key = torch .nested .nested_tensor ([
299- torch .randn (s .item (), E_k , dtype = dtype , device = device )
300- for s in sentence_lengths
301- ], layout = torch .jagged )
302-
303- value = torch .nested .nested_tensor ([
304- torch .randn (s .item (), E_v , dtype = dtype , device = device )
305- for s in sentence_lengths
306- ], layout = torch .jagged )
307-
308- return query , key , value , sentence_lengths
309-
310- import timeit
311- import math
312-
313- def benchmark (func , * args , ** kwargs ):
314- torch .cuda .synchronize ()
315- torch .cuda .reset_peak_memory_stats ()
316- begin = timeit .default_timer ()
317- output = func (* args , ** kwargs )
318- torch .cuda .synchronize ()
319- end = timeit .default_timer ()
320- return output , (end - begin ), torch .cuda .max_memory_allocated ()
274+ while word != 3 and word != 386 and word != 858 :
275+ sentence_lengths [ibatch ] += 1
276+ word = np .random .zipf (alpha )
277+ return torch .tensor (sentence_lengths )
278+
279+ # Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
280+ # in the form of nested tensors with the jagged layout.
281+ def gen_batch (N , E_q , E_k , E_v , device , dtype = torch .float32 , query_seq_len_1 = False ):
282+ # generate semi-realistic data using Zipf distribution for sentence lengths
283+ sentence_lengths = zipf_sentence_lengths (alpha = 1.2 , batch_size = N )
284+
285+ # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
286+ # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
287+ # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
288+ if query_seq_len_1 :
289+ query = torch .nested .nested_tensor ([
290+ torch .randn (1 , E_q , dtype = dtype , device = device )
291+ for l in sentence_lengths
292+ ], layout = torch .jagged )
293+ else :
294+ query = torch .nested .nested_tensor ([
295+ torch .randn (l .item (), E_q , dtype = dtype , device = device )
296+ for l in sentence_lengths
297+ ], layout = torch .jagged )
298+
299+ key = torch .nested .nested_tensor ([
300+ torch .randn (s .item (), E_k , dtype = dtype , device = device )
301+ for s in sentence_lengths
302+ ], layout = torch .jagged )
303+
304+ value = torch .nested .nested_tensor ([
305+ torch .randn (s .item (), E_v , dtype = dtype , device = device )
306+ for s in sentence_lengths
307+ ], layout = torch .jagged )
308+
309+ return query , key , value , sentence_lengths
310+
311+ import timeit
312+ import math
313+
314+ def benchmark (func , * args , ** kwargs ):
315+ torch .cuda .synchronize ()
316+ torch .cuda .reset_peak_memory_stats ()
317+ begin = timeit .default_timer ()
318+ output = func (* args , ** kwargs )
319+ torch .cuda .synchronize ()
320+ end = timeit .default_timer ()
321+ return output , (end - begin ), torch .cuda .max_memory_allocated ()
321322
322323##############################################################################
323324# We will now demonstrate the performance improvements of using nested tensors
0 commit comments