diff --git a/tests/e2e/test_embeding_inference.py b/tests/e2e/test_embeding_inference.py new file mode 100644 index 000000000..c4a107adb --- /dev/null +++ b/tests/e2e/test_embeding_inference.py @@ -0,0 +1,107 @@ +import time +from typing import List + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from transformers import AutoModel, AutoTokenizer +from vllm import LLM + +MODEL_ID = "Qwen/Qwen3-Embedding-0.6B" +MAX_NUM_BATCHED_TOKENS = 128 +MAX_NUM_SEQS = 8 +RTOL = 5e-3 +ATOL = 5e-3 + + +def last_token_pool(last_hidden_states: torch.Tensor, + attention_mask: torch.Tensor) -> torch.Tensor: + """Reference pooling implementation from Qwen3 embedding docs.""" + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + return last_hidden_states[:, -1] + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, + device=last_hidden_states.device), + sequence_lengths] + + +def hf_embeddings(texts: List[str], model: AutoModel, + tokenizer: AutoTokenizer) -> np.ndarray: + """Get reference embeddings using HF Transformers""" + batch_dict = tokenizer(texts, + padding=True, + truncation=True, + max_length=MAX_NUM_BATCHED_TOKENS, + return_tensors="pt") + with torch.no_grad(): + outputs = model(**batch_dict) + embeddings = last_token_pool(outputs.last_hidden_state, + batch_dict["attention_mask"]) + embeddings = F.normalize(embeddings, p=2, dim=1) + return embeddings.cpu().numpy() + + +def vllm_embeddings(texts: List[str]) -> np.ndarray: + """Get embeddings via vLLM """ + llm = LLM(model=MODEL_ID, + runner="pooling", + convert="embed", + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, + max_num_seqs=MAX_NUM_SEQS, + max_model_len=MAX_NUM_BATCHED_TOKENS) + outputs = llm.embed(texts) + embeddings = np.asarray( + [np.array(output.outputs.embedding, dtype=np.float32) for output in outputs]) + del llm + # Wait for TPU runtime tear down before next test. + time.sleep(10) + return embeddings + + +def compare_embeddings(vllm_emb: np.ndarray, + hf_emb: np.ndarray) -> List[tuple[bool, float, float]]: + """Compare embeddings with diagnostics.""" + results = [] + for v_emb, h_emb in zip(vllm_emb, hf_emb): + is_close = np.allclose(v_emb, h_emb, rtol=RTOL, atol=ATOL) + max_diff = float(np.max(np.abs(v_emb - h_emb))) + cos_sim = float(np.dot(v_emb, h_emb) / + (np.linalg.norm(v_emb) * np.linalg.norm(h_emb))) + results.append((is_close, max_diff, cos_sim)) + return results + + +@pytest.mark.tpu +def test_last_token_embedding_pooling(monkeypatch: pytest.MonkeyPatch): + prompts = [ + "The quick brown fox jumps over the lazy dog near the river bank.", + "Machine learning systems process large datasets to extract useful information.", + "Neural networks learn hierarchical representations from raw data automatically.", + "Transformer architectures power modern language models used in production today.", + "Vector embeddings capture semantic meaning in high dimensional spaces for retrieval.", + "Artificial intelligence continues to transform industries across the global economy.", + "Gradient descent iteratively updates parameters to minimize model loss functions.", + "Attention mechanisms allow models to focus on the most relevant parts of input." + ] + + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, + padding_side="left", + trust_remote_code=True) + hf_model = AutoModel.from_pretrained(MODEL_ID, + trust_remote_code=True, + torch_dtype=torch.float32) + hf_model.eval() + + with monkeypatch.context(): + vllm_embeds = vllm_embeddings(prompts) + hf_embeds = hf_embeddings(prompts, hf_model, tokenizer) + + assert vllm_embeds.shape == hf_embeds.shape == (len(prompts), hf_embeds.shape[1]) + + comparisons = compare_embeddings(vllm_embeds, hf_embeds) + for idx, (is_close, max_diff, cos_sim) in enumerate(comparisons): + assert is_close, ( + f"Embedding {idx} mismatch (max_diff={max_diff:.2e}, cos_sim={cos_sim:.6f})") diff --git a/tests/models/jax/test_adapters.py b/tests/models/jax/test_adapters.py new file mode 100644 index 000000000..7e74c0b28 --- /dev/null +++ b/tests/models/jax/test_adapters.py @@ -0,0 +1,147 @@ +from unittest.mock import MagicMock + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from flax import nnx +from flax.typing import PRNGKey +from jax.sharding import Mesh +from vllm.config import ModelConfig +from vllm.config.pooler import PoolerConfig + +from tpu_inference.layers.common.attention_metadata import AttentionMetadata +from tpu_inference.layers.jax.pool.pooler import (CLSPoolingMethod, + LastPoolingMethod, + MeanPoolingMethod) +from tpu_inference.layers.jax.pool.pooling_metadata import ( + TPUSupportedPoolingMetadata, ) +from tpu_inference.models.jax.adapters import as_embedding_model +from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM +from tpu_inference.runner.kv_cache import create_kv_caches + + +class MockVllmConfig: + + def __init__(self, model: str, pooling_type: str): + self.model_config = ModelConfig(model=model) + self.model_config.dtype = jnp.bfloat16 + self.model_config.pooler_config = PoolerConfig( + pooling_type=pooling_type, normalize=False) + self.cache_config = MagicMock(cache_dtype="auto") + self.load_config = MagicMock() + self.load_config.download_dir = None + + +@pytest.fixture(scope="module") +def mesh(): + if not jax.devices(): + pytest.skip("No JAX devices available for mesh creation.") + + devices = np.array(jax.local_devices()[:1]) + device_mesh = devices.reshape((len(devices), 1, 1, 1)) + + with Mesh(device_mesh, + axis_names=('data', 'attn_dp', 'expert', 'model')) as m: + yield m + + +@pytest.fixture +def rng() -> PRNGKey: + return jax.random.PRNGKey(0) + + +@pytest.fixture +def mock_model_inputs(): + num_tokens = 6 + num_reqs = 1 + max_num_blocks_per_req = 4 + input_ids = jnp.arange(num_tokens, dtype=jnp.int32) + positions = jnp.arange(num_tokens, dtype=jnp.int32) + block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req), + dtype=jnp.int32).reshape(-1) + seq_lens = jnp.ones((num_reqs, ), dtype=jnp.int32) + query_start_loc = jnp.arange(num_reqs + 1, dtype=jnp.int32) + request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32) + + attention_metadata = AttentionMetadata( + input_positions=positions, + block_tables=block_tables, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + request_distribution=request_distribution, + ) + + return input_ids, attention_metadata + + +TEST_MODELS = [ + ("Qwen/Qwen3-0.6B", Qwen3ForCausalLM), +] + + +@pytest.mark.parametrize( + ("model_id", "model_cls", "pooling_type", "pooling_cls"), + [ + (model_id, model_cls, pooling_type, pooling_cls) + for model_id, model_cls in TEST_MODELS + for pooling_type, pooling_cls in [ + ("LAST", LastPoolingMethod), + ("CLS", CLSPoolingMethod), + ("MEAN", MeanPoolingMethod), + ] + ], +) +def test_embedding_adapter(model_id, model_cls, pooling_type, pooling_cls, rng, + mesh, mock_model_inputs): + EmbeddingModel = as_embedding_model(model_cls) + vllm_config = MockVllmConfig(model_id, pooling_type) + model = EmbeddingModel(vllm_config, rng, mesh) + + assert isinstance(model.pooler.pooling, pooling_cls) + assert model.is_pooling_model + assert isinstance(model.pooler.head, nnx.Module) + + model.load_weights(rng) + + hf_config = vllm_config.model_config.hf_config + head_dim = 128 + kv_caches = create_kv_caches( + num_blocks=4, + block_size=32, + num_kv_heads=hf_config.num_key_value_heads, + head_size=head_dim, + mesh=mesh, + layer_names=["layer"] * hf_config.num_hidden_layers, + cache_dtype=jnp.bfloat16, + ) + + input_ids, attention_metadata = mock_model_inputs + kv_caches, hidden_states, _ = model(kv_caches, input_ids, + attention_metadata) + + num_tokens = input_ids.shape[0] + pooling_metadata = TPUSupportedPoolingMetadata( + prompt_lens=jnp.array([num_tokens], dtype=jnp.int32), + first_token_indices=jnp.array([0], dtype=jnp.int32), + last_token_indices=jnp.array([num_tokens - 1], dtype=jnp.int32), + num_scheduled_tokens=jnp.array([num_tokens], dtype=jnp.int32), + ) + + embeddings = model.pooler(hidden_states, pooling_metadata) + assert embeddings.shape == (1, hf_config.hidden_size) + + hidden_np = np.array(hidden_states, dtype=np.float32) + last_index = int(pooling_metadata.last_token_indices[0]) + first_index = int(pooling_metadata.first_token_indices[0]) + if pooling_type == "LAST": + expected = hidden_np[last_index] + elif pooling_type == "CLS": + expected = hidden_np[first_index] + else: + start = first_index + end = last_index + 1 + expected = hidden_np[start:end].mean(axis=0) + + np.testing.assert_allclose(np.array(embeddings[0]), expected, rtol=1e-5, + atol=1e-5) diff --git a/tests/runner/test_input_batch.py b/tests/runner/test_input_batch.py index d3391c4e0..5b5e41bc5 100644 --- a/tests/runner/test_input_batch.py +++ b/tests/runner/test_input_batch.py @@ -1,6 +1,7 @@ import numpy as np import pytest from vllm.sampling_params import SamplingParams +from vllm.pooling_params import PoolingParams from tpu_inference.runner.input_batch import CachedRequestState, InputBatch @@ -26,15 +27,36 @@ def input_batch(): ) +@pytest.fixture +def input_batch_for_pooling(): + return InputBatch( + max_num_reqs=MAX_NUM_REQS, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, + pin_memory=False, + vocab_size=VOCAB_SIZE, + block_sizes=BLOCK_SIZES, + is_pooling_model = True, + is_spec_decode=False, + ) + + + def create_dummy_request(req_id: str, prompt_len: int = 10, output_len: int = 5, sampling_params: SamplingParams = None, + pooling_params: PoolingParams = None, block_ids=None) -> CachedRequestState: """Helper function to create a CachedRequestState instance.""" + if sampling_params is None: sampling_params = SamplingParams(temperature=0.8, top_p=0.9, top_k=50) + if pooling_params: + sampling_params = None + + prompt_token_ids = list(range(prompt_len)) output_token_ids = list(range(prompt_len, prompt_len + output_len)) @@ -49,7 +71,7 @@ def create_dummy_request(req_id: str, prompt_token_ids=prompt_token_ids, mm_features=[], sampling_params=sampling_params, - pooling_params=None, + pooling_params=pooling_params, block_ids=block_ids, num_computed_tokens=0, lora_request=None, @@ -210,3 +232,95 @@ def test_all_greedy_property(input_batch: InputBatch): # Remove it, should be true again input_batch.random_reqs.remove("req-r") assert input_batch.all_greedy + + + +def test_add_pooling_request(input_batch_for_pooling: InputBatch): + pooling_params = PoolingParams(dimensions = 768, normalize = True, use_activation = True) + req = create_dummy_request("req-1", prompt_len = 20, output_len = 4, pooling_params = pooling_params) + input_batch_for_pooling.add_request(req) + + assert input_batch_for_pooling.num_reqs == 1 + assert "req-1" in input_batch_for_pooling.req_id_to_index + assert input_batch_for_pooling.req_id_to_index["req-1"] == 0 + assert input_batch_for_pooling.req_ids == ["req-1"] + assert len(input_batch_for_pooling.spec_decode_unsupported_reqs) == 0 + + assert input_batch_for_pooling.num_prompt_tokens[0] == 20 + assert input_batch_for_pooling.num_tokens[0] == 24 + assert input_batch_for_pooling.num_tokens_no_spec[0] == 24 + expected_tokens = np.array(req.prompt_token_ids + req.output_token_ids) + np.testing.assert_array_equal(input_batch_for_pooling.token_ids_cpu[0, :24], + expected_tokens) + + assert input_batch_for_pooling.get_pooling_params() == [pooling_params] + + +def test_add_multiple_pooling_requests(input_batch_for_pooling: InputBatch): + pooling_params_1 = PoolingParams(dimensions=512, normalize=True) + pooling_params_2 = PoolingParams(dimensions=1024, normalize=False) + + req1 = create_dummy_request("req-1", + prompt_len=8, + output_len=2, + pooling_params=pooling_params_1) + req2 = create_dummy_request("req-2", + prompt_len=6, + output_len=3, + pooling_params=pooling_params_2) + + input_batch_for_pooling.add_request(req1) + input_batch_for_pooling.add_request(req2) + + assert input_batch_for_pooling.num_reqs == 2 + assert input_batch_for_pooling.req_ids == ["req-1", "req-2"] + + pooling_values = input_batch_for_pooling.get_pooling_params() + assert pooling_values[0] == pooling_params_1 + assert pooling_values[1] == pooling_params_2 + + +def test_remove_single_pooling_request(input_batch_for_pooling: InputBatch): + pooling_params_1 = PoolingParams(dimensions=256) + pooling_params_2 = PoolingParams(dimensions=768) + + req1 = create_dummy_request("req-1", pooling_params=pooling_params_1) + req2 = create_dummy_request("req-2", pooling_params=pooling_params_2) + + input_batch_for_pooling.add_request(req1) + input_batch_for_pooling.add_request(req2) + + removed_index = input_batch_for_pooling.remove_request("req-1") + assert removed_index == 0 + assert "req-1" not in input_batch_for_pooling.pooling_params + + input_batch_for_pooling.condense([removed_index]) + + assert input_batch_for_pooling.req_ids == ["req-2"] + pooling_values = input_batch_for_pooling.get_pooling_params() + assert pooling_values == [pooling_params_2] + + +def test_remove_multiple_pooling_requests(input_batch_for_pooling: InputBatch): + pooling_params = [ + PoolingParams(dimensions=128 + i * 64) for i in range(3) + ] + reqs = [ + create_dummy_request(f"req-{i}", pooling_params=pooling_params[i]) + for i in range(3) + ] + + for req in reqs: + input_batch_for_pooling.add_request(req) + + removed_indices = [] + removed_indices.append(input_batch_for_pooling.remove_request("req-0")) + removed_indices.append(input_batch_for_pooling.remove_request("req-2")) + + removed_indices = sorted( + [idx for idx in removed_indices if idx is not None], reverse=True) + input_batch_for_pooling.condense(removed_indices) + + assert input_batch_for_pooling.req_ids == ["req-1"] + pooling_values = input_batch_for_pooling.get_pooling_params() + assert pooling_values == [pooling_params[1]] diff --git a/tests/runner/test_tpu_runner.py b/tests/runner/test_tpu_runner.py index 1fbbb99a2..e07fb6581 100644 --- a/tests/runner/test_tpu_runner.py +++ b/tests/runner/test_tpu_runner.py @@ -5,7 +5,11 @@ import numpy as np from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.config.pooler import PoolerConfig +from tpu_inference.layers.jax.pool.pooler import Pooler +from tpu_inference.layers.jax.pool.pooling_metadata import ( + TPUSupportedPoolingMetadata, ) from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -182,3 +186,95 @@ def test_is_multimodal_model(self): dummy_mm_embeds = jnp.ones((10, 128)) _ = self.runner._get_input_ids_embeds(dummy_input_ids, dummy_mm_embeds) self.runner.get_input_embeddings_fn.assert_not_called() + + +class TestTPUJaxRunnerEmbeddingModel: + + def setup_method(self): + self.mock_devices = [MagicMock(coords=i) for i in range(1)] + self.mock_rng_key = MagicMock() + device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, -1) + self.mock_mesh = jax.make_mesh(device_array.shape, + ('data', 'attn_dp', 'expert', 'model')) + + pooler_config = PoolerConfig(pooling_type="LAST", normalize=False) + pooler = Pooler.for_embed(pooler_config) + + + + with patch('jax.devices', return_value=self.mock_devices), \ + patch('jax.make_mesh', return_value=self.mock_mesh), \ + patch('jax.random.key', return_value=jax.random.PRNGKey(0)), \ + patch('tpu_inference.runner.tpu_runner.nnx.Rngs', return_value=self.mock_rng_key), \ + patch('tpu_inference.runner.tpu_runner.get_model', + return_value=self._model_get_model(pooler)), \ + patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', + return_value=self.mock_mesh): + + model_config = ModelConfig(tokenizer_mode="auto", + trust_remote_code=False, + convert = "embed", + runner = "pooling", + seed=0, + dtype='bfloat16') + model_config.pooler_config = pooler_config + cache_config = CacheConfig(block_size=16, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="auto") + scheduler_config = SchedulerConfig(max_num_seqs=4) + parallel_config = ParallelConfig(pipeline_parallel_size=1, + tensor_parallel_size=1, + worker_use_ray=False) + vllm_config = VllmConfig(model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + speculative_config=None, + observability_config={}, + additional_config={}) + self.runner = TPUModelRunner(vllm_config, + devices=self.mock_devices) + self.runner.load_model() + + def _model_get_model(self, pooler): + class DummyEmbeddingModel: + + def __init__(self, pooler): + self.is_pooling_model = True + self.pooler = pooler + + mock_multimodal_fns = { + "precompile_vision_encoder_fn": None, + "get_multimodal_embeddings_fn": None, + "get_input_embeddings_fn": None, + "get_mrope_input_positions_fn": None + } + return ( + MagicMock(), # TPUModelRunner.model_fn + MagicMock(), # TPUModelRunner.compute_logits_fn + MagicMock(), # TPUModelRunner.combine_hidden_states_fn + mock_multimodal_fns, # TPUModelRunner.multimodal_fns + MagicMock(), # TPUModelRunner.state (model params) + None, # TPUModelRunner.lora_manager + DummyEmbeddingModel(pooler), # TPUModelRunner.model + ) + + def test_get_supported_tasks_pooling(self): + assert self.runner.is_pooling_model + assert self.runner.get_supported_tasks() == ("embed", ) + + def test_pooler_forward(self): + hidden_states = jnp.arange(6, dtype=jnp.float32).reshape(3, 2) + + metadata = TPUSupportedPoolingMetadata( + prompt_lens=jnp.array([3], dtype=jnp.int32), + first_token_indices=jnp.array([0], dtype=jnp.int32), + last_token_indices=jnp.array([2], dtype=jnp.int32), + num_scheduled_tokens=jnp.array([3], dtype=jnp.int32), + ) + mock_pooler = MagicMock(return_value=hidden_states[-1]) + self.runner.pooler = mock_pooler + outputs = self.runner.pooler(hidden_states, metadata) + np.testing.assert_array_equal(np.asarray(outputs), + np.asarray(hidden_states[-1])) diff --git a/tpu_inference/layers/jax/pool/pooler.py b/tpu_inference/layers/jax/pool/pooler.py new file mode 100644 index 000000000..1e1577935 --- /dev/null +++ b/tpu_inference/layers/jax/pool/pooler.py @@ -0,0 +1,256 @@ +import enum +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from flax import nnx +from tpu_inference.layers.jax.pool.pooling_metadata import TPUSupportedPoolingMetadata, is_partial_prefill + +from vllm.config.pooler import PoolerConfig + + +# [padded_num_reqs, dim] +# or [padded_num_reqs, padded_max_num_batchec_token_per_req, dim] for allpool +PoolerOutput = jax.Array + + +class PoolingType(enum.Enum): + LAST = "LAST" + MEAN = "MEAN" + CLS = "CLS" + ALL = "ALL" + + +@dataclass(frozen=True) +class ResolvedPoolingConfig: + task: str + pooling_type: PoolingType + normalize: bool + + @classmethod + def from_config( + cls, + task: str, + pooler_config: PoolerConfig | None, + ) -> "ResolvedPoolingConfig": + pooler_config = pooler_config or PoolerConfig() + + # The encode functionality is currently disabled because we cannot use DispatchPooler + # as intended. (It was part of ModelForEmbedding, and in newer versions it was renamed to token_embed.) + # This is because TPU does not support alternating requests between these two tasks, and it is + # out of scope to change the vllm request handler/API server to separate these requests. + # Therefore, this is disabled by default—users cannot use token_embed/encode functionality for now. + + if task == "embed": + default_pooling_type = PoolingType.LAST + default_normalize = True + elif task == "encode": + raise ValueError(f"Unsupported pooling task: {task}") + else: + raise ValueError(f"Unsupported pooling task: {task}") + + pooling_type_str = pooler_config.pooling_type or default_pooling_type.name + pooling_type = PoolingType(pooling_type_str.upper()) + normalize = ( + pooler_config.normalize + if pooler_config.normalize is not None + else default_normalize + ) + + return cls(task=task, pooling_type=pooling_type, normalize=normalize) + + +class PoolingMethod(nnx.Module): + """ + Base class for pooling methods. Factory method `from_pooling_type` creates + specific pooling method instances based on the provided `PoolingType`. + """ + @staticmethod + def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": + if pooling_type is PoolingType.ALL: + raise NotImplementedError("ALL pooling is not implemented yet.") + # return AllPoolingMethod() + if pooling_type is PoolingType.MEAN: + return MeanPoolingMethod() + if pooling_type is PoolingType.LAST: + return LastPoolingMethod() + if pooling_type is PoolingType.CLS: + return CLSPoolingMethod() + raise NotImplementedError(f"Unsupported pooling type: {pooling_type}") + + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + raise NotImplementedError + + +class AllPoolingMethod(PoolingMethod): + """ + Pools all token embeddings for each request. + Most of the time, this is for encoder models; hence, it requires full sequences during the prefill because of bidirectional attention. + """ + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + raise NotImplementedError("ALL pooling is not implemented yet.") + + +class MeanPoolingMethod(PoolingMethod): + """ + Pools by computing the mean of token embeddings for each request. + Most of the time, this is for encoder models; hence, it requires full sequences during the prefill because of bidirectional attention. + """ + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + padded_prompt_lens = pooling_metadata.prompt_lens + padded_start_indices = pooling_metadata.first_token_indices + padded_end_indices = pooling_metadata.last_token_indices + cumsum = jnp.cumsum(hidden_states, axis = 0, dtype=jnp.float32) + + return ( + cumsum[padded_end_indices] + - cumsum[padded_start_indices] + + hidden_states[padded_start_indices] + ) / padded_prompt_lens[:, None] + + +class LastPoolingMethod(PoolingMethod): + """ + Pools by selecting the last token embedding for each request. + Most of the time, this is used for causal/decoder models and is also supported with partial prefill. + """ + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + return hidden_states[pooling_metadata.last_token_indices] + + +class CLSPoolingMethod(PoolingMethod): + """ + Pools by selecting the [CLS] token (first token) embedding for each request. + Most of the time, this is for encoder models; hence, it requires full sequences during the prefill because of bidirectional attention. + """ + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> jax.Array: + return hidden_states[pooling_metadata.first_token_indices] + + +class PoolerHead(nnx.Module): + """ + Base class for Pooler Heads that process the pooled output. + """ + def __call__( + self, + pooled: jax.Array, + token_embeddings: jax.Array, + token_mask: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> PoolerOutput: + raise NotImplementedError + + +class EmbeddingPoolerHead(PoolerHead): + """ + Pooler head for embedding tasks, primarily responsible for normalization. + """ + def __init__(self, default_normalize: bool) -> None: + super().__init__() + self.default_normalize = default_normalize + + def __call__( + self, + pooled: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> PoolerOutput: + + # In the torch version, this part should handle other computations related to pooling_params, such as + # normalization and truncating the embedding dimensions (for matryoshka models). + # The problem with TPU is that we want a consistent output shape, and I feel like + # the best we can do is to handle this outside JIT, on the CPU. + # While you can actually do normalization with jnp.where, or maybe jax.lax.cond for branching in jax.jit, + # for the sake of simplicity, we either normalize all requests or none of them based on pooling_config. + # Pooler output: [padded_num_reqs, dim] + + + if self.default_normalize: + pooled = normalize(pooled) + + return pooled + + +class Pooler(nnx.Module): + """ + Base class for Poolers with factory methods to create Poolers for different tasks. + """ + @staticmethod + def for_encode(pooler_config: PoolerConfig | None) -> "Pooler": + resolved = ResolvedPoolingConfig.from_config("encode", pooler_config) + raise NotImplementedError("EncodePooler is currently disabled.") + + @staticmethod + def for_embed(pooler_config: PoolerConfig | None) -> "Pooler": + resolved = ResolvedPoolingConfig.from_config("embed", pooler_config) + return EmbeddingPooler.from_config(resolved) + + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> PoolerOutput: + raise NotImplementedError + + def get_supported_tasks(self) -> set[str]: + raise NotImplementedError + + +class EmbeddingPooler(Pooler): + """ + Pooler for embedding tasks. + pooling: PoolingMethod instance that performs the pooling operation. + head: EmbeddingPoolerHead instance that processes the pooled output, primarily for normalization. + """ + def __init__( + self, + pooling: PoolingMethod, + head: EmbeddingPoolerHead, + ) -> None: + self.pooling = pooling + self.head = head + + @classmethod + def from_config(cls, config: ResolvedPoolingConfig) -> None: + pooling = PoolingMethod.from_pooling_type(config.pooling_type) + head = EmbeddingPoolerHead(config.normalize) + return cls(pooling, head) + + def __call__( + self, + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + ) -> PoolerOutput: + hidden_states = hidden_states.astype(jnp.float32) + # the output mus be of type torch.tensor, but we cannot convert numpy to torch if the dtype is bf16 + pooled = self.pooling(hidden_states, pooling_metadata) + return self.head(pooled, pooling_metadata) + + def get_supported_tasks(self) -> set[str]: + return ("embed",) + + +def normalize(embeddings: jax.Array) -> jax.Array: + norms = jnp.linalg.norm(embeddings, axis=-1, keepdims=True) + norms = jnp.maximum(norms, 1e-12) + normalized = embeddings / norms + return normalized diff --git a/tpu_inference/layers/jax/pool/pooling.py b/tpu_inference/layers/jax/pool/pooling.py new file mode 100644 index 000000000..8137eaeb6 --- /dev/null +++ b/tpu_inference/layers/jax/pool/pooling.py @@ -0,0 +1,17 @@ +import functools + +import jax + +from .pooler import Pooler, PoolerOutput +from .pooling_metadata import TPUSupportedPoolingMetadata + + +# actually my idea is not to jist this function but the model.pooler, +# we can put some postprocesing here. +@jax.jit +def pool( + hidden_states: jax.Array, + pooling_metadata: TPUSupportedPoolingMetadata, + pooler: Pooler, +) -> PoolerOutput: + return pooler(hidden_states, pooling_metadata) diff --git a/tpu_inference/layers/jax/pool/pooling_metadata.py b/tpu_inference/layers/jax/pool/pooling_metadata.py new file mode 100644 index 000000000..d07e305d9 --- /dev/null +++ b/tpu_inference/layers/jax/pool/pooling_metadata.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import functools +import logging +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh +from tpu_inference.runner.input_batch import InputBatch +from tpu_inference.utils import device_array + +logger = logging.getLogger(__name__) + + +SUPPORTED_POOLING_TASKS = {"embed"} + + +def build_pooling_cursor( + num_scheduled_tokens: list[int], + padded_num_reqs: int, +): + + n_seq = len(num_scheduled_tokens) + padded_num_scheduled_tokens = jnp.zeros(padded_num_reqs) + padded_num_scheduled_tokens = padded_num_scheduled_tokens.at[:n_seq].set( + jnp.asarray(num_scheduled_tokens, dtype=jnp.int32) + ) + cumsum = jnp.cumsum(padded_num_scheduled_tokens, dtype = jnp.int64) + first_token_indices = jnp.concatenate((jnp.asarray((0,)), cumsum[:-1])) + last_token_indices = (first_token_indices + padded_num_scheduled_tokens - 1).astype(jnp.int64) + last_token_indices = jnp.where( + padded_num_scheduled_tokens > 0, last_token_indices, first_token_indices + ) + return first_token_indices, last_token_indices, padded_num_scheduled_tokens + + +@functools.partial( + jax.tree_util.register_dataclass, + data_fields=( + "prompt_lens", + "first_token_indices", + "last_token_indices", + "num_scheduled_tokens", + ), + meta_fields = (), +) +@dataclass +class TPUSupportedPoolingMetadata: + """Device metadata required for pooling computations.""" + + prompt_lens: jax.Array + first_token_indices: jax.Array + last_token_indices: jax.Array + num_scheduled_tokens: jax.Array + + @classmethod + def from_input_batch( + cls, + mesh: Mesh, + input_batch: InputBatch, + padded_num_scheduled_tokens: list[int], + padded_num_reqs: int, + ) -> TPUSupportedPoolingMetadata: + pooling_params_list = input_batch.get_pooling_params() + + num_reqs = input_batch.num_reqs + assert len(pooling_params_list) == num_reqs + assert len(input_batch.num_prompt_tokens[:num_reqs]) == len(padded_num_scheduled_tokens) + + padded_prompt_lens= jnp.zeros(padded_num_reqs, dtype=np.int32) + padded_prompt_lens= padded_prompt_lens.at[:num_reqs].set(input_batch.num_prompt_tokens[:num_reqs]) + + first_token_indices, last_token_indices, padded_num_scheduled_tokens = build_pooling_cursor( + padded_num_scheduled_tokens, padded_num_reqs + ) + + prompt_lens, first_token_indices, last_token_indices, num_scheduled_tokens = device_array( + mesh, + (padded_prompt_lens, first_token_indices, last_token_indices, padded_num_scheduled_tokens), + ) + + #everything in pooling_metadata is padded. + return cls( + prompt_lens=prompt_lens, + first_token_indices=first_token_indices, + last_token_indices=last_token_indices, + num_scheduled_tokens = num_scheduled_tokens, + ) + + +def is_partial_prefill(pooling_metadata: TPUSupportedPoolingMetadata): + return not jnp.all(pooling_metadata.prompt_lens == pooling_metadata.num_scheduled_tokens) diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index dd4082709..c486dba8a 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -13,6 +13,7 @@ from tpu_inference import envs from tpu_inference.layers.common.sharding import ShardingAxisName from tpu_inference.logger import init_logger +from tpu_inference.models.jax.adapters import as_embedding_model from tpu_inference.models.jax.utils.quantization.quantization_utils import ( apply_qwix_on_abstract_model, apply_qwix_quantization, load_random_weights_into_qwix_abstract_model) @@ -184,6 +185,7 @@ def create_sharded_model(): return jit_model + # TODO(pooyam): We need to refactor this. This is returning a bunch of functions that do not work with all models and this is not very easy to see from the code. def get_flax_model( vllm_config: VllmConfig, @@ -191,12 +193,22 @@ def get_flax_model( mesh: Mesh, is_draft_model: bool = False, ) -> nnx.Module: - if is_draft_model: - model_class = _get_model_architecture( - vllm_config.speculative_config.draft_model_config.hf_config) - else: - model_class = _get_model_architecture( - vllm_config.model_config.hf_config) + + model_config = ( + vllm_config.speculative_config.draft_model_config + if is_draft_model + else vllm_config.model_config + ) + + model_class = _get_model_architecture(model_config.hf_config) + + convert_type = getattr(model_config, "convert_type", "none") + if convert_type == "embed": + logger.debug_once( "Converting %s to embedding model", model_class.__name__) + model_class = as_embedding_model(model_class) + elif convert_type not in ("none", "generate"): + raise NotImplementedError( f"TPU JAX backend does not support convert_type={convert_type!r} yet") + jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh) kv_cache_sharding = NamedSharding( mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model")) @@ -269,7 +281,7 @@ def combine_hidden_states(graphdef, state, hidden_states): run_get_multimodal_embeddings, graphdef) get_input_embeddings_fn = functools.partial(run_get_input_embeddings, graphdef) - lora_manager, model = None, None + lora_manager, _ = None, None combine_hidden_states_fn = functools.partial(combine_hidden_states, graphdef) diff --git a/tpu_inference/models/jax/adapters.py b/tpu_inference/models/jax/adapters.py new file mode 100644 index 000000000..c978f4136 --- /dev/null +++ b/tpu_inference/models/jax/adapters.py @@ -0,0 +1,108 @@ +import typing as tp + +import torch +import jax +from flax import nnx +from flax.typing import PRNGKey +from jax.sharding import Mesh + +from tpu_inference.layers.jax.pool.pooler import Pooler +from vllm.config import VllmConfig + +_T = tp.TypeVar("_T", bound=type[nnx.Module]) + +_GENERATE_SUFFIXES = ( + "ForCausalLM", + "ForConditionalGeneration", +) + +class PoolingMixin: + """ + VllmModelForPooling + The reason for creating this Mixin instead of VllmModelForPooling is due to a conflict in the metaclass between nnx.Module and VllmModelForPooling. + """ + is_pooling_model: tp.ClassVar[tp.Literal[True]] = True + + default_pooling_type: tp.ClassVar[str] = "LAST" + pooler: Pooler + + +def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: + model_name = orig_model_name + for suffix in _GENERATE_SUFFIXES: + model_name = model_name.removesuffix(suffix) + return model_name + pooling_suffix + + +def _create_pooling_model_cls(orig_cls: _T) -> _T: + class ModelForPooling(orig_cls, PoolingMixin): + is_pooling_model = True + + def __init__( + self, + vllm_config: VllmConfig, + rng_key: jax.Array, + mesh: Mesh, + ) -> None: + super().__init__( + vllm_config=vllm_config, + rng_key=rng_key, + mesh=mesh, + ) + + if getattr(self, "pooler", None) is None: + self._init_pooler(vllm_config=vllm_config) + + def _init_pooler(self, vllm_config: VllmConfig) -> None: + raise NotImplementedError + + return ModelForPooling + + +def as_embedding_model(cls: _T) -> _T: + """ + convert a `CausalModel` to an embedding model by adding a Pooler for embedding + """ + + class ModelForEmbedding(_create_pooling_model_cls(cls)): + def _init_pooler(self, vllm_config: VllmConfig) -> None: + pooler_config = vllm_config.model_config.pooler_config + if pooler_config is None: + raise ValueError( + "Embedding models require `pooler_config` to be set in the model configuration." + ) + + self.pooler = Pooler.for_embed(pooler_config) + + ModelForEmbedding.__name__ = _get_pooling_model_name( + cls.__name__, + "ForEmbedding", + ) + return ModelForEmbedding + + + +def init_pooler_from_vllm_model( + vllm_model: torch.nn.Module, + vllm_config: VllmConfig, + rng_key: PRNGKey, + mesh: Mesh, +): + class DummyModule: + def __init__(self, vllm_config, rng_key, mesh): + pass + + for suffix in _GENERATE_SUFFIXES: + if suffix in vllm_model.__class__.__name__: + return None + + if "ForEmbedding" in vllm_model.__class__.__name__: + EmbedModel = as_embedding_model(DummyModule) + + embed_model = EmbedModel(vllm_config=vllm_config, rng_key=rng_key, mesh=mesh,) + embed_model._init_pooler(vllm_config) + return embed_model.pooler + else: + raise NotImplementedError( + f"Pooling initialization for {vllm_model.__class__.__name__} is not implemented." + ) diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 64f026dae..36f6287f1 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -316,6 +316,11 @@ def _load_hf_weights_on_thread(vllm_config, if hf_key.endswith(".weight"): hf_key = hf_key.removesuffix(".weight") + base_hf_key = hf_key + + if not hf_key.startswith("model."): + hf_key = f"model.{hf_key}" + # Find the corresponding model key using the HF key if "layers" in hf_key: layer_num = re.search(r"layers\.(\d+)", hf_key).group(1) @@ -328,7 +333,7 @@ def _load_hf_weights_on_thread(vllm_config, model_key = name_map[layer_key] model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key) else: - if hf_key not in name_map and hf_key == "lm_head": + if hf_key not in name_map and base_hf_key == "lm_head": logger.warning( f"Skip loading {hf_key} due to tie_word_embeddings") continue diff --git a/tpu_inference/models/vllm/vllm_model_wrapper.py b/tpu_inference/models/vllm/vllm_model_wrapper.py index 79c059b72..e429de77e 100644 --- a/tpu_inference/models/vllm/vllm_model_wrapper.py +++ b/tpu_inference/models/vllm/vllm_model_wrapper.py @@ -30,6 +30,8 @@ from tpu_inference.models.vllm.vllm_model_wrapper_context import ( get_vllm_model_wrapper_context, set_vllm_model_wrapper_context) from tpu_inference.runner.lora_utils import replace_lora_metadata +from tpu_inference.layers.jax.pool.pooler import Pooler +from tpu_inference.models.jax.adapters import init_pooler_from_vllm_model logger = init_logger(__name__) @@ -72,6 +74,7 @@ class VllmModelWrapper: rng: PRNGKey mesh: Mesh model: _VllmRunner + pooler: Pooler def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh): self.vllm_config = vllm_config @@ -137,6 +140,10 @@ def load_weights(self): self.model = _VllmRunner(vllm_model) params_and_buffers = shard_model_to_tpu(self.model, self.mesh) + + # TODO: need to seperate this params_and_buffer for pooler (some pooler is not stateless) + self.pooler = init_pooler_from_vllm_model(vllm_model, self.vllm_config, self.rng, self.mesh) + # Returning to the jax land, so we need to wrap it into a JaxValue. return jax_view(params_and_buffers), lora_manager diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 86c55adc1..293e55d39 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -11,11 +11,17 @@ from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.common.sharding import ShardingAxisName +from tpu_inference.layers.jax.pool.pooling import pool +from tpu_inference.layers.jax.pool.pooling_metadata import ( + TPUSupportedPoolingMetadata, +) from tpu_inference.layers.jax.sample.sampling import sample -from tpu_inference.layers.jax.sample.sampling_metadata import \ - TPUSupportedSamplingMetadata +from tpu_inference.layers.jax.sample.sampling_metadata import ( + TPUSupportedSamplingMetadata, +) from tpu_inference.logger import init_logger from tpu_inference.utils import device_array +from torchax.ops.mappings import t2j_dtype if TYPE_CHECKING: from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -79,6 +85,9 @@ def capture_model(self) -> None: self._run_compilation, ) self._precompile_input_embeddings_merger() self._precompile_backbone_with_inputs_embeds() + if self.runner.is_pooling_model: + self._precompile_pooling() + return if self.runner.scheduler_config.async_scheduling: self._precompile_substitute_placeholder_token() self._precompile_select_from_array() @@ -90,6 +99,52 @@ def capture_model(self) -> None: if self.runner.speculative_config: self._precompile_speculative_decoding() + def _precompile_pooling(self) -> None: + pooler = getattr(self.runner, "pooler", None) + if pooler is None: + logger.warning( + "Pooling precompile skipped because model has no pooler attribute.") + return + + logger.info("Precompile pooling kernels for pooling models.") + + hidden_size = self.runner.model_config.get_hidden_size() + dtype = self.runner.model_config.dtype + hidden_sharding = NamedSharding( + self.runner.mesh, PartitionSpec(None, None)) + + for num_tokens in self.runner.num_tokens_paddings: + hidden_states = self._create_dummy_tensor( + (num_tokens, hidden_size), jnp.bfloat16, sharding=hidden_sharding) + + for num_reqs in self.runner.num_reqs_paddings: + if num_reqs == 0 or num_reqs > num_tokens: + continue + + # can we just use (one array here) + prompt_lens = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + first_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + last_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + num_scheduled_tokens = self._create_dummy_tensor(num_reqs, dtype = jnp.int32) + + + pooling_metadata = TPUSupportedPoolingMetadata( + prompt_lens=prompt_lens, + first_token_indices=first_token_indices, + last_token_indices=last_token_indices, + num_scheduled_tokens = num_scheduled_tokens, + ) + + self._run_compilation( + "pool", + pool, + hidden_states, + pooling_metadata, + pooler, + num_tokens=num_tokens, + num_reqs=num_reqs, + ) + def _precompile_input_embeddings_merger(self) -> None: for num_tokens in self.runner.num_tokens_paddings: hidden_size = self.runner.vllm_config.model_config.get_hidden_size( diff --git a/tpu_inference/runner/input_batch.py b/tpu_inference/runner/input_batch.py index 79f28ebb0..5d5e41a8d 100644 --- a/tpu_inference/runner/input_batch.py +++ b/tpu_inference/runner/input_batch.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import numpy as np from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.utils.collection_utils import swap_dict_values from vllm.v1.core.sched.output import NewRequestData @@ -52,9 +53,11 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], + is_pooling_model: bool = False, is_spec_decode: bool = False, ): self.is_spec_decode = is_spec_decode + self.is_pooling_model = is_pooling_model self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_batched_tokens = max_num_batched_tokens @@ -131,6 +134,7 @@ def __init__( self.req_output_token_ids: list[Optional[list[int]]] = [] self.request_distribution: list[int] = [0, 0, 0] + self.pooling_params: dict[str, PoolingParams] = {} @property def req_ids(self) -> list[str]: @@ -175,71 +179,73 @@ def add_request( self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(request.block_ids, req_index) - sampling_params = request.sampling_params - - if (self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): - self.spec_decode_unsupported_reqs.add(req_id) - - if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 - self.greedy_reqs.add(req_id) - else: - self.temperature_cpu[req_index] = sampling_params.temperature - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - top_k = sampling_params.top_k - if top_k <= 0 or top_k >= self.vocab_size: - top_k = 1 - self.top_k_cpu[req_index] = top_k - if sampling_params.min_tokens: - self.min_tokens[req_index] = (sampling_params.min_tokens, - sampling_params.all_stop_token_ids) - - # NOTE(woosuk): self.generators should not include the requests that - # do not have their own generator. - if request.generator is not None: - self.generators[req_index] = request.generator - - if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = sampling_params.logprobs - if sampling_params.logit_bias is not None: - self.logit_bias[req_index] = sampling_params.logit_bias - - if sampling_params.allowed_token_ids: - self.has_allowed_token_ids.add(req_id) - if self.allowed_token_ids_mask_cpu is None: - # Lazy allocation for this tensor, which can be large. + if sampling_params := request.sampling_params: + if (self.is_spec_decode + and is_spec_decode_unsupported(sampling_params)): + self.spec_decode_unsupported_reqs.add(req_id) + + if sampling_params.sampling_type == SamplingType.GREEDY: + # Avoid later division by zero. + self.temperature_cpu[req_index] = -1.0 + self.greedy_reqs.add(req_id) + else: + self.temperature_cpu[req_index] = sampling_params.temperature + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + top_k = sampling_params.top_k + if top_k <= 0 or top_k >= self.vocab_size: + top_k = 1 + self.top_k_cpu[req_index] = top_k + if sampling_params.min_tokens: + self.min_tokens[req_index] = (sampling_params.min_tokens, + sampling_params.all_stop_token_ids) + + + # NOTE(woosuk): self.generators should not include the requests that + # do not have their own generator. + if request.generator is not None: + self.generators[req_index] = request.generator + + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = sampling_params.logprobs + if sampling_params.logit_bias is not None: + self.logit_bias[req_index] = sampling_params.logit_bias + + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu is None: + # Lazy allocation for this tensor, which can be large. + # False means we don't fill with -inf. + self.allowed_token_ids_mask = jnp.zeros(self.max_num_reqs, + self.vocab_size, + dtype=jnp.bool) + self.allowed_token_ids_mask_cpu = np.zeros(self.max_num_reqs, + self.vocab_size, + dtype=np.bool) + self.allowed_token_ids_mask_cpu[req_index] = True # False means we don't fill with -inf. - self.allowed_token_ids_mask = jnp.zeros(self.max_num_reqs, - self.vocab_size, - dtype=jnp.bool) - self.allowed_token_ids_mask_cpu = np.zeros(self.max_num_reqs, - self.vocab_size, - dtype=np.bool) - self.allowed_token_ids_mask_cpu[req_index] = True - # False means we don't fill with -inf. - self.allowed_token_ids_mask_cpu[req_index][ - sampling_params.allowed_token_ids] = False - - if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids - - # Add request lora ID - if request.lora_request: - lora_id = request.lora_request.lora_int_id - if lora_id not in self.lora_id_to_request_ids: - self.lora_id_to_request_ids[lora_id] = set() - - self.request_lora_mapping[req_index] = lora_id - self.lora_id_to_request_ids[lora_id].add(request.req_id) - self.lora_id_to_lora_request[lora_id] = request.lora_request - else: - # No LoRA - self.request_lora_mapping[req_index] = 0 + self.allowed_token_ids_mask_cpu[req_index][ + sampling_params.allowed_token_ids] = False + + if sampling_params.bad_words_token_ids: + self.bad_words_token_ids[ + req_index] = sampling_params.bad_words_token_ids + + # Add request lora ID + if request.lora_request: + lora_id = request.lora_request.lora_int_id + if lora_id not in self.lora_id_to_request_ids: + self.lora_id_to_request_ids[lora_id] = set() + + self.request_lora_mapping[req_index] = lora_id + self.lora_id_to_request_ids[lora_id].add(request.req_id) + self.lora_id_to_lora_request[lora_id] = request.lora_request + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + elif pooling_params := request.pooling_params: + self.pooling_params[req_id] = pooling_params def remove_request(self, req_id: str) -> Optional[int]: """This method must always be followed by a call to condense().""" @@ -256,6 +262,9 @@ def remove_request(self, req_id: str) -> Optional[int]: self.min_tokens.pop(req_index, None) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) + if self.is_pooling_model: + self.pooling_params.pop(req_id, None) + return req_index # LoRA lora_id = self.request_lora_mapping[req_index] @@ -411,6 +420,10 @@ def all_greedy(self) -> bool: def max_num_logprobs(self) -> Optional[int]: return max(self.num_logprobs.values()) if self.num_logprobs else None + def get_pooling_params(self) -> list[PoolingParams]: + assert len(self.pooling_params) == len(self.req_ids) + return [self.pooling_params[req_id] for req_id in self.req_ids] + def make_lora_inputs( self, num_scheduled_tokens: np.ndarray ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: diff --git a/tpu_inference/runner/persistent_batch_manager.py b/tpu_inference/runner/persistent_batch_manager.py index c8c093315..68f68b60d 100644 --- a/tpu_inference/runner/persistent_batch_manager.py +++ b/tpu_inference/runner/persistent_batch_manager.py @@ -122,7 +122,7 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput", prompt_token_ids=new_req_data.prompt_token_ids, mm_features=new_req_data.mm_features, sampling_params=sampling_params, - pooling_params=None, + pooling_params=new_req_data.pooling_params, generator=None, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 09e563928..a526a7e6f 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -46,6 +46,11 @@ gather_logprobs, sample) from tpu_inference.layers.jax.sample.sampling_metadata import \ TPUSupportedSamplingMetadata +from tpu_inference.layers.jax.pool.pooling import pool +from tpu_inference.layers.jax.pool.pooling_metadata import ( + SUPPORTED_POOLING_TASKS, + TPUSupportedPoolingMetadata, +) from tpu_inference.logger import init_logger from tpu_inference.models.common.model_loader import get_model from tpu_inference.models.jax.utils.weight_utils import ( @@ -230,6 +235,9 @@ def __init__( ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext() self.dp_size = self.vllm_config.sharding_config.total_dp_size + self.is_pooling_model = self.model_config.runner_type == "pooling" + self.pooler = None + self._init_random() self._init_mesh() self._init_phased_profiling() @@ -248,6 +256,7 @@ def __init__( self.uses_mrope, self.model_config) self.lora_utils = LoraUtils(self) + cache_config = self.cache_config if cache_config.cache_dtype == "auto": model_dtype = self.dtype @@ -418,6 +427,7 @@ def _init_inputs(self) -> None: pin_memory=False, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], + is_pooling_model = self.is_pooling_model, is_spec_decode=bool(self.vllm_config.speculative_config), ) @@ -490,6 +500,10 @@ def load_model(self): ) multimodal_fns = multimodal_fns or {} + + if self.is_pooling_model: + self.pooler = self.model.pooler + self.precompile_vision_encoder_fn = multimodal_fns.get( "precompile_vision_encoder_fn", None) self.get_multimodal_embeddings_fn = multimodal_fns.get( @@ -513,6 +527,8 @@ def load_model(self): f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB") def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + if self.is_pooling_model: + return self.pooler.get_supported_tasks() return ("generate", ) def get_kv_cache_spec(self): @@ -689,11 +705,13 @@ def _execute_model( input_ids, attn_metadata, _, + pooling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, ) = self._prepare_inputs(scheduler_output) + # multi-modal support if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -737,6 +755,35 @@ def _execute_model( lora_metadata, ) + if self.is_pooling_model: + assert pooling_metadata is not None + num_reqs = self.input_batch.num_reqs + req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) + + raw_pooler_output = pool(hidden_states, pooling_metadata, self.pooler) + raw_pooler_output = np.asarray(jax.device_get(raw_pooler_output))[:num_reqs] + prompt_lens = np.asarray( + jax.device_get(pooling_metadata.prompt_lens) + )[:num_reqs] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + + + pooler_output = [] + for raw_output, seq_len, prompt_len in zip(raw_pooler_output, seq_lens_cpu, prompt_lens): + output = raw_output if seq_len == prompt_len else None + pooler_output.append(torch.from_numpy(raw_output)) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + ) + return attn_metadata, model_runner_output + hidden_states = self._select_from_array_fn(hidden_states, logits_indices) logits = self.compute_logits_fn( @@ -1399,6 +1446,7 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + pooling_metadata = None # Get the number of scheduled tokens for each request. num_scheduled_tokens_per_req = [] @@ -1436,6 +1484,14 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): arange = np.concatenate( [self.arange_cpu[:n] for n in num_scheduled_tokens_per_req]) + if self.is_pooling_model: + pooling_metadata = TPUSupportedPoolingMetadata.from_input_batch( + self.mesh, + self.input_batch, + num_scheduled_tokens_per_req, + padded_num_reqs, + ) + # Get positions. positions_np = self.positions_cpu[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], @@ -1557,8 +1613,16 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): attention_metadata.query_start_loc_cpu = query_start_loc_cpu attention_metadata.seq_lens_cpu = seq_lens_cpu logits_indices_selector = None - return (input_ids, attention_metadata, sampling_metadata, - logits_indices, spec_decode_metadata, logits_indices_selector) + + return ( + input_ids, + attention_metadata, + sampling_metadata, + pooling_metadata, + logits_indices, + spec_decode_metadata, + logits_indices_selector + ) def _get_input_ids_embeds(self, input_ids: jax.Array, mm_embeds: list[jax.Array]):