diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d93aae129..839e772e4 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -333,7 +333,6 @@ class HpuModelAdapter(torch.nn.Module, KVConnectorModelRunnerMixin): def __init__(self, model, vllm_config): super().__init__() self.model = model - self.prefill_use_fusedsdpa = get_config().prompt_attn_impl == 'fsdpa_impl' self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'] self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -344,12 +343,6 @@ def __init__(self, model, vllm_config): self.unified_attn_persistent_ctx = None self.flatten_input = get_config().flatten_input self.is_mm_optimized = is_mm_optimized(self.model) - self.sliding_window = vllm_config.model_config.get_sliding_window() - self.interleaved_sliding_window = is_interleaved(vllm_config.model_config.hf_text_config) - if self.interleaved_sliding_window: - self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true") - self.slice_size = int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024")) - self.slice_thld = int(os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192')) # for DP self.dummy_num_input_tokens = -1 @@ -400,141 +393,6 @@ def _reset_rotary_cos_sin(self): if hasattr(self._rotary_embed_module, "sin"): delattr(self._rotary_embed_module, "sin") - def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): - if (attn_metadata is None or (self.prefill_use_fusedsdpa and attn_metadata.block_list is None) - or not attn_metadata.is_prompt): - return attn_metadata - - if attn_metadata.attn_bias is not None: - return attn_metadata - - prefill_metadata = attn_metadata - - seq_lens_t = prefill_metadata.seq_lens_tensor - context_lens_t = prefill_metadata.context_lens_tensor - - block_list = attn_metadata.block_list - max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) - max_context_len = max_context_len * self.block_size - past_mask = torch.arange(0, max_context_len, dtype=torch.int32, device=device) - past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(context_lens_t.view(-1, 1)).view( - batch_size, 1, -1).expand(batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) - - len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).ge( - seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) - causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool), - diagonal=1) - mask = causal_mask.logical_or(len_mask) - mask = torch.concat((past_mask, mask), dim=-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) - attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", attn_bias=attn_bias) - return attn_metadata - - def _set_attn_bias_for_sliding_window(self, attn_metadata, batch_size, seq_len, window_size, device, dtype): - - if (attn_metadata is None or not attn_metadata.is_prompt): - return attn_metadata - - prefill_metadata = attn_metadata - shift = 0 - - # FusedSDPA with window_size is only supported when the seq_len is multiple of the slice_size - if self.prefill_use_fusedsdpa and self.use_window_sdpa and \ - seq_len >= self.slice_thld and self.slice_size != 0 and \ - seq_len % self.slice_size == 0 and attn_metadata.block_list is None: - # no need to set sliding window mask, just use built-in window-sdpa - return attn_metadata - - if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: - context_lens_t = prefill_metadata.context_lens_tensor - - block_list = attn_metadata.block_list - max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) - max_context_len = max_context_len * self.block_size - - invalid_lens_t = context_lens_t - window_size + torch.arange(seq_len, device=device) - 1 - past_indices = torch.arange(max_context_len, device=device) - past_mask = ((past_indices.unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) & - (past_indices.unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(0))).unsqueeze(1) - - # Create boolean sliding window mask - causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift) - causal_mask = torch.triu(causal_mask, diagonal=shift - window_size + 1) - causal_mask = causal_mask.view(batch_size, 1, seq_len, seq_len) - - # TODO: Investigate further - Removing Padding cause accuracy issue - # seq_lens_t = prefill_metadata.seq_lens_tensor - # len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).lt( - # seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) - # causal_mask = causal_mask.logical_and(len_mask) - - mask = torch.concat((past_mask, causal_mask), dim=-1) - attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device), - torch.tensor(float('-inf'), dtype=dtype, device=device)) - else: - # CAUSAL MASK without removing padding (CAUSAL+sliding window) - # removing padding cause accuracy issue for images input - tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1) - mask = torch.tril(tensor, diagonal=shift) - mask = torch.triu(mask, diagonal=shift - window_size + 1) - attn_bias = torch.log(mask) - - attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) - return attn_metadata - - def _set_block_mapping(self, metadata, batch_size, device, dtype, is_window_block=False): - if is_window_block: - block_usage = metadata.window_block_usage - block_groups = metadata.window_block_groups - else: - block_usage = metadata.block_usage - block_groups = metadata.block_groups - - mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) - - if not is_fake_hpu(): - block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) - else: - # Unfortunately one_hot on CPU - # doesn't handle out of bounds classes so we need to convert - # all negative values to 0 (block_mapping) or bs (block_groups) - block_groups = block_groups.to(torch.long) - block_mapping = torch.nn.functional.relu(block_groups) - block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) - oob_values = block_groups.lt(0) - block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) - block_groups.masked_fill_(oob_values, batch_size) - if is_window_block: - metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", window_block_groups=block_groups) - else: - metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", block_groups=block_groups) - block_mapping = block_mapping.to(dtype) - if is_window_block: - metadata = custom_tuple_replace(metadata, - "TrimmedAttentionMetadata", - window_block_mapping=block_mapping, - window_attn_bias=attn_bias) - else: - metadata = custom_tuple_replace(metadata, - "TrimmedAttentionMetadata", - block_mapping=block_mapping, - attn_bias=attn_bias) - return metadata - - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): - if attn_metadata.is_prompt: - attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) - if self.interleaved_sliding_window: - attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, - self.sliding_window, device, dtype) - else: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) - if self.interleaved_sliding_window: - attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) - return attn_metadata - def forward(self, *args, **kwargs): # TODO(kzawora): something goes VERY WRONG when operating on # kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata @@ -547,9 +405,6 @@ def forward(self, *args, **kwargs): if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] - if not self.unified_attn: - kwargs['attn_metadata'] = self._update_metadata(kwargs['attn_metadata'], input_ids.size(0), - input_ids.size(1), input_ids.device, self.dtype) if self._rotary_prepare_cos_sin is not None: self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) attn_meta = kwargs.pop('attn_metadata') @@ -622,6 +477,16 @@ def subtuple(obj: object, typename: str, to_copy: list[str], to_override: Option return _TYPE_CACHE[typename]['type'](**values) # type: ignore +def metadata_update_with_trim(obj: object, typename: str, trim: bool, **to_override): + if trim: + return custom_tuple_replace(obj, typename, **to_override) + + for key in to_override: + assert hasattr(obj, key), f"Field {key} must exist in untrimmed metadata." + setattr(obj, key, to_override[key]) + return obj + + def custom_tuple_replace(obj: object, typename: str, **to_override): # Torch compile dynamo doesn't support calling any named tuple # dynamic methods other than len and get_attr. This function is @@ -914,6 +779,8 @@ def __init__( assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' + self.metadata_processor = HPUAttentionMetadataProcessor(vllm_config) + def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True) -> CpuGpuBuffer: return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory, with_numpy=numpy) @@ -1828,6 +1695,15 @@ def _form_prefill_batch(self, contents): block_list=context_blocks_t, attn_bias=attn_bias, block_size=self.block_size) + + attn_metadata = self.metadata_processor.process_metadata(attn_metadata, + token_ids.size(0), + token_ids.size(1), + torch.device('cpu'), + token_ids.device, + self.dtype, + trim=False) + return PrefillInputData(request_ids=[req_ids], prompt_lens=[query_lens], token_ids=[token_ids], @@ -2118,21 +1994,31 @@ def _create_decode_input_data(self, spec_decode_metadata = None logits_indices_device = async_h2d_copy(logits_indices, device=self.device) + attn_metadata = HPUAttentionMetadataV1.make_decode_metadata( + block_list=block_list_device, + block_usage=block_usage_device, + block_groups=block_groups_device, + input_positions=None, + slot_mapping=slot_mapping_device, + block_size=self.block_size, + window_block_list=window_block_list_device, + window_block_usage=window_block_usage_device, + window_block_groups=window_block_groups_device, + ) + + attn_metadata = self.metadata_processor.process_metadata(attn_metadata, + token_ids_device.size(0), + token_ids_device.size(1), + torch.device('cpu'), + token_ids_device.device, + self.dtype, + trim=False) + return DecodeInputData(num_decodes=num_decodes, token_ids=token_ids_device, position_ids=positions_device, logits_indices=logits_indices_device, - attn_metadata=HPUAttentionMetadataV1.make_decode_metadata( - block_list=block_list_device, - block_usage=block_usage_device, - block_groups=block_groups_device, - input_positions=None, - slot_mapping=slot_mapping_device, - block_size=self.block_size, - window_block_list=window_block_list_device, - window_block_usage=window_block_usage_device, - window_block_groups=window_block_groups_device, - ), + attn_metadata=attn_metadata, spec_decode_metadata=spec_decode_metadata) def _prepare_decode_inputs(self, @@ -2469,7 +2355,7 @@ def _execute_model_generic(self, else: # no hpu graphs for t.compile? use_graphs = False - trimmed_attn_metadata = attn_metadata if self.unified_attn else trim_attn_metadata(attn_metadata) + if self.is_driver_worker: model_event_name = ("model_forward_" f"bs{batch_size}_" @@ -2478,6 +2364,7 @@ def _execute_model_generic(self, f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' + trimmed_attn_metadata = trim_attn_metadata(attn_metadata) with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward(input_ids=token_ids, positions=position_ids, @@ -4735,3 +4622,311 @@ def device(self): def dtype(self): """Returns the torch.dtype of the tensors within the tuple.""" return self._dtype + + +class HPUAttentionMetadataProcessor: + """ + Processor class for post-processing HPU attention metadata. + + This class takes already-built attention metadata and augments it with + additional tensors such as attention bias masks and block mappings that + are required for efficient attention computation on HPU. It does NOT build + the metadata from scratch - it post-processes existing metadata structures. + """ + + def __init__( + self, + vllm_config: VllmConfig, + ): + """ + Initialize the attention metadata processor. + """ + self.prefill_use_fusedsdpa = get_config().prompt_attn_impl == 'fsdpa_impl' + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.dtype = vllm_config.model_config.dtype + self.sliding_window = vllm_config.model_config.get_sliding_window() + self.interleaved_sliding_window = is_interleaved(vllm_config.model_config.hf_text_config) + self.unified_attn = get_config().unified_attn + + if self.interleaved_sliding_window: + self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true") + self.slice_size = int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024")) + self.slice_thld = int(os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192')) + + def _set_attn_bias(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + src_device: torch.device, dst_device: torch.device, dtype: torch.dtype, + trim: bool) -> HPUAttentionMetadataV1: + """ + Set attention bias for prompt phase. + + Creates causal attention masks with proper handling of padding and context. + + Args: + attn_metadata: Input attention metadata + batch_size: Batch size + seq_len: Sequence length + src_device: Device to create tensors on + dst_device: Device to move tensors to + dtype: Data type for the bias tensor + + Returns: + Updated attention metadata with attn_bias set + """ + if (attn_metadata is None or (self.prefill_use_fusedsdpa and attn_metadata.block_list is None) + or not attn_metadata.is_prompt): + return attn_metadata + + if attn_metadata.attn_bias is not None: + return attn_metadata + + prefill_metadata = attn_metadata + + seq_lens_t = prefill_metadata.seq_lens_tensor + assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias" + context_lens_t = prefill_metadata.context_lens_tensor + assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" + + block_list = attn_metadata.block_list + max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + past_mask = torch.arange(0, max_context_len, dtype=torch.int32, device=src_device) + past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(context_lens_t.view(-1, 1)).view( + batch_size, 1, -1).expand(batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) + + len_mask = (torch.arange(0, seq_len, device=src_device, dtype=torch.int32).view(1, seq_len).ge( + seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) + causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=src_device, dtype=torch.bool), + diagonal=1) + mask = causal_mask.logical_or(len_mask) + mask = torch.concat((past_mask, mask), dim=-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + if src_device != dst_device: + attn_bias = attn_bias.to(dst_device, non_blocking=True) + attn_metadata = metadata_update_with_trim(prefill_metadata, + "TrimmedAttentionMetadata", + trim=trim, + attn_bias=attn_bias) + return attn_metadata + + def _set_attn_bias_for_sliding_window(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + window_size: int, src_device: torch.device, dst_device: torch.device, + dtype: torch.dtype, trim: bool) -> HPUAttentionMetadataV1: + """ + Set attention bias for sliding window attention in prompt phase. + + Args: + attn_metadata: Input attention metadata + batch_size: Batch size + seq_len: Sequence length + window_size: Sliding window size + src_device: Device to create tensors on + dst_device: Device to move tensors to + dtype: Data type for the bias tensor + + Returns: + Updated attention metadata with window_attn_bias set + """ + if (attn_metadata is None or not attn_metadata.is_prompt): + return attn_metadata + + prefill_metadata = attn_metadata + shift = 0 + + # FusedSDPA with window_size is only supported when the seq_len is multiple of the slice_size + if self.prefill_use_fusedsdpa and self.use_window_sdpa and \ + seq_len >= self.slice_thld and self.slice_size != 0 and \ + seq_len % self.slice_size == 0 and attn_metadata.block_list is None: + # no need to set sliding window mask, just use built-in window-sdpa + return attn_metadata + + if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: + context_lens_t = prefill_metadata.context_lens_tensor + assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" + + block_list = attn_metadata.block_list + max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + + invalid_lens_t = context_lens_t - window_size + torch.arange(seq_len, device=src_device) - 1 + past_indices = torch.arange(max_context_len, device=src_device) + past_mask = ((past_indices.unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) & + (past_indices.unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(0))).unsqueeze(1) + + # Create boolean sliding window mask + causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=src_device), diagonal=shift) + causal_mask = torch.triu(causal_mask, diagonal=shift - window_size + 1) + causal_mask = causal_mask.view(batch_size, 1, seq_len, seq_len) + + # TODO: Investigate further - Removing Padding cause accuracy issue + # seq_lens_t = prefill_metadata.seq_lens_tensor + # len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).lt( + # seq_lens_t.unsqueeze(-1)).view(batch_size, 1, 1, seq_len)) + # causal_mask = causal_mask.logical_and(len_mask) + + mask = torch.concat((past_mask, causal_mask), dim=-1) + attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=src_device), + torch.tensor(float('-inf'), dtype=dtype, device=src_device)) + else: + # CAUSAL MASK without removing padding (CAUSAL+sliding window) + # removing padding cause accuracy issue for images input + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=src_device, dtype=dtype, fill_value=1) + mask = torch.tril(tensor, diagonal=shift) + mask = torch.triu(mask, diagonal=shift - window_size + 1) + attn_bias = torch.log(mask) + + if src_device != dst_device: + attn_bias = attn_bias.to(dst_device, non_blocking=True) + attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias) + return attn_metadata + + def _set_block_mapping(self, + metadata: HPUAttentionMetadataV1, + batch_size: int, + src_device: torch.device, + dst_device: torch.device, + dtype: torch.dtype, + is_window_block: bool = False, + trim: bool = False) -> HPUAttentionMetadataV1: + """ + Set block mapping for decode phase. + + Creates block mapping and attention bias for paged attention during decode. + + Args: + metadata: Input attention metadata + batch_size: Batch size + device: Device to create tensors on + dtype: Data type for tensors + is_window_block: Whether this is for window blocks + + Returns: + Updated attention metadata with block_mapping and attn_bias set + """ + if is_window_block: + block_usage = metadata.window_block_usage + block_groups = metadata.window_block_groups + else: + block_usage = metadata.block_usage + block_groups = metadata.block_groups + + mask = torch.arange(0, self.block_size, device=src_device, dtype=torch.int32).unsqueeze(0) + mask = mask >= block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)) + + if not is_fake_hpu(): + block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) + else: + # Unfortunately one_hot on CPU + # doesn't handle out of bounds classes so we need to convert + # all negative values to 0 (block_mapping) or bs (block_groups) + block_groups = block_groups.to(torch.long) + block_mapping = torch.nn.functional.relu(block_groups) + block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) + oob_values = block_groups.lt(0) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + block_groups.masked_fill_(oob_values, batch_size) + if src_device != dst_device: + block_groups = block_groups.to(dst_device, non_blocking=True) + if is_window_block: + metadata = metadata_update_with_trim(metadata, + "TrimmedAttentionMetadata", + trim=trim, + window_block_groups=block_groups) + else: + metadata = metadata_update_with_trim(metadata, + "TrimmedAttentionMetadata", + trim=trim, + block_groups=block_groups) + block_mapping = block_mapping.to(dtype) + if src_device != dst_device: + block_mapping = block_mapping.to(dst_device, non_blocking=True) + attn_bias = attn_bias.to(dst_device, non_blocking=True) + if is_window_block: + metadata = metadata_update_with_trim(metadata, + "TrimmedAttentionMetadata", + trim=trim, + window_block_mapping=block_mapping, + window_attn_bias=attn_bias) + else: + metadata = metadata_update_with_trim(metadata, + "TrimmedAttentionMetadata", + trim=trim, + block_mapping=block_mapping, + attn_bias=attn_bias) + return metadata + + def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: int, seq_len: int, + src_device: torch.device, dst_device: torch.device, dtype: torch.dtype, + trim: bool) -> HPUAttentionMetadataV1: + """ + Post-process attention metadata with appropriate masks and mappings. + + This is the main entry point for processing attention metadata. It augments + the metadata with attention bias masks (for prompt phase) or block mappings + (for decode phase), with support for sliding window attention. + + Args: + attn_metadata: Input attention metadata (already built) + batch_size: Batch size + seq_len: Sequence length (for prompt phase) + device: Device to create tensors on + dtype: Data type for tensors + + Returns: + Post-processed attention metadata with additional tensors + """ + if self.unified_attn: + return attn_metadata + + if attn_metadata.is_prompt: + attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, src_device, dst_device, dtype, trim) + if self.interleaved_sliding_window: + attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, + self.sliding_window, src_device, dst_device, + dtype, trim) + else: + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, False, + trim) + # NOTE(kzawora): I'm not sure why we set block mapping twice for sliding window + # - we should check if that can be reduced to a single call. + if self.interleaved_sliding_window: + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, src_device, dst_device, dtype, True, + trim) + + return attn_metadata + + def process_metadata_dict(self, attn_metadata: dict, batch_size: int, seq_len: int, src_device: torch.device, + dst_device: torch.device, dtype: torch.dtype, trim: bool) -> dict: + """ + Post-process a dictionary of attention metadata (for multi-layer models). + + This method optimizes the processing by checking if all metadata objects in the + dictionary are the same instance, and if so, processes only once. + + Args: + attn_metadata: Dictionary mapping layer names to attention metadata + batch_size: Batch size + seq_len: Sequence length (for prompt phase) + src_device: Device to create tensors on + dst_device: Device to move tensors to + dtype: Data type for tensors + + Returns: + Dictionary with post-processed attention metadata + """ + from vllm_gaudi.extension.logger import logger + + first_attn_metadata = next(iter(attn_metadata.values())) + updated_attn_metadata = self.process_metadata(first_attn_metadata, batch_size, seq_len, src_device, dst_device, + dtype, trim) + + for key in attn_metadata: + if attn_metadata[key] is first_attn_metadata: + attn_metadata[key] = updated_attn_metadata + else: + msg = f"Different attn_metadata encountered on layer {key}. Processing it individually." + logger.warning(msg) + attn_metadata[key] = self.process_metadata(attn_metadata[key], batch_size, seq_len, src_device, + dst_device, dtype, trim) + return attn_metadata