Skip to content

Commit 257dada

Browse files
Add graph compilation tracking to high level profiler (#50)
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
1 parent a2c77c6 commit 257dada

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

vllm_gaudi/extension/features.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,7 @@ def get_features():
9797
1.,
9898
env_var='VLLM_UNIFIED_ATTENTION_SHARED_CACHE_RATIO',
9999
env_var_type=float),
100+
Value('high_level_profiler_enabled', False, env_var='VLLM_PROFILER_ENABLED', env_var_type=boolean),
101+
Value('track_graph_compilation', False, env_var='PT_HPU_METRICS_GC_DETAILS', env_var_type=boolean),
100102
]
101103
return split_values_and_flags(features)

vllm_gaudi/extension/profiler.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
33
###############################################################################
44

5+
import contextlib
56
import gc
67
import gzip
78
import json
@@ -19,6 +20,7 @@
1920

2021
from vllm_gaudi.extension.utils import is_fake_hpu
2122
from .logger import logger
23+
from vllm_gaudi.extension.runtime import get_config
2224

2325

2426
class FileWriter(threading.Thread):
@@ -142,8 +144,7 @@ class HabanaHighLevelProfiler:
142144
event_cache: List[Any] = []
143145

144146
def __init__(self, vllm_instance_id=None):
145-
self.enabled = os.getenv('VLLM_PROFILER_ENABLED', 'false').lower() == 'true' and int(os.getenv('RANK',
146-
'0')) == 0
147+
self.enabled = get_config().high_level_profiler_enabled and int(os.getenv('RANK', '0')) == 0
147148
self.pid = os.getpid()
148149
if self.enabled:
149150
self.vllm_instance_id = vllm_instance_id if vllm_instance_id is not None \
@@ -158,6 +159,8 @@ def __init__(self, vllm_instance_id=None):
158159
file_writer.start()
159160
if os.getenv('VLLM_PROFILER_ENABLED') == 'full':
160161
self.enabled = True # don't save separate high-level traces
162+
self.gc_track_recompiles = get_config().track_graph_compilation
163+
self.num_graph_compilations = 0
161164

162165
def _dump_with_sep(self, entry):
163166
entry = json.dumps(entry) + ','
@@ -256,11 +259,45 @@ def handler_fn(prof) -> None:
256259
def record_event(self, type, name, args=None):
257260
if self.enabled:
258261
self.start(type, name, args)
259-
yield
262+
with self.track_graph_compile(type, args) \
263+
if self.gc_track_recompiles \
264+
else contextlib.nullcontext():
265+
yield
260266
self.end()
261267
else:
262268
yield
263269

270+
def record_block(self, type, name, ts, dur, args=None):
271+
if self.enabled:
272+
event = {
273+
'pid': self.pid,
274+
'tid': self.event_tid[type],
275+
'ph': 'X',
276+
'name': name,
277+
'ts': ts,
278+
'dur': dur,
279+
'args': args
280+
}
281+
self._dump_with_sep(event)
282+
283+
@contextmanager
284+
def track_graph_compile(self, type, args=None):
285+
start = self.get_timestamp_us()
286+
import habana_frameworks.torch as htorch
287+
from habana_frameworks.torch.hpu.metrics import metric_localcontext
288+
with metric_localcontext("graph_compilation") as gc:
289+
yield
290+
htorch.hpu.synchronize()
291+
if gc.stats()[0][1] != 0:
292+
compile_start_time = start
293+
for recipe in gc.stats()[3][1]:
294+
recipe_name = recipe[0]
295+
compile_time = recipe[1]
296+
self.num_graph_compilations += 1
297+
self.record_counter(compile_start_time, {'cumulative_graph_compilations': self.num_graph_compilations})
298+
self.record_block(type, 'GRAPH COMPILE: ' + recipe_name, compile_start_time, compile_time, args)
299+
compile_start_time += compile_time
300+
264301

265302
# Adapted from https://stackoverflow.com/a/49361727
266303
def format_bytes(size):

vllm_gaudi/v1/worker/hpu_worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm.v1.outputs import (DraftTokenIds, AsyncModelRunnerOutput, ModelRunnerOutput)
2626
from vllm.v1.worker.utils import bind_kv_cache
2727
from vllm_gaudi.utils import is_fake_hpu
28-
from vllm_gaudi.v1.worker.hpu_model_runner import HPUModelRunner, bool_helper
28+
from vllm_gaudi.v1.worker.hpu_model_runner import HPUModelRunner
2929
from vllm.v1.worker.worker_base import WorkerBase
3030

3131
from vllm_gaudi.extension.logger import logger as init_logger
@@ -82,8 +82,7 @@ def __init__(
8282
from vllm.utils.import_utils import init_cached_hf_modules
8383
init_cached_hf_modules()
8484

85-
self.gc_track_recompiles = bool("PT_HPU_METRICS_GC_DETAILS" in os.environ
86-
and bool_helper(os.getenv("PT_HPU_METRICS_GC_DETAILS")))
85+
self.gc_track_recompiles = get_config().track_graph_compilation and not get_config().high_level_profiler_enabled
8786
self.step = 0
8887
self.profile_steps = get_config().VLLM_PROFILE_STEPS
8988
self.step_profiler = setup_step_profiler(self.profile_steps)

0 commit comments

Comments
 (0)