File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change 11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ from functools import cache
45from typing import TYPE_CHECKING , Literal
56
67import torch
@@ -366,6 +367,14 @@ def apply_repetition_penalties_cuda(
366367 )
367368
368369
370+ @cache
371+ def _should_use_cuda_repetition_penalties () -> bool :
372+ """Whether to run the CUDA repetition-penalty kernel on this platform.
373+ Disable on Ada (SM 8.9); fall back to PyTorch there.
374+ """
375+ return not current_platform .is_device_capability (89 )
376+
377+
369378def apply_repetition_penalties (
370379 logits : torch .Tensor ,
371380 prompt_mask : torch .Tensor ,
@@ -380,7 +389,11 @@ def apply_repetition_penalties(
380389 output_mask: A boolean tensor indicating which tokens appear in the output.
381390 repetition_penalties: The repetition penalties of shape (num_seqs, ).
382391 """
383- if logits .is_cuda and logits .is_contiguous ():
392+ if (
393+ logits .is_cuda
394+ and logits .is_contiguous ()
395+ and _should_use_cuda_repetition_penalties ()
396+ ):
384397 apply_repetition_penalties_cuda (
385398 logits , prompt_mask , output_mask , repetition_penalties
386399 )
You can’t perform that action at this time.
0 commit comments