From eee43ccb53be5b87dd7e3f1b23c1269df47e2d8b Mon Sep 17 00:00:00 2001 From: Huamin Li <3ericli@gmail.com> Date: Wed, 5 Nov 2025 19:36:10 -0800 Subject: [PATCH] use current_platform.is_device_capability Signed-off-by: Huamin Li <3ericli@gmail.com> --- vllm/_custom_ops.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index cfcf534c613f..2fbc54419004 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import cache from typing import TYPE_CHECKING, Literal import torch @@ -366,6 +367,14 @@ def apply_repetition_penalties_cuda( ) +@cache +def _should_use_cuda_repetition_penalties() -> bool: + """Whether to run the CUDA repetition-penalty kernel on this platform. + Disable on Ada (SM 8.9); fall back to PyTorch there. + """ + return not current_platform.is_device_capability(89) + + def apply_repetition_penalties( logits: torch.Tensor, prompt_mask: torch.Tensor, @@ -380,7 +389,11 @@ def apply_repetition_penalties( output_mask: A boolean tensor indicating which tokens appear in the output. repetition_penalties: The repetition penalties of shape (num_seqs, ). """ - if logits.is_cuda and logits.is_contiguous(): + if ( + logits.is_cuda + and logits.is_contiguous() + and _should_use_cuda_repetition_penalties() + ): apply_repetition_penalties_cuda( logits, prompt_mask, output_mask, repetition_penalties )