22# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
33###############################################################################
44
5+ import contextlib
56import gc
67import gzip
78import json
1920
2021from vllm_gaudi .extension .utils import is_fake_hpu
2122from .logger import logger
23+ from vllm_gaudi .extension .runtime import get_config
2224
2325
2426class FileWriter (threading .Thread ):
@@ -142,8 +144,7 @@ class HabanaHighLevelProfiler:
142144 event_cache : List [Any ] = []
143145
144146 def __init__ (self , vllm_instance_id = None ):
145- self .enabled = os .getenv ('VLLM_PROFILER_ENABLED' , 'false' ).lower () == 'true' and int (os .getenv ('RANK' ,
146- '0' )) == 0
147+ self .enabled = get_config ().high_level_profiler_enabled and int (os .getenv ('RANK' , '0' )) == 0
147148 self .pid = os .getpid ()
148149 if self .enabled :
149150 self .vllm_instance_id = vllm_instance_id if vllm_instance_id is not None \
@@ -158,6 +159,8 @@ def __init__(self, vllm_instance_id=None):
158159 file_writer .start ()
159160 if os .getenv ('VLLM_PROFILER_ENABLED' ) == 'full' :
160161 self .enabled = True # don't save separate high-level traces
162+ self .gc_track_recompiles = get_config ().track_graph_compilation
163+ self .num_graph_compilations = 0
161164
162165 def _dump_with_sep (self , entry ):
163166 entry = json .dumps (entry ) + ','
@@ -256,11 +259,45 @@ def handler_fn(prof) -> None:
256259 def record_event (self , type , name , args = None ):
257260 if self .enabled :
258261 self .start (type , name , args )
259- yield
262+ with self .track_graph_compile (type , args ) \
263+ if self .gc_track_recompiles \
264+ else contextlib .nullcontext ():
265+ yield
260266 self .end ()
261267 else :
262268 yield
263269
270+ def record_block (self , type , name , ts , dur , args = None ):
271+ if self .enabled :
272+ event = {
273+ 'pid' : self .pid ,
274+ 'tid' : self .event_tid [type ],
275+ 'ph' : 'X' ,
276+ 'name' : name ,
277+ 'ts' : ts ,
278+ 'dur' : dur ,
279+ 'args' : args
280+ }
281+ self ._dump_with_sep (event )
282+
283+ @contextmanager
284+ def track_graph_compile (self , type , args = None ):
285+ start = self .get_timestamp_us ()
286+ import habana_frameworks .torch as htorch
287+ from habana_frameworks .torch .hpu .metrics import metric_localcontext
288+ with metric_localcontext ("graph_compilation" ) as gc :
289+ yield
290+ htorch .hpu .synchronize ()
291+ if gc .stats ()[0 ][1 ] != 0 :
292+ compile_start_time = start
293+ for recipe in gc .stats ()[3 ][1 ]:
294+ recipe_name = recipe [0 ]
295+ compile_time = recipe [1 ]
296+ self .num_graph_compilations += 1
297+ self .record_counter (compile_start_time , {'cumulative_graph_compilations' : self .num_graph_compilations })
298+ self .record_block (type , 'GRAPH COMPILE: ' + recipe_name , compile_start_time , compile_time , args )
299+ compile_start_time += compile_time
300+
264301
265302# Adapted from https://stackoverflow.com/a/49361727
266303def format_bytes (size ):
0 commit comments