From a8494d52bc47759f2deb063255d2270d80b70be3 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Fri, 7 Nov 2025 13:27:24 +0000 Subject: [PATCH] [Model][Qwen3VL] Simplify `get_mrope_input_positions` using numpy Signed-off-by: Lukas Geiger --- vllm/model_executor/models/qwen3_vl.py | 44 +++++++------------------- 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index d880e6015e5d..87494c6735cd 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1432,13 +1432,11 @@ def get_mrope_input_positions( vision_start_token_id = hf_config.vision_start_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id - ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() + input_tokens_array = np.array(input_tokens) + vision_start_mask = input_tokens_array == vision_start_token_id + vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1] + image_nums = np.count_nonzero(vision_tokens == image_token_id) + video_nums = np.count_nonzero(vision_tokens == video_token_id) llm_pos_ids_list: list = [] st = 0 @@ -1474,43 +1472,23 @@ def get_mrope_input_positions( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) + grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)) + llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - - return llm_positions, mrope_position_delta + return torch.from_numpy(llm_positions), mrope_position_delta def get_language_model(self) -> torch.nn.Module: return self.language_model