Skip to content

Conversation

@carlesoctav
Copy link

Description

Inital patch to support embedding model (from CausalLM model architecture). The initial plan is similar to #899, keep the pooling implementation in jax. Right now, i've only tested it with decoder model (taking the last token as the embedding representation). For otherl no-decoder-model, model that takes cls token and use full attention, I'll just need to ensure the prefill phase of that request contains all the token, and return an error if that's not the case.

Also, this one is just for nnx model_impl, for torchax, I think I just need to add a way to initialize the Pooler without actually initialize the the BaseModel/CausalLMModel.

FIXES: #899

Tests

run the server

vllm serve Qwen/Qwen3-Embedding-0.6B --convert embed --max-num-batched-token 128 --max-num-seqs 8

and use this script to see the comparsion with transformers output.

import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import numpy as np
from openai import OpenAI


def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    """Reference implementation from Qwen3-Embedding documentation"""
    left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[
            torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths
        ]


def get_transformers_embeddings(texts, model, tokenizer, max_length=8192):
    """Get embeddings using pure PyTorch + Transformers implementation"""
    # Tokenize the input texts
    batch_dict = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    batch_dict = batch_dict.to(model.device)

    with torch.no_grad():
        outputs = model(**batch_dict)
        embeddings = last_token_pool(
            outputs.last_hidden_state, batch_dict["attention_mask"]
        )
        # Normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)

    return embeddings.cpu().numpy()


def get_vllm_embeddings(texts, client, model_name):
    """Get embeddings using vLLM OpenAI API"""
    response = client.embeddings.create(model=model_name, input=texts)
    embeddings = []
    for item in response.data:
        embeddings.append(item.embedding)
    return np.array(embeddings)


def cosine_similarity(a, b):
    """Compute cosine similarity between two vectors"""
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


def compare_embeddings(vllm_emb, transformers_emb, rtol=5e-3, atol=5e-3):
    """Compare vLLM and Transformers embeddings using np.allclose"""
    results = []
    for i in range(len(vllm_emb)):
        is_close = np.allclose(vllm_emb[i], transformers_emb[i], rtol=rtol, atol=atol)
        # Compute max absolute difference and cosine similarity for diagnostics
        max_diff = np.max(np.abs(vllm_emb[i] - transformers_emb[i]))
        cos_sim = cosine_similarity(vllm_emb[i], transformers_emb[i])
        results.append((is_close, max_diff, cos_sim))
    return results


# Initialize models
print("Loading models...")
print("-" * 60)

# vLLM OpenAI API client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
model_name = "Qwen/Qwen3-Embedding-0.6B"

# Transformers model
tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-Embedding-0.6B", padding_side="left"
)
model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
model.eval()

print("✓ Models loaded successfully\n")

# Test data - each text contains at least 10 words
test_texts = [
    "The quick brown fox jumps over the lazy dog in the morning",
    "Machine learning algorithms can process vast amounts of data efficiently today",
    "Natural language processing enables computers to understand human communication patterns",
    "Deep learning networks require substantial computational resources for training models",
    "Artificial intelligence continues to transform industries across the global economy",
    "Vector embeddings represent semantic meaning in high dimensional mathematical spaces",
    "Transformer architecture revolutionized natural language understanding and generation tasks",
    "Large language models demonstrate impressive capabilities in various complex applications",
    "Neural networks learn hierarchical representations from raw input data automatically",
    "Context windows determine how much information models can process simultaneously",
    "Attention mechanisms allow models to focus on relevant parts of input",
    "Pre-training and fine-tuning are essential steps in modern model development",
    "Gradient descent optimization helps neural networks minimize their loss functions",
    "Backpropagation efficiently computes gradients for updating network weights during training",
    "Batch normalization improves training stability and convergence speed significantly",
]

# Test with 1, 2, 3, 4, 5, 7, 8, 15, 16 embeddings
test_cases = [1, 2, 3, 4, 5, 7, 8, 15, 16]

for num_texts in test_cases:
    print(f"{'=' * 60}")
    print(f"Testing with {num_texts} embedding(s)")
    print(f"{'=' * 60}")

    inputs = test_texts[:num_texts]

    try:
        # Get vLLM embeddings via OpenAI API
        print("Running vLLM (OpenAI API)...")
        vllm_embeddings = get_vllm_embeddings(inputs, client, model_name)
        print(f"✓ vLLM: Generated {len(vllm_embeddings)} embedding(s)")
        print(f"  Shape: {vllm_embeddings.shape}")

        # Get Transformers embeddings
        print("Running Transformers...")
        transformers_embeddings = get_transformers_embeddings(inputs, model, tokenizer)
        print(f"✓ Transformers: Generated {len(transformers_embeddings)} embedding(s)")
        print(f"  Shape: {transformers_embeddings.shape}")

        # Compare embeddings
        print("\nComparing embeddings (rtol=5e-3, atol=5e-3):")
        results = compare_embeddings(vllm_embeddings, transformers_embeddings)

        all_pass = True
        for i, (is_close, max_diff, cos_sim) in enumerate(results):
            status = "✓" if is_close else "✗"
            print(
                f"  {status} Embedding {i}: max_diff={max_diff:.2e}, cos_sim={cos_sim:.6f}, allclose={is_close}"
            )
            if not is_close:
                all_pass = False

        if all_pass:
            print("  ✓ PASS: All embeddings match (np.allclose)")
        else:
            print("  ✗ FAIL: Some embeddings differ")

    except Exception as e:
        print(f"✗ Failed with error: {e}")
        import traceback

        traceback.print_exc()

    print()

print(f"{'=' * 60}")
print("All tests completed!")
print(f"{'=' * 60}")
****

i tried o make a proper best, but i think we cant use LLM class from vllm for now, since tpu_inference path/check is not there(?)
commands to reproduce.

Checklist

Before submitting this PR, please make sure:
[x] I have performed a self-review of my code
[x] I have necessary comments in my code, particularly in hard-to-understand areas.
[x] I have made or will make corresponding changes to any relevant documentation.

@carlesoctav
Copy link
Author

torchax:
MODEL_IMPL_TYPE=vllm vllm serve Qwen/Qwen3-Embedding-0.6B --convert embed --max-num-batched-token 128 --max-num-seqs 8

@carlesoctav carlesoctav changed the title Support Embedding Model/Task For NNX model_impl Support Embedding Model/Task Nov 5, 2025
@carlesoctav
Copy link
Author

some benchmark that i ran, not really sure what to benchmark, tho.

 MODEL_IMPL_TYPE=vllm vllm serve Qwen/Qwen3-Embedding-0.6B --convert embed --max-num-seqs 200  --max-model-len 1024

jax

(vllm) carlesoctav@t1v-n-42336699-w-0:/mnt/carles/vllm$ vllm bench serve --model Qwen/Qwen3-Embedding-0.6B --dataset-name random --backend openai-embeddings --endpoint /v1/embeddings
INFO 11-05 13:20:18 [__init__.py:22] TPU info: node_name=carant-8 | tpu_type=v4-8 | worker_id=0 | num_chips=4 | num_cores_per_chip=2
INFO 11-05 13:20:19 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-05 13:20:19 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-05 13:20:19 [interface.py:171] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
Namespace(subparser='bench', bench_type='serve', dispatch_function=<function BenchmarkServingSubcommand.cmd at 0x7f95d5f38a40>, seed=0, num_prompts=1000, dataset_name='random', no_stream=False, dataset_path=None, no_oversample=False, skip_chat_template=False, disable_shuffle=False, custom_output_len=256, spec_bench_output_len=256, spec_bench_category=None, sonnet_input_len=550, sonnet_output_len=150, sonnet_prefix_len=200, sharegpt_output_len=None, blazedit_min_distance=0.0, blazedit_max_distance=1.0, random_input_len=1024, random_output_len=128, random_range_ratio=0.0, random_prefix_len=0, random_batch_size=1, no_reranker=False, random_mm_base_items_per_request=1, random_mm_num_mm_items_range_ratio=0.0, random_mm_limit_mm_per_prompt={'image': 255, 'video': 1}, random_mm_bucket_config={(256, 256, 1): 0.5, (720, 1280, 1): 0.5, (720, 1280, 16): 0.0}, hf_subset=None, hf_split=None, hf_name=None, hf_output_len=None, prefix_repetition_prefix_len=256, prefix_repetition_suffix_len=256, prefix_repetition_num_prefixes=10, prefix_repetition_output_len=128, label=None, backend='openai-embeddings', base_url=None, host='127.0.0.1', port=8000, endpoint='/v1/embeddings', header=None, max_concurrency=None, model='Qwen/Qwen3-Embedding-0.6B', tokenizer=None, use_beam_search=False, logprobs=None, request_rate=inf, burstiness=1.0, trust_remote_code=False, disable_tqdm=False, num_warmups=0, profile=False, save_result=False, save_detailed=False, append_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics=None, metric_percentiles='99', goodput=None, request_id_prefix='bench-7799e524-', top_p=None, top_k=None, min_p=None, temperature=None, frequency_penalty=None, presence_penalty=None, repetition_penalty=None, tokenizer_mode='auto', served_model_name=None, lora_modules=None, ramp_up_strategy=None, ramp_up_start_rps=None, ramp_up_end_rps=None, ready_check_timeout_sec=600, extra_body=None)
INFO 11-05 13:20:22 [datasets.py:614] Sampling input_len from [1023, 1023] and output_len from [128, 128]
Starting initial single prompt test run...
Waiting for endpoint to become up in 600 seconds
 |                                                                                                                                                                         | 00:00 elapsed, 132:57:55 remaining
Initial test run completed.
Starting main benchmark run...
Traffic request rate: inf
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:21<00:00, 46.85it/s]
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  21.34
Total input tokens:                      1024000
Request throughput (req/s):              46.85
Total Token throughput (tok/s):          47976.64
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11642.71
Median E2EL (ms):                        11663.70
P99 E2EL (ms):                           21013.45
==================================================

torchax

Waiting for endpoint to become up in 600 seconds
 |                                                                                                                                                                         | 00:00 elapsed, 142:03:29 remaining
Initial test run completed.
Starting main benchmark run...
Traffic request rate: inf
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:21<00:00, 45.46it/s]
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  22.00
Total input tokens:                      1024000
Request throughput (req/s):              45.46
Total Token throughput (tok/s):          46549.37
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11915.31
Median E2EL (ms):                        11932.26
P99 E2EL (ms):                           21543.47
==================================================

@bvrockwell
Copy link
Collaborator

This is awesome @carlesoctav - thank you for the contribution! @py4 is on call this week.

I think the only thing I would do is add tests (check out the contributor guide for how to add CI). Cc @jcyang43 who can help with adding them once they are in :)

@saltysoup
Copy link
Collaborator

Hi @carlesoctav do you know if this other models are supported with this PR?

I tried using different embedding models and it seems like the fallback login to vllm implementation could be broken as it attempts to find a matching architecture in flaxx model registry, and then trying it again before crashing.

I was able to reproduce your Qwen embedding model and that seems to work as you've reported. Just the other models doesn't seem to be working.

Below are two different embedding model (Gemma and granite) FYI

Gemma embedding model

vllm serve google/embeddinggemma-300m --convert embed
INFO 11-06 09:33:48 [__init__.py:22] TPU info: node_name=ikwak-1chip | tpu_type=v6e-1 | worker_id=0 | num_chips=1 | num_cores_per_chip=1
INFO 11-06 09:33:49 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-06 09:33:49 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-06 09:33:49 [interface.py:171] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
(APIServer pid=412265) INFO 11-06 09:33:50 [api_server.py:1961] vLLM API server version 0.11.1rc6.dev158+gc3ee80a01
(APIServer pid=412265) INFO 11-06 09:33:50 [utils.py:253] non-default args: {'model_tag': 'google/embeddinggemma-300m', 'model': 'google/embeddinggemma-300m', 'convert': 'embed'}
(APIServer pid=412265) INFO 11-06 09:33:51 [config.py:896] Found sentence-transformers tokenize configuration.
(APIServer pid=412265) INFO 11-06 09:33:51 [config.py:784] Found sentence-transformers modules configuration.
(APIServer pid=412265) INFO 11-06 09:33:51 [config.py:815] Found pooling configuration.
(APIServer pid=412265) INFO 11-06 09:33:51 [model.py:871] Resolved `--runner auto` to `--runner pooling`. Pass the value explicitly to silence this message.
(APIServer pid=412265) INFO 11-06 09:33:51 [model.py:630] Resolved architecture: Gemma3TextModel
(APIServer pid=412265) INFO 11-06 09:33:51 [model.py:1951] Downcasting torch.float32 to torch.bfloat16.
(APIServer pid=412265) INFO 11-06 09:33:51 [model.py:1728] Using max model len 2048
(APIServer pid=412265) INFO 11-06 09:33:51 [arg_utils.py:1767] (Disabling) chunked prefill by default
(APIServer pid=412265) INFO 11-06 09:33:51 [arg_utils.py:1770] (Disabling) prefix caching by default
(APIServer pid=412265) INFO 11-06 09:33:51 [vllm.py:535] Only "last" pooling supports chunked prefill and prefix caching; disabling both.
(APIServer pid=412265) INFO 11-06 09:33:51 [vllm.py:535] Only models using causal attention supports chunked prefill and prefix caching; disabling both.
(APIServer pid=412265) WARNING 11-06 09:33:51 [tpu_jax.py:149] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
(APIServer pid=412265) INFO 11-06 09:33:51 [tpu_jax.py:185] Force using UniProcExecutor for JAX on single host.
INFO 11-06 09:33:55 [__init__.py:22] TPU info: node_name=ikwak-1chip | tpu_type=v6e-1 | worker_id=0 | num_chips=1 | num_cores_per_chip=1
INFO 11-06 09:33:56 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-06 09:33:56 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-06 09:33:56 [interface.py:171] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:57 [core.py:93] Initializing a V1 LLM engine (v0.11.1rc6.dev158+gc3ee80a01) with config: model='google/embeddinggemma-300m', speculative_config=None, tokenizer='google/embeddinggemma-300m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=<class 'jax.numpy.bfloat16'>, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=None, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=google/embeddinggemma-300m, enable_prefix_caching=False, chunked_prefill_enabled=False, pooler_config=PoolerConfig(pooling_type='MEAN', normalize=True, dimensions=None, enable_chunked_processing=None, max_embed_len=None, softmax=None, activation=None, use_activation=None, logit_bias=None, step_tag_id=None, returned_token_ids=None), compilation_config={'level': None, 'mode': 2, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'openxla', 'custom_ops': ['all'], 'splitting_ops': None, 'use_inductor': None, 'compile_sizes': None, 'inductor_compile_config': {'enable_auto_functionalized_v2': False}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'use_cudagraph': True, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': None, 'cudagraph_copy_inputs': False, 'full_cuda_graph': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {}, 'max_cudagraph_capture_size': None, 'local_cache_dir': None}
(EngineCore_DP0 pid=412362) WARNING 11-06 09:33:57 [tpu_jax.py:214] Pin memory is not supported on TPU.
(EngineCore_DP0 pid=412362) WARNING 11-06 09:33:57 [tpu_worker_jax.py:57] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:59 [parallel_state.py:1325] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:59 [tpu_jax_runner.py:268] Device sequence enforced: False
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:59 [tpu_jax_runner.py:293] Init mesh | mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto))
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:59 [utils.py:93] Prepared token paddings: [16, 32, 64, 128, 256, 512, 1024, 2048]
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:59 [utils.py:59] Prepared request paddings: [8, 16, 32, 64, 128, 256]
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:59 [compilation_manager.py:39] Enabling JAX compile cache.
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:59 [tpu_worker_jax.py:154] Init worker | rank=0 | node_id=0 | is_driver_worker=True | hbm=[(0.0, 31.25)]GiB
(EngineCore_DP0 pid=412362) INFO 11-06 09:33:59 [model_loader.py:325] Loading model with MODEL_IMPL_TYPE=flax_nnx
(EngineCore_DP0 pid=412362) WARNING 11-06 09:33:59 [model_loader.py:335] Flax model failed with: 'Model architectures ['Gemma3TextModel'] are not supported for now. Supported architectures: ['Llama4ForCausalLM', 'DeepseekV3ForCausalLM', 'LlamaForCausalLM', 'Qwen2ForCausalLM', 'Qwen3ForCausalLM', 'Qwen2_5_VLForConditionalGeneration', 'Phi3ForCausalLM', 'Eagle3LlamaForCausalLM', 'GptOssForCausalLM']'. Falling back to vLLM implementation.
(EngineCore_DP0 pid=412362) INFO 11-06 09:34:00 [tpu_jax.py:64] Cannot use None backend on TPU.
(EngineCore_DP0 pid=412362) INFO 11-06 09:34:00 [tpu_jax.py:67] Using Pallas V1 backend.
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843] EngineCore failed to start.
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843] Traceback (most recent call last):
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 330, in get_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     return get_flax_model(vllm_config, rng, mesh, is_draft_model)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 202, in get_flax_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     model_class = _get_model_architecture(model_config.hf_config)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 59, in _get_model_architecture
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     raise UnsupportedArchitectureError(
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843] tpu_inference.models.common.model_loader.UnsupportedArchitectureError: Model architectures ['Gemma3TextModel'] are not supported for now. Supported architectures: ['Llama4ForCausalLM', 'DeepseekV3ForCausalLM', 'LlamaForCausalLM', 'Qwen2ForCausalLM', 'Qwen3ForCausalLM', 'Qwen2_5_VLForConditionalGeneration', 'Phi3ForCausalLM', 'Eagle3LlamaForCausalLM', 'GptOssForCausalLM']
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843] 
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843] During handling of the above exception, another exception occurred:
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843] 
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843] Traceback (most recent call last):
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/engine/core.py", line 834, in run_engine_core
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/engine/core.py", line 602, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     super().__init__(
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/engine/core.py", line 102, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     self.model_executor = executor_class(vllm_config)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/executor/abstract.py", line 101, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     self._init_executor()
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/executor/uniproc_executor.py", line 48, in _init_executor
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     self.driver_worker.load_model()
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/worker/tpu_worker_jax.py", line 235, in load_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     self.model_runner.load_model()
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/runner/tpu_jax_runner.py", line 412, in load_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                                                                                                                                       ^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 340, in get_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     return get_vllm_model(vllm_config, rng, mesh)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 309, in get_vllm_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     params, lora_manager = model.load_weights()
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                            ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/vllm/vllm_model_wrapper.py", line 112, in load_weights
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/model_loader/__init__.py", line 130, in get_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     return loader.load_model(vllm_config=vllm_config, model_config=model_config)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/model_loader/base_loader.py", line 49, in load_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     model = initialize_model(
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]             ^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/model_loader/utils.py", line 55, in initialize_model
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     return model_class(vllm_config=vllm_config, prefix=prefix)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/adapters.py", line 173, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/compilation/decorators.py", line 276, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     old_init(self, **kwargs)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/gemma3.py", line 377, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     self.start_layer, self.end_layer, self.layers = make_layers(
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                                                     ^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/utils.py", line 646, in make_layers
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/gemma3.py", line 379, in <lambda>
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     lambda prefix: Gemma3DecoderLayer(
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                    ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/gemma3.py", line 303, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     self.self_attn = Gemma3Attention(
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                      ^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/gemma3.py", line 190, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     self.attn = attn_cls(
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                 ^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/attention/layers/encoder_only_attention.py", line 80, in __init__
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/attention/layers/encoder_only_attention.py", line 29, in create_encoder_only_attention_backend
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     underlying_builder = underlying_attn_backend.get_builder_cls()
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/attention/backends/abstract.py", line 70, in get_builder_cls
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843]     raise NotImplementedError
(EngineCore_DP0 pid=412362) ERROR 11-06 09:34:00 [core.py:843] NotImplementedError

granite embedding model

vllm serve ibm-granite/granite-embedding-278m-multilingual --convert embed
INFO 11-06 09:29:28 [__init__.py:22] TPU info: node_name=ikwak-1chip | tpu_type=v6e-1 | worker_id=0 | num_chips=1 | num_cores_per_chip=1
INFO 11-06 09:29:28 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-06 09:29:28 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-06 09:29:28 [interface.py:171] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
(APIServer pid=411421) INFO 11-06 09:29:30 [api_server.py:1961] vLLM API server version 0.11.1rc6.dev158+gc3ee80a01
(APIServer pid=411421) INFO 11-06 09:29:30 [utils.py:253] non-default args: {'model_tag': 'ibm-granite/granite-embedding-278m-multilingual', 'model': 'ibm-granite/granite-embedding-278m-multilingual', 'convert': 'embed'}
config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 698/698 [00:00<00:00, 7.98MB/s]
sentence_bert_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 54.0/54.0 [00:00<00:00, 606kB/s]
(APIServer pid=411421) INFO 11-06 09:29:30 [config.py:896] Found sentence-transformers tokenize configuration.
modules.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 350/350 [00:00<00:00, 4.11MB/s]
(APIServer pid=411421) INFO 11-06 09:29:36 [config.py:784] Found sentence-transformers modules configuration.
config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 191/191 [00:00<00:00, 2.62MB/s]
(APIServer pid=411421) INFO 11-06 09:29:36 [config.py:815] Found pooling configuration.
(APIServer pid=411421) INFO 11-06 09:29:36 [model.py:871] Resolved `--runner auto` to `--runner pooling`. Pass the value explicitly to silence this message.
(APIServer pid=411421) INFO 11-06 09:29:36 [model.py:630] Resolved architecture: XLMRobertaModel
tokenizer_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 5.67MB/s]
(APIServer pid=411421) INFO 11-06 09:29:36 [model.py:1728] Using max model len 512
(APIServer pid=411421) INFO 11-06 09:29:36 [arg_utils.py:1767] (Disabling) chunked prefill by default
(APIServer pid=411421) INFO 11-06 09:29:36 [arg_utils.py:1770] (Disabling) prefix caching by default
(APIServer pid=411421) INFO 11-06 09:29:36 [vllm.py:535] Only "last" pooling supports chunked prefill and prefix caching; disabling both.
(APIServer pid=411421) WARNING 11-06 09:29:36 [tpu_jax.py:149] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
(APIServer pid=411421) INFO 11-06 09:29:36 [tpu_jax.py:185] Force using UniProcExecutor for JAX on single host.
sentencepiece.bpe.model: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5.07M/5.07M [00:00<00:00, 15.0MB/s]
tokenizer.json: 9.08MB [00:00, 18.2MB/s]
special_tokens_map.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 239/239 [00:00<00:00, 3.27MB/s]
INFO 11-06 09:29:42 [__init__.py:22] TPU info: node_name=ikwak-1chip | tpu_type=v6e-1 | worker_id=0 | num_chips=1 | num_cores_per_chip=1
INFO 11-06 09:29:43 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-06 09:29:43 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-06 09:29:43 [interface.py:171] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:44 [core.py:93] Initializing a V1 LLM engine (v0.11.1rc6.dev158+gc3ee80a01) with config: model='ibm-granite/granite-embedding-278m-multilingual', speculative_config=None, tokenizer='ibm-granite/granite-embedding-278m-multilingual', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=<class 'jax.numpy.bfloat16'>, max_seq_len=512, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=None, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=ibm-granite/granite-embedding-278m-multilingual, enable_prefix_caching=False, chunked_prefill_enabled=False, pooler_config=PoolerConfig(pooling_type='CLS', normalize=True, dimensions=None, enable_chunked_processing=None, max_embed_len=None, softmax=None, activation=None, use_activation=None, logit_bias=None, step_tag_id=None, returned_token_ids=None), compilation_config={'level': None, 'mode': 2, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'openxla', 'custom_ops': ['all'], 'splitting_ops': None, 'use_inductor': None, 'compile_sizes': None, 'inductor_compile_config': {'enable_auto_functionalized_v2': False}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'use_cudagraph': True, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': None, 'cudagraph_copy_inputs': False, 'full_cuda_graph': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {}, 'max_cudagraph_capture_size': None, 'local_cache_dir': None}
(EngineCore_DP0 pid=411644) WARNING 11-06 09:29:44 [tpu_jax.py:214] Pin memory is not supported on TPU.
(EngineCore_DP0 pid=411644) WARNING 11-06 09:29:44 [tpu_worker_jax.py:57] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:46 [parallel_state.py:1325] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:46 [tpu_jax_runner.py:268] Device sequence enforced: False
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:46 [tpu_jax_runner.py:293] Init mesh | mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto))
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:46 [utils.py:93] Prepared token paddings: [16, 32, 64, 128, 256, 512, 1024, 2048]
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:46 [utils.py:59] Prepared request paddings: [8, 16, 32, 64, 128, 256]
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:46 [compilation_manager.py:39] Enabling JAX compile cache.
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:46 [tpu_worker_jax.py:154] Init worker | rank=0 | node_id=0 | is_driver_worker=True | hbm=[(0.0, 31.25)]GiB
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:46 [model_loader.py:325] Loading model with MODEL_IMPL_TYPE=flax_nnx
(EngineCore_DP0 pid=411644) WARNING 11-06 09:29:46 [model_loader.py:335] Flax model failed with: 'Model architectures ['XLMRobertaModel'] are not supported for now. Supported architectures: ['Llama4ForCausalLM', 'DeepseekV3ForCausalLM', 'LlamaForCausalLM', 'Qwen2ForCausalLM', 'Qwen3ForCausalLM', 'Qwen2_5_VLForConditionalGeneration', 'Phi3ForCausalLM', 'Eagle3LlamaForCausalLM', 'GptOssForCausalLM']'. Falling back to vLLM implementation.
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:47 [tpu_jax.py:64] Cannot use None backend on TPU.
(EngineCore_DP0 pid=411644) INFO 11-06 09:29:47 [tpu_jax.py:67] Using Pallas V1 backend.
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843] EngineCore failed to start.
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843] Traceback (most recent call last):
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 330, in get_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     return get_flax_model(vllm_config, rng, mesh, is_draft_model)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 202, in get_flax_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     model_class = _get_model_architecture(model_config.hf_config)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 59, in _get_model_architecture
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     raise UnsupportedArchitectureError(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843] tpu_inference.models.common.model_loader.UnsupportedArchitectureError: Model architectures ['XLMRobertaModel'] are not supported for now. Supported architectures: ['Llama4ForCausalLM', 'DeepseekV3ForCausalLM', 'LlamaForCausalLM', 'Qwen2ForCausalLM', 'Qwen3ForCausalLM', 'Qwen2_5_VLForConditionalGeneration', 'Phi3ForCausalLM', 'Eagle3LlamaForCausalLM', 'GptOssForCausalLM']
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843] 
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843] During handling of the above exception, another exception occurred:
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843] 
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843] Traceback (most recent call last):
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/engine/core.py", line 834, in run_engine_core
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/engine/core.py", line 602, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     super().__init__(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/engine/core.py", line 102, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.model_executor = executor_class(vllm_config)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/executor/abstract.py", line 101, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self._init_executor()
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/v1/executor/uniproc_executor.py", line 48, in _init_executor
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.driver_worker.load_model()
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/worker/tpu_worker_jax.py", line 235, in load_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.model_runner.load_model()
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/runner/tpu_jax_runner.py", line 412, in load_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                                                                                                                                       ^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 340, in get_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     return get_vllm_model(vllm_config, rng, mesh)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/common/model_loader.py", line 309, in get_vllm_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     params, lora_manager = model.load_weights()
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                            ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/tpu-inference/tpu_inference/models/vllm/vllm_model_wrapper.py", line 112, in load_weights
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/model_loader/__init__.py", line 130, in get_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     return loader.load_model(vllm_config=vllm_config, model_config=model_config)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/model_loader/base_loader.py", line 49, in load_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     model = initialize_model(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]             ^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/model_loader/utils.py", line 55, in initialize_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     return model_class(vllm_config=vllm_config, prefix=prefix)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/roberta.py", line 111, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     super().__init__(vllm_config=vllm_config, prefix=prefix)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/bert.py", line 484, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.model = self._build_model(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                  ^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/roberta.py", line 141, in _build_model
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     return BertModel(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]            ^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/compilation/decorators.py", line 276, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     old_init(self, **kwargs)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/bert.py", line 376, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder")
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/bert.py", line 126, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     BertLayer(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/bert.py", line 155, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.attention = BertAttention(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                      ^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/bert.py", line 199, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.self = BertSelfAttention(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                 ^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/model_executor/models/bert.py", line 258, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     self.attn = EncoderOnlyAttention(
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                 ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/attention/layers/encoder_only_attention.py", line 80, in __init__
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/attention/layers/encoder_only_attention.py", line 29, in create_encoder_only_attention_backend
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     underlying_builder = underlying_attn_backend.get_builder_cls()
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]   File "/home/ikwak_google_com/vllm/vllm/attention/backends/abstract.py", line 70, in get_builder_cls
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843]     raise NotImplementedError
(EngineCore_DP0 pid=411644) ERROR 11-06 09:29:47 [core.py:843] NotImplementedError

@carlesoctav
Copy link
Author

carlesoctav commented Nov 6, 2025

Hmm, I feel like this happens because when we initialize the model with the vLLM implementation, we're using the old vLLM TPU, which does not support the non-causal attention, yet?

@carlesoctav
Copy link
Author

carlesoctav commented Nov 6, 2025

Actually, it tries to use tpu_inference.layers.vllm.attention, but tpu_inference doesn’t support encoder attention, and the PallasAttentionBackend class doesn’t have a get_builder_cls method.

@bvrockwell
Copy link
Collaborator

Thanks for all this @carlesoctav ! @QiliangCui just FYI not sure the right person to review.

@vipannalla vipannalla requested a review from py4 November 10, 2025 04:33
@vipannalla
Copy link
Collaborator

@py4 , please take a look.

Copy link
Collaborator

@py4 py4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tnx for the contribution! I left few comments and also :

please add unit test and e2e test for your changes.

  1. unit tests for https://github.com/vllm-project/tpu-inference/tree/main/tests/layers/jax
  2. unit tests for input_batch_jx.py changes
  3. unit tests for tpu_jax_runner.py changes
  4. e2e test that runs a small embedding model here: https://github.com/vllm-project/tpu-inference/tree/main/tests/e2e

)


def is_partial_prefill(pooling_metadata: TPUSupportedPoolingMetadata):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function doesn't seem to be used? Should we delete it? Also what happens if the prefill is partial?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i can remove it.
This is for an assertion check for MeanPooling and CLS Pooling, and it applies mostly to encoder-type models (which is aren't supported on TPU), since these models use bidirectional attention and don't support kv-cache, all prompts must be processed within the same prefill stage.

I should put this on mean pooling and CLS pooling but im not really sure if asserting that depends on the arguments value is possible inside jit.

) / padded_prompt_lens.unsqueeze(1)


class LastPoolingMethod(PoolingMethod):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstrings for your classes and functions



# actually my idea is not to jist this function but the model.pooler,
# we can put some postprocesing here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the comment here. It seems the function is actually jitted?

return jit_model



Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra


convert_type = getattr(model_config, "convert_type", "none")
if convert_type == "embed":
logger.debug_once( "Converting %s to embedding model", model_class.__name__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we keep this?




def init_pooler_from_vllm_model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me this is not native support for vLLM models. Here we are still using the JAX implementation. However i think we should use torchax (https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/models/vllm/vllm_model_wrapper.py) to automatically wrap vllm models.

@hfan, @kyuyeunk and @vanbasten23 for commenting on vLLM model support

You can can also remove this adapter thing and vLLM support and send this PR as JAX support and do vLLM support as follow up

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but it’s has the same issue why we doing sampling in JAX, no? It requires arguments that are “jit friendly.” The PyTorch version takes a list of PoolingParams from request_input, which isn’t "jit friendly".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using vLLM model implementation goes far beyond just using torchax to convert torch tensor to jax tensor. It means incorporating code in a way that utilizes api exposed by vLLM with minimal coding from the scratch.

Also, any vLLM model related code should reside in either models/vLLM or layers/vllm folder

And a functionality that is shared by both vLLM model and jax model should reside in model/common or layers/common.

Please refer to these prs:

if hf_key.endswith(".weight"):
hf_key = hf_key.removesuffix(".weight")

if not hf_key.startswith('models.'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this mess with non-embedding models?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, based on my observation some model checkpoints use the base model (e.g., QwenModel) and not the causalLM variant. Because of that, hf-key is start without string model (yeah there was a typo i should check model, instead of models).
The idea here simply to prepend model string to the hf_key. This shouldn't affect non-embedding model since hf_key already contains model for those cases.

i ran all the tests/models/ and everything loads correctly for the standard models.

self._precompile_backbone_with_inputs_embeds()
if self.runner.is_pooling_model:
self._precompile_pooling()
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return here for embedding models and skip the rest of precompilation? do we also need previous precompilations?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m following the approach used in the PyTorch version. In that implementation, the lm_head is removed, so we also need to skip any compilation steps related to logits creation and sampling (for torchax implementation).

@@ -213,6 +221,7 @@ def __init__(
self.uses_mrope, self.model_config)
self.lora_utils = LoraUtils(self)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra

@@ -588,9 +604,15 @@ def _execute_model(
# "Should not schedule a request that does nothing!")
return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT

# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please keep this comment

@py4 py4 requested review from hfan, kyuyeunk and vanbasten23 November 15, 2025 00:15
@kyuyeunk
Copy link
Collaborator

@carlesoctav can you resolve the branch conflict first before I can give a proper review?

@carlesoctav
Copy link
Author

@kyuyeunk done.

get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
graphdef)
lora_manager, model = None, None
lora_manager, _ = None, None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason for removing this?




def init_pooler_from_vllm_model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using vLLM model implementation goes far beyond just using torchax to convert torch tensor to jax tensor. It means incorporating code in a way that utilizes api exposed by vLLM with minimal coding from the scratch.

Also, any vLLM model related code should reside in either models/vLLM or layers/vllm folder

And a functionality that is shared by both vLLM model and jax model should reside in model/common or layers/common.

Please refer to these prs:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Add embedding model functionality to tpu-inference

6 participants