Skip to content

Commit a8494d5

Browse files
committed
[Model][Qwen3VL] Simplify get_mrope_input_positions using numpy
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
1 parent df4d3a4 commit a8494d5

File tree

1 file changed

+11
-33
lines changed

1 file changed

+11
-33
lines changed

vllm/model_executor/models/qwen3_vl.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,13 +1432,11 @@ def get_mrope_input_positions(
14321432
vision_start_token_id = hf_config.vision_start_token_id
14331433
spatial_merge_size = hf_config.vision_config.spatial_merge_size
14341434

1435-
input_tokens_tensor = torch.tensor(input_tokens)
1436-
vision_start_indices = torch.argwhere(
1437-
input_tokens_tensor == vision_start_token_id
1438-
).squeeze(1)
1439-
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
1440-
image_nums = (vision_tokens == image_token_id).sum()
1441-
video_nums = (vision_tokens == video_token_id).sum()
1435+
input_tokens_array = np.array(input_tokens)
1436+
vision_start_mask = input_tokens_array == vision_start_token_id
1437+
vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1]
1438+
image_nums = np.count_nonzero(vision_tokens == image_token_id)
1439+
video_nums = np.count_nonzero(vision_tokens == video_token_id)
14421440
llm_pos_ids_list: list = []
14431441

14441442
st = 0
@@ -1474,43 +1472,23 @@ def get_mrope_input_positions(
14741472

14751473
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
14761474
llm_pos_ids_list.append(
1477-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1475+
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
14781476
)
14791477

1480-
t_index = (
1481-
torch.arange(llm_grid_t)
1482-
.view(-1, 1)
1483-
.expand(-1, llm_grid_h * llm_grid_w)
1484-
.flatten()
1485-
)
1486-
h_index = (
1487-
torch.arange(llm_grid_h)
1488-
.view(1, -1, 1)
1489-
.expand(llm_grid_t, -1, llm_grid_w)
1490-
.flatten()
1491-
)
1492-
w_index = (
1493-
torch.arange(llm_grid_w)
1494-
.view(1, 1, -1)
1495-
.expand(llm_grid_t, llm_grid_h, -1)
1496-
.flatten()
1497-
)
1498-
llm_pos_ids_list.append(
1499-
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
1500-
)
1478+
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
1479+
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
15011480
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
15021481

15031482
if st < len(input_tokens):
15041483
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
15051484
text_len = len(input_tokens) - st
15061485
llm_pos_ids_list.append(
1507-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1486+
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
15081487
)
15091488

1510-
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1489+
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
15111490
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1512-
1513-
return llm_positions, mrope_position_delta
1491+
return torch.from_numpy(llm_positions), mrope_position_delta
15141492

15151493
def get_language_model(self) -> torch.nn.Module:
15161494
return self.language_model

0 commit comments

Comments
 (0)