Skip to content

Commit 2d2551b

Browse files
wip
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent f8dd09a commit 2d2551b

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

vllm/config/compilation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -877,13 +877,17 @@ def custom_op_log_check(self):
877877
)
878878

879879
def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int):
880-
if not self.cudagraph_capture_sizes:
880+
if not self.cudagraph_capture_sizes or multiple_of <= 1:
881881
return
882882

883+
assert self.max_cudagraph_capture_size is not None
884+
883885
rounded_sizes = sorted(
884-
round_up(size, multiple_of)
885-
for size in self.cudagraph_capture_sizes
886-
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
886+
set(
887+
round_up(size, multiple_of)
888+
for size in self.cudagraph_capture_sizes
889+
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
890+
)
887891
)
888892

889893
if len(rounded_sizes) == 0:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4240,6 +4240,8 @@ def _check_and_update_cudagraph_mode(
42404240
# we need to adjust the cudagraph sizes to be a multiple of the uniform
42414241
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
42424242
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
4243+
# Will be removed in the near future when we have seperate cudagraph capture
4244+
# sizes for decode and mixed prefill-decode.
42434245
if (
42444246
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
42454247
and cudagraph_mode.separate_routine()
@@ -4248,6 +4250,8 @@ def _check_and_update_cudagraph_mode(
42484250
self.compilation_config.adjust_cudagraph_sizes_to_be_multipe_of(
42494251
self.uniform_decode_query_len
42504252
)
4253+
self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes
4254+
42514255
self.compilation_config.compute_bs_to_padded_graph_size()
42524256

42534257
# Trigger cudagraph dispatching keys initialization after

0 commit comments

Comments
 (0)