Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from functools import cache
from typing import TYPE_CHECKING, Literal

import torch
Expand Down Expand Up @@ -366,6 +367,14 @@ def apply_repetition_penalties_cuda(
)


@cache
def _should_use_cuda_repetition_penalties() -> bool:
"""Whether to run the CUDA repetition-penalty kernel on this platform.
Disable on Ada (SM 8.9); fall back to PyTorch there.
"""
return not current_platform.is_device_capability(89)


def apply_repetition_penalties(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
Expand All @@ -380,7 +389,11 @@ def apply_repetition_penalties(
output_mask: A boolean tensor indicating which tokens appear in the output.
repetition_penalties: The repetition penalties of shape (num_seqs, ).
"""
if logits.is_cuda and logits.is_contiguous():
if (
logits.is_cuda
and logits.is_contiguous()
and _should_use_cuda_repetition_penalties()
):
apply_repetition_penalties_cuda(
logits, prompt_mask, output_mask, repetition_penalties
)
Expand Down