Skip to content

Conversation

@lgeiger
Copy link
Contributor

@lgeiger lgeiger commented Nov 7, 2025

Purpose

This PR simplifies the Qwen3VL get_mrope_input_positions computation by using np.indices to make the code more readable.

In a profile I also noticed that the torch CPUs ops, especially torch.tensor(list[int]) are slower than their numpy equivalents so this PR also changes the computation to numpy.

Before:

Screenshot 2025-11-07 at 15 14 40 Screenshot 2025-11-07 at 15 15 10

After:

Screenshot 2025-11-07 at 15 15 24 Screenshot 2025-11-07 at 15 14 53

Test Plan

VLLM_WORKER_MULTIPROC_METHOD=spawn lm_eval --model vllm-vlm --model_args "pretrained=Qwen/Qwen3-VL-30B-A3B-Instruct-FP8,max_model_len=10000" --tasks chartqa --batch_size auto --apply_chat_template

Test Result

Before:

Tasks Version Filter n-shot Metric Value Stderr
chartqa 0 none 0 anywhere_accuracy 0.8680 ± 0.0068
none 0 exact_match 0.6340 ± 0.0096
none 0 relaxed_accuracy 0.8572 ± 0.0070

After:

Tasks Version Filter n-shot Metric Value Stderr
chartqa 0 none 0 anywhere_accuracy 0.8672 ± 0.0068
none 0 exact_match 0.6324 ± 0.0096
none 0 relaxed_accuracy 0.8576 ± 0.0070

@lgeiger lgeiger requested a review from sighingnow as a code owner November 7, 2025 15:24
@mergify mergify bot added the qwen Related to Qwen models label Nov 7, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the get_mrope_input_positions function to use NumPy instead of PyTorch for improved performance and readability. The changes are logical and well-implemented, replacing complex PyTorch operations with more concise NumPy equivalents like np.indices.

I've identified one high-severity issue related to an edge case where empty input_tokens would cause a crash. I've provided a code suggestion to handle this case gracefully. Other than that, the changes look good.

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm thinking of directly using the positions from mm_features instead of having to calculate the mask again, WDYT?

@lgeiger
Copy link
Contributor Author

lgeiger commented Nov 10, 2025

Actually I'm thinking of directly using the positions from mm_features instead of having to calculate the mask again, WDYT?

I'm not entirely sure how you plan to do this, but not having to compute the mask again does sound sensible.

@DarkLight1337
Copy link
Member

We can pass req_state.mm_features which includes mm_position (PlaceholderRange) into get_mrope_input_positions.

@DarkLight1337
Copy link
Member

I will open another PR to update the argument list, then we can migrate the models to actually make use of mm_position one by one.

@DarkLight1337
Copy link
Member

Opened #28399

@mergify
Copy link

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lgeiger.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2025
@DarkLight1337
Copy link
Member

Feel free to update your PR now

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) November 12, 2025 00:48
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 12, 2025
@lgeiger
Copy link
Contributor Author

lgeiger commented Nov 12, 2025

@DarkLight1337 Just to make sure there's no confusion: I only rebased this PR so far but haven't made use of mm_position yet. The PR is still valid, though.

I'm not sure when I'll have time to migrate the logic to use mm_position but I'll try to have a look early next week. If you want to make this change soon feel free to make a PR otherwise I'll have a look at it on the weekend.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 12, 2025

Feel free to work on this yourself!

@DarkLight1337 DarkLight1337 merged commit cbb799e into vllm-project:main Nov 12, 2025
55 checks passed
@lgeiger lgeiger deleted the qwen3vl-mrope branch November 14, 2025 15:21
geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
…lm-project#28302)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: George D. Torres <gdavtor@gmail.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…lm-project#28302)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants