Skip to content
1 change: 0 additions & 1 deletion vllm_gaudi/extension/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

def get_user_flags():
flags = [
Env('VLLM_USE_V1', boolean),
Env('VLLM_ENABLE_EXPERIMENTAL_FLAGS', boolean),
Env('VLLM_EXPONENTIAL_BUCKETING', boolean),
Env('VLLM_PROMPT_BS_BUCKET_MIN', int),
Expand Down
1 change: 1 addition & 0 deletions vllm_gaudi/extension/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def finalize_config():

user_flags = filter_defined(detected, USER_FLAGS)
experimental_flags = filter_defined(detected, EXPERIMENTAL_FLAGS)
experimental_flags = {k: v for k, v in user_flags.items() if k not in user_flags}
environment_values = filter_defined(detected, ENVIRONMENT_VALUES)
feature_values = filter_defined(detected, FEATURE_VALUES)

Expand Down
77 changes: 45 additions & 32 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import sys
import time
from tqdm import tqdm
from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, Callable, Optional, TypeAlias, Union, cast)

Expand Down Expand Up @@ -3599,7 +3600,7 @@ def log_warmup(self, phase, i, max_i, first_dim, second_dim, third_dim, causal=F
f"query_len:{second_dim} "
f"num_blocks:{third_dim} "
f"free_mem:{free_mem}")
logger.info(msg)
tqdm.write(msg)

def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, img_args):
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
Expand Down Expand Up @@ -3749,45 +3750,57 @@ def warmup_graphs(self, buckets, is_prompt, kv_caches, starting_mem=0, total_bat
idx = 0
num_candidates = len(buckets)
captured_all = True
for idx, (batch_size, seq_len, num_blocks) in enumerate(reversed(buckets)):
if seq_len > self.max_num_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
phase = f"Graph/{'prompt' if is_prompt else 'decode'}"
if is_prompt:
batch_seq = batch_size * seq_len * num_blocks if num_blocks else batch_size * seq_len
else:
batch_seq = batch_size

graphed_bucket = (batch_size, seq_len, num_blocks, is_prompt)
if graphed_bucket in self.graphed_buckets:
continue
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len, num_blocks)
prompt_cfg, decode_cfg = None, None
with HabanaMemoryProfiler() as mem_prof:
developer_settings = get_config().VLLM_ENABLE_EXPERIMENTAL_FLAGS
phase = 'Prompt' if is_prompt else 'Decode'
desc = f'{phase} warmup processing: '
with tqdm(total=num_candidates, desc=desc, unit="item") as pbar:
for idx, (batch_size, seq_len, num_blocks) in enumerate(reversed(buckets)):
if seq_len > self.max_num_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
if is_prompt:
prompt_cfg = (batch_size, seq_len, num_blocks)
batch_seq = batch_size * seq_len * num_blocks if num_blocks else batch_size * seq_len
else:
decode_cfg = (batch_size, 1, num_blocks)
self._prepare_dummy_scenario(prompt_cfg, decode_cfg)
# TODO(kzawora): align_workers
used_mem = mem_prof.consumed_device_memory
total_mem += used_mem
total_batch_seq += batch_seq
batch_seq = batch_size

graphed_bucket = (batch_size, seq_len, num_blocks, is_prompt)
if graphed_bucket in self.graphed_buckets:
continue
self.graphed_buckets.add(graphed_bucket)
if developer_settings:
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len, num_blocks)
prompt_cfg, decode_cfg = None, None
with HabanaMemoryProfiler() as mem_prof:
if is_prompt:
prompt_cfg = (batch_size, seq_len, num_blocks)
else:
decode_cfg = (batch_size, 1, num_blocks)
self._prepare_dummy_scenario(prompt_cfg, decode_cfg)
# TODO(kzawora): align_workers
used_mem = mem_prof.consumed_device_memory
total_mem += used_mem
total_batch_seq += batch_seq

pbar.set_postfix_str(f"{idx}/{num_candidates}")
pbar.update(1)

return total_mem, total_batch_seq, captured_all

def warmup_unified_graphs(self, buckets, kv_cache):
idx = 0
num_candidates = len(buckets)
for idx, (query, shared_ctx, unique_ctx, is_causal) in enumerate(reversed(buckets)):
unified_cfg = (query, shared_ctx, unique_ctx, is_causal)
if unified_cfg in self.graphed_buckets:
continue
self.graphed_buckets.add(unified_cfg)
self.log_warmup("Unified CFG", idx, num_candidates, query, shared_ctx, unique_ctx, is_causal)
self._prepare_dummy_unified_scenario(unified_cfg)
developer_settings = get_config().VLLM_ENABLE_EXPERIMENTAL_FLAGS
with tqdm(total=num_candidates, desc="Unified Attention warmup", unit="item") as pbar:
for idx, (query, shared_ctx, unique_ctx, is_causal) in enumerate(reversed(buckets)):
unified_cfg = (query, shared_ctx, unique_ctx, is_causal)
if unified_cfg in self.graphed_buckets:
continue
self.graphed_buckets.add(unified_cfg)
if developer_settings:
self.log_warmup("Unified CFG", idx, num_candidates, query, shared_ctx, unique_ctx, is_causal)
self._prepare_dummy_unified_scenario(unified_cfg)
pbar.set_postfix_str(f"{idx}/{num_candidates}")
pbar.update(1)

def _add_dummy_request(self,
requests,
Expand Down