@@ -269,250 +269,3 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke
269269 if attn is None :
270270 return query
271271 return attn
272-
273-
274- def to_hpu (data : Optional [Union [torch .tensor , list ]], dtype : Optional [torch .dtype ] = None ) -> torch .tensor :
275- """Copy either data or a cpu tensor to hpu"""
276- if data is None :
277- return None
278- if torch .is_tensor (data ):
279- return data .to ('hpu' , non_blocking = True )
280- else :
281- return to_hpu (torch .tensor (data , dtype = dtype , device = 'cpu' ))
282-
283-
284- def mask_to_bias (mask : torch .tensor , dtype : torch .dtype ) -> torch .tensor :
285- """Convert attn mask to attn bias"""
286- return torch .zeros_like (mask , dtype = dtype ).masked_fill_ (mask , - math .inf )
287-
288-
289- def create_causal_bias (groups : torch .tensor , positions : torch .tensor , dtype : torch .dtype ) -> torch .tensor :
290- """Create causal bias from groups and positions"""
291- group_mask = groups .unsqueeze (- 1 ) != groups .unsqueeze (0 )
292- position_mask = positions .unsqueeze (- 1 ) < positions .unsqueeze (0 )
293- causal_mask = (group_mask | position_mask )
294- return mask_to_bias (causal_mask , dtype )
295-
296-
297- def indices_and_offsets (counts : torch .tensor ) -> tuple [torch .tensor , torch .tensor ]:
298- """Split groups of sizes 'counts' into individual indices and offsets. Example:
299- counts([1, 2, 3]) -> group_indices=[0, 1, 1, 2, 2, 2] group_offsets=[0, 0, 1, 0, 1, 2]"""
300- cum_end = torch .cumsum (counts , dim = 0 , dtype = counts .dtype )
301- cum_start = cum_end - counts
302- total = cum_end [- 1 ] + 1
303- indices = torch .zeros (total , dtype = counts .dtype , device = counts .device )
304- indices .scatter_add_ (0 , cum_end [:- 1 ].to (torch .int64 ), torch .ones_like (cum_end [:- 1 ]))
305- indices = torch .cumsum (indices , dim = 0 )
306- offsets = torch .arange (total , dtype = counts .dtype , device = counts .device ) - cum_start .index_select (0 , indices )
307- return indices [:- 1 ], offsets [:- 1 ]
308-
309-
310- def fetch_2d (table : torch .tensor , indices : torch .tensor , offsets : torch .tensor ) -> torch .tensor :
311- """Fetch data from a 2d table using indices and offsets"""
312- assert table .dim () == 2 , 'Only 2D tables are supported!'
313- flat_indices = indices * table .size (- 1 ) + offsets
314- return table .flatten ().index_select (0 , flat_indices )
315-
316-
317- def group_sum (groups : torch .tensor , values : torch .tensor ):
318- """ Sum values coresponding to the same groups """
319- max_value = groups .amax ().item ()
320- tmp = torch .zeros ((max_value + 1 , ), dtype = values .dtype , device = values .device )
321- tmp .scatter_add_ (0 , groups .to (torch .int64 ), values )
322- return tmp .index_select (0 , groups )
323-
324-
325- def generate_bias (block_usages : torch .tensor , block_size : torch .tensor , dtype : torch .dtype ) -> torch .tensor :
326- """ Generate block bias based on block_usage """
327- block_len_range = torch .arange (1 , block_size + 1 , dtype = block_usages .dtype , device = block_usages .device )
328- block_mask = block_len_range .unsqueeze (0 ) > block_usages .unsqueeze (- 1 )
329- return mask_to_bias (block_mask , dtype = dtype )
330-
331-
332- @dataclass
333- class UnifiedBatch :
334- req_ids_cpu : list [str ]
335- token_ids : torch .tensor
336- token_positions : torch .tensor
337- new_token_positions_cpu : torch .tensor
338- logits_indices : torch .tensor
339- logits_groups_cpu : torch .tensor
340- attn_metadata : HPUUnifiedAttentionMetadata
341-
342-
343- @dataclass
344- class Context :
345- """ Contains relevant information for computing past context either from shared or unique blocks"""
346- group_ids : torch .tensor
347- group_offsets : torch .tensor
348- block_ids : torch .tensor
349- block_usages : torch .tensor
350-
351- @staticmethod
352- def create (total_tokens : torch .tensor , block_table : torch .tensor , block_size : int ) -> 'Context' :
353- """ Create a new Context obj """
354- num_ctx_blocks = (total_tokens + block_size - 1 ) // block_size
355- if num_ctx_blocks .sum () <= 0 :
356- return None
357-
358- group_ids , group_offsets = indices_and_offsets (num_ctx_blocks )
359- block_ids = fetch_2d (block_table , group_ids , group_offsets )
360- #NOTE(kzawora): Originally, we were clamping
361- # total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1
362- # I'm not sure why +1 was there originally, but in non-block-aligned prefix-prefill scenarios
363- # it made causal mask not cover the first unused token.
364- # (e.g. with context 28, the 28th slot was unmasked, causing the effective context length to be 29)
365- block_usages = torch .clamp (total_tokens .index_select (0 , group_ids ) - group_offsets * block_size , 1 , block_size )
366-
367- ctx = Context (group_ids , group_offsets , block_ids , block_usages )
368- all_shapes = [v .shape for v in ctx ._values () if torch .is_tensor (v )]
369- for t in all_shapes [1 :]:
370- assert all_shapes [0 ] == t
371- return ctx
372-
373- def _values (self ) -> tuple [torch .tensor , torch .tensor , torch .tensor , torch .tensor ]:
374- """ Split Context into individual values """
375- return (self .group_ids , self .group_offsets , self .block_ids , self .block_usages )
376-
377- def index_select (self , indices : torch .tensor ) -> 'Context' :
378- """ Create a new Context from only specified indices """
379- if indices .size (0 ) <= 0 :
380- return None
381- values = [v .index_select (0 , indices ) for v in self ._values ()]
382- return Context (* values )
383-
384- def split (self , num_scheduled_tokens : torch .tensor ) -> tuple ['Context' , 'Context' ]:
385- """ Split a Context into a shared block Context and unique block Context"""
386- num_tokens = num_scheduled_tokens .index_select (0 , self .group_ids )
387- block_tokens = group_sum (self .block_ids , num_tokens )
388- shared_idx = torch .argwhere (block_tokens > 1 ).flatten ()
389- unique_idx = torch .argwhere (block_tokens == 1 ).flatten ()
390- assert shared_idx .size (0 ) + unique_idx .size (0 ) == self .group_ids .size (0 )
391- return self .index_select (shared_idx ), self .index_select (unique_idx )
392-
393-
394- def hpu_tensor (tensor : torch .tensor , shape : tuple , pad_value : Union [int , float ]) -> torch .tensor :
395- """ Pad if necessary and move tensor to HPU"""
396- if tensor is None :
397- return None
398- assert len (tensor .shape ) == len (shape )
399- orig_shape = tensor .shape
400- padding = tuple (itertools .chain (* [(0 , target - cur ) for cur , target in reversed (list (zip (tensor .shape , shape )))]))
401- assert all (p >= 0 for p in padding )
402- if sum (padding ) > 0 :
403- tensor = torch .nn .functional .pad (tensor , padding , value = pad_value )
404- return to_hpu (tensor )
405-
406-
407- def create_unified_batch (req_ids : list [str ], all_token_ids : torch .tensor , num_computed_tokens : torch .tensor ,
408- num_scheduled_tokens : torch .tensor , num_prompt_tokens : torch .tensor , block_table : torch .tensor ,
409- block_size : int , dtype : torch .dtype , bucketing_fn : Callable [[bool , int , int , int , int ],
410- tuple [int , int , int , int ]],
411- get_dp_padding_fn : Callable [[int ], int ]) -> UnifiedBatch :
412- """ Calculate all necessary tensors needed for batch scheduling """
413- total_tokens = num_computed_tokens + num_scheduled_tokens
414- query_len = num_scheduled_tokens .sum ().item ()
415- is_prompt = total_tokens <= num_prompt_tokens
416- cached_tokens = num_computed_tokens + torch .where (is_prompt , 0 , num_scheduled_tokens )
417- contains_prompts = torch .any (is_prompt ).item ()
418- num_output_tokens = total_tokens - num_prompt_tokens + 1
419- num_output_tokens = torch .clamp (num_output_tokens , torch .zeros_like (num_scheduled_tokens ), num_scheduled_tokens )
420- group_starts = torch .cumsum (num_scheduled_tokens , dim = 0 ) - num_scheduled_tokens
421-
422- token_groups , token_offsets = indices_and_offsets (num_scheduled_tokens )
423- token_positions = token_offsets + num_computed_tokens .index_select (0 , token_groups )
424- token_ids = fetch_2d (all_token_ids , token_groups , token_positions )
425-
426- token_blocks = fetch_2d (block_table , token_groups , token_positions .floor_divide (block_size ))
427- token_slots = token_blocks * block_size + token_positions .fmod (block_size )
428-
429- logits_groups , logits_offsets = indices_and_offsets (num_output_tokens )
430- start_logits_indices = torch .cumsum (num_scheduled_tokens , dim = 0 ,
431- dtype = num_scheduled_tokens .dtype ) - num_output_tokens
432- logits_indices = logits_offsets + start_logits_indices .index_select (0 , logits_groups )
433- new_token_positions = total_tokens .index_select (0 , logits_groups )
434-
435- def first_dim (t : Optional [torch .tensor ]) -> int :
436- """ Takes first dim size or 0 if tensor is None"""
437- return t .size (0 ) if t is not None else 0
438-
439- causal_bias = None
440- shared_blocks = None
441- shared_bias = None
442- unique_blocks = 0
443- unique_block_mapping = None
444- unique_bias = None
445-
446- if contains_prompts :
447- causal_bias = create_causal_bias (token_groups , token_positions , dtype )
448-
449- ctx = Context .create (cached_tokens , block_table , block_size )
450- if ctx :
451- shared_ctx , unique_ctx = ctx .split (num_scheduled_tokens )
452- if shared_ctx :
453- shared_blocks , orig_shared_blocks = torch .unique (shared_ctx .block_ids , return_inverse = True )
454-
455- shared_group_starts = group_starts .index_select (0 , shared_ctx .group_ids )
456-
457- shared_tokens = num_scheduled_tokens .index_select (0 , shared_ctx .group_ids )
458- shared_token_indices , shared_token_offsets = indices_and_offsets (shared_tokens )
459-
460- shared_token_idx = shared_group_starts .index_select (0 , shared_token_indices ) + shared_token_offsets
461- shared_block_idx = orig_shared_blocks .index_select (0 , shared_token_indices )
462- shared_block_usage = shared_ctx .block_usages .index_select (0 , shared_token_indices )
463- shared_block_bias = generate_bias (shared_block_usage , block_size , dtype )
464-
465- shared_bias = torch .full ((query_len , shared_blocks .size (0 ), block_size ),
466- - math .inf ,
467- dtype = dtype ,
468- device = shared_blocks .device )
469- shared_bias .index_put_ ((shared_token_idx , shared_block_idx ), shared_block_bias )
470-
471- if unique_ctx :
472- unique_blocks = torch .amax (unique_ctx .block_ids ).item () + 1
473- unique_bias = torch .full ((unique_blocks , block_size ),
474- - math .inf ,
475- dtype = dtype ,
476- device = unique_ctx .block_ids .device )
477- unique_block_bias = generate_bias (unique_ctx .block_usages , block_size , dtype )
478- unique_bias .index_copy_ (0 , unique_ctx .block_ids .to (torch .int64 ), unique_block_bias )
479- unique_group_starts = group_starts .index_select (0 , unique_ctx .group_ids )
480- unique_block_mapping = torch .full ((unique_blocks , ),
481- - 1 ,
482- dtype = torch .int64 ,
483- device = unique_ctx .block_ids .device )
484- unique_block_mapping .index_copy_ (0 , unique_ctx .block_ids .to (torch .int64 ), unique_group_starts )
485-
486- bucket = bucketing_fn (contains_prompts , first_dim (token_ids ), first_dim (shared_blocks ), unique_blocks ,
487- first_dim (logits_indices ))
488- target_qlen , target_shared_blocks , target_unique_blocks , target_logits = bucket
489-
490- target_qlen += get_dp_padding_fn (target_qlen )
491- target_shared_blocks += get_dp_padding_fn (target_shared_blocks )
492- target_unique_blocks += get_dp_padding_fn (target_unique_blocks )
493- target_logits += get_dp_padding_fn (target_logits )
494-
495- default_causal_width = 512
496- fmin = torch .finfo (dtype ).min
497- feps = torch .finfo (dtype ).tiny
498-
499- return UnifiedBatch (req_ids_cpu = req_ids ,
500- token_ids = hpu_tensor (token_ids , (target_qlen , ), - 1 ),
501- token_positions = hpu_tensor (token_positions , (target_qlen , ), - 1 ),
502- new_token_positions_cpu = new_token_positions ,
503- logits_indices = hpu_tensor (logits_indices , (target_logits , ), - 1 ),
504- logits_groups_cpu = logits_groups ,
505- attn_metadata = HPUUnifiedAttentionMetadata (
506- block_size = block_size ,
507- slot_mapping = hpu_tensor (token_slots , (target_qlen , ), - 1 ),
508- causal_bias = hpu_tensor (causal_bias , (target_qlen , target_qlen ), - math .inf ),
509- causal_width = default_causal_width ,
510- shared_blocks = hpu_tensor (shared_blocks , (target_shared_blocks , ), - 1 ),
511- shared_bias = hpu_tensor (shared_bias , (target_qlen , target_shared_blocks , block_size ),
512- - math .inf ),
513- unique_blocks = target_unique_blocks ,
514- unique_block_mapping = hpu_tensor (unique_block_mapping , (target_unique_blocks , ), - 1 ),
515- unique_bias = hpu_tensor (unique_bias , (target_unique_blocks , block_size ), - math .inf ),
516- fmin = to_hpu (fmin ),
517- feps = to_hpu (feps ),
518- ))
0 commit comments