diff --git a/vllm_gaudi/extension/config.py b/vllm_gaudi/extension/config.py index f627fcd16..ee8d8323f 100644 --- a/vllm_gaudi/extension/config.py +++ b/vllm_gaudi/extension/config.py @@ -143,8 +143,12 @@ def __init__(self, name: str, value_type: Constructor, check: Checker = skip_val self.value_type = value_type self.check = check + @cache + def get_from_env(self, n): + return os.environ.get(n) + def __call__(self, _): - value = os.environ.get(self.name) + value = self.get_from_env(self.name) if value is not None: try: value = self.value_type(value) @@ -157,6 +161,13 @@ def __call__(self, _): return None +class ExperimentalEnv(Env): + def __call__(self, _): + if Env('VLLM_ENABLE_EXPERIMENTAL_FLAGS', boolean)(): + return super(ExperimentalEnv, self).__call__() + return None + + class Value: """A callable that returns the value calculated through its dependencies or overriden by an associated experimental flag""" diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index b5c161dd7..f2d67d8db 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -15,29 +15,29 @@ def get_user_flags(): Env('VLLM_USE_V1', boolean), Env('VLLM_ENABLE_EXPERIMENTAL_FLAGS', boolean), Env('VLLM_EXPONENTIAL_BUCKETING', boolean), - Env('VLLM_PROMPT_BS_BUCKET_MIN', int), - Env('VLLM_PROMPT_BS_BUCKET_STEP', int), - Env('VLLM_PROMPT_BS_BUCKET_MAX', int), - Env('VLLM_PROMPT_QUERY_BUCKET_MIN', int), - Env('VLLM_PROMPT_QUERY_BUCKET_STEP', int), - Env('VLLM_PROMPT_QUERY_BUCKET_MAX', int), - Env('VLLM_PROMPT_SEQ_BUCKET_MIN', int), - Env('VLLM_PROMPT_SEQ_BUCKET_STEP', int), - Env('VLLM_PROMPT_SEQ_BUCKET_MAX', int), - Env('VLLM_PROMPT_CTX_BUCKET_MIN', int), - Env('VLLM_PROMPT_CTX_BUCKET_STEP', int), - Env('VLLM_PROMPT_CTX_BUCKET_MAX', int), - Env('VLLM_DECODE_BS_BUCKET_MIN', int), - Env('VLLM_DECODE_BS_BUCKET_STEP', int), - Env('VLLM_DECODE_BS_BUCKET_MAX', int), - Env('VLLM_DECODE_BLOCK_BUCKET_MIN', int), - Env('VLLM_DECODE_BLOCK_BUCKET_STEP', int), - Env('VLLM_DECODE_BLOCK_BUCKET_MAX', int), - Env('VLLM_DECODE_BLOCK_BUCKET_LIMIT', int), - Env('VLLM_BUCKETING_FROM_FILE', str), + ExperimentalEnv('VLLM_PROMPT_BS_BUCKET_MIN', int), + ExperimentalEnv('VLLM_PROMPT_BS_BUCKET_STEP', int), + ExperimentalEnv('VLLM_PROMPT_BS_BUCKET_MAX', int), + ExperimentalEnv('VLLM_PROMPT_QUERY_BUCKET_MIN', int), + ExperimentalEnv('VLLM_PROMPT_QUERY_BUCKET_STEP', int), + ExperimentalEnv('VLLM_PROMPT_QUERY_BUCKET_MAX', int), + ExperimentalEnv('VLLM_PROMPT_SEQ_BUCKET_MIN', int), + ExperimentalEnv('VLLM_PROMPT_SEQ_BUCKET_STEP', int), + ExperimentalEnv('VLLM_PROMPT_SEQ_BUCKET_MAX', int), + ExperimentalEnv('VLLM_PROMPT_CTX_BUCKET_MIN', int), + ExperimentalEnv('VLLM_PROMPT_CTX_BUCKET_STEP', int), + ExperimentalEnv('VLLM_PROMPT_CTX_BUCKET_MAX', int), + ExperimentalEnv('VLLM_DECODE_BS_BUCKET_MIN', int), + ExperimentalEnv('VLLM_DECODE_BS_BUCKET_STEP', int), + ExperimentalEnv('VLLM_DECODE_BS_BUCKET_MAX', int), + ExperimentalEnv('VLLM_DECODE_BLOCK_BUCKET_MIN', int), + ExperimentalEnv('VLLM_DECODE_BLOCK_BUCKET_STEP', int), + ExperimentalEnv('VLLM_DECODE_BLOCK_BUCKET_MAX', int), + ExperimentalEnv('VLLM_DECODE_BLOCK_BUCKET_LIMIT', int), + ExperimentalEnv('VLLM_BUCKETING_FROM_FILE', str), # Non-vllm flags that are also important to print - Env('EXPERIMENTAL_WEIGHT_SHARING', str), + ExperimentalEnv('EXPERIMENTAL_WEIGHT_SHARING', str), Env('PT_HPU_WEIGHT_SHARING', str), Env('RUNTIME_SCALE_PATCHING', str), @@ -51,13 +51,13 @@ def get_user_flags(): def get_experimental_flags(): flags = [ - Env('VLLM_PT_PROFILE', str), - Env('VLLM_PROFILE_PROMPT', str), - Env('VLLM_PROFILE_DECODE', str), - Env('VLLM_PROFILE_STEPS', list_of(int)), - Env('VLLM_DEFRAG_THRESHOLD', int), - Env('VLLM_DEFRAG_WITH_GRAPHS', boolean), - Env('VLLM_DEBUG', list_of(str), check=for_all(choice('steps', 'defrag', 'fwd'))), + ExperimentalEnv('VLLM_PT_PROFILE', str), + ExperimentalEnv('VLLM_PROFILE_PROMPT', str), + ExperimentalEnv('VLLM_PROFILE_DECODE', str), + ExperimentalEnv('VLLM_PROFILE_STEPS', list_of(int)), + ExperimentalEnv('VLLM_DEFRAG_THRESHOLD', int), + ExperimentalEnv('VLLM_DEFRAG_WITH_GRAPHS', boolean), + ExperimentalEnv('VLLM_DEBUG', list_of(str), check=for_all(choice('steps', 'defrag', 'fwd'))), ] return to_dict(flags)