From abe35dc16e3ce90c8d8186c961330b6cbde54301 Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Tue, 4 Nov 2025 09:52:07 +0000 Subject: [PATCH 01/13] merge main --- tpu_inference/layers/jax/pool/pooler.py | 228 ++++++++++++++++++ tpu_inference/layers/jax/pool/pooling.py | 14 ++ .../layers/jax/pool/pooling_metadata.py | 118 +++++++++ tpu_inference/models/common/model_loader.py | 24 +- tpu_inference/models/jax/adapters.py | 81 +++++++ tpu_inference/runner/input_batch.py | 19 ++ .../runner/persistent_batch_manager.py | 2 +- tpu_inference/runner/tpu_runner.py | 92 +++++++ 8 files changed, 571 insertions(+), 7 deletions(-) create mode 100644 tpu_inference/layers/jax/pool/pooler.py create mode 100644 tpu_inference/layers/jax/pool/pooling.py create mode 100644 tpu_inference/layers/jax/pool/pooling_metadata.py create mode 100644 tpu_inference/models/jax/adapters.py diff --git a/tpu_inference/layers/jax/pool/pooler.py b/tpu_inference/layers/jax/pool/pooler.py new file mode 100644 index 000000000..e2e95f878 --- /dev/null +++ b/tpu_inference/layers/jax/pool/pooler.py @@ -0,0 +1,228 @@ +import enum +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from flax import nnx +from tpu_inference.layers.jax.pool.pooling_metadata import TPUSupportedPoolingMetadata + +from vllm.config.pooler import PoolerConfig + + +@jax.tree_util.register_dataclass +@dataclass +class PoolingResult: + """Outputs produced by pooling kernels.""" + + num_reqs: int + pooler_output: jax.Array # [padded_num_reqs, dim] + # or [padded_num_reqs, padded_max_num_batchec_token_per_req, dim] for allpool + + +class PoolingType(enum.Enum): + LAST = "LAST" + MEAN = "MEAN" + CLS = "CLS" + ALL = "ALL" + + +@dataclass(frozen=True) +class ResolvedPoolingConfig: + task: str + pooling_type: PoolingType + normalize: bool + + @classmethod + def from_config( + cls, + task: str, + pooler_config: PoolerConfig | None, + ) -> "ResolvedPoolingConfig": + pooler_config = pooler_config or PoolerConfig() + + # The encode functionality is currently disabled because we cannot use DispatchPooler + # as intended. (It was part of ModelForEmbedding, and in newer versions it was renamed to token_embed.) + # This is because TPU does not support alternating requests between these two tasks, and it is + # out of scope to change the vllm request handler/API server to separate these requests. + # Therefore, this is disabled by default—users cannot use token_embed/encode functionality for now. + + if task == "embed": + default_pooling_type = PoolingType.LAST + default_normalize = True + elif task == "encode": + raise ValueError(f"Unsupported pooling task: {task}") + else: + raise ValueError(f"Unsupported pooling task: {task}") + + pooling_type_str = pooler_config.pooling_type or default_pooling_type.name + pooling_type = PoolingType(pooling_type_str.upper()) + normalize = ( + pooler_config.normalize + if pooler_config.normalize is not None + else default_normalize + ) + + return cls(task=task, pooling_type=pooling_type, normalize=normalize) + + +class PoolingMethod(nnx.Module): + @staticmethod + def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": + if pooling_type is PoolingType.ALL: + raise NotImplementedError("ALL pooling is not implemented yet.") + # return AllPoolingMethod() + if pooling_type is PoolingType.MEAN: + return MeanPoolingMethod() + if pooling_type is PoolingType.LAST: + return LastPoolingMethod() + if pooling_type is PoolingType.CLS: + return CLSPoolingMethod() + raise NotImplementedError(f"Unsupported pooling type: {pooling_type}") + + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + raise NotImplementedError + + +class AllPoolingMethod(PoolingMethod): + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + pass + + +class MeanPoolingMethod(PoolingMethod): + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + padded_prompt_lens = pooling_metadata.prompt_lens + padded_start_indices = pooling_metadata.first_token_indices + padded_end_indices = pooling_metadata.last_token_indices + cumsum = jnp.cumsum(hidden_states, dtype=jnp.float32) + + return ( + cumsum[padded_end_indices] + - cumsum[padded_start_indices] + + hidden_states[padded_start_indices] + ) / padded_prompt_lens.unsqueeze(1) + + +class LastPoolingMethod(PoolingMethod): + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + return hidden_states[pooling_metadata.last_token_indices] + + +class CLSPoolingMethod(PoolingMethod): + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + return hidden_states[pooling_metadata.first_token_indices] + + +class PoolerHead(nnx.Module): + def __call__( + self, + pooled: jax.Array, + token_embeddings: jax.Array, + token_mask: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> PoolingResult: + raise NotImplementedError + + +class EmbeddingPoolerHead(PoolerHead): + def __init__(self, default_normalize: bool) -> None: + super().__init__() + self.default_normalize = default_normalize + + def __call__( + self, + pooled: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> PoolingResult: + + # In the torch version, this part should handle other computations related to pooling_params, such as + # normalization and truncating the embedding dimensions (for matryoshka models). + # The problem with TPU is that we want a consistent output shape, and I feel like + # the best we can do is to handle this outside JIT, on the CPU. + # While you can actually do normalization with jnp.where, or maybe jax.lax.cond for branching in jax.jit, + # for the sake of simplicity, we either normalize all requests or none of them based on pooling_config. + # Pooler output: [padded_num_reqs, dim] + + + if self.default_normalize: + pooled = normalize(pooled) + + return PoolingResult( + num_reqs=pooling_metadata.num_reqs, + pooler_output=pooled, + ) + + +class Pooler(nnx.Module): + @staticmethod + def for_encode(pooler_config: PoolerConfig | None) -> "Pooler": + resolved = ResolvedPoolingConfig.from_config("encode", pooler_config) + raise NotImplementedError("EncodePooler is currently disabled.") + + @staticmethod + def for_embed(pooler_config: PoolerConfig | None) -> "Pooler": + resolved = ResolvedPoolingConfig.from_config("embed", pooler_config) + return EmbeddingPooler.from_config(resolved) + + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> PoolingResult: + raise NotImplementedError + + def get_supported_tasks(self) -> set[str]: + raise NotImplementedError + + +class EmbeddingPooler(Pooler): + def __init__( + self, + pooling: PoolingMethod, + head: EmbeddingPoolerHead, + ) -> None: + self.pooling = pooling + self.head = head + + @classmethod + def from_config(cls, config: ResolvedPoolingConfig) -> None: + pooling = PoolingMethod.from_pooling_type(config.pooling_type) + head = EmbeddingPoolerHead(config.normalize) + return cls(pooling, head) + + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> PoolingResult: + pooled = self.pooling(hidden_states, pooling_metadata) + return self.head(pooled, pooling_metadata) + + def get_supported_tasks(self) -> set[str]: + return {"embed"} + + +def normalize(embeddings: jax.Array) -> jax.Array: + norms = jnp.linalg.norm(embeddings, axis=-1, keepdims=True) + norms = jnp.maximum(norms, 1e-12) + normalized = embeddings / norms + return normalized diff --git a/tpu_inference/layers/jax/pool/pooling.py b/tpu_inference/layers/jax/pool/pooling.py new file mode 100644 index 000000000..8fb45b932 --- /dev/null +++ b/tpu_inference/layers/jax/pool/pooling.py @@ -0,0 +1,14 @@ +import functools + +import jax + +from .pooler import Pooler, PoolerOutput +from .pooling_metadata import TPUSupportedPoolingMetadata + +@jax.jit +def pool( + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + pooler: Pooler, +) -> PoolerOutput: + return pooler(hidden_states, pooling_metadata) diff --git a/tpu_inference/layers/jax/pool/pooling_metadata.py b/tpu_inference/layers/jax/pool/pooling_metadata.py new file mode 100644 index 000000000..4922da239 --- /dev/null +++ b/tpu_inference/layers/jax/pool/pooling_metadata.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import functools +import logging +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh +from tpu_inference.runner.input_batch_jax import InputBatch +from tpu_inference.utils import device_array + +logger = logging.getLogger(__name__) + + +SUPPORTED_POOLING_TASKS = {"embed"} + + +def build_pooling_cursor( + num_scheduled_tokens: list[int], + padded_num_seqs: int, + prompt_lens: jax.Array, +): + assert len(prompt_lens) == len(num_scheduled_tokens) + + n_seq = len(num_scheduled_tokens) + num_sched_tokens_padded = jnp.zeros(padded_num_seqs) + num_sched_tokens_padded = num_sched_tokens_padded.at[:n_seq].set( + jnp.asarary(num_scheduled_tokens, dtype=jnp.int32) + ) + cumsum = jnp.cumsum(num_scheduled_tokens) + first_token_indices = jnp.concatenate((jnp.asarray(0), cumsum[:-1])) + last_token_indices = first_token_indices + num_sched_tokens_padded - 1 + last_token_indices = jnp.where( + num_sched_tokens_padded > 0, last_token_indices, first_token_indices + ) + return first_token_indices, last_token_indices + + +@functools.partial( + jax.tree_util.register_dataclass, + data_fields=( + "prompt_lens", + "normalize", + "num_reqs", + "padded_num_reqs", + ), + meta_fields=("task_id",), +) +@dataclass +class TPUSupportedPoolingMetadata: + """Device metadata required for pooling computations.""" + + prompt_lens: jax.Array + first_token_indices: jax.Array + last_token_indices: jax.Array + normalize: jax.Array + num_reqs: int + padded_num_reqs: int + task: str + + @classmethod + def from_input_batch( + cls, + mesh: Mesh, + input_batch: InputBatch, + num_scheduled_tokens: list[int], + padded_num_reqs: int, + ) -> TPUSupportedPoolingMetadata: + pooling_params_list = input_batch.get_pooling_params() + + num_reqs = input_batch.num_reqs + assert len(pooling_params_list) == num_reqs + + padded_prompt_lens_np = np.zeros(padded_num_reqs, dtype=np.int32) + padded_prompt_lens_np[:num_reqs] = input_batch.num_prompt_tokens[:num_reqs] + + normalize = np.full(padded_num_reqs, -1, dtype=np.int8) + + # Instead of shutting down the whole program, we should just ignore it and make it return 'embed' by default, + # but provide a warning. + for idx, params in enumerate(pooling_params_list): + if params.normalize is True: + normalize[idx] = 1 + elif params.normalize is False: + normalize[idx] = 0 + + if (task := params.task) not in SUPPORTED_POOLING_TASKS: + logger.warning( + f"Unsupported pooling task '{task}'. Supported tasks: {sorted(SUPPORTED_POOLING_TASKS)}. Defaulting to 'embed'." + ) + + # maybe in the future if we need to support multiple tasks in one batch, we need to make sure each batch has only one task + # if not task_values: + # raise ValueError("Pooling metadata requires at least one request") + # if any(task != task_values[0] for task in task_values): + # raise ValueError("Mixed pooling tasks within the same batch are not supported yet") + + task = "embed" + first_token_indices, last_token_indices = build_pooling_cursor( + num_scheduled_tokens, padded_num_reqs, padded_prompt_lens_np[:num_reqs] + ) + + prompt_lens, normalize, first_token_indices, last_token_indices = device_array( + mesh, + (padded_prompt_lens_np, normalize, first_token_indices, last_token_indices), + ) + + return cls( + prompt_lens=prompt_lens, + first_token_indices=first_token_indices, + last_token_indices=last_token_indices, + normalize=normalize, + task=task, + num_reqs=num_reqs, + padded_num_reqs=padded_num_reqs, + ) diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index dd4082709..bfe935821 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -13,6 +13,7 @@ from tpu_inference import envs from tpu_inference.layers.common.sharding import ShardingAxisName from tpu_inference.logger import init_logger +from tpu_inference.models.jax.adapters import as_embedding_model from tpu_inference.models.jax.utils.quantization.quantization_utils import ( apply_qwix_on_abstract_model, apply_qwix_quantization, load_random_weights_into_qwix_abstract_model) @@ -184,6 +185,7 @@ def create_sharded_model(): return jit_model + # TODO(pooyam): We need to refactor this. This is returning a bunch of functions that do not work with all models and this is not very easy to see from the code. def get_flax_model( vllm_config: VllmConfig, @@ -191,12 +193,22 @@ def get_flax_model( mesh: Mesh, is_draft_model: bool = False, ) -> nnx.Module: - if is_draft_model: - model_class = _get_model_architecture( - vllm_config.speculative_config.draft_model_config.hf_config) - else: - model_class = _get_model_architecture( - vllm_config.model_config.hf_config) + + model_config = ( + vllm_config.speculative_config.draft_model_config + if is_draft_model + else vllm_config.model_config + ) + + model_class = _get_model_architecture(model_config.hf_config) + + convert_type = getattr(model_config, "convert_type", "none") + if convert_type == "embed": + logger.debug_once( "Converting %s to embedding model", model_class.__name__) + model_class = as_embedding_model(model_class) + elif convert_type not in ("none", "generate"): + raise NotImplementedError( f"TPU JAX backend does not support convert_type={convert_type!r} yet") + jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh) kv_cache_sharding = NamedSharding( mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model")) diff --git a/tpu_inference/models/jax/adapters.py b/tpu_inference/models/jax/adapters.py new file mode 100644 index 000000000..8cd97d9ec --- /dev/null +++ b/tpu_inference/models/jax/adapters.py @@ -0,0 +1,81 @@ +import typing as tp + +import jax +from flax import nnx +from jax.sharding import Mesh + +from vllm.config import VllmConfig +from vllm.model_executor.models.interfaces_base import ( + VllmModelForPooling, + is_pooling_model, +) +from tpu_inference.layers.jax.pool.pooler import Pooler + +_T = tp.TypeVar("_T", bound=type[nnx.Module]) + +_GENERATE_SUFFIXES = ( + "ForCausalLM", + "ForConditionalGeneration", +) + + +def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: + model_name = orig_model_name + for suffix in _GENERATE_SUFFIXES: + model_name = model_name.removesuffix(suffix) + return model_name + pooling_suffix + + +def _create_pooling_model_cls(orig_cls: _T) -> _T: + class ModelForPooling(orig_cls, VllmModelForPooling): + is_pooling_model = True + + def __init__( + self, + vllm_config: VllmConfig, + rng_key: jax.Array, + mesh: Mesh, + ) -> None: + super().__init__( + vllm_config=vllm_config, + rng_key=rng_key, + mesh=mesh, + ) + + + # Pooling models do not require language modeling heads. + # However, there is a problem: since the pattern for loading weights in nnx + # is abstract_module -> module, removing the lm_head attribute or leaves from the abstract_module + # results in an error, I think. + # This is because, during hf_load_weights, we need to match between the hf_key and nnx_key. + + # for attr in ("model.lm_head"): + # if hasattr(self, attr): + # delattr(self, attr) + + if getattr(self, "pooler", None) is None: + self._init_pooler(vllm_config=vllm_config) + + def _init_pooler(self, vllm_config: VllmConfig) -> None: + raise NotImplementedError + + return ModelForPooling + + +def as_embedding_model(cls: _T) -> _T: + + class ModelForEmbedding(_create_pooling_model_cls(cls)): + def _init_pooler(self, vllm_config: VllmConfig) -> None: + pooler_config = vllm_config.model_config.pooler_config + if pooler_config is None: + raise ValueError( + "Embedding models require `pooler_config` to be set in the model configuration." + ) + + self.pooler = Pooler.for_embed(pooler_config) + + ModelForEmbedding.__name__ = _get_pooling_model_name( + cls.__name__, + "ForEmbedding", + ) + return ModelForEmbedding # type: ignore[return-value] diff --git a/tpu_inference/runner/input_batch.py b/tpu_inference/runner/input_batch.py index 79f28ebb0..70c26e6c0 100644 --- a/tpu_inference/runner/input_batch.py +++ b/tpu_inference/runner/input_batch.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import numpy as np from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.utils.collection_utils import swap_dict_values from vllm.v1.core.sched.output import NewRequestData @@ -131,6 +132,7 @@ def __init__( self.req_output_token_ids: list[Optional[list[int]]] = [] self.request_distribution: list[int] = [0, 0, 0] + self.pooling_params: dict[str, PoolingParams] = {} @property def req_ids(self) -> list[str]: @@ -198,6 +200,9 @@ def add_request( self.min_tokens[req_index] = (sampling_params.min_tokens, sampling_params.all_stop_token_ids) + if request.pooling_params is not None: + self.pooling_params[req_id] = request.pooling_params + # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. if request.generator is not None: @@ -272,6 +277,7 @@ def remove_request(self, req_id: str) -> Optional[int]: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) + self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: @@ -411,6 +417,19 @@ def all_greedy(self) -> bool: def max_num_logprobs(self) -> Optional[int]: return max(self.num_logprobs.values()) if self.num_logprobs else None + def get_pooling_params(self) -> list[PoolingParams]: + if not self.pooling_params: + return [] + + if len(self.pooling_params) != self.num_reqs: + missing = set(self.req_ids) - set(self.pooling_params) + raise ValueError( + "Pooling params are missing for requests: " + f"{sorted(missing)}" + ) + + return [self.pooling_params[req_id] for req_id in self.req_ids] + def make_lora_inputs( self, num_scheduled_tokens: np.ndarray ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: diff --git a/tpu_inference/runner/persistent_batch_manager.py b/tpu_inference/runner/persistent_batch_manager.py index c8c093315..68f68b60d 100644 --- a/tpu_inference/runner/persistent_batch_manager.py +++ b/tpu_inference/runner/persistent_batch_manager.py @@ -122,7 +122,7 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput", prompt_token_ids=new_req_data.prompt_token_ids, mm_features=new_req_data.mm_features, sampling_params=sampling_params, - pooling_params=None, + pooling_params=new_req_data.pooling_params, generator=None, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 09e563928..8d89bdce1 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -46,6 +46,11 @@ gather_logprobs, sample) from tpu_inference.layers.jax.sample.sampling_metadata import \ TPUSupportedSamplingMetadata +from tpu_inference.layers.jax.pool.pooling_metadata import ( + SUPPORTED_POOLING_TASKS, + TPUSupportedPoolingMetadata, +) +from tpu_inference.layers.jax.sharding import build_mesh from tpu_inference.logger import init_logger from tpu_inference.models.common.model_loader import get_model from tpu_inference.models.jax.utils.weight_utils import ( @@ -248,6 +253,7 @@ def __init__( self.uses_mrope, self.model_config) self.lora_utils = LoraUtils(self) + self.is_pooling_model = False cache_config = self.cache_config if cache_config.cache_dtype == "auto": model_dtype = self.dtype @@ -490,6 +496,7 @@ def load_model(self): ) multimodal_fns = multimodal_fns or {} + self.is_pooling_model = hasttr(self.model, "pooler") self.precompile_vision_encoder_fn = multimodal_fns.get( "precompile_vision_encoder_fn", None) self.get_multimodal_embeddings_fn = multimodal_fns.get( @@ -684,6 +691,7 @@ def _execute_model( # "Should not schedule a request that does nothing!") return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT +<<<<<<< HEAD:tpu_inference/runner/tpu_runner.py # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b ( input_ids, @@ -693,6 +701,17 @@ def _execute_model( spec_decode_metadata, logits_indices_selector, ) = self._prepare_inputs(scheduler_output) +======= + ( + input_ids, + attn_metadata, + sampling_metadata, + pooling_metadata, + logits_indices, + spec_decode_metadata, + ) = self._prepare_inputs(scheduler_output) + +>>>>>>> 8a07e5ea (merge main):tpu_inference/runner/tpu_jax_runner.py # multi-modal support if self.is_multimodal_model: @@ -737,6 +756,59 @@ def _execute_model( lora_metadata, ) + if self.is_pooling_model: + assert pooling_metadata is not None + pooling_result = pool(self.model.pooler, hidden_states, pooling_metadata ) + + num_reqs = self.input_batch.num_reqs + req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) + prompt_logprobs_dict = {req_id: None for req_id in req_ids} + prompt_lens = np.asarray( + jax.device_get(pooling_metadata.prompt_lens) + )[:num_reqs] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + + task_id = pooling_metadata.primary_task_id + pooler_output: list[torch.Tensor | None] = [] + + if task_id == SUPPORTED_POOLING_TASKS["embed"]: + embeddings = np.asarray( + jax.device_get(pooling_result.embeddings) + )[:num_reqs] + dimensions = np.asarray( + jax.device_get(pooling_metadata.dimensions) + )[:num_reqs] + + for idx in range(num_reqs): + if seq_lens_cpu[idx] != prompt_lens[idx]: + pooler_output.append(None) + continue + + embedding = embeddings[idx] + dim_override = int(dimensions[idx]) + if dim_override > 0: + embedding = embedding[:dim_override] + embedding_np = embedding.astype(np.float32, copy=False) + pooler_output.append( + torch.tensor(embedding_np, dtype=torch.float32) + ) + + else: + raise NotImplementedError( + f"Unsupported pooling task id: {task_id}" + ) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + ) + return attn_metadata, model_runner_output + hidden_states = self._select_from_array_fn(hidden_states, logits_indices) logits = self.compute_logits_fn( @@ -1399,6 +1471,7 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + pooling_metadata = None # Get the number of scheduled tokens for each request. num_scheduled_tokens_per_req = [] @@ -1436,6 +1509,13 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): arange = np.concatenate( [self.arange_cpu[:n] for n in num_scheduled_tokens_per_req]) + if self.is_pooling_model: + pooling_metadata = TPUSupportedPoolingMetadata.from_input_batch( + self.mesh, + self.input_batch, + padded_num_reqs, + ) + # Get positions. positions_np = self.positions_cpu[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], @@ -1556,9 +1636,21 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): # This is for making these cpu buffers hidden during tracing attention_metadata.query_start_loc_cpu = query_start_loc_cpu attention_metadata.seq_lens_cpu = seq_lens_cpu +<<<<<<< HEAD:tpu_inference/runner/tpu_runner.py logits_indices_selector = None return (input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector) +======= + + return ( + input_ids, + attention_metadata, + sampling_metadata, + pooling_metadata, + logits_indices, + spec_decode_metadata, + ) +>>>>>>> 8a07e5ea (merge main):tpu_inference/runner/tpu_jax_runner.py def _get_input_ids_embeds(self, input_ids: jax.Array, mm_embeds: list[jax.Array]): From 2baaaaa0f234be7e08408544530349a10ee80364 Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Wed, 5 Nov 2025 07:19:20 +0000 Subject: [PATCH 02/13] it's up and runing well (atleast for qwen-embed and lastpooling) --- tpu_inference/layers/jax/pool/pooler.py | 26 ++- tpu_inference/layers/jax/pool/pooling.py | 3 + .../layers/jax/pool/pooling_metadata.py | 18 ++- tpu_inference/models/common/model_loader.py | 2 +- tpu_inference/models/jax/adapters.py | 17 +- .../models/jax/utils/weight_utils.py | 3 + tpu_inference/runner/compilation_manager.py | 74 ++++++++- tpu_inference/runner/input_batch.py | 150 +++++++++--------- tpu_inference/runner/tpu_runner.py | 59 +++---- 9 files changed, 206 insertions(+), 146 deletions(-) diff --git a/tpu_inference/layers/jax/pool/pooler.py b/tpu_inference/layers/jax/pool/pooler.py index e2e95f878..0f4f6e0f2 100644 --- a/tpu_inference/layers/jax/pool/pooler.py +++ b/tpu_inference/layers/jax/pool/pooler.py @@ -9,14 +9,9 @@ from vllm.config.pooler import PoolerConfig -@jax.tree_util.register_dataclass -@dataclass -class PoolingResult: - """Outputs produced by pooling kernels.""" - - num_reqs: int - pooler_output: jax.Array # [padded_num_reqs, dim] - # or [padded_num_reqs, padded_max_num_batchec_token_per_req, dim] for allpool +# [padded_num_reqs, dim] +# or [padded_num_reqs, padded_max_num_batchec_token_per_req, dim] for allpool +PoolerOutput = jax.Array class PoolingType(enum.Enum): @@ -139,7 +134,7 @@ def __call__( token_embeddings: jax.Array, token_mask: jax.Array, pooling_metadata: TPUSupportedPoolingMetadata, - ) -> PoolingResult: + ) -> PoolerOutput: raise NotImplementedError @@ -152,7 +147,7 @@ def __call__( self, pooled: jax.Array, pooling_metadata: TPUSupportedPoolingMetadata, - ) -> PoolingResult: + ) -> PoolerOutput: # In the torch version, this part should handle other computations related to pooling_params, such as # normalization and truncating the embedding dimensions (for matryoshka models). @@ -166,10 +161,7 @@ def __call__( if self.default_normalize: pooled = normalize(pooled) - return PoolingResult( - num_reqs=pooling_metadata.num_reqs, - pooler_output=pooled, - ) + return pooled class Pooler(nnx.Module): @@ -187,7 +179,7 @@ def __call__( self, hidden_states: jax.Array, pooling_metadata: TPUSupportedPoolingMetadata, - ) -> PoolingResult: + ) -> PoolerOutput: raise NotImplementedError def get_supported_tasks(self) -> set[str]: @@ -213,7 +205,9 @@ def __call__( self, hidden_states: jax.Array, pooling_metadata: TPUSupportedPoolingMetadata, - ) -> PoolingResult: + ) -> PoolerOutput: + hidden_states = hidden_states.astype(jnp.float32) + # the output mus be of type torch.tensor, but we cannot convert numpy to torch if the dtype is bf16 pooled = self.pooling(hidden_states, pooling_metadata) return self.head(pooled, pooling_metadata) diff --git a/tpu_inference/layers/jax/pool/pooling.py b/tpu_inference/layers/jax/pool/pooling.py index 8fb45b932..8137eaeb6 100644 --- a/tpu_inference/layers/jax/pool/pooling.py +++ b/tpu_inference/layers/jax/pool/pooling.py @@ -5,6 +5,9 @@ from .pooler import Pooler, PoolerOutput from .pooling_metadata import TPUSupportedPoolingMetadata + +# actually my idea is not to jist this function but the model.pooler, +# we can put some postprocesing here. @jax.jit def pool( hidden_states: jax.Array, diff --git a/tpu_inference/layers/jax/pool/pooling_metadata.py b/tpu_inference/layers/jax/pool/pooling_metadata.py index 4922da239..a2f34df25 100644 --- a/tpu_inference/layers/jax/pool/pooling_metadata.py +++ b/tpu_inference/layers/jax/pool/pooling_metadata.py @@ -25,15 +25,15 @@ def build_pooling_cursor( assert len(prompt_lens) == len(num_scheduled_tokens) n_seq = len(num_scheduled_tokens) - num_sched_tokens_padded = jnp.zeros(padded_num_seqs) - num_sched_tokens_padded = num_sched_tokens_padded.at[:n_seq].set( - jnp.asarary(num_scheduled_tokens, dtype=jnp.int32) + num_scheduled_tokens_padded = jnp.zeros(padded_num_seqs) + num_scheduled_tokens_padded = num_scheduled_tokens_padded.at[:n_seq].set( + jnp.asarray(num_scheduled_tokens, dtype=jnp.int32) ) - cumsum = jnp.cumsum(num_scheduled_tokens) - first_token_indices = jnp.concatenate((jnp.asarray(0), cumsum[:-1])) - last_token_indices = first_token_indices + num_sched_tokens_padded - 1 + cumsum = jnp.cumsum(num_scheduled_tokens_padded, dtype = jnp.int64) + first_token_indices = jnp.concatenate((jnp.asarray((0,)), cumsum[:-1])) + last_token_indices = (first_token_indices + num_scheduled_tokens_padded - 1).astype(jnp.int64) last_token_indices = jnp.where( - num_sched_tokens_padded > 0, last_token_indices, first_token_indices + num_scheduled_tokens_padded > 0, last_token_indices, first_token_indices ) return first_token_indices, last_token_indices @@ -42,11 +42,13 @@ def build_pooling_cursor( jax.tree_util.register_dataclass, data_fields=( "prompt_lens", + "first_token_indices", + "last_token_indices", "normalize", "num_reqs", "padded_num_reqs", ), - meta_fields=("task_id",), + meta_fields=("task",), ) @dataclass class TPUSupportedPoolingMetadata: diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index bfe935821..c486dba8a 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -281,7 +281,7 @@ def combine_hidden_states(graphdef, state, hidden_states): run_get_multimodal_embeddings, graphdef) get_input_embeddings_fn = functools.partial(run_get_input_embeddings, graphdef) - lora_manager, model = None, None + lora_manager, _ = None, None combine_hidden_states_fn = functools.partial(combine_hidden_states, graphdef) diff --git a/tpu_inference/models/jax/adapters.py b/tpu_inference/models/jax/adapters.py index 8cd97d9ec..33143f5c7 100644 --- a/tpu_inference/models/jax/adapters.py +++ b/tpu_inference/models/jax/adapters.py @@ -4,12 +4,8 @@ from flax import nnx from jax.sharding import Mesh -from vllm.config import VllmConfig -from vllm.model_executor.models.interfaces_base import ( - VllmModelForPooling, - is_pooling_model, -) from tpu_inference.layers.jax.pool.pooler import Pooler +from vllm.config import VllmConfig _T = tp.TypeVar("_T", bound=type[nnx.Module]) @@ -18,6 +14,15 @@ "ForConditionalGeneration", ) +class PoolingMixin: + """ + same as VllmModelForPooling + """ + is_pooling_model: tp.ClassVar[tp.Literal[True]] = True + + default_pooling_type: tp.ClassVar[str] = "LAST" + pooler: Pooler + def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: model_name = orig_model_name @@ -27,7 +32,7 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: def _create_pooling_model_cls(orig_cls: _T) -> _T: - class ModelForPooling(orig_cls, VllmModelForPooling): + class ModelForPooling(orig_cls, PoolingMixin): is_pooling_model = True def __init__( diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 64f026dae..9cfe74e39 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -316,6 +316,9 @@ def _load_hf_weights_on_thread(vllm_config, if hf_key.endswith(".weight"): hf_key = hf_key.removesuffix(".weight") + if not hf_key.startswith('models.'): + hf_key = 'model.' + hf_key + # Find the corresponding model key using the HF key if "layers" in hf_key: layer_num = re.search(r"layers\.(\d+)", hf_key).group(1) diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 86c55adc1..5276d7b7e 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -11,9 +11,14 @@ from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.common.sharding import ShardingAxisName +from tpu_inference.layers.jax.pool.pooling import pool +from tpu_inference.layers.jax.pool.pooling_metadata import ( + TPUSupportedPoolingMetadata, +) from tpu_inference.layers.jax.sample.sampling import sample -from tpu_inference.layers.jax.sample.sampling_metadata import \ - TPUSupportedSamplingMetadata +from tpu_inference.layers.jax.sample.sampling_metadata import ( + TPUSupportedSamplingMetadata, +) from tpu_inference.logger import init_logger from tpu_inference.utils import device_array @@ -79,6 +84,9 @@ def capture_model(self) -> None: self._run_compilation, ) self._precompile_input_embeddings_merger() self._precompile_backbone_with_inputs_embeds() + if self.runner.is_pooling_model: + self._precompile_pooling() + return if self.runner.scheduler_config.async_scheduling: self._precompile_substitute_placeholder_token() self._precompile_select_from_array() @@ -90,6 +98,68 @@ def capture_model(self) -> None: if self.runner.speculative_config: self._precompile_speculative_decoding() + def _precompile_pooling(self) -> None: + pooler = getattr(self.runner, "pooler", None) + if pooler is None: + logger.warning( + "Pooling precompile skipped because model has no pooler attribute.") + return + + logger.info("Precompile pooling kernels for pooling models.") + + hidden_size = self.runner.model_config.get_hidden_size() + dtype = self.runner.model_config.dtype + hidden_sharding = NamedSharding( + self.runner.mesh, PartitionSpec(None, None)) + + for num_tokens in self.runner.num_tokens_paddings: + hidden_states = self._create_dummy_tensor( + (num_tokens, hidden_size), dtype, sharding=hidden_sharding) + + for num_reqs in self.runner.num_reqs_paddings: + if num_reqs == 0 or num_reqs > num_tokens: + continue + + prompt_lens = np.ones(num_reqs, dtype=np.int32) + first_token_indices = np.arange(num_reqs, dtype=np.int32) + last_token_indices = first_token_indices.copy() + normalize = np.ones(num_reqs, dtype=np.int8) + + ( + prompt_lens, + normalize, + first_token_indices, + last_token_indices, + ) = device_array( + self.runner.mesh, + ( + prompt_lens, + normalize, + first_token_indices, + last_token_indices, + ), + ) + + pooling_metadata = TPUSupportedPoolingMetadata( + prompt_lens=prompt_lens, + first_token_indices=first_token_indices, + last_token_indices=last_token_indices, + normalize=normalize, + num_reqs=num_reqs, + padded_num_reqs=num_reqs, + task="embed", + ) + + self._run_compilation( + "pool", + pool, + hidden_states, + pooling_metadata, + pooler, + num_tokens=num_tokens, + num_reqs=num_reqs, + ) + def _precompile_input_embeddings_merger(self) -> None: for num_tokens in self.runner.num_tokens_paddings: hidden_size = self.runner.vllm_config.model_config.get_hidden_size( diff --git a/tpu_inference/runner/input_batch.py b/tpu_inference/runner/input_batch.py index 70c26e6c0..5d5e41a8d 100644 --- a/tpu_inference/runner/input_batch.py +++ b/tpu_inference/runner/input_batch.py @@ -53,9 +53,11 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], + is_pooling_model: bool = False, is_spec_decode: bool = False, ): self.is_spec_decode = is_spec_decode + self.is_pooling_model = is_pooling_model self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_batched_tokens = max_num_batched_tokens @@ -177,74 +179,73 @@ def add_request( self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(request.block_ids, req_index) - sampling_params = request.sampling_params - - if (self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): - self.spec_decode_unsupported_reqs.add(req_id) - - if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 - self.greedy_reqs.add(req_id) - else: - self.temperature_cpu[req_index] = sampling_params.temperature - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - top_k = sampling_params.top_k - if top_k <= 0 or top_k >= self.vocab_size: - top_k = 1 - self.top_k_cpu[req_index] = top_k - if sampling_params.min_tokens: - self.min_tokens[req_index] = (sampling_params.min_tokens, - sampling_params.all_stop_token_ids) - - if request.pooling_params is not None: - self.pooling_params[req_id] = request.pooling_params - - # NOTE(woosuk): self.generators should not include the requests that - # do not have their own generator. - if request.generator is not None: - self.generators[req_index] = request.generator - - if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = sampling_params.logprobs - if sampling_params.logit_bias is not None: - self.logit_bias[req_index] = sampling_params.logit_bias - - if sampling_params.allowed_token_ids: - self.has_allowed_token_ids.add(req_id) - if self.allowed_token_ids_mask_cpu is None: - # Lazy allocation for this tensor, which can be large. + if sampling_params := request.sampling_params: + if (self.is_spec_decode + and is_spec_decode_unsupported(sampling_params)): + self.spec_decode_unsupported_reqs.add(req_id) + + if sampling_params.sampling_type == SamplingType.GREEDY: + # Avoid later division by zero. + self.temperature_cpu[req_index] = -1.0 + self.greedy_reqs.add(req_id) + else: + self.temperature_cpu[req_index] = sampling_params.temperature + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + top_k = sampling_params.top_k + if top_k <= 0 or top_k >= self.vocab_size: + top_k = 1 + self.top_k_cpu[req_index] = top_k + if sampling_params.min_tokens: + self.min_tokens[req_index] = (sampling_params.min_tokens, + sampling_params.all_stop_token_ids) + + + # NOTE(woosuk): self.generators should not include the requests that + # do not have their own generator. + if request.generator is not None: + self.generators[req_index] = request.generator + + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = sampling_params.logprobs + if sampling_params.logit_bias is not None: + self.logit_bias[req_index] = sampling_params.logit_bias + + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu is None: + # Lazy allocation for this tensor, which can be large. + # False means we don't fill with -inf. + self.allowed_token_ids_mask = jnp.zeros(self.max_num_reqs, + self.vocab_size, + dtype=jnp.bool) + self.allowed_token_ids_mask_cpu = np.zeros(self.max_num_reqs, + self.vocab_size, + dtype=np.bool) + self.allowed_token_ids_mask_cpu[req_index] = True # False means we don't fill with -inf. - self.allowed_token_ids_mask = jnp.zeros(self.max_num_reqs, - self.vocab_size, - dtype=jnp.bool) - self.allowed_token_ids_mask_cpu = np.zeros(self.max_num_reqs, - self.vocab_size, - dtype=np.bool) - self.allowed_token_ids_mask_cpu[req_index] = True - # False means we don't fill with -inf. - self.allowed_token_ids_mask_cpu[req_index][ - sampling_params.allowed_token_ids] = False - - if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids - - # Add request lora ID - if request.lora_request: - lora_id = request.lora_request.lora_int_id - if lora_id not in self.lora_id_to_request_ids: - self.lora_id_to_request_ids[lora_id] = set() - - self.request_lora_mapping[req_index] = lora_id - self.lora_id_to_request_ids[lora_id].add(request.req_id) - self.lora_id_to_lora_request[lora_id] = request.lora_request - else: - # No LoRA - self.request_lora_mapping[req_index] = 0 + self.allowed_token_ids_mask_cpu[req_index][ + sampling_params.allowed_token_ids] = False + + if sampling_params.bad_words_token_ids: + self.bad_words_token_ids[ + req_index] = sampling_params.bad_words_token_ids + + # Add request lora ID + if request.lora_request: + lora_id = request.lora_request.lora_int_id + if lora_id not in self.lora_id_to_request_ids: + self.lora_id_to_request_ids[lora_id] = set() + + self.request_lora_mapping[req_index] = lora_id + self.lora_id_to_request_ids[lora_id].add(request.req_id) + self.lora_id_to_lora_request[lora_id] = request.lora_request + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + elif pooling_params := request.pooling_params: + self.pooling_params[req_id] = pooling_params def remove_request(self, req_id: str) -> Optional[int]: """This method must always be followed by a call to condense().""" @@ -261,6 +262,9 @@ def remove_request(self, req_id: str) -> Optional[int]: self.min_tokens.pop(req_index, None) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) + if self.is_pooling_model: + self.pooling_params.pop(req_id, None) + return req_index # LoRA lora_id = self.request_lora_mapping[req_index] @@ -277,7 +281,6 @@ def remove_request(self, req_id: str) -> Optional[int]: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) - self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: @@ -418,16 +421,7 @@ def max_num_logprobs(self) -> Optional[int]: return max(self.num_logprobs.values()) if self.num_logprobs else None def get_pooling_params(self) -> list[PoolingParams]: - if not self.pooling_params: - return [] - - if len(self.pooling_params) != self.num_reqs: - missing = set(self.req_ids) - set(self.pooling_params) - raise ValueError( - "Pooling params are missing for requests: " - f"{sorted(missing)}" - ) - + assert len(self.pooling_params) == len(self.req_ids) return [self.pooling_params[req_id] for req_id in self.req_ids] def make_lora_inputs( diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 8d89bdce1..7860a068a 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -46,6 +46,7 @@ gather_logprobs, sample) from tpu_inference.layers.jax.sample.sampling_metadata import \ TPUSupportedSamplingMetadata +from tpu_inference.layers.jax.pool.pooling import pool from tpu_inference.layers.jax.pool.pooling_metadata import ( SUPPORTED_POOLING_TASKS, TPUSupportedPoolingMetadata, @@ -235,6 +236,9 @@ def __init__( ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext() self.dp_size = self.vllm_config.sharding_config.total_dp_size + self.is_pooling_model = self.model_config.runner_type == "pooling" + self.pooler = None + self._init_random() self._init_mesh() self._init_phased_profiling() @@ -253,7 +257,7 @@ def __init__( self.uses_mrope, self.model_config) self.lora_utils = LoraUtils(self) - self.is_pooling_model = False + cache_config = self.cache_config if cache_config.cache_dtype == "auto": model_dtype = self.dtype @@ -424,6 +428,7 @@ def _init_inputs(self) -> None: pin_memory=False, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], + is_pooling_model = self.is_pooling_model, is_spec_decode=bool(self.vllm_config.speculative_config), ) @@ -496,7 +501,11 @@ def load_model(self): ) multimodal_fns = multimodal_fns or {} - self.is_pooling_model = hasttr(self.model, "pooler") + + if self.is_pooling_model: + self.pooler = self.model.pooler + + print(f"DEBUGPRINT[96]: tpu_jax_runner.py:396: self.is_pooling_model={self.is_pooling_model}") self.precompile_vision_encoder_fn = multimodal_fns.get( "precompile_vision_encoder_fn", None) self.get_multimodal_embeddings_fn = multimodal_fns.get( @@ -520,6 +529,8 @@ def load_model(self): f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB") def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + if self.is_pooling_model: + return ("embed", ) return ("generate", ) def get_kv_cache_spec(self): @@ -758,52 +769,29 @@ def _execute_model( if self.is_pooling_model: assert pooling_metadata is not None - pooling_result = pool(self.model.pooler, hidden_states, pooling_metadata ) - num_reqs = self.input_batch.num_reqs req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) - prompt_logprobs_dict = {req_id: None for req_id in req_ids} + + raw_pooler_output = pool(hidden_states, pooling_metadata, self.pooler) + raw_pooler_output = np.asarray(jax.device_get(raw_pooler_output))[:num_reqs] prompt_lens = np.asarray( jax.device_get(pooling_metadata.prompt_lens) )[:num_reqs] seq_lens_cpu = self.seq_lens_cpu[:num_reqs] - task_id = pooling_metadata.primary_task_id - pooler_output: list[torch.Tensor | None] = [] - - if task_id == SUPPORTED_POOLING_TASKS["embed"]: - embeddings = np.asarray( - jax.device_get(pooling_result.embeddings) - )[:num_reqs] - dimensions = np.asarray( - jax.device_get(pooling_metadata.dimensions) - )[:num_reqs] - - for idx in range(num_reqs): - if seq_lens_cpu[idx] != prompt_lens[idx]: - pooler_output.append(None) - continue - - embedding = embeddings[idx] - dim_override = int(dimensions[idx]) - if dim_override > 0: - embedding = embedding[:dim_override] - embedding_np = embedding.astype(np.float32, copy=False) - pooler_output.append( - torch.tensor(embedding_np, dtype=torch.float32) - ) - - else: - raise NotImplementedError( - f"Unsupported pooling task id: {task_id}" - ) + + + pooler_output = [] + for raw_output, seq_len, prompt_len in zip(raw_pooler_output, seq_lens_cpu, prompt_lens): + output = raw_output if seq_len == prompt_len else None + pooler_output.append(torch.from_numpy(raw_output)) model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], logprobs=None, - prompt_logprobs_dict=prompt_logprobs_dict, + prompt_logprobs_dict={}, pooler_output=pooler_output, kv_connector_output=kv_connector_output, ) @@ -1513,6 +1501,7 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): pooling_metadata = TPUSupportedPoolingMetadata.from_input_batch( self.mesh, self.input_batch, + num_scheduled_tokens_per_req, padded_num_reqs, ) From cc78a3a03837e31d5fac587ca5eb24f583f885fe Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Wed, 5 Nov 2025 09:33:51 +0000 Subject: [PATCH 03/13] add support torchax embedding --- tpu_inference/models/jax/adapters.py | 30 +++++++++++++++++++ .../models/vllm/vllm_model_wrapper.py | 7 +++++ tpu_inference/runner/compilation_manager.py | 27 +++++------------ 3 files changed, 44 insertions(+), 20 deletions(-) diff --git a/tpu_inference/models/jax/adapters.py b/tpu_inference/models/jax/adapters.py index 33143f5c7..29c36bdfb 100644 --- a/tpu_inference/models/jax/adapters.py +++ b/tpu_inference/models/jax/adapters.py @@ -1,7 +1,9 @@ import typing as tp +import torch import jax from flax import nnx +from flax.typing import PRNGKey from jax.sharding import Mesh from tpu_inference.layers.jax.pool.pooler import Pooler @@ -84,3 +86,31 @@ def _init_pooler(self, vllm_config: VllmConfig) -> None: "ForEmbedding", ) return ModelForEmbedding # type: ignore[return-value] + + + +def init_pooler_from_vllm_model( + vllm_model: torch.nn.Module, + vllm_config: VllmConfig, + rng_key: PRNGKey, + mesh: Mesh, +): + class DummyModule: + def __init__(self, vllm_config, rng_key, mesh): + pass + + for suffix in _GENERATE_SUFFIXES: + if suffix in vllm_model.__class__.__name__: + return None + + if "ForEmbedding" in vllm_model.__class__.__name__: + EmbedModel = as_embedding_model(DummyModule) + + embed_model = EmbedModel(vllm_config=vllm_config, rng_key=rng_key, mesh=mesh,) + embed_model._init_pooler(vllm_config) + return embed_model.pooler + else: + raise NotImplementedError( + f"Pooling initialization for {vllm_model.__class__.__name__} is not implemented." + ) + diff --git a/tpu_inference/models/vllm/vllm_model_wrapper.py b/tpu_inference/models/vllm/vllm_model_wrapper.py index 79c059b72..e429de77e 100644 --- a/tpu_inference/models/vllm/vllm_model_wrapper.py +++ b/tpu_inference/models/vllm/vllm_model_wrapper.py @@ -30,6 +30,8 @@ from tpu_inference.models.vllm.vllm_model_wrapper_context import ( get_vllm_model_wrapper_context, set_vllm_model_wrapper_context) from tpu_inference.runner.lora_utils import replace_lora_metadata +from tpu_inference.layers.jax.pool.pooler import Pooler +from tpu_inference.models.jax.adapters import init_pooler_from_vllm_model logger = init_logger(__name__) @@ -72,6 +74,7 @@ class VllmModelWrapper: rng: PRNGKey mesh: Mesh model: _VllmRunner + pooler: Pooler def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh): self.vllm_config = vllm_config @@ -137,6 +140,10 @@ def load_weights(self): self.model = _VllmRunner(vllm_model) params_and_buffers = shard_model_to_tpu(self.model, self.mesh) + + # TODO: need to seperate this params_and_buffer for pooler (some pooler is not stateless) + self.pooler = init_pooler_from_vllm_model(vllm_model, self.vllm_config, self.rng, self.mesh) + # Returning to the jax land, so we need to wrap it into a JaxValue. return jax_view(params_and_buffers), lora_manager diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 5276d7b7e..011ecdfb0 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -21,6 +21,7 @@ ) from tpu_inference.logger import init_logger from tpu_inference.utils import device_array +from torchax.ops.mappings import t2j_dtype if TYPE_CHECKING: from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -114,31 +115,17 @@ def _precompile_pooling(self) -> None: for num_tokens in self.runner.num_tokens_paddings: hidden_states = self._create_dummy_tensor( - (num_tokens, hidden_size), dtype, sharding=hidden_sharding) + (num_tokens, hidden_size), t2j_dtype(dtype), sharding=hidden_sharding) for num_reqs in self.runner.num_reqs_paddings: if num_reqs == 0 or num_reqs > num_tokens: continue - prompt_lens = np.ones(num_reqs, dtype=np.int32) - first_token_indices = np.arange(num_reqs, dtype=np.int32) - last_token_indices = first_token_indices.copy() - normalize = np.ones(num_reqs, dtype=np.int8) - - ( - prompt_lens, - normalize, - first_token_indices, - last_token_indices, - ) = device_array( - self.runner.mesh, - ( - prompt_lens, - normalize, - first_token_indices, - last_token_indices, - ), - ) + prompt_lens = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + first_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + last_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + normalize = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + pooling_metadata = TPUSupportedPoolingMetadata( prompt_lens=prompt_lens, From 198d49f4a719a47cd6f0e1c4fa07329ea27883db Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Wed, 5 Nov 2025 10:14:12 +0000 Subject: [PATCH 04/13] change to jnp.bfloat16 --- tpu_inference/runner/compilation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 011ecdfb0..73e2a5d5b 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -115,7 +115,7 @@ def _precompile_pooling(self) -> None: for num_tokens in self.runner.num_tokens_paddings: hidden_states = self._create_dummy_tensor( - (num_tokens, hidden_size), t2j_dtype(dtype), sharding=hidden_sharding) + (num_tokens, hidden_size), jnp.bfloat16, sharding=hidden_sharding) for num_reqs in self.runner.num_reqs_paddings: if num_reqs == 0 or num_reqs > num_tokens: From 34cb3a08ad8a6f9d59e5b3f3902a9cbc9d84527e Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Wed, 5 Nov 2025 12:38:23 +0000 Subject: [PATCH 05/13] simplify --- tpu_inference/layers/jax/pool/pooler.py | 4 +- .../layers/jax/pool/pooling_metadata.py | 76 ++++++------------- tpu_inference/runner/compilation_manager.py | 8 +- tpu_inference/runner/tpu_runner.py | 4 +- 4 files changed, 31 insertions(+), 61 deletions(-) diff --git a/tpu_inference/layers/jax/pool/pooler.py b/tpu_inference/layers/jax/pool/pooler.py index 0f4f6e0f2..f8332cb3f 100644 --- a/tpu_inference/layers/jax/pool/pooler.py +++ b/tpu_inference/layers/jax/pool/pooler.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp from flax import nnx -from tpu_inference.layers.jax.pool.pooling_metadata import TPUSupportedPoolingMetadata +from tpu_inference.layers.jax.pool.pooling_metadata import TPUSupportedPoolingMetadata, is_partial_prefill from vllm.config.pooler import PoolerConfig @@ -212,7 +212,7 @@ def __call__( return self.head(pooled, pooling_metadata) def get_supported_tasks(self) -> set[str]: - return {"embed"} + return ("embed",) def normalize(embeddings: jax.Array) -> jax.Array: diff --git a/tpu_inference/layers/jax/pool/pooling_metadata.py b/tpu_inference/layers/jax/pool/pooling_metadata.py index a2f34df25..284fecc07 100644 --- a/tpu_inference/layers/jax/pool/pooling_metadata.py +++ b/tpu_inference/layers/jax/pool/pooling_metadata.py @@ -19,23 +19,21 @@ def build_pooling_cursor( num_scheduled_tokens: list[int], - padded_num_seqs: int, - prompt_lens: jax.Array, + padded_num_reqs: int, ): - assert len(prompt_lens) == len(num_scheduled_tokens) n_seq = len(num_scheduled_tokens) - num_scheduled_tokens_padded = jnp.zeros(padded_num_seqs) - num_scheduled_tokens_padded = num_scheduled_tokens_padded.at[:n_seq].set( + padded_num_scheduled_tokens = jnp.zeros(padded_num_reqs) + padded_num_scheduled_tokens = padded_num_scheduled_tokens.at[:n_seq].set( jnp.asarray(num_scheduled_tokens, dtype=jnp.int32) ) - cumsum = jnp.cumsum(num_scheduled_tokens_padded, dtype = jnp.int64) + cumsum = jnp.cumsum(padded_num_scheduled_tokens, dtype = jnp.int64) first_token_indices = jnp.concatenate((jnp.asarray((0,)), cumsum[:-1])) - last_token_indices = (first_token_indices + num_scheduled_tokens_padded - 1).astype(jnp.int64) + last_token_indices = (first_token_indices + padded_num_scheduled_tokens - 1).astype(jnp.int64) last_token_indices = jnp.where( - num_scheduled_tokens_padded > 0, last_token_indices, first_token_indices + padded_num_scheduled_tokens > 0, last_token_indices, first_token_indices ) - return first_token_indices, last_token_indices + return first_token_indices, last_token_indices, padded_num_scheduled_tokens @functools.partial( @@ -44,11 +42,9 @@ def build_pooling_cursor( "prompt_lens", "first_token_indices", "last_token_indices", - "normalize", - "num_reqs", - "padded_num_reqs", + "num_scheduled_tokens", ), - meta_fields=("task",), + meta_fields = (), ) @dataclass class TPUSupportedPoolingMetadata: @@ -57,64 +53,42 @@ class TPUSupportedPoolingMetadata: prompt_lens: jax.Array first_token_indices: jax.Array last_token_indices: jax.Array - normalize: jax.Array - num_reqs: int - padded_num_reqs: int - task: str + num_scheduled_tokens: jax.Array @classmethod def from_input_batch( cls, mesh: Mesh, input_batch: InputBatch, - num_scheduled_tokens: list[int], + padded_num_scheduled_tokens: list[int], padded_num_reqs: int, ) -> TPUSupportedPoolingMetadata: pooling_params_list = input_batch.get_pooling_params() num_reqs = input_batch.num_reqs assert len(pooling_params_list) == num_reqs + assert len(input_batch.num_prompt_tokens[:num_reqs]) == len(padded_num_scheduled_tokens) - padded_prompt_lens_np = np.zeros(padded_num_reqs, dtype=np.int32) - padded_prompt_lens_np[:num_reqs] = input_batch.num_prompt_tokens[:num_reqs] - - normalize = np.full(padded_num_reqs, -1, dtype=np.int8) - - # Instead of shutting down the whole program, we should just ignore it and make it return 'embed' by default, - # but provide a warning. - for idx, params in enumerate(pooling_params_list): - if params.normalize is True: - normalize[idx] = 1 - elif params.normalize is False: - normalize[idx] = 0 - - if (task := params.task) not in SUPPORTED_POOLING_TASKS: - logger.warning( - f"Unsupported pooling task '{task}'. Supported tasks: {sorted(SUPPORTED_POOLING_TASKS)}. Defaulting to 'embed'." - ) - - # maybe in the future if we need to support multiple tasks in one batch, we need to make sure each batch has only one task - # if not task_values: - # raise ValueError("Pooling metadata requires at least one request") - # if any(task != task_values[0] for task in task_values): - # raise ValueError("Mixed pooling tasks within the same batch are not supported yet") - - task = "embed" - first_token_indices, last_token_indices = build_pooling_cursor( - num_scheduled_tokens, padded_num_reqs, padded_prompt_lens_np[:num_reqs] + padded_prompt_lens= jnp.zeros(padded_num_reqs, dtype=np.int32) + padded_prompt_lens= padded_prompt_lens.at[:num_reqs].set(input_batch.num_prompt_tokens[:num_reqs]) + + first_token_indices, last_token_indices, padded_num_scheduled_tokens = build_pooling_cursor( + padded_num_scheduled_tokens, padded_num_reqs ) - prompt_lens, normalize, first_token_indices, last_token_indices = device_array( + prompt_lens, first_token_indices, last_token_indices, num_scheduled_tokens = device_array( mesh, - (padded_prompt_lens_np, normalize, first_token_indices, last_token_indices), + (padded_prompt_lens, first_token_indices, last_token_indices, padded_num_scheduled_tokens), ) + #everything in pooling_metadata is padded. return cls( prompt_lens=prompt_lens, first_token_indices=first_token_indices, last_token_indices=last_token_indices, - normalize=normalize, - task=task, - num_reqs=num_reqs, - padded_num_reqs=padded_num_reqs, + num_scheduled_tokens = num_scheduled_tokens, ) + + +def is_partial_prefill(pooling_metadata: TPUSupportedPoolingMetadata): + return not jnp.all(pooling_metadata.prompt_lens == pooling_metadata.num_scheduled_tokens) diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 73e2a5d5b..293e55d39 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -121,20 +121,18 @@ def _precompile_pooling(self) -> None: if num_reqs == 0 or num_reqs > num_tokens: continue + # can we just use (one array here) prompt_lens = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) first_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) last_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) - normalize = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + num_scheduled_tokens = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) pooling_metadata = TPUSupportedPoolingMetadata( prompt_lens=prompt_lens, first_token_indices=first_token_indices, last_token_indices=last_token_indices, - normalize=normalize, - num_reqs=num_reqs, - padded_num_reqs=num_reqs, - task="embed", + num_scheduled_tokens = num_scheduled_tokens, ) self._run_compilation( diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 7860a068a..42f9f070e 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -505,7 +505,6 @@ def load_model(self): if self.is_pooling_model: self.pooler = self.model.pooler - print(f"DEBUGPRINT[96]: tpu_jax_runner.py:396: self.is_pooling_model={self.is_pooling_model}") self.precompile_vision_encoder_fn = multimodal_fns.get( "precompile_vision_encoder_fn", None) self.get_multimodal_embeddings_fn = multimodal_fns.get( @@ -530,7 +529,7 @@ def load_model(self): def get_supported_tasks(self) -> tuple[SupportedTask, ...]: if self.is_pooling_model: - return ("embed", ) + return self.pooler.get_supported_tasks() return ("generate", ) def get_kv_cache_spec(self): @@ -780,7 +779,6 @@ def _execute_model( seq_lens_cpu = self.seq_lens_cpu[:num_reqs] - pooler_output = [] for raw_output, seq_len, prompt_len in zip(raw_pooler_output, seq_lens_cpu, prompt_lens): output = raw_output if seq_len == prompt_len else None From 89c877ba149dd8583f46db2694daabe9e1c35209 Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Sat, 15 Nov 2025 01:50:21 +0000 Subject: [PATCH 06/13] fix --- tpu_inference/runner/tpu_runner.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 42f9f070e..1be5f86b4 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -701,27 +701,17 @@ def _execute_model( # "Should not schedule a request that does nothing!") return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT -<<<<<<< HEAD:tpu_inference/runner/tpu_runner.py # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b ( input_ids, attn_metadata, _, - logits_indices, - spec_decode_metadata, - logits_indices_selector, - ) = self._prepare_inputs(scheduler_output) -======= - ( - input_ids, - attn_metadata, - sampling_metadata, pooling_metadata, logits_indices, spec_decode_metadata, + logits_indices_selector, ) = self._prepare_inputs(scheduler_output) ->>>>>>> 8a07e5ea (merge main):tpu_inference/runner/tpu_jax_runner.py # multi-modal support if self.is_multimodal_model: @@ -1623,11 +1613,7 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): # This is for making these cpu buffers hidden during tracing attention_metadata.query_start_loc_cpu = query_start_loc_cpu attention_metadata.seq_lens_cpu = seq_lens_cpu -<<<<<<< HEAD:tpu_inference/runner/tpu_runner.py logits_indices_selector = None - return (input_ids, attention_metadata, sampling_metadata, - logits_indices, spec_decode_metadata, logits_indices_selector) -======= return ( input_ids, @@ -1636,8 +1622,8 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): pooling_metadata, logits_indices, spec_decode_metadata, + logits_indices_selector ) ->>>>>>> 8a07e5ea (merge main):tpu_inference/runner/tpu_jax_runner.py def _get_input_ids_embeds(self, input_ids: jax.Array, mm_embeds: list[jax.Array]): From fed5a1e9a94f8350d497145139ad025fed9126b9 Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Sat, 15 Nov 2025 02:00:21 +0000 Subject: [PATCH 07/13] done --- tpu_inference/layers/jax/pool/pooling_metadata.py | 2 +- tpu_inference/runner/tpu_runner.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tpu_inference/layers/jax/pool/pooling_metadata.py b/tpu_inference/layers/jax/pool/pooling_metadata.py index 284fecc07..d07e305d9 100644 --- a/tpu_inference/layers/jax/pool/pooling_metadata.py +++ b/tpu_inference/layers/jax/pool/pooling_metadata.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import numpy as np from jax.sharding import Mesh -from tpu_inference.runner.input_batch_jax import InputBatch +from tpu_inference.runner.input_batch import InputBatch from tpu_inference.utils import device_array logger = logging.getLogger(__name__) diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 1be5f86b4..a526a7e6f 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -51,7 +51,6 @@ SUPPORTED_POOLING_TASKS, TPUSupportedPoolingMetadata, ) -from tpu_inference.layers.jax.sharding import build_mesh from tpu_inference.logger import init_logger from tpu_inference.models.common.model_loader import get_model from tpu_inference.models.jax.utils.weight_utils import ( From 1bf9ef0be6f57d5312063db280580697bdbb0783 Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Sat, 15 Nov 2025 02:50:42 +0000 Subject: [PATCH 08/13] fix weight loader --- tpu_inference/models/jax/utils/weight_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 9cfe74e39..36f6287f1 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -316,8 +316,10 @@ def _load_hf_weights_on_thread(vllm_config, if hf_key.endswith(".weight"): hf_key = hf_key.removesuffix(".weight") - if not hf_key.startswith('models.'): - hf_key = 'model.' + hf_key + base_hf_key = hf_key + + if not hf_key.startswith("model."): + hf_key = f"model.{hf_key}" # Find the corresponding model key using the HF key if "layers" in hf_key: @@ -331,7 +333,7 @@ def _load_hf_weights_on_thread(vllm_config, model_key = name_map[layer_key] model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key) else: - if hf_key not in name_map and hf_key == "lm_head": + if hf_key not in name_map and base_hf_key == "lm_head": logger.warning( f"Skip loading {hf_key} due to tie_word_embeddings") continue From e28497c90aa249a1dee79106bbcd0ea15c266df7 Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Sat, 15 Nov 2025 04:55:28 +0000 Subject: [PATCH 09/13] add tests --- tests/e2e/test_pooling.py | 107 ++++++++++++++++++++++ tests/models/jax/test_adapters.py | 147 ++++++++++++++++++++++++++++++ tests/runner/test_input_batch.py | 116 ++++++++++++++++++++++- 3 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/test_pooling.py create mode 100644 tests/models/jax/test_adapters.py diff --git a/tests/e2e/test_pooling.py b/tests/e2e/test_pooling.py new file mode 100644 index 000000000..c4a107adb --- /dev/null +++ b/tests/e2e/test_pooling.py @@ -0,0 +1,107 @@ +import time +from typing import List + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from transformers import AutoModel, AutoTokenizer +from vllm import LLM + +MODEL_ID = "Qwen/Qwen3-Embedding-0.6B" +MAX_NUM_BATCHED_TOKENS = 128 +MAX_NUM_SEQS = 8 +RTOL = 5e-3 +ATOL = 5e-3 + + +def last_token_pool(last_hidden_states: torch.Tensor, + attention_mask: torch.Tensor) -> torch.Tensor: + """Reference pooling implementation from Qwen3 embedding docs.""" + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + return last_hidden_states[:, -1] + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, + device=last_hidden_states.device), + sequence_lengths] + + +def hf_embeddings(texts: List[str], model: AutoModel, + tokenizer: AutoTokenizer) -> np.ndarray: + """Get reference embeddings using HF Transformers""" + batch_dict = tokenizer(texts, + padding=True, + truncation=True, + max_length=MAX_NUM_BATCHED_TOKENS, + return_tensors="pt") + with torch.no_grad(): + outputs = model(**batch_dict) + embeddings = last_token_pool(outputs.last_hidden_state, + batch_dict["attention_mask"]) + embeddings = F.normalize(embeddings, p=2, dim=1) + return embeddings.cpu().numpy() + + +def vllm_embeddings(texts: List[str]) -> np.ndarray: + """Get embeddings via vLLM """ + llm = LLM(model=MODEL_ID, + runner="pooling", + convert="embed", + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, + max_num_seqs=MAX_NUM_SEQS, + max_model_len=MAX_NUM_BATCHED_TOKENS) + outputs = llm.embed(texts) + embeddings = np.asarray( + [np.array(output.outputs.embedding, dtype=np.float32) for output in outputs]) + del llm + # Wait for TPU runtime tear down before next test. + time.sleep(10) + return embeddings + + +def compare_embeddings(vllm_emb: np.ndarray, + hf_emb: np.ndarray) -> List[tuple[bool, float, float]]: + """Compare embeddings with diagnostics.""" + results = [] + for v_emb, h_emb in zip(vllm_emb, hf_emb): + is_close = np.allclose(v_emb, h_emb, rtol=RTOL, atol=ATOL) + max_diff = float(np.max(np.abs(v_emb - h_emb))) + cos_sim = float(np.dot(v_emb, h_emb) / + (np.linalg.norm(v_emb) * np.linalg.norm(h_emb))) + results.append((is_close, max_diff, cos_sim)) + return results + + +@pytest.mark.tpu +def test_last_token_embedding_pooling(monkeypatch: pytest.MonkeyPatch): + prompts = [ + "The quick brown fox jumps over the lazy dog near the river bank.", + "Machine learning systems process large datasets to extract useful information.", + "Neural networks learn hierarchical representations from raw data automatically.", + "Transformer architectures power modern language models used in production today.", + "Vector embeddings capture semantic meaning in high dimensional spaces for retrieval.", + "Artificial intelligence continues to transform industries across the global economy.", + "Gradient descent iteratively updates parameters to minimize model loss functions.", + "Attention mechanisms allow models to focus on the most relevant parts of input." + ] + + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, + padding_side="left", + trust_remote_code=True) + hf_model = AutoModel.from_pretrained(MODEL_ID, + trust_remote_code=True, + torch_dtype=torch.float32) + hf_model.eval() + + with monkeypatch.context(): + vllm_embeds = vllm_embeddings(prompts) + hf_embeds = hf_embeddings(prompts, hf_model, tokenizer) + + assert vllm_embeds.shape == hf_embeds.shape == (len(prompts), hf_embeds.shape[1]) + + comparisons = compare_embeddings(vllm_embeds, hf_embeds) + for idx, (is_close, max_diff, cos_sim) in enumerate(comparisons): + assert is_close, ( + f"Embedding {idx} mismatch (max_diff={max_diff:.2e}, cos_sim={cos_sim:.6f})") diff --git a/tests/models/jax/test_adapters.py b/tests/models/jax/test_adapters.py new file mode 100644 index 000000000..7e74c0b28 --- /dev/null +++ b/tests/models/jax/test_adapters.py @@ -0,0 +1,147 @@ +from unittest.mock import MagicMock + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from flax import nnx +from flax.typing import PRNGKey +from jax.sharding import Mesh +from vllm.config import ModelConfig +from vllm.config.pooler import PoolerConfig + +from tpu_inference.layers.common.attention_metadata import AttentionMetadata +from tpu_inference.layers.jax.pool.pooler import (CLSPoolingMethod, + LastPoolingMethod, + MeanPoolingMethod) +from tpu_inference.layers.jax.pool.pooling_metadata import ( + TPUSupportedPoolingMetadata, ) +from tpu_inference.models.jax.adapters import as_embedding_model +from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM +from tpu_inference.runner.kv_cache import create_kv_caches + + +class MockVllmConfig: + + def __init__(self, model: str, pooling_type: str): + self.model_config = ModelConfig(model=model) + self.model_config.dtype = jnp.bfloat16 + self.model_config.pooler_config = PoolerConfig( + pooling_type=pooling_type, normalize=False) + self.cache_config = MagicMock(cache_dtype="auto") + self.load_config = MagicMock() + self.load_config.download_dir = None + + +@pytest.fixture(scope="module") +def mesh(): + if not jax.devices(): + pytest.skip("No JAX devices available for mesh creation.") + + devices = np.array(jax.local_devices()[:1]) + device_mesh = devices.reshape((len(devices), 1, 1, 1)) + + with Mesh(device_mesh, + axis_names=('data', 'attn_dp', 'expert', 'model')) as m: + yield m + + +@pytest.fixture +def rng() -> PRNGKey: + return jax.random.PRNGKey(0) + + +@pytest.fixture +def mock_model_inputs(): + num_tokens = 6 + num_reqs = 1 + max_num_blocks_per_req = 4 + input_ids = jnp.arange(num_tokens, dtype=jnp.int32) + positions = jnp.arange(num_tokens, dtype=jnp.int32) + block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req), + dtype=jnp.int32).reshape(-1) + seq_lens = jnp.ones((num_reqs, ), dtype=jnp.int32) + query_start_loc = jnp.arange(num_reqs + 1, dtype=jnp.int32) + request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32) + + attention_metadata = AttentionMetadata( + input_positions=positions, + block_tables=block_tables, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + request_distribution=request_distribution, + ) + + return input_ids, attention_metadata + + +TEST_MODELS = [ + ("Qwen/Qwen3-0.6B", Qwen3ForCausalLM), +] + + +@pytest.mark.parametrize( + ("model_id", "model_cls", "pooling_type", "pooling_cls"), + [ + (model_id, model_cls, pooling_type, pooling_cls) + for model_id, model_cls in TEST_MODELS + for pooling_type, pooling_cls in [ + ("LAST", LastPoolingMethod), + ("CLS", CLSPoolingMethod), + ("MEAN", MeanPoolingMethod), + ] + ], +) +def test_embedding_adapter(model_id, model_cls, pooling_type, pooling_cls, rng, + mesh, mock_model_inputs): + EmbeddingModel = as_embedding_model(model_cls) + vllm_config = MockVllmConfig(model_id, pooling_type) + model = EmbeddingModel(vllm_config, rng, mesh) + + assert isinstance(model.pooler.pooling, pooling_cls) + assert model.is_pooling_model + assert isinstance(model.pooler.head, nnx.Module) + + model.load_weights(rng) + + hf_config = vllm_config.model_config.hf_config + head_dim = 128 + kv_caches = create_kv_caches( + num_blocks=4, + block_size=32, + num_kv_heads=hf_config.num_key_value_heads, + head_size=head_dim, + mesh=mesh, + layer_names=["layer"] * hf_config.num_hidden_layers, + cache_dtype=jnp.bfloat16, + ) + + input_ids, attention_metadata = mock_model_inputs + kv_caches, hidden_states, _ = model(kv_caches, input_ids, + attention_metadata) + + num_tokens = input_ids.shape[0] + pooling_metadata = TPUSupportedPoolingMetadata( + prompt_lens=jnp.array([num_tokens], dtype=jnp.int32), + first_token_indices=jnp.array([0], dtype=jnp.int32), + last_token_indices=jnp.array([num_tokens - 1], dtype=jnp.int32), + num_scheduled_tokens=jnp.array([num_tokens], dtype=jnp.int32), + ) + + embeddings = model.pooler(hidden_states, pooling_metadata) + assert embeddings.shape == (1, hf_config.hidden_size) + + hidden_np = np.array(hidden_states, dtype=np.float32) + last_index = int(pooling_metadata.last_token_indices[0]) + first_index = int(pooling_metadata.first_token_indices[0]) + if pooling_type == "LAST": + expected = hidden_np[last_index] + elif pooling_type == "CLS": + expected = hidden_np[first_index] + else: + start = first_index + end = last_index + 1 + expected = hidden_np[start:end].mean(axis=0) + + np.testing.assert_allclose(np.array(embeddings[0]), expected, rtol=1e-5, + atol=1e-5) diff --git a/tests/runner/test_input_batch.py b/tests/runner/test_input_batch.py index d3391c4e0..5b5e41bc5 100644 --- a/tests/runner/test_input_batch.py +++ b/tests/runner/test_input_batch.py @@ -1,6 +1,7 @@ import numpy as np import pytest from vllm.sampling_params import SamplingParams +from vllm.pooling_params import PoolingParams from tpu_inference.runner.input_batch import CachedRequestState, InputBatch @@ -26,15 +27,36 @@ def input_batch(): ) +@pytest.fixture +def input_batch_for_pooling(): + return InputBatch( + max_num_reqs=MAX_NUM_REQS, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, + pin_memory=False, + vocab_size=VOCAB_SIZE, + block_sizes=BLOCK_SIZES, + is_pooling_model = True, + is_spec_decode=False, + ) + + + def create_dummy_request(req_id: str, prompt_len: int = 10, output_len: int = 5, sampling_params: SamplingParams = None, + pooling_params: PoolingParams = None, block_ids=None) -> CachedRequestState: """Helper function to create a CachedRequestState instance.""" + if sampling_params is None: sampling_params = SamplingParams(temperature=0.8, top_p=0.9, top_k=50) + if pooling_params: + sampling_params = None + + prompt_token_ids = list(range(prompt_len)) output_token_ids = list(range(prompt_len, prompt_len + output_len)) @@ -49,7 +71,7 @@ def create_dummy_request(req_id: str, prompt_token_ids=prompt_token_ids, mm_features=[], sampling_params=sampling_params, - pooling_params=None, + pooling_params=pooling_params, block_ids=block_ids, num_computed_tokens=0, lora_request=None, @@ -210,3 +232,95 @@ def test_all_greedy_property(input_batch: InputBatch): # Remove it, should be true again input_batch.random_reqs.remove("req-r") assert input_batch.all_greedy + + + +def test_add_pooling_request(input_batch_for_pooling: InputBatch): + pooling_params = PoolingParams(dimensions = 768, normalize = True, use_activation = True) + req = create_dummy_request("req-1", prompt_len = 20, output_len = 4, pooling_params = pooling_params) + input_batch_for_pooling.add_request(req) + + assert input_batch_for_pooling.num_reqs == 1 + assert "req-1" in input_batch_for_pooling.req_id_to_index + assert input_batch_for_pooling.req_id_to_index["req-1"] == 0 + assert input_batch_for_pooling.req_ids == ["req-1"] + assert len(input_batch_for_pooling.spec_decode_unsupported_reqs) == 0 + + assert input_batch_for_pooling.num_prompt_tokens[0] == 20 + assert input_batch_for_pooling.num_tokens[0] == 24 + assert input_batch_for_pooling.num_tokens_no_spec[0] == 24 + expected_tokens = np.array(req.prompt_token_ids + req.output_token_ids) + np.testing.assert_array_equal(input_batch_for_pooling.token_ids_cpu[0, :24], + expected_tokens) + + assert input_batch_for_pooling.get_pooling_params() == [pooling_params] + + +def test_add_multiple_pooling_requests(input_batch_for_pooling: InputBatch): + pooling_params_1 = PoolingParams(dimensions=512, normalize=True) + pooling_params_2 = PoolingParams(dimensions=1024, normalize=False) + + req1 = create_dummy_request("req-1", + prompt_len=8, + output_len=2, + pooling_params=pooling_params_1) + req2 = create_dummy_request("req-2", + prompt_len=6, + output_len=3, + pooling_params=pooling_params_2) + + input_batch_for_pooling.add_request(req1) + input_batch_for_pooling.add_request(req2) + + assert input_batch_for_pooling.num_reqs == 2 + assert input_batch_for_pooling.req_ids == ["req-1", "req-2"] + + pooling_values = input_batch_for_pooling.get_pooling_params() + assert pooling_values[0] == pooling_params_1 + assert pooling_values[1] == pooling_params_2 + + +def test_remove_single_pooling_request(input_batch_for_pooling: InputBatch): + pooling_params_1 = PoolingParams(dimensions=256) + pooling_params_2 = PoolingParams(dimensions=768) + + req1 = create_dummy_request("req-1", pooling_params=pooling_params_1) + req2 = create_dummy_request("req-2", pooling_params=pooling_params_2) + + input_batch_for_pooling.add_request(req1) + input_batch_for_pooling.add_request(req2) + + removed_index = input_batch_for_pooling.remove_request("req-1") + assert removed_index == 0 + assert "req-1" not in input_batch_for_pooling.pooling_params + + input_batch_for_pooling.condense([removed_index]) + + assert input_batch_for_pooling.req_ids == ["req-2"] + pooling_values = input_batch_for_pooling.get_pooling_params() + assert pooling_values == [pooling_params_2] + + +def test_remove_multiple_pooling_requests(input_batch_for_pooling: InputBatch): + pooling_params = [ + PoolingParams(dimensions=128 + i * 64) for i in range(3) + ] + reqs = [ + create_dummy_request(f"req-{i}", pooling_params=pooling_params[i]) + for i in range(3) + ] + + for req in reqs: + input_batch_for_pooling.add_request(req) + + removed_indices = [] + removed_indices.append(input_batch_for_pooling.remove_request("req-0")) + removed_indices.append(input_batch_for_pooling.remove_request("req-2")) + + removed_indices = sorted( + [idx for idx in removed_indices if idx is not None], reverse=True) + input_batch_for_pooling.condense(removed_indices) + + assert input_batch_for_pooling.req_ids == ["req-1"] + pooling_values = input_batch_for_pooling.get_pooling_params() + assert pooling_values == [pooling_params[1]] From 713d4bd651ee6a36308829e042e76197ef053afc Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Sat, 15 Nov 2025 04:55:51 +0000 Subject: [PATCH 10/13] docstring --- tpu_inference/layers/jax/pool/pooler.py | 48 +++++++++++++++++++++---- tpu_inference/models/jax/adapters.py | 20 ++++------- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/tpu_inference/layers/jax/pool/pooler.py b/tpu_inference/layers/jax/pool/pooler.py index f8332cb3f..c67a05886 100644 --- a/tpu_inference/layers/jax/pool/pooler.py +++ b/tpu_inference/layers/jax/pool/pooler.py @@ -61,6 +61,10 @@ def from_config( class PoolingMethod(nnx.Module): + """ + Base class for pooling methods. Factory method `from_pooling_type` creates + specific pooling method instances based on the provided `PoolingType`. + """ @staticmethod def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": if pooling_type is PoolingType.ALL: @@ -83,15 +87,23 @@ def __call__( class AllPoolingMethod(PoolingMethod): + """ + Pools all token embeddings for each request. + Most of the time, this is for encoder models; hence, it requires full sequences during the prefill because of bidirectional attention. + """ def __call__( self, hidden_states: jax.Array, pooling_metadata: TPUSupportedPoolingMetadata, ) -> jax.Array: - pass + raise NotImplementedError("ALL pooling is not implemented yet.") class MeanPoolingMethod(PoolingMethod): + """ + Pools by computing the mean of token embeddings for each request. + Most of the time, this is for encoder models; hence, it requires full sequences during the prefill because of bidirectional attention. + """ def __call__( self, hidden_states: jax.Array, @@ -100,16 +112,20 @@ def __call__( padded_prompt_lens = pooling_metadata.prompt_lens padded_start_indices = pooling_metadata.first_token_indices padded_end_indices = pooling_metadata.last_token_indices - cumsum = jnp.cumsum(hidden_states, dtype=jnp.float32) - return ( - cumsum[padded_end_indices] - - cumsum[padded_start_indices] - + hidden_states[padded_start_indices] - ) / padded_prompt_lens.unsqueeze(1) + def pool_fn(start, end, length): + seq = hidden_states[start:end + 1] + return jnp.sum(seq, axis=0) / length + + return jax.vmap(pool_fn)(padded_start_indices, padded_end_indices, + padded_prompt_lens) class LastPoolingMethod(PoolingMethod): + """ + Pools by selecting the last token embedding for each request. + Most of the time, this is used for causal/decoder models and is also supported with partial prefill. + """ def __call__( self, hidden_states: jax.Array, @@ -119,6 +135,10 @@ def __call__( class CLSPoolingMethod(PoolingMethod): + """ + Pools by selecting the [CLS] token (first token) embedding for each request. + Most of the time, this is for encoder models; hence, it requires full sequences during the prefill because of bidirectional attention. + """ def __call__( self, hidden_states: jax.Array, @@ -128,6 +148,9 @@ def __call__( class PoolerHead(nnx.Module): + """ + Base class for Pooler Heads that process the pooled output. + """ def __call__( self, pooled: jax.Array, @@ -139,6 +162,9 @@ def __call__( class EmbeddingPoolerHead(PoolerHead): + """ + Pooler head for embedding tasks, primarily responsible for normalization. + """ def __init__(self, default_normalize: bool) -> None: super().__init__() self.default_normalize = default_normalize @@ -165,6 +191,9 @@ def __call__( class Pooler(nnx.Module): + """ + Base class for Poolers with factory methods to create Poolers for different tasks. + """ @staticmethod def for_encode(pooler_config: PoolerConfig | None) -> "Pooler": resolved = ResolvedPoolingConfig.from_config("encode", pooler_config) @@ -187,6 +216,11 @@ def get_supported_tasks(self) -> set[str]: class EmbeddingPooler(Pooler): + """ + Pooler for embedding tasks. + pooling: PoolingMethod instance that performs the pooling operation. + head: EmbeddingPoolerHead instance that processes the pooled output, primarily for normalization. + """ def __init__( self, pooling: PoolingMethod, diff --git a/tpu_inference/models/jax/adapters.py b/tpu_inference/models/jax/adapters.py index 29c36bdfb..c978f4136 100644 --- a/tpu_inference/models/jax/adapters.py +++ b/tpu_inference/models/jax/adapters.py @@ -18,7 +18,8 @@ class PoolingMixin: """ - same as VllmModelForPooling + VllmModelForPooling + The reason for creating this Mixin instead of VllmModelForPooling is due to a conflict in the metaclass between nnx.Module and VllmModelForPooling. """ is_pooling_model: tp.ClassVar[tp.Literal[True]] = True @@ -49,17 +50,6 @@ def __init__( mesh=mesh, ) - - # Pooling models do not require language modeling heads. - # However, there is a problem: since the pattern for loading weights in nnx - # is abstract_module -> module, removing the lm_head attribute or leaves from the abstract_module - # results in an error, I think. - # This is because, during hf_load_weights, we need to match between the hf_key and nnx_key. - - # for attr in ("model.lm_head"): - # if hasattr(self, attr): - # delattr(self, attr) - if getattr(self, "pooler", None) is None: self._init_pooler(vllm_config=vllm_config) @@ -70,6 +60,9 @@ def _init_pooler(self, vllm_config: VllmConfig) -> None: def as_embedding_model(cls: _T) -> _T: + """ + convert a `CausalModel` to an embedding model by adding a Pooler for embedding + """ class ModelForEmbedding(_create_pooling_model_cls(cls)): def _init_pooler(self, vllm_config: VllmConfig) -> None: @@ -85,7 +78,7 @@ def _init_pooler(self, vllm_config: VllmConfig) -> None: cls.__name__, "ForEmbedding", ) - return ModelForEmbedding # type: ignore[return-value] + return ModelForEmbedding @@ -113,4 +106,3 @@ def __init__(self, vllm_config, rng_key, mesh): raise NotImplementedError( f"Pooling initialization for {vllm_model.__class__.__name__} is not implemented." ) - From 635901a491db3108db663685efc4b1f1607b46a7 Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Sat, 15 Nov 2025 05:36:49 +0000 Subject: [PATCH 11/13] fix mean pooling --- tpu_inference/layers/jax/pool/pooler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tpu_inference/layers/jax/pool/pooler.py b/tpu_inference/layers/jax/pool/pooler.py index c67a05886..1e1577935 100644 --- a/tpu_inference/layers/jax/pool/pooler.py +++ b/tpu_inference/layers/jax/pool/pooler.py @@ -112,13 +112,13 @@ def __call__( padded_prompt_lens = pooling_metadata.prompt_lens padded_start_indices = pooling_metadata.first_token_indices padded_end_indices = pooling_metadata.last_token_indices + cumsum = jnp.cumsum(hidden_states, axis = 0, dtype=jnp.float32) - def pool_fn(start, end, length): - seq = hidden_states[start:end + 1] - return jnp.sum(seq, axis=0) / length - - return jax.vmap(pool_fn)(padded_start_indices, padded_end_indices, - padded_prompt_lens) + return ( + cumsum[padded_end_indices] + - cumsum[padded_start_indices] + + hidden_states[padded_start_indices] + ) / padded_prompt_lens[:, None] class LastPoolingMethod(PoolingMethod): From ed6a0f6a2f84fcfc8bf2c2ca965b8871fca61ce2 Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Sat, 15 Nov 2025 06:39:48 +0000 Subject: [PATCH 12/13] add more tests --- tests/runner/test_tpu_runner.py | 96 +++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/tests/runner/test_tpu_runner.py b/tests/runner/test_tpu_runner.py index 1fbbb99a2..e07fb6581 100644 --- a/tests/runner/test_tpu_runner.py +++ b/tests/runner/test_tpu_runner.py @@ -5,7 +5,11 @@ import numpy as np from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.config.pooler import PoolerConfig +from tpu_inference.layers.jax.pool.pooler import Pooler +from tpu_inference.layers.jax.pool.pooling_metadata import ( + TPUSupportedPoolingMetadata, ) from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -182,3 +186,95 @@ def test_is_multimodal_model(self): dummy_mm_embeds = jnp.ones((10, 128)) _ = self.runner._get_input_ids_embeds(dummy_input_ids, dummy_mm_embeds) self.runner.get_input_embeddings_fn.assert_not_called() + + +class TestTPUJaxRunnerEmbeddingModel: + + def setup_method(self): + self.mock_devices = [MagicMock(coords=i) for i in range(1)] + self.mock_rng_key = MagicMock() + device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, -1) + self.mock_mesh = jax.make_mesh(device_array.shape, + ('data', 'attn_dp', 'expert', 'model')) + + pooler_config = PoolerConfig(pooling_type="LAST", normalize=False) + pooler = Pooler.for_embed(pooler_config) + + + + with patch('jax.devices', return_value=self.mock_devices), \ + patch('jax.make_mesh', return_value=self.mock_mesh), \ + patch('jax.random.key', return_value=jax.random.PRNGKey(0)), \ + patch('tpu_inference.runner.tpu_runner.nnx.Rngs', return_value=self.mock_rng_key), \ + patch('tpu_inference.runner.tpu_runner.get_model', + return_value=self._model_get_model(pooler)), \ + patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', + return_value=self.mock_mesh): + + model_config = ModelConfig(tokenizer_mode="auto", + trust_remote_code=False, + convert = "embed", + runner = "pooling", + seed=0, + dtype='bfloat16') + model_config.pooler_config = pooler_config + cache_config = CacheConfig(block_size=16, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="auto") + scheduler_config = SchedulerConfig(max_num_seqs=4) + parallel_config = ParallelConfig(pipeline_parallel_size=1, + tensor_parallel_size=1, + worker_use_ray=False) + vllm_config = VllmConfig(model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + speculative_config=None, + observability_config={}, + additional_config={}) + self.runner = TPUModelRunner(vllm_config, + devices=self.mock_devices) + self.runner.load_model() + + def _model_get_model(self, pooler): + class DummyEmbeddingModel: + + def __init__(self, pooler): + self.is_pooling_model = True + self.pooler = pooler + + mock_multimodal_fns = { + "precompile_vision_encoder_fn": None, + "get_multimodal_embeddings_fn": None, + "get_input_embeddings_fn": None, + "get_mrope_input_positions_fn": None + } + return ( + MagicMock(), # TPUModelRunner.model_fn + MagicMock(), # TPUModelRunner.compute_logits_fn + MagicMock(), # TPUModelRunner.combine_hidden_states_fn + mock_multimodal_fns, # TPUModelRunner.multimodal_fns + MagicMock(), # TPUModelRunner.state (model params) + None, # TPUModelRunner.lora_manager + DummyEmbeddingModel(pooler), # TPUModelRunner.model + ) + + def test_get_supported_tasks_pooling(self): + assert self.runner.is_pooling_model + assert self.runner.get_supported_tasks() == ("embed", ) + + def test_pooler_forward(self): + hidden_states = jnp.arange(6, dtype=jnp.float32).reshape(3, 2) + + metadata = TPUSupportedPoolingMetadata( + prompt_lens=jnp.array([3], dtype=jnp.int32), + first_token_indices=jnp.array([0], dtype=jnp.int32), + last_token_indices=jnp.array([2], dtype=jnp.int32), + num_scheduled_tokens=jnp.array([3], dtype=jnp.int32), + ) + mock_pooler = MagicMock(return_value=hidden_states[-1]) + self.runner.pooler = mock_pooler + outputs = self.runner.pooler(hidden_states, metadata) + np.testing.assert_array_equal(np.asarray(outputs), + np.asarray(hidden_states[-1])) From 690ce1eea843bcb8f7b46769f8d0f3261d1877fb Mon Sep 17 00:00:00 2001 From: carlesoctav Date: Sat, 15 Nov 2025 06:41:13 +0000 Subject: [PATCH 13/13] rename --- tests/e2e/{test_pooling.py => test_embeding_inference.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/e2e/{test_pooling.py => test_embeding_inference.py} (100%) diff --git a/tests/e2e/test_pooling.py b/tests/e2e/test_embeding_inference.py similarity index 100% rename from tests/e2e/test_pooling.py rename to tests/e2e/test_embeding_inference.py