From 8ab19f22157644834e765e73059b8d27f3c8a7f6 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 13:44:16 +0200 Subject: [PATCH 01/12] WA for preemptions Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_input_batch.py | 9 +++++ vllm_gaudi/v1/worker/hpu_model_runner.py | 46 ++++++++++++++---------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_input_batch.py b/vllm_gaudi/v1/worker/hpu_input_batch.py index 40845bc1e..57f89d0f6 100644 --- a/vllm_gaudi/v1/worker/hpu_input_batch.py +++ b/vllm_gaudi/v1/worker/hpu_input_batch.py @@ -256,6 +256,15 @@ def add_request( start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids + #NOTE(kzawora): In non-preemption scenario, + # self.input_batch.num_prompt_tokens[batch_idx] == self.input_batch.num_tokens[batch_idx]. + # In preemption scenario, we want num_prompt_tokens to also include the tokens emitted before preemption, + # as that is used as basis for recomputing prefill. + # This also assumes that preemption is complete and reduces num_computed_tokens to 0 and preempted sequences + # don't retain any originally used cache blocks. + #if request.num_computed_tokens == 0: + # self.num_prompt_tokens[req_index] = num_prompt_tokens + len(request.output_token_ids) + # Number of token ids in token_ids_cpu. # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 4ea0fdd4e..31f0eee59 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1488,13 +1488,18 @@ def _get_prompts_and_decodes( num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] num_prompt_tokens = self.input_batch.num_prompt_tokens[i] + num_all_tokens = self.input_batch.num_tokens[i] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] - if num_computed_tokens < num_prompt_tokens and \ not self.is_decoder_only(req_id): # This is prompt break + if num_computed_tokens < num_all_tokens and num_scheduled_tokens != 1 and \ + not self.is_decoder_only(req_id): + break + #from fpdb import ForkedPdb; ForkedPdb().set_trace() + # This is decode # NOTE(chendi): To support spec decode, # we don't assume num_scheduled_tokens == 1. @@ -1517,12 +1522,12 @@ def _get_prompts_and_decodes( num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] # Must be prompt - assert num_computed_tokens < num_prompt_tokens - num_output_tokens = len(self.requests[req_id].output_token_ids) - if not has_kv_transfer_group(): - #P case num_output_tokens has non 0 - assert num_output_tokens == 0, \ - f'req_id: {req_id}, {num_output_tokens}' + # assert num_computed_tokens < num_prompt_tokens + # num_output_tokens = len(self.requests[req_id].output_token_ids) + #if not has_kv_transfer_group(): + # #P case num_output_tokens has non 0 + # assert num_output_tokens == 0, \ + # f'req_id: {req_id}, {num_output_tokens}' prompt_req_ids.append(req_id) prompt_scheduled_tokens.append(num_scheduled_tokens) @@ -1678,26 +1683,29 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul for batch_idx in range(num_decodes, num_reqs): req_id = self.input_batch.req_ids[batch_idx] - context_len = self.input_batch.num_computed_tokens_cpu[batch_idx] - query_len = num_scheduled_tokens[batch_idx] + seq_num_computed_tokens = self.input_batch.num_computed_tokens_cpu[batch_idx] + seq_num_scheduled_tokens = num_scheduled_tokens[batch_idx] - token_ids = self.input_batch.token_ids_cpu[batch_idx, context_len:context_len + query_len].tolist() + token_ids = self.input_batch.token_ids_cpu[batch_idx, seq_num_computed_tokens:seq_num_computed_tokens + + seq_num_scheduled_tokens].tolist() - num_blocks = round_up(context_len + query_len, self.block_size) // self.block_size + num_blocks = round_up(seq_num_computed_tokens + seq_num_scheduled_tokens, + self.block_size) // self.block_size blocks = block_table_cpu_tensor[batch_idx, :num_blocks].tolist() if not warmup: blocks = [self.defragmenter.resolve(b) for b in blocks] - - prompt_tokens = self.input_batch.num_prompt_tokens[batch_idx] - # TODO: Fix non-prompt case - num_output_logits = max(0, context_len + query_len - prompt_tokens + 1) - logits_positions = list(range(query_len - num_output_logits, query_len)) + #NOTE(kzawora): In non-preemption scenario, + # self.input_batch.num_prompt_tokens[batch_idx] == self.input_batch.num_tokens[batch_idx]. + # In preemption scenario num_tokens will also include the tokens emitted before preemption + num_all_tokens = self.input_batch.num_tokens[batch_idx] + num_output_logits = max(0, seq_num_computed_tokens + seq_num_scheduled_tokens - num_all_tokens + 1) + logits_positions = list(range(seq_num_scheduled_tokens - num_output_logits, seq_num_scheduled_tokens)) new_batch_contents = BatchContents( req_ids=[req_id], token_ids=[token_ids], - context_lens=[context_len], - prompt_lens=[prompt_tokens], + context_lens=[seq_num_computed_tokens], + prompt_lens=[num_all_tokens], blocks=[blocks], logits_positions=[logits_positions], ) @@ -3331,7 +3339,7 @@ def execute_model( num_tokens = len(token_ids) self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids self.input_batch.num_tokens[i] += len(token_ids) - req_state.output_token_ids.extend(token_ids) + #req_state.output_token_ids.extend(token_ids) # NOTE(chendi): enable cache based on PR(#20291) # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. From 602ac39c82bde21352be6021bbb27ac7d57a0c0d Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 13:54:50 +0200 Subject: [PATCH 02/12] Fix spec decode & unified attn preemptions Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_input_batch.py | 4 ++-- vllm_gaudi/v1/worker/hpu_model_runner.py | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_input_batch.py b/vllm_gaudi/v1/worker/hpu_input_batch.py index 57f89d0f6..b46c5de64 100644 --- a/vllm_gaudi/v1/worker/hpu_input_batch.py +++ b/vllm_gaudi/v1/worker/hpu_input_batch.py @@ -262,8 +262,8 @@ def add_request( # as that is used as basis for recomputing prefill. # This also assumes that preemption is complete and reduces num_computed_tokens to 0 and preempted sequences # don't retain any originally used cache blocks. - #if request.num_computed_tokens == 0: - # self.num_prompt_tokens[req_index] = num_prompt_tokens + len(request.output_token_ids) + if request.num_computed_tokens == 0: + self.num_prompt_tokens[req_index] = num_prompt_tokens + len(request.output_token_ids) # Number of token ids in token_ids_cpu. # NOTE(woosuk): This may include spec decode tokens. diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 31f0eee59..03a30aade 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1488,18 +1488,12 @@ def _get_prompts_and_decodes( num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] num_prompt_tokens = self.input_batch.num_prompt_tokens[i] - num_all_tokens = self.input_batch.num_tokens[i] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] if num_computed_tokens < num_prompt_tokens and \ not self.is_decoder_only(req_id): # This is prompt break - if num_computed_tokens < num_all_tokens and num_scheduled_tokens != 1 and \ - not self.is_decoder_only(req_id): - break - #from fpdb import ForkedPdb; ForkedPdb().set_trace() - # This is decode # NOTE(chendi): To support spec decode, # we don't assume num_scheduled_tokens == 1. @@ -1697,15 +1691,15 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul #NOTE(kzawora): In non-preemption scenario, # self.input_batch.num_prompt_tokens[batch_idx] == self.input_batch.num_tokens[batch_idx]. # In preemption scenario num_tokens will also include the tokens emitted before preemption - num_all_tokens = self.input_batch.num_tokens[batch_idx] - num_output_logits = max(0, seq_num_computed_tokens + seq_num_scheduled_tokens - num_all_tokens + 1) + num_prompt_tokens = self.input_batch.num_prompt_tokens[batch_idx] + num_output_logits = max(0, seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1) logits_positions = list(range(seq_num_scheduled_tokens - num_output_logits, seq_num_scheduled_tokens)) new_batch_contents = BatchContents( req_ids=[req_id], token_ids=[token_ids], context_lens=[seq_num_computed_tokens], - prompt_lens=[num_all_tokens], + prompt_lens=[num_prompt_tokens], blocks=[blocks], logits_positions=[logits_positions], ) From 1391286e73171d3f9d4525723aad4fccbdff8123 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 13:58:54 +0200 Subject: [PATCH 03/12] code cleanup Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 03a30aade..d93aae129 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1516,12 +1516,8 @@ def _get_prompts_and_decodes( num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] # Must be prompt - # assert num_computed_tokens < num_prompt_tokens - # num_output_tokens = len(self.requests[req_id].output_token_ids) - #if not has_kv_transfer_group(): - # #P case num_output_tokens has non 0 - # assert num_output_tokens == 0, \ - # f'req_id: {req_id}, {num_output_tokens}' + assert num_computed_tokens < num_prompt_tokens + # NOTE(kzawora): In preempted sequences, num_output_tokens can be > 0, and still be a valid prefill prompt_req_ids.append(req_id) prompt_scheduled_tokens.append(num_scheduled_tokens) @@ -3333,7 +3329,7 @@ def execute_model( num_tokens = len(token_ids) self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids self.input_batch.num_tokens[i] += len(token_ids) - #req_state.output_token_ids.extend(token_ids) + # NOTE(chendi): enable cache based on PR(#20291) # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. From 0f8b45fd54dd8ed991777d079d21508c378f21bb Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 16:22:18 +0200 Subject: [PATCH 04/12] Extract metadata update to HPUAttentionMetadataProcessor Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 419 +++++++++++++++-------- 1 file changed, 282 insertions(+), 137 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d93aae129..dc9de6490 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -346,6 +346,8 @@ def __init__(self, model, vllm_config): self.is_mm_optimized = is_mm_optimized(self.model) self.sliding_window = vllm_config.model_config.get_sliding_window() self.interleaved_sliding_window = is_interleaved(vllm_config.model_config.hf_text_config) + self.metadata_processor = HPUAttentionMetadataProcessor(self.block_size, self.dtype, self.prefill_use_fusedsdpa, + self.sliding_window, self.interleaved_sliding_window) if self.interleaved_sliding_window: self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true") self.slice_size = int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024")) @@ -400,141 +402,6 @@ def _reset_rotary_cos_sin(self): if hasattr(self._rotary_embed_module, "sin"): delattr(self._rotary_embed_module, "sin") - def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): - if (attn_metadata is None or (self.prefill_use_fusedsdpa and attn_metadata.block_list is None) - or not attn_metadata.is_prompt): - return attn_metadata - - if attn_metadata.attn_bias is not None: - return attn_metadata - - prefill_metadata = attn_metadata - - seq_lens_t = prefill_metadata.seq_lens_tensor - context_lens_t = prefill_metadata.context_lens_tensor - - block_list = attn_metadata.block_list - max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) - max_context_len = max_context_len * self.block_size - past_mask = torch.arange(0, max_context_len, dtype=torch.int32, device=device) - past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(context_lens_t.view(-1, 1)).view( - batch_size, 1, -1).expand(batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) - - len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).ge( - seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) - causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool), - diagonal=1) - mask = causal_mask.logical_or(len_mask) - mask = torch.concat((past_mask, mask), dim=-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) - attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) - return attn_metadata - - def _set_attn_bias_for_sliding_window(self, attn_metadata, batch_size, seq_len, window_size, device, dtype): - - if (attn_metadata is None or not attn_metadata.is_prompt): - return attn_metadata - - prefill_metadata = attn_metadata - shift = 0 - - # FusedSDPA with window_size is only supported when the seq_len is multiple of the slice_size - if self.prefill_use_fusedsdpa and self.use_window_sdpa and \ - seq_len >= self.slice_thld and self.slice_size != 0 and \ - seq_len % self.slice_size == 0 and attn_metadata.block_list is None: - # no need to set sliding window mask, just use built-in window-sdpa - return attn_metadata - - if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: - context_lens_t = prefill_metadata.context_lens_tensor - - block_list = attn_metadata.block_list - max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) - max_context_len = max_context_len * self.block_size - - invalid_lens_t = context_lens_t - window_size + torch.arange(seq_len, device=device) - 1 - past_indices = torch.arange(max_context_len, device=device) - past_mask = ((past_indices.unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) & - (past_indices.unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(0))).unsqueeze(1) - - # Create boolean sliding window mask - causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) - causal_mask = torch.triu(causal_mask, diagonal=shift - window_size + 1) - causal_mask = causal_mask.view(batch_size, 1, seq_len, seq_len) - - # TODO: Investigate further - Removing Padding cause accuracy issue - # seq_lens_t = prefill_metadata.seq_lens_tensor - # len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).lt( - # seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) - # causal_mask = causal_mask.logical_and(len_mask) - - mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) - else: - # CAUSAL MASK without removing padding (CAUSAL+sliding window) - # removing padding cause accuracy issue for images input - tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) - mask = torch.tril(tensor, diagonal=shift) - mask = torch.triu(mask, diagonal=shift - window_size + 1) - attn_bias = torch.log(mask) - - attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) - return attn_metadata - - def _set_block_mapping(self, metadata, batch_size, device, dtype, is_window_block=False): - if is_window_block: - block_usage = metadata.window_block_usage - block_groups = metadata.window_block_groups - else: - block_usage = metadata.block_usage - block_groups = metadata.block_groups - - mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) - - if not is_fake_hpu(): - block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) - else: - # Unfortunately one_hot on CPU - # doesn't handle out of bounds classes so we need to convert - # all negative values to 0 (block_mapping) or bs (block_groups) - block_groups = block_groups.to(torch.long) - block_mapping = torch.nn.functional.relu(block_groups) - block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) - oob_values = block_groups.lt(0) - block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) - block_groups.masked_fill_(oob_values, batch_size) - if is_window_block: - metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", window_block_groups=block_groups) - else: - metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", block_groups=block_groups) - block_mapping = block_mapping.to(dtype) - if is_window_block: - metadata = custom_tuple_replace(metadata, - "TrimmedAttentionMetadata", - window_block_mapping=block_mapping, - window_attn_bias=attn_bias) - else: - metadata = custom_tuple_replace(metadata, - "TrimmedAttentionMetadata", - block_mapping=block_mapping, - attn_bias=attn_bias) - return metadata - - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): - if attn_metadata.is_prompt: - attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) - if self.interleaved_sliding_window: - attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, - self.sliding_window, device, dtype) - else: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) - if self.interleaved_sliding_window: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) - return attn_metadata - def forward(self, *args, **kwargs): # TODO(kzawora): something goes VERY WRONG when operating on # kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata @@ -548,8 +415,9 @@ def forward(self, *args, **kwargs): kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] if not self.unified_attn: - kwargs['attn_metadata'] = self._update_metadata(kwargs['attn_metadata'], input_ids.size(0), - input_ids.size(1), input_ids.device, self.dtype) + kwargs['attn_metadata'] = self.metadata_processor.process_metadata(kwargs['attn_metadata'], + input_ids.size(0), input_ids.size(1), + input_ids.device, self.dtype) if self._rotary_prepare_cos_sin is not None: self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) attn_meta = kwargs.pop('attn_metadata') @@ -4735,3 +4603,280 @@ def device(self): def dtype(self): """Returns the torch.dtype of the tensors within the tuple.""" return self._dtype + + +class HPUAttentionMetadataProcessor: + """ + Processor class for post-processing HPU attention metadata. + + This class takes already-built attention metadata and augments it with + additional tensors such as attention bias masks and block mappings that + are required for efficient attention computation on HPU. It does NOT build + the metadata from scratch - it post-processes existing metadata structures. + """ + + def __init__( + self, + block_size: int, + dtype: torch.dtype, + prefill_use_fusedsdpa: Optional[bool] = None, + sliding_window: Optional[int] = None, + interleaved_sliding_window: bool = False, + ): + """ + Initialize the attention metadata processor. + + Args: + block_size: Size of KV cache blocks + dtype: Data type for attention operations + prefill_use_fusedsdpa: Whether to use fused SDPA for prefill + sliding_window: Sliding window size (None if not using sliding window) + interleaved_sliding_window: Whether to use interleaved sliding window + """ + self.block_size = block_size + self.dtype = dtype + self.prefill_use_fusedsdpa = prefill_use_fusedsdpa if prefill_use_fusedsdpa is not None \ + else (get_config().prompt_attn_impl == 'fsdpa_impl') + self.sliding_window = sliding_window + self.interleaved_sliding_window = interleaved_sliding_window + + if self.interleaved_sliding_window: + self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true") + self.slice_size = int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024")) + self.slice_thld = int(os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192')) + + def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, device: torch.device, + dtype: torch.dtype) -> HPUAttentionMetadataV1: + """ + Set attention bias for prompt phase. + + Creates causal attention masks with proper handling of padding and context. + + Args: + attn_metadata: Input attention metadata + batch_size: Batch size + seq_len: Sequence length + device: Device to create tensors on + dtype: Data type for the bias tensor + + Returns: + Updated attention metadata with attn_bias set + """ + if (attn_metadata is None or (self.prefill_use_fusedsdpa and attn_metadata.block_list is None) + or not attn_metadata.is_prompt): + return attn_metadata + + if attn_metadata.attn_bias is not None: + return attn_metadata + + prefill_metadata = attn_metadata + + seq_lens_t = prefill_metadata.seq_lens_tensor + context_lens_t = prefill_metadata.context_lens_tensor + + block_list = attn_metadata.block_list + max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + past_mask = torch.arange(0, max_context_len, dtype=torch.int32, device=device) + past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(context_lens_t.view(-1, 1)).view( + batch_size, 1, -1).expand(batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) + + len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).ge( + seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) + causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool), + diagonal=1) + mask = causal_mask.logical_or(len_mask) + mask = torch.concat((past_mask, mask), dim=-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) + return attn_metadata + + def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + window_size: int, device: torch.device, + dtype: torch.dtype) -> HPUAttentionMetadataV1: + """ + Set attention bias for sliding window attention in prompt phase. + + Args: + attn_metadata: Input attention metadata + batch_size: Batch size + seq_len: Sequence length + window_size: Sliding window size + device: Device to create tensors on + dtype: Data type for the bias tensor + + Returns: + Updated attention metadata with window_attn_bias set + """ + if (attn_metadata is None or not attn_metadata.is_prompt): + return attn_metadata + + prefill_metadata = attn_metadata + shift = 0 + + # FusedSDPA with window_size is only supported when the seq_len is multiple of the slice_size + if self.prefill_use_fusedsdpa and self.use_window_sdpa and \ + seq_len >= self.slice_thld and self.slice_size != 0 and \ + seq_len % self.slice_size == 0 and attn_metadata.block_list is None: + # no need to set sliding window mask, just use built-in window-sdpa + return attn_metadata + + if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: + context_lens_t = prefill_metadata.context_lens_tensor + + block_list = attn_metadata.block_list + max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + + invalid_lens_t = context_lens_t - window_size + torch.arange(seq_len, device=device) - 1 + past_indices = torch.arange(max_context_len, device=device) + past_mask = ((past_indices.unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) & + (past_indices.unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(0))).unsqueeze(1) + + # Create boolean sliding window mask + causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) + causal_mask = torch.triu(causal_mask, diagonal=shift - window_size + 1) + causal_mask = causal_mask.view(batch_size, 1, seq_len, seq_len) + + # TODO: Investigate further - Removing Padding cause accuracy issue + # seq_lens_t = prefill_metadata.seq_lens_tensor + # len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).lt( + # seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) + # causal_mask = causal_mask.logical_and(len_mask) + + mask = torch.concat((past_mask, causal_mask), dim=-1) + attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), + torch.tensor(float('-inf'), dtype=dtype, device=device)) + else: + # CAUSAL MASK without removing padding (CAUSAL+sliding window) + # removing padding cause accuracy issue for images input + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + mask = torch.tril(tensor, diagonal=shift) + mask = torch.triu(mask, diagonal=shift - window_size + 1) + attn_bias = torch.log(mask) + + attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) + return attn_metadata + + def _set_block_mapping(self, + metadata: HPUAttentionMetadataV1, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + is_window_block: bool = False) -> HPUAttentionMetadataV1: + """ + Set block mapping for decode phase. + + Creates block mapping and attention bias for paged attention during decode. + + Args: + metadata: Input attention metadata + batch_size: Batch size + device: Device to create tensors on + dtype: Data type for tensors + is_window_block: Whether this is for window blocks + + Returns: + Updated attention metadata with block_mapping and attn_bias set + """ + if is_window_block: + block_usage = metadata.window_block_usage + block_groups = metadata.window_block_groups + else: + block_usage = metadata.block_usage + block_groups = metadata.block_groups + + mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + + if not is_fake_hpu(): + block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) + else: + # Unfortunately one_hot on CPU + # doesn't handle out of bounds classes so we need to convert + # all negative values to 0 (block_mapping) or bs (block_groups) + block_groups = block_groups.to(torch.long) + block_mapping = torch.nn.functional.relu(block_groups) + block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) + oob_values = block_groups.lt(0) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + block_groups.masked_fill_(oob_values, batch_size) + if is_window_block: + metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", window_block_groups=block_groups) + else: + metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", block_groups=block_groups) + block_mapping = block_mapping.to(dtype) + if is_window_block: + metadata = custom_tuple_replace(metadata, + "TrimmedAttentionMetadata", + window_block_mapping=block_mapping, + window_attn_bias=attn_bias) + else: + metadata = custom_tuple_replace(metadata, + "TrimmedAttentionMetadata", + block_mapping=block_mapping, + attn_bias=attn_bias) + return metadata + + def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + device: torch.device, dtype: torch.dtype) -> HPUAttentionMetadataV1: + """ + Post-process attention metadata with appropriate masks and mappings. + + This is the main entry point for processing attention metadata. It augments + the metadata with attention bias masks (for prompt phase) or block mappings + (for decode phase), with support for sliding window attention. + + Args: + attn_metadata: Input attention metadata (already built) + batch_size: Batch size + seq_len: Sequence length (for prompt phase) + device: Device to create tensors on + dtype: Data type for tensors + + Returns: + Post-processed attention metadata with additional tensors + """ + if attn_metadata.is_prompt: + attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) + if self.interleaved_sliding_window and self.sliding_window is not None: + attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, + self.sliding_window, device, dtype) + else: + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) + if self.interleaved_sliding_window: + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) + return attn_metadata + + def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: int, device: torch.device, + dtype: torch.dtype) -> dict: + """ + Post-process a dictionary of attention metadata (for multi-layer models). + + This method optimizes the processing by checking if all metadata objects in the + dictionary are the same instance, and if so, processes only once. + + Args: + attn_metadata: Dictionary mapping layer names to attention metadata + batch_size: Batch size + seq_len: Sequence length (for prompt phase) + device: Device to create tensors on + dtype: Data type for tensors + + Returns: + Dictionary with post-processed attention metadata + """ + from vllm_gaudi.extension.logger import logger + + first_attn_metadata = next(iter(attn_metadata.values())) + updated_attn_metadata = self.process_metadata(first_attn_metadata, batch_size, seq_len, device, dtype) + + for key in attn_metadata: + if attn_metadata[key] is first_attn_metadata: + attn_metadata[key] = updated_attn_metadata + else: + msg = f"Different attn_metadata encountered on layer {key}. Processing it individually." + logger.warning(msg) + attn_metadata[key] = self.process_metadata(attn_metadata[key], batch_size, seq_len, device, dtype) + return attn_metadata From 437b0c325db5b54b996f6fd22f247ef3c2e13219 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 16:34:40 +0200 Subject: [PATCH 05/12] use vllm_config Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 29 ++++++++---------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index dc9de6490..c873dd564 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -333,7 +333,6 @@ class HpuModelAdapter(torch.nn.Module, KVConnectorModelRunnerMixin): def __init__(self, model, vllm_config): super().__init__() self.model = model - self.prefill_use_fusedsdpa = get_config().prompt_attn_impl == 'fsdpa_impl' self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'] self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -344,14 +343,8 @@ def __init__(self, model, vllm_config): self.unified_attn_persistent_ctx = None self.flatten_input = get_config().flatten_input self.is_mm_optimized = is_mm_optimized(self.model) - self.sliding_window = vllm_config.model_config.get_sliding_window() self.interleaved_sliding_window = is_interleaved(vllm_config.model_config.hf_text_config) - self.metadata_processor = HPUAttentionMetadataProcessor(self.block_size, self.dtype, self.prefill_use_fusedsdpa, - self.sliding_window, self.interleaved_sliding_window) - if self.interleaved_sliding_window: - self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true") - self.slice_size = int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024")) - self.slice_thld = int(os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192')) + self.metadata_processor = HPUAttentionMetadataProcessor(vllm_config) # for DP self.dummy_num_input_tokens = -1 @@ -4617,11 +4610,7 @@ class HPUAttentionMetadataProcessor: def __init__( self, - block_size: int, - dtype: torch.dtype, - prefill_use_fusedsdpa: Optional[bool] = None, - sliding_window: Optional[int] = None, - interleaved_sliding_window: bool = False, + vllm_config: VllmConfig, ): """ Initialize the attention metadata processor. @@ -4633,12 +4622,14 @@ def __init__( sliding_window: Sliding window size (None if not using sliding window) interleaved_sliding_window: Whether to use interleaved sliding window """ - self.block_size = block_size - self.dtype = dtype - self.prefill_use_fusedsdpa = prefill_use_fusedsdpa if prefill_use_fusedsdpa is not None \ - else (get_config().prompt_attn_impl == 'fsdpa_impl') - self.sliding_window = sliding_window - self.interleaved_sliding_window = interleaved_sliding_window + self.prefill_use_fusedsdpa = get_config().prompt_attn_impl == 'fsdpa_impl' + self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'] + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.dtype = vllm_config.model_config.dtype + self.flatten_input = get_config().flatten_input + self.sliding_window = vllm_config.model_config.get_sliding_window() + self.interleaved_sliding_window = is_interleaved(vllm_config.model_config.hf_text_config) if self.interleaved_sliding_window: self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true") From 5b79e4cdde5f9f449a4afe94323d4751b4f16356 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 17:22:35 +0200 Subject: [PATCH 06/12] Move metadata processing outside HPUModelAdapter, process biases on CPU Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 169 +++++++++++++++-------- 1 file changed, 109 insertions(+), 60 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index c873dd564..cc08c291b 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -343,8 +343,6 @@ def __init__(self, model, vllm_config): self.unified_attn_persistent_ctx = None self.flatten_input = get_config().flatten_input self.is_mm_optimized = is_mm_optimized(self.model) - self.interleaved_sliding_window = is_interleaved(vllm_config.model_config.hf_text_config) - self.metadata_processor = HPUAttentionMetadataProcessor(vllm_config) # for DP self.dummy_num_input_tokens = -1 @@ -407,10 +405,6 @@ def forward(self, *args, **kwargs): if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] - if not self.unified_attn: - kwargs['attn_metadata'] = self.metadata_processor.process_metadata(kwargs['attn_metadata'], - input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) if self._rotary_prepare_cos_sin is not None: self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) attn_meta = kwargs.pop('attn_metadata') @@ -483,6 +477,16 @@ def subtuple(obj: object, typename: str, to_copy: list[str], to_override: Option return _TYPE_CACHE[typename]['type'](**values) # type: ignore +def metadata_update_with_trim(obj: object, typename: str, trim: bool, **to_override): + if trim: + return custom_tuple_replace(obj, typename, **to_override) + + for key in to_override: + assert hasattr(obj, key), f"Field {key} must exist in untrimmed metadata." + setattr(obj, key, to_override[key]) + return obj + + def custom_tuple_replace(obj: object, typename: str, **to_override): # Torch compile dynamo doesn't support calling any named tuple # dynamic methods other than len and get_attr. This function is @@ -775,6 +779,8 @@ def __init__( assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' + self.metadata_processor = HPUAttentionMetadataProcessor(vllm_config) + def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True) -> CpuGpuBuffer: return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory, with_numpy=numpy) @@ -1689,6 +1695,15 @@ def _form_prefill_batch(self, contents): block_list=context_blocks_t, attn_bias=attn_bias, block_size=self.block_size) + + attn_metadata = self.metadata_processor.process_metadata(attn_metadata, + token_ids.size(0), + token_ids.size(1), + 'cpu', + token_ids.device, + self.dtype, + trim=False) + return PrefillInputData(request_ids=[req_ids], prompt_lens=[query_lens], token_ids=[token_ids], @@ -1979,21 +1994,31 @@ def _create_decode_input_data(self, spec_decode_metadata = None logits_indices_device = async_h2d_copy(logits_indices, device=self.device) + attn_metadata = HPUAttentionMetadataV1.make_decode_metadata( + block_list=block_list_device, + block_usage=block_usage_device, + block_groups=block_groups_device, + input_positions=None, + slot_mapping=slot_mapping_device, + block_size=self.block_size, + window_block_list=window_block_list_device, + window_block_usage=window_block_usage_device, + window_block_groups=window_block_groups_device, + ), + + attn_metadata = self.metadata_processor.process_metadata(attn_metadata, + token_ids_device.size(0), + token_ids_device.size(1), + 'cpu', + token_ids.device, + self.dtype, + trim=False) + return DecodeInputData(num_decodes=num_decodes, token_ids=token_ids_device, position_ids=positions_device, logits_indices=logits_indices_device, - attn_metadata=HPUAttentionMetadataV1.make_decode_metadata( - block_list=block_list_device, - block_usage=block_usage_device, - block_groups=block_groups_device, - input_positions=None, - slot_mapping=slot_mapping_device, - block_size=self.block_size, - window_block_list=window_block_list_device, - window_block_usage=window_block_usage_device, - window_block_groups=window_block_groups_device, - ), + attn_metadata=attn_metadata, spec_decode_metadata=spec_decode_metadata) def _prepare_decode_inputs(self, @@ -2330,7 +2355,7 @@ def _execute_model_generic(self, else: # no hpu graphs for t.compile? use_graphs = False - trimmed_attn_metadata = attn_metadata if self.unified_attn else trim_attn_metadata(attn_metadata) + if self.is_driver_worker: model_event_name = ("model_forward_" f"bs{batch_size}_" @@ -2339,6 +2364,7 @@ def _execute_model_generic(self, f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' + trimmed_attn_metadata = trim_attn_metadata(attn_metadata) with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward(input_ids=token_ids, positions=position_ids, @@ -4614,13 +4640,6 @@ def __init__( ): """ Initialize the attention metadata processor. - - Args: - block_size: Size of KV cache blocks - dtype: Data type for attention operations - prefill_use_fusedsdpa: Whether to use fused SDPA for prefill - sliding_window: Sliding window size (None if not using sliding window) - interleaved_sliding_window: Whether to use interleaved sliding window """ self.prefill_use_fusedsdpa = get_config().prompt_attn_impl == 'fsdpa_impl' self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'] @@ -4630,14 +4649,16 @@ def __init__( self.flatten_input = get_config().flatten_input self.sliding_window = vllm_config.model_config.get_sliding_window() self.interleaved_sliding_window = is_interleaved(vllm_config.model_config.hf_text_config) + self.unified_attn = get_config().unified_attn if self.interleaved_sliding_window: self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true") self.slice_size = int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024")) self.slice_thld = int(os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192')) - def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, device: torch.device, - dtype: torch.dtype) -> HPUAttentionMetadataV1: + def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + src_device: torch.device, dst_device: torch.device, dtype: torch.dtype, + trim: bool) -> HPUAttentionMetadataV1: """ Set attention bias for prompt phase. @@ -4668,23 +4689,28 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, block_list = attn_metadata.block_list max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) max_context_len = max_context_len * self.block_size - past_mask = torch.arange(0, max_context_len, dtype=torch.int32, device=device) + past_mask = torch.arange(0, max_context_len, dtype=torch.int32, device=src_device) past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(context_lens_t.view(-1, 1)).view( batch_size, 1, -1).expand(batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) - len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).ge( + len_mask = (torch.arange(0, seq_len, device=src_device, dtype=torch.int32).view(1, seq_len).ge( seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) - causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool), + causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=src_device, dtype=torch.bool), diagonal=1) mask = causal_mask.logical_or(len_mask) mask = torch.concat((past_mask, mask), dim=-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) - attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) + if src_device != dst_device: + attn_bias = attn_bias.to(dst_device, non_blocking=True) + attn_metadata = metadata_update_with_trim(prefill_metadata, + "TrimmedAttentionMetadata", + trim=trim, + attn_bias=attn_bias) return attn_metadata def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, - window_size: int, device: torch.device, - dtype: torch.dtype) -> HPUAttentionMetadataV1: + window_size: int, src_device: torch.device, dst_device: torch.device, + dtype: torch.dtype, trim: bool) -> HPUAttentionMetadataV1: """ Set attention bias for sliding window attention in prompt phase. @@ -4719,13 +4745,13 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) max_context_len = max_context_len * self.block_size - invalid_lens_t = context_lens_t - window_size + torch.arange(seq_len, device=device) - 1 - past_indices = torch.arange(max_context_len, device=device) + invalid_lens_t = context_lens_t - window_size + torch.arange(seq_len, device=src_device) - 1 + past_indices = torch.arange(max_context_len, device=src_device) past_mask = ((past_indices.unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) & (past_indices.unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(0))).unsqueeze(1) # Create boolean sliding window mask - causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) + causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=src_device), diagonal=shift) causal_mask = torch.triu(causal_mask, diagonal=shift - window_size + 1) causal_mask = causal_mask.view(batch_size, 1, seq_len, seq_len) @@ -4736,25 +4762,29 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV # causal_mask = causal_mask.logical_and(len_mask) mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) + attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=src_device), + torch.tensor(float('-inf'), dtype=dtype, device=src_device)) else: # CAUSAL MASK without removing padding (CAUSAL+sliding window) # removing padding cause accuracy issue for images input - tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=src_device, dtype=dtype, fill_value=1) mask = torch.tril(tensor, diagonal=shift) mask = torch.triu(mask, diagonal=shift - window_size + 1) attn_bias = torch.log(mask) + if src_device != dst_device: + attn_bias = attn_bias.to(dst_device, non_blocking=True) attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) return attn_metadata def _set_block_mapping(self, metadata: HPUAttentionMetadataV1, batch_size: int, - device: torch.device, + src_device: torch.device, + dst_device: torch.device, dtype: torch.dtype, - is_window_block: bool = False) -> HPUAttentionMetadataV1: + is_window_block: bool = False, + trim: bool = False) -> HPUAttentionMetadataV1: """ Set block mapping for decode phase. @@ -4777,7 +4807,7 @@ def _set_block_mapping(self, block_usage = metadata.block_usage block_groups = metadata.block_groups - mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) + mask = torch.arange(0, self.block_size, device=src_device, dtype=torch.int32).unsqueeze(0) mask = mask >= block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) @@ -4793,25 +4823,39 @@ def _set_block_mapping(self, oob_values = block_groups.lt(0) block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) block_groups.masked_fill_(oob_values, batch_size) + if src_device != dst_device: + block_groups = block_groups.to(dst_device, non_blocking=True) if is_window_block: - metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", window_block_groups=block_groups) + metadata = metadata_update_with_trim(metadata, + "TrimmedAttentionMetadata", + trim=trim, + window_block_groups=block_groups) else: - metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", block_groups=block_groups) + metadata = metadata_update_with_trim(metadata, + "TrimmedAttentionMetadata", + trim=trim, + block_groups=block_groups) block_mapping = block_mapping.to(dtype) + if src_device != dst_device: + block_mapping = block_mapping.to(dst_device, non_blocking=True) + attn_bias = attn_bias.to(dst_device, non_blocking=True) if is_window_block: - metadata = custom_tuple_replace(metadata, - "TrimmedAttentionMetadata", - window_block_mapping=block_mapping, - window_attn_bias=attn_bias) + metadata = metadata_update_with_trim(metadata, + "TrimmedAttentionMetadata", + trim=trim, + window_block_mapping=block_mapping, + window_attn_bias=attn_bias) else: - metadata = custom_tuple_replace(metadata, - "TrimmedAttentionMetadata", - block_mapping=block_mapping, - attn_bias=attn_bias) + metadata = metadata_update_with_trim(metadata, + "TrimmedAttentionMetadata", + trim=trim, + block_mapping=block_mapping, + attn_bias=attn_bias) return metadata def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, - device: torch.device, dtype: torch.dtype) -> HPUAttentionMetadataV1: + src_device: torch.device, dst_device: torch.device, dtype: torch.dtype, + trim: bool) -> HPUAttentionMetadataV1: """ Post-process attention metadata with appropriate masks and mappings. @@ -4829,19 +4873,24 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in Returns: Post-processed attention metadata with additional tensors """ + if self.unified_attn: + return attn_metadata + if attn_metadata.is_prompt: - attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) + attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, src_device, dst_device, dtype, trim) if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, - self.sliding_window, device, dtype) + self.sliding_window, src_device, dst_device, + dtype, trim) else: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, trim) if self.interleaved_sliding_window: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, True, + trim) return attn_metadata def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: int, device: torch.device, - dtype: torch.dtype) -> dict: + dtype: torch.dtype, trim: bool) -> dict: """ Post-process a dictionary of attention metadata (for multi-layer models). @@ -4861,7 +4910,7 @@ def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: i from vllm_gaudi.extension.logger import logger first_attn_metadata = next(iter(attn_metadata.values())) - updated_attn_metadata = self.process_metadata(first_attn_metadata, batch_size, seq_len, device, dtype) + updated_attn_metadata = self.process_metadata(first_attn_metadata, batch_size, seq_len, device, dtype, trim) for key in attn_metadata: if attn_metadata[key] is first_attn_metadata: @@ -4869,5 +4918,5 @@ def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: i else: msg = f"Different attn_metadata encountered on layer {key}. Processing it individually." logger.warning(msg) - attn_metadata[key] = self.process_metadata(attn_metadata[key], batch_size, seq_len, device, dtype) + attn_metadata[key] = self.process_metadata(attn_metadata[key], batch_size, seq_len, device, dtype, trim) return attn_metadata From 00c22bca427a3f9e8f3793254cc28bb914ebe607 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 17:32:49 +0200 Subject: [PATCH 07/12] fix precommit Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index c873dd564..71f6e76a8 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -4614,13 +4614,6 @@ def __init__( ): """ Initialize the attention metadata processor. - - Args: - block_size: Size of KV cache blocks - dtype: Data type for attention operations - prefill_use_fusedsdpa: Whether to use fused SDPA for prefill - sliding_window: Sliding window size (None if not using sliding window) - interleaved_sliding_window: Whether to use interleaved sliding window """ self.prefill_use_fusedsdpa = get_config().prompt_attn_impl == 'fsdpa_impl' self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'] @@ -4663,7 +4656,9 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, prefill_metadata = attn_metadata seq_lens_t = prefill_metadata.seq_lens_tensor + assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias" context_lens_t = prefill_metadata.context_lens_tensor + assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" block_list = attn_metadata.block_list max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) @@ -4714,6 +4709,7 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: context_lens_t = prefill_metadata.context_lens_tensor + assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" block_list = attn_metadata.block_list max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) From 1a5a3e11ecdec321a95cfbff2e340774ee7d1991 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 17:37:14 +0200 Subject: [PATCH 08/12] make copilot happy Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 71f6e76a8..1160b37b2 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -4616,11 +4616,9 @@ def __init__( Initialize the attention metadata processor. """ self.prefill_use_fusedsdpa = get_config().prompt_attn_impl == 'fsdpa_impl' - self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'] self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.dtype = vllm_config.model_config.dtype - self.flatten_input = get_config().flatten_input self.sliding_window = vllm_config.model_config.get_sliding_window() self.interleaved_sliding_window = is_interleaved(vllm_config.model_config.hf_text_config) @@ -4827,7 +4825,7 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in """ if attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) - if self.interleaved_sliding_window and self.sliding_window is not None: + if self.interleaved_sliding_window: attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, self.sliding_window, device, dtype) else: From 09d71303dbd3f78c8cd4dc1388ef0c2db9f13e12 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 17:44:01 +0200 Subject: [PATCH 09/12] make copilot happy Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 3dd977c9a..4192fa084 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1699,7 +1699,7 @@ def _form_prefill_batch(self, contents): attn_metadata = self.metadata_processor.process_metadata(attn_metadata, token_ids.size(0), token_ids.size(1), - 'cpu', + torch.device('cpu'), token_ids.device, self.dtype, trim=False) @@ -2009,7 +2009,7 @@ def _create_decode_input_data(self, attn_metadata = self.metadata_processor.process_metadata(attn_metadata, token_ids_device.size(0), token_ids_device.size(1), - 'cpu', + torch.device('cpu'), token_ids.device, self.dtype, trim=False) @@ -4884,10 +4884,9 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in self.sliding_window, src_device, dst_device, dtype, trim) else: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, trim) - if self.interleaved_sliding_window: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, True, - trim) + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, + self.interleaved_sliding_window, trim) + return attn_metadata def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: int, device: torch.device, From e35abfb243e031de44510fe5ecd3604aba804f7e Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 17:47:16 +0200 Subject: [PATCH 10/12] make copilot happy again Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 4192fa084..6de9d20c2 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -4666,7 +4666,8 @@ def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, attn_metadata: Input attention metadata batch_size: Batch size seq_len: Sequence length - device: Device to create tensors on + src_device: Device to create tensors on + dst_device: Device to move tensors to dtype: Data type for the bias tensor Returns: @@ -4719,7 +4720,8 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV batch_size: Batch size seq_len: Sequence length window_size: Sliding window size - device: Device to create tensors on + src_device: Device to create tensors on + dst_device: Device to move tensors to dtype: Data type for the bias tensor Returns: @@ -4889,8 +4891,8 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in return attn_metadata - def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: int, device: torch.device, - dtype: torch.dtype, trim: bool) -> dict: + def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: int, src_device: torch.device, + dst_device: torch.device, dtype: torch.dtype, trim: bool) -> dict: """ Post-process a dictionary of attention metadata (for multi-layer models). @@ -4901,7 +4903,8 @@ def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: i attn_metadata: Dictionary mapping layer names to attention metadata batch_size: Batch size seq_len: Sequence length (for prompt phase) - device: Device to create tensors on + src_device: Device to create tensors on + dst_device: Device to move tensors to dtype: Data type for tensors Returns: @@ -4910,7 +4913,8 @@ def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: i from vllm_gaudi.extension.logger import logger first_attn_metadata = next(iter(attn_metadata.values())) - updated_attn_metadata = self.process_metadata(first_attn_metadata, batch_size, seq_len, device, dtype, trim) + updated_attn_metadata = self.process_metadata(first_attn_metadata, batch_size, seq_len, src_device, dst_device, + dtype, trim) for key in attn_metadata: if attn_metadata[key] is first_attn_metadata: @@ -4918,5 +4922,6 @@ def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: i else: msg = f"Different attn_metadata encountered on layer {key}. Processing it individually." logger.warning(msg) - attn_metadata[key] = self.process_metadata(attn_metadata[key], batch_size, seq_len, device, dtype, trim) + attn_metadata[key] = self.process_metadata(attn_metadata[key], batch_size, seq_len, src_device, + dst_device, dtype, trim) return attn_metadata From 36aef30fda137d98fc79b52f33cba87a2a2f8d23 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 17:52:55 +0200 Subject: [PATCH 11/12] make copilot happy again Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 6de9d20c2..4599edb9d 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2010,7 +2010,7 @@ def _create_decode_input_data(self, token_ids_device.size(0), token_ids_device.size(1), torch.device('cpu'), - token_ids.device, + token_ids_device.device, self.dtype, trim=False) @@ -4886,8 +4886,13 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in self.sliding_window, src_device, dst_device, dtype, trim) else: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, - self.interleaved_sliding_window, trim) + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, False, + trim) + # NOTE(kzawora): I'm not sure why we set block mapping twice for sliding window + # - we should check if that can be reduced to a single call. + if self.interleaved_sliding_window: + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, True, + trim) return attn_metadata From 138ecaf7c3420917aa7efae4c8ceffe511a62bd1 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 5 Nov 2025 17:56:43 +0200 Subject: [PATCH 12/12] GODDAMN COMMA Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 4599edb9d..839e772e4 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2004,7 +2004,7 @@ def _create_decode_input_data(self, window_block_list=window_block_list_device, window_block_usage=window_block_usage_device, window_block_groups=window_block_groups_device, - ), + ) attn_metadata = self.metadata_processor.process_metadata(attn_metadata, token_ids_device.size(0),