Skip to content

Commit be9503d

Browse files
Unified Attention - batch preparation rewrite (#400)
This PR introduces a persistent batch optimization for unified attention, as well as full rewrite from PyTorch to NumPy. Effectively, it reduces the batch preparation time by 6-15x across the board. --------- Signed-off-by: Konrad Zawora <kzawora@habana.ai> Co-authored-by: Artur Fierka <artur.fierka@intel.com>
1 parent 724f8c1 commit be9503d

File tree

4 files changed

+353
-250
lines changed

4 files changed

+353
-250
lines changed

vllm_gaudi/extension/features.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,9 @@ def get_features():
9393
Value('unified_attn', False),
9494
Value('scale_adjustment', True, env_var='VLLM_SCALE_ADJUSTMENT', env_var_type=boolean),
9595
Value('flatten_input', Any(ModelType('qwen3_moe'), ModelType('granitemoe'), ModelType('glm4_moe'))),
96+
Value('unified_attn_shared_cache_ratio',
97+
1.,
98+
env_var='VLLM_UNIFIED_ATTENTION_SHARED_CACHE_RATIO',
99+
env_var_type=float),
96100
]
97101
return split_values_and_flags(features)

vllm_gaudi/extension/unified.py

Lines changed: 0 additions & 247 deletions
Original file line numberDiff line numberDiff line change
@@ -269,250 +269,3 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke
269269
if attn is None:
270270
return query
271271
return attn
272-
273-
274-
def to_hpu(data: Optional[Union[torch.tensor, list]], dtype: Optional[torch.dtype] = None) -> torch.tensor:
275-
"""Copy either data or a cpu tensor to hpu"""
276-
if data is None:
277-
return None
278-
if torch.is_tensor(data):
279-
return data.to('hpu', non_blocking=True)
280-
else:
281-
return to_hpu(torch.tensor(data, dtype=dtype, device='cpu'))
282-
283-
284-
def mask_to_bias(mask: torch.tensor, dtype: torch.dtype) -> torch.tensor:
285-
"""Convert attn mask to attn bias"""
286-
return torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
287-
288-
289-
def create_causal_bias(groups: torch.tensor, positions: torch.tensor, dtype: torch.dtype) -> torch.tensor:
290-
"""Create causal bias from groups and positions"""
291-
group_mask = groups.unsqueeze(-1) != groups.unsqueeze(0)
292-
position_mask = positions.unsqueeze(-1) < positions.unsqueeze(0)
293-
causal_mask = (group_mask | position_mask)
294-
return mask_to_bias(causal_mask, dtype)
295-
296-
297-
def indices_and_offsets(counts: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
298-
"""Split groups of sizes 'counts' into individual indices and offsets. Example:
299-
counts([1, 2, 3]) -> group_indices=[0, 1, 1, 2, 2, 2] group_offsets=[0, 0, 1, 0, 1, 2]"""
300-
cum_end = torch.cumsum(counts, dim=0, dtype=counts.dtype)
301-
cum_start = cum_end - counts
302-
total = cum_end[-1] + 1
303-
indices = torch.zeros(total, dtype=counts.dtype, device=counts.device)
304-
indices.scatter_add_(0, cum_end[:-1].to(torch.int64), torch.ones_like(cum_end[:-1]))
305-
indices = torch.cumsum(indices, dim=0)
306-
offsets = torch.arange(total, dtype=counts.dtype, device=counts.device) - cum_start.index_select(0, indices)
307-
return indices[:-1], offsets[:-1]
308-
309-
310-
def fetch_2d(table: torch.tensor, indices: torch.tensor, offsets: torch.tensor) -> torch.tensor:
311-
"""Fetch data from a 2d table using indices and offsets"""
312-
assert table.dim() == 2, 'Only 2D tables are supported!'
313-
flat_indices = indices * table.size(-1) + offsets
314-
return table.flatten().index_select(0, flat_indices)
315-
316-
317-
def group_sum(groups: torch.tensor, values: torch.tensor):
318-
""" Sum values coresponding to the same groups """
319-
max_value = groups.amax().item()
320-
tmp = torch.zeros((max_value + 1, ), dtype=values.dtype, device=values.device)
321-
tmp.scatter_add_(0, groups.to(torch.int64), values)
322-
return tmp.index_select(0, groups)
323-
324-
325-
def generate_bias(block_usages: torch.tensor, block_size: torch.tensor, dtype: torch.dtype) -> torch.tensor:
326-
""" Generate block bias based on block_usage """
327-
block_len_range = torch.arange(1, block_size + 1, dtype=block_usages.dtype, device=block_usages.device)
328-
block_mask = block_len_range.unsqueeze(0) > block_usages.unsqueeze(-1)
329-
return mask_to_bias(block_mask, dtype=dtype)
330-
331-
332-
@dataclass
333-
class UnifiedBatch:
334-
req_ids_cpu: list[str]
335-
token_ids: torch.tensor
336-
token_positions: torch.tensor
337-
new_token_positions_cpu: torch.tensor
338-
logits_indices: torch.tensor
339-
logits_groups_cpu: torch.tensor
340-
attn_metadata: HPUUnifiedAttentionMetadata
341-
342-
343-
@dataclass
344-
class Context:
345-
""" Contains relevant information for computing past context either from shared or unique blocks"""
346-
group_ids: torch.tensor
347-
group_offsets: torch.tensor
348-
block_ids: torch.tensor
349-
block_usages: torch.tensor
350-
351-
@staticmethod
352-
def create(total_tokens: torch.tensor, block_table: torch.tensor, block_size: int) -> 'Context':
353-
""" Create a new Context obj """
354-
num_ctx_blocks = (total_tokens + block_size - 1) // block_size
355-
if num_ctx_blocks.sum() <= 0:
356-
return None
357-
358-
group_ids, group_offsets = indices_and_offsets(num_ctx_blocks)
359-
block_ids = fetch_2d(block_table, group_ids, group_offsets)
360-
#NOTE(kzawora): Originally, we were clamping
361-
# total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1
362-
# I'm not sure why +1 was there originally, but in non-block-aligned prefix-prefill scenarios
363-
# it made causal mask not cover the first unused token.
364-
# (e.g. with context 28, the 28th slot was unmasked, causing the effective context length to be 29)
365-
block_usages = torch.clamp(total_tokens.index_select(0, group_ids) - group_offsets * block_size, 1, block_size)
366-
367-
ctx = Context(group_ids, group_offsets, block_ids, block_usages)
368-
all_shapes = [v.shape for v in ctx._values() if torch.is_tensor(v)]
369-
for t in all_shapes[1:]:
370-
assert all_shapes[0] == t
371-
return ctx
372-
373-
def _values(self) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
374-
""" Split Context into individual values """
375-
return (self.group_ids, self.group_offsets, self.block_ids, self.block_usages)
376-
377-
def index_select(self, indices: torch.tensor) -> 'Context':
378-
""" Create a new Context from only specified indices """
379-
if indices.size(0) <= 0:
380-
return None
381-
values = [v.index_select(0, indices) for v in self._values()]
382-
return Context(*values)
383-
384-
def split(self, num_scheduled_tokens: torch.tensor) -> tuple['Context', 'Context']:
385-
""" Split a Context into a shared block Context and unique block Context"""
386-
num_tokens = num_scheduled_tokens.index_select(0, self.group_ids)
387-
block_tokens = group_sum(self.block_ids, num_tokens)
388-
shared_idx = torch.argwhere(block_tokens > 1).flatten()
389-
unique_idx = torch.argwhere(block_tokens == 1).flatten()
390-
assert shared_idx.size(0) + unique_idx.size(0) == self.group_ids.size(0)
391-
return self.index_select(shared_idx), self.index_select(unique_idx)
392-
393-
394-
def hpu_tensor(tensor: torch.tensor, shape: tuple, pad_value: Union[int, float]) -> torch.tensor:
395-
""" Pad if necessary and move tensor to HPU"""
396-
if tensor is None:
397-
return None
398-
assert len(tensor.shape) == len(shape)
399-
orig_shape = tensor.shape
400-
padding = tuple(itertools.chain(*[(0, target - cur) for cur, target in reversed(list(zip(tensor.shape, shape)))]))
401-
assert all(p >= 0 for p in padding)
402-
if sum(padding) > 0:
403-
tensor = torch.nn.functional.pad(tensor, padding, value=pad_value)
404-
return to_hpu(tensor)
405-
406-
407-
def create_unified_batch(req_ids: list[str], all_token_ids: torch.tensor, num_computed_tokens: torch.tensor,
408-
num_scheduled_tokens: torch.tensor, num_prompt_tokens: torch.tensor, block_table: torch.tensor,
409-
block_size: int, dtype: torch.dtype, bucketing_fn: Callable[[bool, int, int, int, int],
410-
tuple[int, int, int, int]],
411-
get_dp_padding_fn: Callable[[int], int]) -> UnifiedBatch:
412-
""" Calculate all necessary tensors needed for batch scheduling """
413-
total_tokens = num_computed_tokens + num_scheduled_tokens
414-
query_len = num_scheduled_tokens.sum().item()
415-
is_prompt = total_tokens <= num_prompt_tokens
416-
cached_tokens = num_computed_tokens + torch.where(is_prompt, 0, num_scheduled_tokens)
417-
contains_prompts = torch.any(is_prompt).item()
418-
num_output_tokens = total_tokens - num_prompt_tokens + 1
419-
num_output_tokens = torch.clamp(num_output_tokens, torch.zeros_like(num_scheduled_tokens), num_scheduled_tokens)
420-
group_starts = torch.cumsum(num_scheduled_tokens, dim=0) - num_scheduled_tokens
421-
422-
token_groups, token_offsets = indices_and_offsets(num_scheduled_tokens)
423-
token_positions = token_offsets + num_computed_tokens.index_select(0, token_groups)
424-
token_ids = fetch_2d(all_token_ids, token_groups, token_positions)
425-
426-
token_blocks = fetch_2d(block_table, token_groups, token_positions.floor_divide(block_size))
427-
token_slots = token_blocks * block_size + token_positions.fmod(block_size)
428-
429-
logits_groups, logits_offsets = indices_and_offsets(num_output_tokens)
430-
start_logits_indices = torch.cumsum(num_scheduled_tokens, dim=0,
431-
dtype=num_scheduled_tokens.dtype) - num_output_tokens
432-
logits_indices = logits_offsets + start_logits_indices.index_select(0, logits_groups)
433-
new_token_positions = total_tokens.index_select(0, logits_groups)
434-
435-
def first_dim(t: Optional[torch.tensor]) -> int:
436-
""" Takes first dim size or 0 if tensor is None"""
437-
return t.size(0) if t is not None else 0
438-
439-
causal_bias = None
440-
shared_blocks = None
441-
shared_bias = None
442-
unique_blocks = 0
443-
unique_block_mapping = None
444-
unique_bias = None
445-
446-
if contains_prompts:
447-
causal_bias = create_causal_bias(token_groups, token_positions, dtype)
448-
449-
ctx = Context.create(cached_tokens, block_table, block_size)
450-
if ctx:
451-
shared_ctx, unique_ctx = ctx.split(num_scheduled_tokens)
452-
if shared_ctx:
453-
shared_blocks, orig_shared_blocks = torch.unique(shared_ctx.block_ids, return_inverse=True)
454-
455-
shared_group_starts = group_starts.index_select(0, shared_ctx.group_ids)
456-
457-
shared_tokens = num_scheduled_tokens.index_select(0, shared_ctx.group_ids)
458-
shared_token_indices, shared_token_offsets = indices_and_offsets(shared_tokens)
459-
460-
shared_token_idx = shared_group_starts.index_select(0, shared_token_indices) + shared_token_offsets
461-
shared_block_idx = orig_shared_blocks.index_select(0, shared_token_indices)
462-
shared_block_usage = shared_ctx.block_usages.index_select(0, shared_token_indices)
463-
shared_block_bias = generate_bias(shared_block_usage, block_size, dtype)
464-
465-
shared_bias = torch.full((query_len, shared_blocks.size(0), block_size),
466-
-math.inf,
467-
dtype=dtype,
468-
device=shared_blocks.device)
469-
shared_bias.index_put_((shared_token_idx, shared_block_idx), shared_block_bias)
470-
471-
if unique_ctx:
472-
unique_blocks = torch.amax(unique_ctx.block_ids).item() + 1
473-
unique_bias = torch.full((unique_blocks, block_size),
474-
-math.inf,
475-
dtype=dtype,
476-
device=unique_ctx.block_ids.device)
477-
unique_block_bias = generate_bias(unique_ctx.block_usages, block_size, dtype)
478-
unique_bias.index_copy_(0, unique_ctx.block_ids.to(torch.int64), unique_block_bias)
479-
unique_group_starts = group_starts.index_select(0, unique_ctx.group_ids)
480-
unique_block_mapping = torch.full((unique_blocks, ),
481-
-1,
482-
dtype=torch.int64,
483-
device=unique_ctx.block_ids.device)
484-
unique_block_mapping.index_copy_(0, unique_ctx.block_ids.to(torch.int64), unique_group_starts)
485-
486-
bucket = bucketing_fn(contains_prompts, first_dim(token_ids), first_dim(shared_blocks), unique_blocks,
487-
first_dim(logits_indices))
488-
target_qlen, target_shared_blocks, target_unique_blocks, target_logits = bucket
489-
490-
target_qlen += get_dp_padding_fn(target_qlen)
491-
target_shared_blocks += get_dp_padding_fn(target_shared_blocks)
492-
target_unique_blocks += get_dp_padding_fn(target_unique_blocks)
493-
target_logits += get_dp_padding_fn(target_logits)
494-
495-
default_causal_width = 512
496-
fmin = torch.finfo(dtype).min
497-
feps = torch.finfo(dtype).tiny
498-
499-
return UnifiedBatch(req_ids_cpu=req_ids,
500-
token_ids=hpu_tensor(token_ids, (target_qlen, ), -1),
501-
token_positions=hpu_tensor(token_positions, (target_qlen, ), -1),
502-
new_token_positions_cpu=new_token_positions,
503-
logits_indices=hpu_tensor(logits_indices, (target_logits, ), -1),
504-
logits_groups_cpu=logits_groups,
505-
attn_metadata=HPUUnifiedAttentionMetadata(
506-
block_size=block_size,
507-
slot_mapping=hpu_tensor(token_slots, (target_qlen, ), -1),
508-
causal_bias=hpu_tensor(causal_bias, (target_qlen, target_qlen), -math.inf),
509-
causal_width=default_causal_width,
510-
shared_blocks=hpu_tensor(shared_blocks, (target_shared_blocks, ), -1),
511-
shared_bias=hpu_tensor(shared_bias, (target_qlen, target_shared_blocks, block_size),
512-
-math.inf),
513-
unique_blocks=target_unique_blocks,
514-
unique_block_mapping=hpu_tensor(unique_block_mapping, (target_unique_blocks, ), -1),
515-
unique_bias=hpu_tensor(unique_bias, (target_unique_blocks, block_size), -math.inf),
516-
fmin=to_hpu(fmin),
517-
feps=to_hpu(feps),
518-
))

0 commit comments

Comments
 (0)