-
Notifications
You must be signed in to change notification settings - Fork 66
[Attention Metadata Overhaul 2/N] Move metadata processing outside HPUModelAdapter, prepare biases on CPU #530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR moves HPU attention metadata processing from the HpuModelAdapter into a dedicated HPUAttentionMetadataProcessor class, allowing metadata biases to be computed on CPU and copied asynchronously to HPU. This refactoring removes metadata processing logic from the model forward path and handles it at input preparation time instead.
Key Changes:
- Extracted metadata processing into a standalone
HPUAttentionMetadataProcessorclass - Moved metadata processing to occur during input preparation (prefill/decode batch formation) rather than in model forward
- Added support for processing metadata on CPU with async copy to HPU device
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
…or' into private/kzawora/metadata_process_cpu Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def metadata_update_with_trim(obj: object, typename: str, trim: bool, **to_override): | ||
| if trim: | ||
| return custom_tuple_replace(obj, typename, **to_override) | ||
|
|
||
| for key in to_override: | ||
| assert hasattr(obj, key), f"Field {key} must exist in untrimmed metadata." | ||
| setattr(obj, key, to_override[key]) | ||
| return obj |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function metadata_update_with_trim lacks a docstring explaining its purpose, parameters, return value, and the distinction between trimmed and untrimmed metadata handling. This is especially important given the conditional logic and the use of setattr for dynamic attribute modification.
| assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias" | ||
| context_lens_t = prefill_metadata.context_lens_tensor | ||
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message should be more specific by indicating which phase (prefill) or operation is being performed when this assertion fails, to help with debugging.
| assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias" | |
| context_lens_t = prefill_metadata.context_lens_tensor | |
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" | |
| assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias during prefill (prompt) phase" | |
| context_lens_t = prefill_metadata.context_lens_tensor | |
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias during prefill (prompt) phase" |
| seq_lens_t = prefill_metadata.seq_lens_tensor | ||
| assert seq_lens_t is not None, "seq_lens_tensor is required to build attn_bias" | ||
| context_lens_t = prefill_metadata.context_lens_tensor | ||
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message should be more specific by indicating which phase (prefill) or operation is being performed when this assertion fails, to help with debugging.
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" | |
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias during prefill phase" |
|
|
||
| if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None: | ||
| context_lens_t = prefill_metadata.context_lens_tensor | ||
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message should be more specific by indicating this is for sliding window attention to aid debugging when this assertion fails.
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias" | |
| assert context_lens_t is not None, "context_lens_tensor is required to build attn_bias for sliding window attention" |
| # NOTE(kzawora): I'm not sure why we set block mapping twice for sliding window | ||
| # - we should check if that can be reduced to a single call. |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This TODO-style comment expresses uncertainty about the implementation. Either investigate and resolve this concern, or rephrase as a clearer explanation if the double call is intentional (e.g., for separate window and non-window blocks).
| # NOTE(kzawora): I'm not sure why we set block mapping twice for sliding window | |
| # - we should check if that can be reduced to a single call. | |
| # For sliding window, we set block mapping twice: once for the base mapping and once for the sliding window mapping. | |
| # This ensures both standard and sliding window block mappings are correctly applied. |
requires #526, the next logical step - we remove usage of metadata postprocessor inside HpuModelAdapter and do it at input preparation time, and on CPU, copying data asynchronously to HPU. I needed also to change some stuff around for the processor to accept untrimmed metadata - this works as-is, but unfortunately I've noticed pretty significant performance drop in small models e2e perf.