File tree Expand file tree Collapse file tree 2 files changed +12
-4
lines changed Expand file tree Collapse file tree 2 files changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -877,13 +877,17 @@ def custom_op_log_check(self):
877877 )
878878
879879 def adjust_cudagraph_sizes_to_be_multipe_of (self , multiple_of : int ):
880- if not self .cudagraph_capture_sizes :
880+ if not self .cudagraph_capture_sizes or multiple_of <= 1 :
881881 return
882882
883+ assert self .max_cudagraph_capture_size is not None
884+
883885 rounded_sizes = sorted (
884- round_up (size , multiple_of )
885- for size in self .cudagraph_capture_sizes
886- if round_up (size , multiple_of ) <= self .max_cudagraph_capture_size
886+ set (
887+ round_up (size , multiple_of )
888+ for size in self .cudagraph_capture_sizes
889+ if round_up (size , multiple_of ) <= self .max_cudagraph_capture_size
890+ )
887891 )
888892
889893 if len (rounded_sizes ) == 0 :
Original file line number Diff line number Diff line change @@ -4240,6 +4240,8 @@ def _check_and_update_cudagraph_mode(
42404240 # we need to adjust the cudagraph sizes to be a multiple of the uniform
42414241 # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
42424242 # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
4243+ # Will be removed in the near future when we have seperate cudagraph capture
4244+ # sizes for decode and mixed prefill-decode.
42434245 if (
42444246 cudagraph_mode .decode_mode () == CUDAGraphMode .FULL
42454247 and cudagraph_mode .separate_routine ()
@@ -4248,6 +4250,8 @@ def _check_and_update_cudagraph_mode(
42484250 self .compilation_config .adjust_cudagraph_sizes_to_be_multipe_of (
42494251 self .uniform_decode_query_len
42504252 )
4253+ self .cudagraph_batch_sizes = self .compilation_config .cudagraph_capture_sizes
4254+
42514255 self .compilation_config .compute_bs_to_padded_graph_size ()
42524256
42534257 # Trigger cudagraph dispatching keys initialization after
You can’t perform that action at this time.
0 commit comments