-
Notifications
You must be signed in to change notification settings - Fork 34
Support Embedding Model/Task #1015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
torchax: |
|
some benchmark that i ran, not really sure what to benchmark, tho. MODEL_IMPL_TYPE=vllm vllm serve Qwen/Qwen3-Embedding-0.6B --convert embed --max-num-seqs 200 --max-model-len 1024jax (vllm) carlesoctav@t1v-n-42336699-w-0:/mnt/carles/vllm$ vllm bench serve --model Qwen/Qwen3-Embedding-0.6B --dataset-name random --backend openai-embeddings --endpoint /v1/embeddings
INFO 11-05 13:20:18 [__init__.py:22] TPU info: node_name=carant-8 | tpu_type=v4-8 | worker_id=0 | num_chips=4 | num_cores_per_chip=2
INFO 11-05 13:20:19 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-05 13:20:19 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-05 13:20:19 [interface.py:171] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
Namespace(subparser='bench', bench_type='serve', dispatch_function=<function BenchmarkServingSubcommand.cmd at 0x7f95d5f38a40>, seed=0, num_prompts=1000, dataset_name='random', no_stream=False, dataset_path=None, no_oversample=False, skip_chat_template=False, disable_shuffle=False, custom_output_len=256, spec_bench_output_len=256, spec_bench_category=None, sonnet_input_len=550, sonnet_output_len=150, sonnet_prefix_len=200, sharegpt_output_len=None, blazedit_min_distance=0.0, blazedit_max_distance=1.0, random_input_len=1024, random_output_len=128, random_range_ratio=0.0, random_prefix_len=0, random_batch_size=1, no_reranker=False, random_mm_base_items_per_request=1, random_mm_num_mm_items_range_ratio=0.0, random_mm_limit_mm_per_prompt={'image': 255, 'video': 1}, random_mm_bucket_config={(256, 256, 1): 0.5, (720, 1280, 1): 0.5, (720, 1280, 16): 0.0}, hf_subset=None, hf_split=None, hf_name=None, hf_output_len=None, prefix_repetition_prefix_len=256, prefix_repetition_suffix_len=256, prefix_repetition_num_prefixes=10, prefix_repetition_output_len=128, label=None, backend='openai-embeddings', base_url=None, host='127.0.0.1', port=8000, endpoint='/v1/embeddings', header=None, max_concurrency=None, model='Qwen/Qwen3-Embedding-0.6B', tokenizer=None, use_beam_search=False, logprobs=None, request_rate=inf, burstiness=1.0, trust_remote_code=False, disable_tqdm=False, num_warmups=0, profile=False, save_result=False, save_detailed=False, append_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics=None, metric_percentiles='99', goodput=None, request_id_prefix='bench-7799e524-', top_p=None, top_k=None, min_p=None, temperature=None, frequency_penalty=None, presence_penalty=None, repetition_penalty=None, tokenizer_mode='auto', served_model_name=None, lora_modules=None, ramp_up_strategy=None, ramp_up_start_rps=None, ramp_up_end_rps=None, ready_check_timeout_sec=600, extra_body=None)
INFO 11-05 13:20:22 [datasets.py:614] Sampling input_len from [1023, 1023] and output_len from [128, 128]
Starting initial single prompt test run...
Waiting for endpoint to become up in 600 seconds
| | 00:00 elapsed, 132:57:55 remaining
Initial test run completed.
Starting main benchmark run...
Traffic request rate: inf
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:21<00:00, 46.85it/s]
============ Serving Benchmark Result ============
Successful requests: 1000
Failed requests: 0
Benchmark duration (s): 21.34
Total input tokens: 1024000
Request throughput (req/s): 46.85
Total Token throughput (tok/s): 47976.64
----------------End-to-end Latency----------------
Mean E2EL (ms): 11642.71
Median E2EL (ms): 11663.70
P99 E2EL (ms): 21013.45
==================================================torchax |
|
This is awesome @carlesoctav - thank you for the contribution! @py4 is on call this week. I think the only thing I would do is add tests (check out the contributor guide for how to add CI). Cc @jcyang43 who can help with adding them once they are in :) |
|
Hi @carlesoctav do you know if this other models are supported with this PR? I tried using different embedding models and it seems like the fallback login to vllm implementation could be broken as it attempts to find a matching architecture in flaxx model registry, and then trying it again before crashing. I was able to reproduce your Qwen embedding model and that seems to work as you've reported. Just the other models doesn't seem to be working. Below are two different embedding model (Gemma and granite) FYI Gemma embedding modelgranite embedding model |
|
Hmm, I feel like this happens because when we initialize the model with the vLLM implementation, we're using the old vLLM TPU, which does not support the non-causal attention, yet? |
|
Actually, it tries to use tpu_inference.layers.vllm.attention, but tpu_inference doesn’t support encoder attention, and the PallasAttentionBackend class doesn’t have a get_builder_cls method. |
|
Thanks for all this @carlesoctav ! @QiliangCui just FYI not sure the right person to review. |
|
@py4 , please take a look. |
py4
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tnx for the contribution! I left few comments and also :
please add unit test and e2e test for your changes.
- unit tests for https://github.com/vllm-project/tpu-inference/tree/main/tests/layers/jax
- unit tests for input_batch_jx.py changes
- unit tests for tpu_jax_runner.py changes
- e2e test that runs a small embedding model here: https://github.com/vllm-project/tpu-inference/tree/main/tests/e2e
| ) | ||
|
|
||
|
|
||
| def is_partial_prefill(pooling_metadata: TPUSupportedPoolingMetadata): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function doesn't seem to be used? Should we delete it? Also what happens if the prefill is partial?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, i can remove it.
This is for an assertion check for MeanPooling and CLS Pooling, and it applies mostly to encoder-type models (which is aren't supported on TPU), since these models use bidirectional attention and don't support kv-cache, all prompts must be processed within the same prefill stage.
I should put this on mean pooling and CLS pooling but im not really sure if asserting that depends on the arguments value is possible inside jit.
| ) / padded_prompt_lens.unsqueeze(1) | ||
|
|
||
|
|
||
| class LastPoolingMethod(PoolingMethod): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add docstrings for your classes and functions
|
|
||
|
|
||
| # actually my idea is not to jist this function but the model.pooler, | ||
| # we can put some postprocesing here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the comment here. It seems the function is actually jitted?
| return jit_model | ||
|
|
||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra
|
|
||
| convert_type = getattr(model_config, "convert_type", "none") | ||
| if convert_type == "embed": | ||
| logger.debug_once( "Converting %s to embedding model", model_class.__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we keep this?
|
|
||
|
|
||
|
|
||
| def init_pooler_from_vllm_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me this is not native support for vLLM models. Here we are still using the JAX implementation. However i think we should use torchax (https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/models/vllm/vllm_model_wrapper.py) to automatically wrap vllm models.
@hfan, @kyuyeunk and @vanbasten23 for commenting on vLLM model support
You can can also remove this adapter thing and vLLM support and send this PR as JAX support and do vLLM support as follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but it’s has the same issue why we doing sampling in JAX, no? It requires arguments that are “jit friendly.” The PyTorch version takes a list of PoolingParams from request_input, which isn’t "jit friendly".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using vLLM model implementation goes far beyond just using torchax to convert torch tensor to jax tensor. It means incorporating code in a way that utilizes api exposed by vLLM with minimal coding from the scratch.
Also, any vLLM model related code should reside in either models/vLLM or layers/vllm folder
And a functionality that is shared by both vLLM model and jax model should reside in model/common or layers/common.
Please refer to these prs:
| if hf_key.endswith(".weight"): | ||
| hf_key = hf_key.removesuffix(".weight") | ||
|
|
||
| if not hf_key.startswith('models.'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this mess with non-embedding models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, based on my observation some model checkpoints use the base model (e.g., QwenModel) and not the causalLM variant. Because of that, hf-key is start without string model (yeah there was a typo i should check model, instead of models).
The idea here simply to prepend model string to the hf_key. This shouldn't affect non-embedding model since hf_key already contains model for those cases.
i ran all the tests/models/ and everything loads correctly for the standard models.
| self._precompile_backbone_with_inputs_embeds() | ||
| if self.runner.is_pooling_model: | ||
| self._precompile_pooling() | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we return here for embedding models and skip the rest of precompilation? do we also need previous precompilations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m following the approach used in the PyTorch version. In that implementation, the lm_head is removed, so we also need to skip any compilation steps related to logits creation and sampling (for torchax implementation).
| @@ -213,6 +221,7 @@ def __init__( | |||
| self.uses_mrope, self.model_config) | |||
| self.lora_utils = LoraUtils(self) | |||
|
|
|||
|
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra
| @@ -588,9 +604,15 @@ def _execute_model( | |||
| # "Should not schedule a request that does nothing!") | |||
| return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT | |||
|
|
|||
| # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please keep this comment
|
@carlesoctav can you resolve the branch conflict first before I can give a proper review? |
|
@kyuyeunk done. |
| get_input_embeddings_fn = functools.partial(run_get_input_embeddings, | ||
| graphdef) | ||
| lora_manager, model = None, None | ||
| lora_manager, _ = None, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reason for removing this?
|
|
||
|
|
||
|
|
||
| def init_pooler_from_vllm_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using vLLM model implementation goes far beyond just using torchax to convert torch tensor to jax tensor. It means incorporating code in a way that utilizes api exposed by vLLM with minimal coding from the scratch.
Also, any vLLM model related code should reside in either models/vLLM or layers/vllm folder
And a functionality that is shared by both vLLM model and jax model should reside in model/common or layers/common.
Please refer to these prs:
Description
Inital patch to support embedding model (from CausalLM model architecture). The initial plan is similar to #899, keep the pooling implementation in jax. Right now, i've only tested it with decoder model (taking the last token as the embedding representation). For otherl no-decoder-model, model that takes cls token and use full attention, I'll just need to ensure the prefill phase of that request contains all the token, and return an error if that's not the case.
Also, this one is just for nnx model_impl, for torchax, I think I just need to add a way to initialize the
Poolerwithout actually initialize the the BaseModel/CausalLMModel.FIXES: #899
Tests
run the server
and use this script to see the comparsion with transformers output.
i tried o make a proper best, but i think we cant use LLM class from vllm for now, since tpu_inference path/check is not there(?)
commands to reproduce.
Checklist
Before submitting this PR, please make sure:
[x] I have performed a self-review of my code
[x] I have necessary comments in my code, particularly in hard-to-understand areas.
[x] I have made or will make corresponding changes to any relevant documentation.