Skip to content

Commit 4eacf50

Browse files
committed
[Model][Qwen3VL] Simplify get_mrope_input_positions using numpy
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
1 parent 68a72a5 commit 4eacf50

File tree

1 file changed

+11
-32
lines changed

1 file changed

+11
-32
lines changed

vllm/model_executor/models/qwen3_vl.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,13 +1434,11 @@ def get_mrope_input_positions(
14341434
vision_start_token_id = hf_config.vision_start_token_id
14351435
spatial_merge_size = hf_config.vision_config.spatial_merge_size
14361436

1437-
input_tokens_tensor = torch.tensor(input_tokens)
1438-
vision_start_indices = torch.argwhere(
1439-
input_tokens_tensor == vision_start_token_id
1440-
).squeeze(1)
1441-
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
1442-
image_nums = (vision_tokens == image_token_id).sum()
1443-
video_nums = (vision_tokens == video_token_id).sum()
1437+
input_tokens_array = np.array(input_tokens)
1438+
vision_start_mask = input_tokens_array == vision_start_token_id
1439+
vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1]
1440+
image_nums = np.count_nonzero(vision_tokens == image_token_id)
1441+
video_nums = np.count_nonzero(vision_tokens == video_token_id)
14441442
llm_pos_ids_list: list = []
14451443

14461444
st = 0
@@ -1484,43 +1482,24 @@ def get_mrope_input_positions(
14841482

14851483
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
14861484
llm_pos_ids_list.append(
1487-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1485+
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
14881486
)
14891487

1490-
t_index = (
1491-
torch.arange(llm_grid_t)
1492-
.view(-1, 1)
1493-
.expand(-1, llm_grid_h * llm_grid_w)
1494-
.flatten()
1495-
)
1496-
h_index = (
1497-
torch.arange(llm_grid_h)
1498-
.view(1, -1, 1)
1499-
.expand(llm_grid_t, -1, llm_grid_w)
1500-
.flatten()
1501-
)
1502-
w_index = (
1503-
torch.arange(llm_grid_w)
1504-
.view(1, 1, -1)
1505-
.expand(llm_grid_t, llm_grid_h, -1)
1506-
.flatten()
1507-
)
1508-
llm_pos_ids_list.append(
1509-
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
1510-
)
1488+
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
1489+
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
15111490
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
15121491

15131492
if st < len(input_tokens):
15141493
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
15151494
text_len = len(input_tokens) - st
15161495
llm_pos_ids_list.append(
1517-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1496+
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
15181497
)
15191498

1520-
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1499+
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
15211500
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
15221501
llm_positions = llm_positions[:, context_len:seq_len]
1523-
return llm_positions, mrope_position_delta
1502+
return torch.from_numpy(llm_positions), mrope_position_delta
15241503

15251504
def get_language_model(self) -> torch.nn.Module:
15261505
return self.language_model

0 commit comments

Comments
 (0)