Skip to content

Commit eee43cc

Browse files
committed
use current_platform.is_device_capability
Signed-off-by: Huamin Li <3ericli@gmail.com>
1 parent 07d6145 commit eee43cc

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

vllm/_custom_ops.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from functools import cache
45
from typing import TYPE_CHECKING, Literal
56

67
import torch
@@ -366,6 +367,14 @@ def apply_repetition_penalties_cuda(
366367
)
367368

368369

370+
@cache
371+
def _should_use_cuda_repetition_penalties() -> bool:
372+
"""Whether to run the CUDA repetition-penalty kernel on this platform.
373+
Disable on Ada (SM 8.9); fall back to PyTorch there.
374+
"""
375+
return not current_platform.is_device_capability(89)
376+
377+
369378
def apply_repetition_penalties(
370379
logits: torch.Tensor,
371380
prompt_mask: torch.Tensor,
@@ -380,7 +389,11 @@ def apply_repetition_penalties(
380389
output_mask: A boolean tensor indicating which tokens appear in the output.
381390
repetition_penalties: The repetition penalties of shape (num_seqs, ).
382391
"""
383-
if logits.is_cuda and logits.is_contiguous():
392+
if (
393+
logits.is_cuda
394+
and logits.is_contiguous()
395+
and _should_use_cuda_repetition_penalties()
396+
):
384397
apply_repetition_penalties_cuda(
385398
logits, prompt_mask, output_mask, repetition_penalties
386399
)

0 commit comments

Comments
 (0)