Skip to content

Commit 82085eb

Browse files
Fix preemption handling (#524)
This PR fixes multitude of bugs we had in preemption handling: - Fixed output token update of `CachedRequestState` - was updated twice per iteration, resulting in doubled tokens - this broke preemption when request was being re-added to input batch - Batch preparation now uses input+output tokens in prefill for preempted sequences (both non-unified and unified attention) - Preempted sequences now get correctly recognized as prefills after they exceed their original prefill length (e.g. prompt was 3 tokens, generated 1024 before preemption - the sequence would get treated as decode after first 3 tokens) - Removed some incorrect assumptions about prefills (can have no pre-existing output tokens) Scenarios with preemptions yield proper accuracy, as can be tested with very low `gpu_memory_utilization` and relatively high `max_num_seqs`: ``` PT_HPU_LAZY_MODE=1 VLLM_SKIP_WARMUP=true lm_eval --model vllm --model_args pretrained=/mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-8B-Instruct/,enforce_eager=False,dtype=bfloat16,max_num_seqs=128,gpu_memory_utilization=0.05,max_model_len=4096,enable_prefix_caching=True,add_bos_token=false,tensor_parallel_size=1,max_gen_toks=2048 --tasks gsm8k_cot_llama --batch_size auto --trust_remote_code --apply_chat_template --fewshot_as_multiturn --num_fewshot 8 | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |---------------|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k_cot_llama| 3|flexible-extract| 8|exact_match|↑ |0.8408|± |0.0101| | | |strict-match | 8|exact_match|↑ |0.8415|± |0.0101| ``` --------- Signed-off-by: Konrad Zawora <kzawora@habana.ai>
1 parent f4aeae8 commit 82085eb

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

vllm_gaudi/v1/worker/hpu_input_batch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,15 @@ def add_request(
256256
start_idx = num_prompt_tokens
257257
end_idx = start_idx + len(request.output_token_ids)
258258
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
259+
#NOTE(kzawora): In non-preemption scenario,
260+
# self.input_batch.num_prompt_tokens[batch_idx] == self.input_batch.num_tokens[batch_idx].
261+
# In preemption scenario, we want num_prompt_tokens to also include the tokens emitted before preemption,
262+
# as that is used as basis for recomputing prefill.
263+
# This also assumes that preemption is complete and reduces num_computed_tokens to 0 and preempted sequences
264+
# don't retain any originally used cache blocks.
265+
if request.num_computed_tokens == 0:
266+
self.num_prompt_tokens[req_index] = num_prompt_tokens + len(request.output_token_ids)
267+
259268
# Number of token ids in token_ids_cpu.
260269
# NOTE(woosuk): This may include spec decode tokens.
261270
self.num_tokens[req_index] = request.num_tokens

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,6 @@ def _get_prompts_and_decodes(
14891489
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
14901490
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
14911491
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1492-
14931492
if num_computed_tokens < num_prompt_tokens and \
14941493
not self.is_decoder_only(req_id):
14951494
# This is prompt
@@ -1518,11 +1517,7 @@ def _get_prompts_and_decodes(
15181517

15191518
# Must be prompt
15201519
assert num_computed_tokens < num_prompt_tokens
1521-
num_output_tokens = len(self.requests[req_id].output_token_ids)
1522-
if not has_kv_transfer_group():
1523-
#P case num_output_tokens has non 0
1524-
assert num_output_tokens == 0, \
1525-
f'req_id: {req_id}, {num_output_tokens}'
1520+
# NOTE(kzawora): In preempted sequences, num_output_tokens can be > 0, and still be a valid prefill
15261521

15271522
prompt_req_ids.append(req_id)
15281523
prompt_scheduled_tokens.append(num_scheduled_tokens)
@@ -1678,26 +1673,29 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul
16781673

16791674
for batch_idx in range(num_decodes, num_reqs):
16801675
req_id = self.input_batch.req_ids[batch_idx]
1681-
context_len = self.input_batch.num_computed_tokens_cpu[batch_idx]
1682-
query_len = num_scheduled_tokens[batch_idx]
1676+
seq_num_computed_tokens = self.input_batch.num_computed_tokens_cpu[batch_idx]
1677+
seq_num_scheduled_tokens = num_scheduled_tokens[batch_idx]
16831678

1684-
token_ids = self.input_batch.token_ids_cpu[batch_idx, context_len:context_len + query_len].tolist()
1679+
token_ids = self.input_batch.token_ids_cpu[batch_idx, seq_num_computed_tokens:seq_num_computed_tokens +
1680+
seq_num_scheduled_tokens].tolist()
16851681

1686-
num_blocks = round_up(context_len + query_len, self.block_size) // self.block_size
1682+
num_blocks = round_up(seq_num_computed_tokens + seq_num_scheduled_tokens,
1683+
self.block_size) // self.block_size
16871684
blocks = block_table_cpu_tensor[batch_idx, :num_blocks].tolist()
16881685
if not warmup:
16891686
blocks = [self.defragmenter.resolve(b) for b in blocks]
1690-
1691-
prompt_tokens = self.input_batch.num_prompt_tokens[batch_idx]
1692-
# TODO: Fix non-prompt case
1693-
num_output_logits = max(0, context_len + query_len - prompt_tokens + 1)
1694-
logits_positions = list(range(query_len - num_output_logits, query_len))
1687+
#NOTE(kzawora): In non-preemption scenario,
1688+
# self.input_batch.num_prompt_tokens[batch_idx] == self.input_batch.num_tokens[batch_idx].
1689+
# In preemption scenario num_tokens will also include the tokens emitted before preemption
1690+
num_prompt_tokens = self.input_batch.num_prompt_tokens[batch_idx]
1691+
num_output_logits = max(0, seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1)
1692+
logits_positions = list(range(seq_num_scheduled_tokens - num_output_logits, seq_num_scheduled_tokens))
16951693

16961694
new_batch_contents = BatchContents(
16971695
req_ids=[req_id],
16981696
token_ids=[token_ids],
1699-
context_lens=[context_len],
1700-
prompt_lens=[prompt_tokens],
1697+
context_lens=[seq_num_computed_tokens],
1698+
prompt_lens=[num_prompt_tokens],
17011699
blocks=[blocks],
17021700
logits_positions=[logits_positions],
17031701
)
@@ -3331,7 +3329,7 @@ def execute_model(
33313329
num_tokens = len(token_ids)
33323330
self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids
33333331
self.input_batch.num_tokens[i] += len(token_ids)
3334-
req_state.output_token_ids.extend(token_ids)
3332+
33353333
# NOTE(chendi): enable cache based on PR(#20291)
33363334
# Cache the sampled tokens in the model runner, so that the scheduler
33373335
# doesn't need to send them back.

0 commit comments

Comments
 (0)