diff --git a/vllm/envs.py b/vllm/envs.py index 9cdb7ea974b8..078e5c38f0f4 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -113,6 +113,7 @@ VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True + VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -944,6 +945,11 @@ def get_vllm_port() -> int | None: os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower() in ("true", "1") ), + # Whether to use aiter triton kernels for gemm ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1") + ), # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") @@ -1586,6 +1592,7 @@ def compute_hash() -> str: "VLLM_ROCM_USE_TRITON_ROPE", "VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", + "VLLM_ROCM_USE_AITER_TRITON_GEMM", "VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_FP8_PADDING", "VLLM_ROCM_MOE_PADDING", diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 3d90c9513683..b17bdd0b7207 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -106,6 +106,7 @@ def default_unquantized_gemm( def use_aiter_triton_gemm(n, m, k, dtype): if ( envs.VLLM_ROCM_USE_AITER == 0 + or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0 # MI300's - fp8nuz=True or current_platform.is_fp8_fnuz() or dtype not in [torch.float16, torch.bfloat16]