From bfab8cc3e2bd71072d8816a7d3cf934ea9ade65e Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 27 Oct 2025 18:24:05 +0000 Subject: [PATCH 01/18] Save an initial version. Not capable of running at all. Signed-off-by: Xiongfei Wei --- tests/lora/conftest.py | 32 ++++ tests/lora/test_layers.py | 385 ++++++++++++++++++++++++++++++++++++++ tests/lora/utils.py | 96 ++++++++++ 3 files changed, 513 insertions(+) create mode 100644 tests/lora/conftest.py create mode 100644 tests/lora/test_layers.py create mode 100644 tests/lora/utils.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py new file mode 100644 index 000000000..d573070de --- /dev/null +++ b/tests/lora/conftest.py @@ -0,0 +1,32 @@ +import tempfile + +import pytest +from vllm.config import set_current_vllm_config +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.engine.arg_utils import EngineArgs + + +@pytest.fixture +def dist_init(): + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + + vllm_config = engine_args.create_engine_config() + + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + ensure_model_parallel_initialized(1, 1) + yield vllm_config + cleanup_dist_env_and_memory(shutdown_ray=True) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py new file mode 100644 index 000000000..a40afb532 --- /dev/null +++ b/tests/lora/test_layers.py @@ -0,0 +1,385 @@ +import random +from typing import Optional + +import jax +import pytest +import torch +import torchax +from jax.sharding import NamedSharding, PartitionSpec +from torchax.interop import torch_view +from torchax.ops.mappings import j2t, t2j +# from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu +from vllm.config import LoRAConfig +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, + MergedColumnParallelLinearWithLoRA) +# yapf: enable +from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.model_executor.layers.linear import MergedColumnParallelLinear +from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform + +from .utils import DummyLoRAManager + +# TODO(xiowei): +# - add test for multi-chip. +# - add equivalent test for ColumnParallelLinearWithShardedLoRA. + +P = PartitionSpec + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + +pytestmark = pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test is only for TPU platform.") + +# prefill stage(True) or decode stage(False) +STAGES = [True, False] + + +def check_punica_wrapper(punica_wrapper) -> bool: + from tpu_inference.lora.torch_punica_tpu import PunicaWrapperTPU + return type(punica_wrapper) is PunicaWrapperTPU + + +def get_random_index_to_id(num_loras: int, + num_slots: int, + log: bool = True) -> list[Optional[int]]: + """Creates a random index_to_lora_id mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + + returns: + index_to_lora_id: a random index_to_lora_id mapping. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: list[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + # xw32: It seems the slot_idx start at 1. + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + index_to_id: list[Optional[int]], + lora_layer: BaseLayerWithLoRA, + baselayer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: + """This method populates the lora layers (BaseLayerWithLoRA) with lora weights. + + Args: + index_to_id: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + lora_layer: the LoRAlayer to populate. + baselayer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + + returns: + lora_dict: a dictionary dict[int, LoRALayerWeights] that maps the lora ID to the corresponding lora weights. + sublora_dict: a dictionary dict[int, list[LoRALayerWeights]] that maps the lora ID to the corresponding lora weights. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. + sublora_dict: dict[int, list[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(index_to_id): + if lora_id is not None: + subloras: list[LoRALayerWeights] = [] + sublora_len = baselayer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager( + baselayer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=baselayer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[(sublora_len * + i):(sublora_len * (i + 1)), :] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + # Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env. + with torchax.default_env(), jax.default_device( + jax.devices("tpu")[0]): + lora_layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: list[int], + num_inputs: int, + input_size: tuple[int, ...], + input_range: tuple[float, float], + input_type: torch.dtype = torch.int, + device: torch.device = "cpu", +) -> tuple[list[torch.Tensor], list[int], list[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. Or the number of requests. + input_size: the size of each individual input. Or the number of tokens. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + + returns: + inputs: a list of torch tensors of size num_inputs. Each input has shape `input_size`. + index_mapping: maps each input token to a lora ID. + prompt_mapping: maps each request to a lora ID. + """ + + low, high = input_range + + inputs: list[torch.Tensor] = [] + index_mapping: list[int] = [] + prompt_mapping: list[int] = [] + + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device=device)) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device=device) * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) +@pytest.mark.parametrize("repeats", [2]) +@pytest.mark.parametrize("fully_shard", [False]) # TODO(xiowei): add "True". +@pytest.mark.parametrize("device", ["cpu"]) +@pytest.mark.parametrize("stage", [True, False]) +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device, stage) -> None: + max_loras = 9 + max_lora_rank = 8 + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=max_lora_rank, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) + + axis_names = ("data", "model") + mesh_shape = ( + 1, 1 + ) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices())) + mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices()) + + def create_column_parallel_packed_layer(): + # Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper. + if repeats == 2: + linear = MergedColumnParallelLinear( + 256, # input_size + [256] * repeats, # output_size + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = MergedColumnParallelLinearWithLoRA( + linear + ) # TODO(xiowei): add test for MergedColumnParallelLinearWithShardedLoRA (fully_shard == True) + elif repeats == 3: + # TODO(xiowei): add test for this case. + raise NotImplementedError("NYI: for MergedQKVParallelLinear case") + else: + # TODO(xiowei): add test for this case. + raise NotImplementedError("NYI: for QKVParallelLinear case") + + n_slices = repeats + # create_lora_weights creates global shape weight. + lora_linear.create_lora_weights(max_loras, lora_config) + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == n_slices) + + return linear, lora_linear + + set_random_seed(6) + + linear, lora_linear = create_column_parallel_packed_layer() + # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor + # BaseLinearLayerWithLoRA.weight property guarantees this. + assert torch.equal(linear.weight, lora_linear.weight) + with torchax.default_env(): + assert torch.equal(linear.weight.data, j2t(lora_linear.weight)) + + max_num_batched_tokens = 8192 + max_batches = 256 + punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches, + device, + max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) + lora_linear.set_mapping(punica_wrapper) + + # load the lora weight, shard it, and send it to TPU. + # create a lora slot index to lora id mapping. + index_to_id = get_random_index_to_id(num_loras, max_loras) + # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights + lora_dict, sublora_dict = populate_loras( + index_to_id, + lora_layer=lora_linear, + baselayer_weights=linear.weight, + repeats=repeats, + ) + + # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 256]. + # index_mapping: list[int] + # prompt_mapping: list[int] + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32, + input_size=(1, 256), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + max_loras, + vocab_size=512, + extra_vocab_size=lora_config.lora_extra_vocab_size, + ) + punica_wrapper.move_to_device(mesh) + + jax_inputs = [] + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + for input in inputs: + # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` + # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + jax_inputs.append(jax_input) + with torchax.default_env(): + lora_result = lora_linear(torch.cat(jax_inputs))[0] + lora_result = j2t(lora_result) + + expected_results: list[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + # linear(input_) returns (output, output_bias) so we only need the first one. + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * + (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * + sublora.scaling) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' + ) + + # Check that resetting the lora weights succeeds + # Here we set all lora weight to be empty. + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], # different from the above create_random_inputs + num_inputs=32, + input_size=(1, 256), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + punica_wrapper.move_to_device(mesh) + + jax_inputs = [] + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + for input in inputs: + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + jax_inputs.append(jax_input) + with torchax.default_env(): + lora_result = lora_linear(torch.cat(jax_inputs))[0] + lora_result = j2t(lora_result) + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py new file mode 100644 index 000000000..41c5cf38d --- /dev/null +++ b/tests/lora/utils.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights + + +# https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py +class DummyLoRAManager: + + def __init__(self, device: torch.device = "cuda:0"): + super().__init__() + self._loras: dict[str, LoRALayerWeights] = {} + self._device = device + + def set_module_lora(self, module_name: str, lora: LoRALayerWeights): + self._loras[module_name] = lora + + def get_module_lora(self, module_name: str) -> LoRALayerWeights: + return self._loras[module_name] + + def init_random_lora( + self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0, + ): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([rank, weight.shape[1]], + dtype=weight.dtype, + device=self._device), + lora_b=torch.rand([weight.shape[0], rank], + dtype=weight.dtype, + device=self._device), + ) + if generate_embeddings_tensor: + lora.embeddings_tensor = torch.rand( + 5, + generate_embeddings_tensor, + dtype=weight.dtype, + device=self._device, + ) + self.set_module_lora(module_name, lora) + + return lora + + def init_lora( + self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None, + ): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([rank, input_dim], device="cuda"), + lora_b=torch.rand([output_dim, input_dim], device="cuda"), + embeddings_tensor=embeddings_tensor, + ) + self.set_module_lora(module_name, lora) + return lora + + def reset_lora(self): + self._loras = {} + + def init_packed_lora( + self, + module_name: str, + input_dim: int, + output_dims: list[int], + noop_lora_index: list[int] | None = None, + rank: int = 8, + ): + base_loras: list[LoRALayerWeights] = [] + noop_lora_index_set = set(noop_lora_index or []) + + for i, out_dim in enumerate(output_dims): + base_lora = self.init_lora( + module_name + "_000_" + str(i), + input_dim, + out_dim, + rank=rank, + noop=i in noop_lora_index_set, + ) + base_loras.append(base_lora) + packed_lora = PackedLoRALayerWeights.pack(base_loras) + self.set_module_lora(module_name, packed_lora) + return packed_lora From 492138682fb9e4e23f2c01d1d73a24cf20a9d4ea Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 17:00:40 +0000 Subject: [PATCH 02/18] now the test can run to completion. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 83 ++++++++++++++------- tpu_inference/runner/compilation_manager.py | 1 + 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index a40afb532..970666dd4 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,6 +8,7 @@ from jax.sharding import NamedSharding, PartitionSpec from torchax.interop import torch_view from torchax.ops.mappings import j2t, t2j +from torchax.interop import jax_view, torch_view # from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu from vllm.config import LoRAConfig # yapf conflicts with isort for this block @@ -22,6 +23,11 @@ from vllm.platforms import current_platform from .utils import DummyLoRAManager +from tpu_inference.layers.vllm.quantization.common import ( + JaxCommonConfig, JaxCommonLinearConfig) +from tpu_inference.layers.vllm.quantization.unquantized import \ + VllmUnquantizedLinearMethod +from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora # TODO(xiowei): # - add test for multi-chip. @@ -224,15 +230,31 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, def create_column_parallel_packed_layer(): # Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper. if repeats == 2: + # In e2e, MergedColumnParallelLinear is created when we load th e model. The weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. linear = MergedColumnParallelLinear( 256, # input_size [256] * repeats, # output_size bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = MergedColumnParallelLinear( + 256, # input_size + [256] * repeats, # output_size + bias=False, + params_dtype=torch.float16) + base_linear.weight.data = linear.weight.data + vllm_config = dist_init + # self.jax_config.mesh.devices[0][0].platform + jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) + linear_method = VllmUnquantizedLinearMethod(jax_config) + linear_method.process_weights_after_loading(base_linear) + # here base_linear.weight is on TPU and sharded. + + # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. lora_linear = MergedColumnParallelLinearWithLoRA( - linear - ) # TODO(xiowei): add test for MergedColumnParallelLinearWithShardedLoRA (fully_shard == True) + base_linear + ) elif repeats == 3: # TODO(xiowei): add test for this case. raise NotImplementedError("NYI: for MergedQKVParallelLinear case") @@ -241,21 +263,22 @@ def create_column_parallel_packed_layer(): raise NotImplementedError("NYI: for QKVParallelLinear case") n_slices = repeats - # create_lora_weights creates global shape weight. - lora_linear.create_lora_weights(max_loras, lora_config) + with torchax.default_env(): + # create_lora_weights creates global shape weight. + lora_linear.create_lora_weights(max_loras, lora_config) + _shard_merged_column_parallel_linear_lora(lora_linear, mesh) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - return linear, lora_linear + return linear, lora_linear, linear_method set_random_seed(6) - linear, lora_linear = create_column_parallel_packed_layer() - # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor - # BaseLinearLayerWithLoRA.weight property guarantees this. - assert torch.equal(linear.weight, lora_linear.weight) + linear, lora_linear, linear_method = create_column_parallel_packed_layer() with torchax.default_env(): - assert torch.equal(linear.weight.data, j2t(lora_linear.weight)) + # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor + # BaseLinearLayerWithLoRA.weight property guarantees this. + assert torch.equal(linear.weight, lora_linear.weight.to('cpu')) max_num_batched_tokens = 8192 max_batches = 256 @@ -297,7 +320,7 @@ def create_column_parallel_packed_layer(): vocab_size=512, extra_vocab_size=lora_config.lora_extra_vocab_size, ) - punica_wrapper.move_to_device(mesh) + # punica_wrapper.move_to_device(mesh) jax_inputs = [] with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): @@ -305,12 +328,12 @@ def create_column_parallel_packed_layer(): # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` jax_input = torch_view(t2j(input)) - jax_input.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) + jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = lora_linear(torch.cat(jax_inputs))[0] - lora_result = j2t(lora_result) + # lora_result = lora_linear(torch.cat(jax_inputs))[0] + # lora_result = j2t(lora_result) + lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -318,23 +341,24 @@ def create_column_parallel_packed_layer(): result = linear(input_)[0] subloras = sublora_dict[lora_id] for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * - (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * + result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] * + (i + 1)] += (input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling) expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' - ) + with torchax.default_env(): + torch.testing.assert_close(lora_result.to('cpu'), + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}' + ) # Check that resetting the lora weights succeeds # Here we set all lora weight to be empty. @@ -368,8 +392,9 @@ def create_column_parallel_packed_layer(): NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = lora_linear(torch.cat(jax_inputs))[0] - lora_result = j2t(lora_result) + lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) + # lora_result = lora_linear(torch.cat(jax_inputs))[0] + # lora_result = j2t(lora_result) expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 4f62afa1e..98828c379 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -7,6 +7,7 @@ import numpy as np import vllm.envs as envs from jax.sharding import NamedSharding, PartitionSpec +from vllm.utils.math_utils import cdiv from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata From 8044fb7b41f0c3a5c3fcd5996cc6a9fa013ae070 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 17:11:08 +0000 Subject: [PATCH 03/18] ok. The case without lora passed. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 45 ++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 970666dd4..192a5464a 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -348,17 +348,17 @@ def create_column_parallel_packed_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - with torchax.default_env(): - torch.testing.assert_close(lora_result.to('cpu'), - expected_result, - rtol=rtol, - atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}' - ) + # with torchax.default_env(): + # torch.testing.assert_close(lora_result.to('cpu'), + # expected_result, + # rtol=rtol, + # atol=atol) + # print( + # f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}' + # ) + # print( + # f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}' + # ) # Check that resetting the lora weights succeeds # Here we set all lora weight to be empty. @@ -382,7 +382,6 @@ def create_column_parallel_packed_layer(): 512, lora_config.lora_extra_vocab_size, ) - punica_wrapper.move_to_device(mesh) jax_inputs = [] with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): @@ -398,13 +397,15 @@ def create_column_parallel_packed_layer(): expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' - ) + with torchax.default_env(): + lora_result_cpu = lora_result.to('cpu') + torch.testing.assert_close(lora_result_cpu, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' + ) From 11c7ea281bc5f3a03f32d6b883d381fc30db15e4 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 17:46:35 +0000 Subject: [PATCH 04/18] ok, the test passed. Need to make it simpler next. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 38 +++++++++++++++----------- tpu_inference/lora/torch_punica_tpu.py | 2 ++ 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 192a5464a..7967fb336 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -248,6 +248,7 @@ def create_column_parallel_packed_layer(): # self.jax_config.mesh.devices[0][0].platform jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) linear_method = VllmUnquantizedLinearMethod(jax_config) + base_linear.quant_method=linear_method linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is on TPU and sharded. @@ -263,6 +264,8 @@ def create_column_parallel_packed_layer(): raise NotImplementedError("NYI: for QKVParallelLinear case") n_slices = repeats + #TODO(xw): check if we can enable torchax globally. + # TODO(xw): check if we can calculate both actual and expected output using torchax. with torchax.default_env(): # create_lora_weights creates global shape weight. lora_linear.create_lora_weights(max_loras, lora_config) @@ -282,10 +285,11 @@ def create_column_parallel_packed_layer(): max_num_batched_tokens = 8192 max_batches = 256 - punica_wrapper = get_punica_wrapper(max_num_batched_tokens, - max_batches, - device, - max_loras=max_loras) + with torchax.default_env(): + punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches, + 'jax', + max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_linear.set_mapping(punica_wrapper) @@ -333,7 +337,8 @@ def create_column_parallel_packed_layer(): with torchax.default_env(): # lora_result = lora_linear(torch.cat(jax_inputs))[0] # lora_result = j2t(lora_result) - lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) + # lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) + lora_result = lora_linear(torch.cat(jax_inputs))[0] expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -348,17 +353,18 @@ def create_column_parallel_packed_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - # with torchax.default_env(): - # torch.testing.assert_close(lora_result.to('cpu'), - # expected_result, - # rtol=rtol, - # atol=atol) - # print( - # f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}' - # ) - # print( - # f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}' - # ) + with torchax.default_env(): + lora_result_cpu = lora_result.to('cpu') + torch.testing.assert_close(lora_result_cpu, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' + ) # Check that resetting the lora weights succeeds # Here we set all lora weight to be empty. diff --git a/tpu_inference/lora/torch_punica_tpu.py b/tpu_inference/lora/torch_punica_tpu.py index 7c2da5361..658e0f521 100644 --- a/tpu_inference/lora/torch_punica_tpu.py +++ b/tpu_inference/lora/torch_punica_tpu.py @@ -23,6 +23,8 @@ class PunicaWrapperTPU(PunicaWrapperBase): PunicaWrapperTPU is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the pytorch punica ops. + + It is created by get_punica_wrapper when we load_lora_model->create_lora_manager. Device is TPU. """ def __init__(self, max_num_batched_tokens: int, max_batches: int, From 2eae008de95fdc156fa46ebedeba05ac6948090d Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 20:37:48 +0000 Subject: [PATCH 05/18] also check if the correct and the sharding is correct. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 7967fb336..9ecadc22d 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -251,6 +251,8 @@ def create_column_parallel_packed_layer(): base_linear.quant_method=linear_method linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is on TPU and sharded. + assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.' + assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.' # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. lora_linear = MergedColumnParallelLinearWithLoRA( @@ -270,6 +272,13 @@ def create_column_parallel_packed_layer(): # create_lora_weights creates global shape weight. lora_linear.create_lora_weights(max_loras, lora_config) _shard_merged_column_parallel_linear_lora(lora_linear, mesh) + # TODO: assert the lora_a_stacked is on TPU and sharded. + assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.' + assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.' + assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.' + assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.' + + # TODO: assert the lora_b_stacked is on TPU and sharded. assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) @@ -324,7 +333,8 @@ def create_column_parallel_packed_layer(): vocab_size=512, extra_vocab_size=lora_config.lora_extra_vocab_size, ) - # punica_wrapper.move_to_device(mesh) + assert jax_view(punica_wrapper._lora_indices_per_batch).platform() == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' jax_inputs = [] with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): From 2d8741805a7d9df8ddaec5f60c62e13c6425bd89 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 20:58:09 +0000 Subject: [PATCH 06/18] cleaned up Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 72 ++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 9ecadc22d..62f3ce202 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -1,7 +1,9 @@ import random from typing import Optional +import copy import jax +import numpy as np import pytest import torch import torchax @@ -27,11 +29,10 @@ JaxCommonConfig, JaxCommonLinearConfig) from tpu_inference.layers.vllm.quantization.unquantized import \ VllmUnquantizedLinearMethod -from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora +from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora, _shard_module_to_tpu # TODO(xiowei): # - add test for multi-chip. -# - add equivalent test for ColumnParallelLinearWithShardedLoRA. P = PartitionSpec @@ -56,7 +57,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: def get_random_index_to_id(num_loras: int, num_slots: int, log: bool = True) -> list[Optional[int]]: - """Creates a random index_to_lora_id mapping. + """Creates a random index_to_lora_id mapping: slot[index] = lora_id. Args: num_loras: The number of active loras in the mapping. @@ -76,7 +77,7 @@ def get_random_index_to_id(num_loras: int, slots: list[Optional[int]] = [None] * num_slots random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() for lora_id, slot_idx in enumerate(random_slot_selections, start=1): - # xw32: It seems the slot_idx start at 1. + # The slot_idx start at 1. slots[slot_idx] = lora_id if log: @@ -92,7 +93,7 @@ def populate_loras( generate_embeddings_tensor: int = 0, repeats: int = 1, ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: - """This method populates the lora layers (BaseLayerWithLoRA) with lora weights. + """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA). Args: index_to_id: a list of lora ids. The index of the lora id @@ -140,8 +141,7 @@ def populate_loras( subloras) if repeats > 1 else subloras[0] # Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env. - with torchax.default_env(), jax.default_device( - jax.devices("tpu")[0]): + with torchax.default_env(): lora_layer.set_lora( slot_idx, lora_a=lora.lora_a, @@ -207,11 +207,10 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) @pytest.mark.parametrize("repeats", [2]) -@pytest.mark.parametrize("fully_shard", [False]) # TODO(xiowei): add "True". -@pytest.mark.parametrize("device", ["cpu"]) +@pytest.mark.parametrize("fully_shard", [False]) @pytest.mark.parametrize("stage", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage) -> None: + stage) -> None: max_loras = 9 max_lora_rank = 8 lora_config = LoRAConfig( @@ -228,9 +227,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices()) def create_column_parallel_packed_layer(): - # Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper. + # We first create a base linear layer, then a lora layer to wrap it. if repeats == 2: - # In e2e, MergedColumnParallelLinear is created when we load th e model. The weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. + # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. linear = MergedColumnParallelLinear( 256, # input_size [256] * repeats, # output_size @@ -245,50 +244,42 @@ def create_column_parallel_packed_layer(): params_dtype=torch.float16) base_linear.weight.data = linear.weight.data vllm_config = dist_init - # self.jax_config.mesh.devices[0][0].platform jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) linear_method = VllmUnquantizedLinearMethod(jax_config) base_linear.quant_method=linear_method - linear_method.process_weights_after_loading(base_linear) - # here base_linear.weight is on TPU and sharded. + linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is moved to TPU and sharded. assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.' assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.' - # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. lora_linear = MergedColumnParallelLinearWithLoRA( base_linear ) elif repeats == 3: - # TODO(xiowei): add test for this case. raise NotImplementedError("NYI: for MergedQKVParallelLinear case") else: - # TODO(xiowei): add test for this case. raise NotImplementedError("NYI: for QKVParallelLinear case") - n_slices = repeats - #TODO(xw): check if we can enable torchax globally. - # TODO(xw): check if we can calculate both actual and expected output using torchax. with torchax.default_env(): - # create_lora_weights creates global shape weight. + # create_lora_weights creates global shape lora weight. lora_linear.create_lora_weights(max_loras, lora_config) - _shard_merged_column_parallel_linear_lora(lora_linear, mesh) - # TODO: assert the lora_a_stacked is on TPU and sharded. + # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. + _shard_module_to_tpu(lora_linear, mesh) + assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.' assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.' assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.' assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.' - - # TODO: assert the lora_b_stacked is on TPU and sharded. + n_slices = repeats assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - return linear, lora_linear, linear_method + return linear, lora_linear set_random_seed(6) - linear, lora_linear, linear_method = create_column_parallel_packed_layer() + linear, lora_linear = create_column_parallel_packed_layer() with torchax.default_env(): - # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor + # lora_linear.weight has type torchax.tensor.Tensor # BaseLinearLayerWithLoRA.weight property guarantees this. assert torch.equal(linear.weight, lora_linear.weight.to('cpu')) @@ -302,8 +293,7 @@ def create_column_parallel_packed_layer(): assert check_punica_wrapper(punica_wrapper) lora_linear.set_mapping(punica_wrapper) - # load the lora weight, shard it, and send it to TPU. - # create a lora slot index to lora id mapping. + # Populate lora matrices (lora_a and lora_b) in the lora layer. index_to_id = get_random_index_to_id(num_loras, max_loras) # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights lora_dict, sublora_dict = populate_loras( @@ -322,10 +312,11 @@ def create_column_parallel_packed_layer(): input_size=(1, 256), input_range=(0, 1), input_type=torch.float16, - device=device) + device='cpu') lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + with torchax.default_env(): + # Here we move the metadata from cpu to tpu. punica_wrapper.update_metadata( lora_mapping, index_to_id, @@ -337,7 +328,7 @@ def create_column_parallel_packed_layer(): assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' jax_inputs = [] - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + with torchax.default_env(): for input in inputs: # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` @@ -345,9 +336,6 @@ def create_column_parallel_packed_layer(): jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): - # lora_result = lora_linear(torch.cat(jax_inputs))[0] - # lora_result = j2t(lora_result) - # lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) lora_result = lora_linear(torch.cat(jax_inputs))[0] expected_results: list[torch.Tensor] = [] @@ -387,10 +375,10 @@ def create_column_parallel_packed_layer(): input_size=(1, 256), input_range=(0, 1), input_type=torch.float16, - device=device) + device='cpu') lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + with torchax.default_env(): punica_wrapper.update_metadata( lora_mapping, index_to_id, @@ -400,16 +388,14 @@ def create_column_parallel_packed_layer(): ) jax_inputs = [] - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + with torchax.default_env(): for input in inputs: jax_input = torch_view(t2j(input)) jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) - # lora_result = lora_linear(torch.cat(jax_inputs))[0] - # lora_result = j2t(lora_result) + lora_result = lora_linear(torch.cat(jax_inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] From c480658fe2c1c2070f57dc025e89b387d3928780 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Oct 2025 22:19:47 +0000 Subject: [PATCH 07/18] ok, fixed the torchax.view.item() issue. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 14 +++++++------- tpu_inference/lora/torch_punica_tpu.py | 4 +++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 62f3ce202..eb0c0ce75 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -231,15 +231,15 @@ def create_column_parallel_packed_layer(): if repeats == 2: # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. linear = MergedColumnParallelLinear( - 256, # input_size - [256] * repeats, # output_size + 64, # input_size + [64] * repeats, # output_size bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) base_linear = MergedColumnParallelLinear( - 256, # input_size - [256] * repeats, # output_size + 64, # input_size + [64] * repeats, # output_size bias=False, params_dtype=torch.float16) base_linear.weight.data = linear.weight.data @@ -303,13 +303,13 @@ def create_column_parallel_packed_layer(): repeats=repeats, ) - # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 256]. + # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64]. # index_mapping: list[int] # prompt_mapping: list[int] inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=32, - input_size=(1, 256), + input_size=(1, 64), input_range=(0, 1), input_type=torch.float16, device='cpu') @@ -372,7 +372,7 @@ def create_column_parallel_packed_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], # different from the above create_random_inputs num_inputs=32, - input_size=(1, 256), + input_size=(1, 64), input_range=(0, 1), input_type=torch.float16, device='cpu') diff --git a/tpu_inference/lora/torch_punica_tpu.py b/tpu_inference/lora/torch_punica_tpu.py index 658e0f521..ca8554b43 100644 --- a/tpu_inference/lora/torch_punica_tpu.py +++ b/tpu_inference/lora/torch_punica_tpu.py @@ -8,6 +8,8 @@ import torch.nn.functional as F import torchax from vllm.lora.punica_wrapper.utils import convert_mapping +from torchax.interop import jax_view, torch_view + if TYPE_CHECKING: # avoid circuit import @@ -283,7 +285,7 @@ def _update_prefill_metadata(self, self.batch_size = 1 self._lora_indices_per_batch[:self. batch_size] = token_lora_tensor[:self. - batch_size] + batch_size].torch() def _pad_prompt_mapping( self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: From cc47babdf3c1b15933ea77750af817e09863e30e Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Oct 2025 23:18:36 +0000 Subject: [PATCH 08/18] add multi-chip test case Signed-off-by: Xiongfei Wei --- .buildkite/pipeline_jax.yml | 13 +++++++++++++ tests/lora/test_layers.py | 12 +++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index f997cf950..52cac4e7a 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -157,6 +157,7 @@ steps: queue: tpu_v6e_queue commands: - | +<<<<<<< HEAD if [[ "$$NIGHTLY" == "1" ]]; then .buildkite/scripts/run_in_docker.sh \ bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ @@ -165,6 +166,12 @@ steps: echo "Skipping: NIGHTLY environment variable not set" exit 0 fi +======= + .buildkite/scripts/run_in_docker.sh \ + bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' +>>>>>>> c17bacea (add multi-chip test case) - label: "E2E MLPerf tests for JAX + vLLM models on multiple chips" key: test_11 @@ -212,6 +219,7 @@ steps: queue: tpu_v6e_8_queue commands: - | +<<<<<<< HEAD if [[ "$$NIGHTLY" == "1" ]]; then .buildkite/scripts/run_in_docker.sh \ bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py' @@ -219,6 +227,11 @@ steps: echo "Skipping: NIGHTLY environment variable not set" exit 0 fi +======= + .buildkite/scripts/run_in_docker.sh \ + bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' +>>>>>>> c17bacea (add multi-chip test case) # ----------------------------------------------------------------- diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index eb0c0ce75..a6c65331d 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -221,10 +221,13 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ) axis_names = ("data", "model") + devices = jax.devices() mesh_shape = ( - 1, 1 + 1, len(devices) + # 1, 1 ) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices())) - mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices()) + print(f'xw32 mesh_shape: {mesh_shape}') + mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) def create_column_parallel_packed_layer(): # We first create a base linear layer, then a lora layer to wrap it. @@ -281,7 +284,10 @@ def create_column_parallel_packed_layer(): with torchax.default_env(): # lora_linear.weight has type torchax.tensor.Tensor # BaseLinearLayerWithLoRA.weight property guarantees this. - assert torch.equal(linear.weight, lora_linear.weight.to('cpu')) + # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix. + # So the below check will fail. + if len(devices) == 1: + assert torch.equal(linear.weight.data, lora_linear.weight.to('cpu')) max_num_batched_tokens = 8192 max_batches = 256 From c2082ffdc5d377cd9ddacab4013a401fb0f86f0e Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Oct 2025 23:24:07 +0000 Subject: [PATCH 09/18] fix the format Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 61 +++++++++++++-------- tpu_inference/lora/torch_punica_tpu.py | 5 +- tpu_inference/runner/compilation_manager.py | 1 - 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index a6c65331d..e6a33775b 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -1,16 +1,13 @@ import random from typing import Optional -import copy import jax -import numpy as np import pytest import torch import torchax from jax.sharding import NamedSharding, PartitionSpec -from torchax.interop import torch_view -from torchax.ops.mappings import j2t, t2j from torchax.interop import jax_view, torch_view +from torchax.ops.mappings import t2j # from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu from vllm.config import LoRAConfig # yapf conflicts with isort for this block @@ -24,12 +21,12 @@ from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform -from .utils import DummyLoRAManager -from tpu_inference.layers.vllm.quantization.common import ( - JaxCommonConfig, JaxCommonLinearConfig) +from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig from tpu_inference.layers.vllm.quantization.unquantized import \ VllmUnquantizedLinearMethod -from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora, _shard_module_to_tpu +from tpu_inference.layers.vllm.sharding import _shard_module_to_tpu + +from .utils import DummyLoRAManager # TODO(xiowei): # - add test for multi-chip. @@ -239,7 +236,7 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - + base_linear = MergedColumnParallelLinear( 64, # input_size [64] * repeats, # output_size @@ -249,14 +246,18 @@ def create_column_parallel_packed_layer(): vllm_config = dist_init jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) linear_method = VllmUnquantizedLinearMethod(jax_config) - base_linear.quant_method=linear_method - linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is moved to TPU and sharded. - assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.' - assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.' - - lora_linear = MergedColumnParallelLinearWithLoRA( + base_linear.quant_method = linear_method + linear_method.process_weights_after_loading( base_linear - ) + ) # here base_linear.weight is moved to TPU and sharded. + assert jax_view(base_linear.weight).platform( + ) == 'tpu', 'base_linear.weight should have been moved to TPU.' + assert not isinstance( + jax_view(base_linear.weight).sharding, + jax.sharding.SingleDeviceSharding + ), 'base_linear.weight should have been sharded.' + + lora_linear = MergedColumnParallelLinearWithLoRA(base_linear) elif repeats == 3: raise NotImplementedError("NYI: for MergedQKVParallelLinear case") else: @@ -268,10 +269,16 @@ def create_column_parallel_packed_layer(): # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. _shard_module_to_tpu(lora_linear, mesh) - assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.' - assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.' - assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.' - assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.' + assert jax_view(lora_linear.lora_a_stacked[0]).platform( + ) == 'tpu', 'lora_a_stacked should have been moved to TPU.' + assert not isinstance( + jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding. + SingleDeviceSharding), 'lora_a_stacked should have been sharded.' + assert jax_view(lora_linear.lora_b_stacked[0]).platform( + ) == 'tpu', 'lora_b_stacked should have been moved to TPU.' + assert not isinstance( + jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding. + SingleDeviceSharding), 'lora_b_stacked should have been sharded.' n_slices = repeats assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) @@ -287,7 +294,8 @@ def create_column_parallel_packed_layer(): # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix. # So the below check will fail. if len(devices) == 1: - assert torch.equal(linear.weight.data, lora_linear.weight.to('cpu')) + assert torch.equal(linear.weight.data, + lora_linear.weight.to('cpu')) max_num_batched_tokens = 8192 max_batches = 256 @@ -330,8 +338,12 @@ def create_column_parallel_packed_layer(): vocab_size=512, extra_vocab_size=lora_config.lora_extra_vocab_size, ) - assert jax_view(punica_wrapper._lora_indices_per_batch).platform() == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' - assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + assert jax_view(punica_wrapper._lora_indices_per_batch).platform( + ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + assert isinstance( + jax_view(punica_wrapper._lora_indices_per_batch).sharding, + jax.sharding.SingleDeviceSharding + ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' jax_inputs = [] with torchax.default_env(): @@ -339,7 +351,8 @@ def create_column_parallel_packed_layer(): # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` jax_input = torch_view(t2j(input)) - jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): lora_result = lora_linear(torch.cat(jax_inputs))[0] diff --git a/tpu_inference/lora/torch_punica_tpu.py b/tpu_inference/lora/torch_punica_tpu.py index ca8554b43..a3cc5bf3d 100644 --- a/tpu_inference/lora/torch_punica_tpu.py +++ b/tpu_inference/lora/torch_punica_tpu.py @@ -8,8 +8,6 @@ import torch.nn.functional as F import torchax from vllm.lora.punica_wrapper.utils import convert_mapping -from torchax.interop import jax_view, torch_view - if TYPE_CHECKING: # avoid circuit import @@ -285,7 +283,8 @@ def _update_prefill_metadata(self, self.batch_size = 1 self._lora_indices_per_batch[:self. batch_size] = token_lora_tensor[:self. - batch_size].torch() + batch_size].torch( + ) def _pad_prompt_mapping( self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 98828c379..4f62afa1e 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -7,7 +7,6 @@ import numpy as np import vllm.envs as envs from jax.sharding import NamedSharding, PartitionSpec -from vllm.utils.math_utils import cdiv from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata From d409a357893a543651f8b78a78b5721e164e003a Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 31 Oct 2025 18:36:21 +0000 Subject: [PATCH 10/18] clean up Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index e6a33775b..abd072b3f 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,7 +8,6 @@ from jax.sharding import NamedSharding, PartitionSpec from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j -# from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu from vllm.config import LoRAConfig # yapf conflicts with isort for this block # yapf: disable @@ -28,9 +27,6 @@ from .utils import DummyLoRAManager -# TODO(xiowei): -# - add test for multi-chip. - P = PartitionSpec TOLERANCES = { @@ -204,26 +200,20 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) @pytest.mark.parametrize("repeats", [2]) -@pytest.mark.parametrize("fully_shard", [False]) @pytest.mark.parametrize("stage", [True, False]) -def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - stage) -> None: +def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: max_loras = 9 max_lora_rank = 8 lora_config = LoRAConfig( max_loras=max_loras, max_lora_rank=max_lora_rank, - fully_sharded_loras=fully_shard, + fully_sharded_loras=False, lora_dtype=torch.float16, ) axis_names = ("data", "model") devices = jax.devices() - mesh_shape = ( - 1, len(devices) - # 1, 1 - ) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices())) - print(f'xw32 mesh_shape: {mesh_shape}') + mesh_shape = (1, len(devices)) mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) def create_column_parallel_packed_layer(): @@ -264,7 +254,6 @@ def create_column_parallel_packed_layer(): raise NotImplementedError("NYI: for QKVParallelLinear case") with torchax.default_env(): - # create_lora_weights creates global shape lora weight. lora_linear.create_lora_weights(max_loras, lora_config) # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. _shard_module_to_tpu(lora_linear, mesh) From 7ea2939a3207e3a4ad8c6b7d7103ba0a6636f208 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 5 Nov 2025 21:18:51 +0000 Subject: [PATCH 11/18] Add lora unit tests to the CI Signed-off-by: Xiongfei Wei --- .buildkite/pipeline_jax.yml | 47 +++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index 52cac4e7a..44f69df9c 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -150,28 +150,20 @@ steps: exit 0 fi - - label: "lora tests for JAX + vLLM models single chip" + - label: "lora e2e tests for JAX + vLLM models single chip" key: test_10 soft_fail: true agents: queue: tpu_v6e_queue commands: - | -<<<<<<< HEAD if [[ "$$NIGHTLY" == "1" ]]; then .buildkite/scripts/run_in_docker.sh \ - bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py' + bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py' else echo "Skipping: NIGHTLY environment variable not set" exit 0 fi -======= - .buildkite/scripts/run_in_docker.sh \ - bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' ->>>>>>> c17bacea (add multi-chip test case) - label: "E2E MLPerf tests for JAX + vLLM models on multiple chips" key: test_11 @@ -209,7 +201,7 @@ steps: exit 0 fi - - label: "lora tests for JAX + vLLM models multi chips" + - label: "lora e2e tests for JAX + vLLM models multi chips" key: test_13 soft_fail: true env: @@ -219,7 +211,6 @@ steps: queue: tpu_v6e_8_queue commands: - | -<<<<<<< HEAD if [[ "$$NIGHTLY" == "1" ]]; then .buildkite/scripts/run_in_docker.sh \ bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py' @@ -227,11 +218,31 @@ steps: echo "Skipping: NIGHTLY environment variable not set" exit 0 fi -======= + + - label: "lora unit tests on single chip" + key: test_14 + soft_fail: true + agents: + queue: tpu_v6e_queue + commands: + - | .buildkite/scripts/run_in_docker.sh \ - bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' ->>>>>>> c17bacea (add multi-chip test case) + bash -c ' python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' + + - label: "lora unit tests on multi chips" + key: test_15 + soft_fail: true + env: + USE_V6E8_QUEUE: "True" + VLLM_LOG_LEVEL: "INFO" + agents: + queue: tpu_v6e_8_queue + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' + # ----------------------------------------------------------------- @@ -253,9 +264,11 @@ steps: - test_11 - test_12 - test_13 + - test_14 + - test_15 agents: queue: cpu commands: - | .buildkite/scripts/check_results.sh \ - "TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 + "TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_14 test_15 From ceaf1b2ff8fef16530b8600d0ce54c3c9ff96911 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 5 Nov 2025 23:00:01 +0000 Subject: [PATCH 12/18] refactored. The test still passed. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 77 ++++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index abd072b3f..a50ae9b21 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -197,6 +197,24 @@ def create_random_inputs( return inputs, index_mapping, prompt_mapping +def _create_linear_and_lora_wrapper(linear, base_linear, lora_cls, vllm_config, + mesh): + base_linear.weight.data = linear.weight.data + jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) + linear_method = VllmUnquantizedLinearMethod(jax_config) + base_linear.quant_method = linear_method + linear_method.process_weights_after_loading( + base_linear) # here base_linear.weight is moved to TPU and sharded. + assert jax_view(base_linear.weight).platform( + ) == 'tpu', 'base_linear.weight should have been moved to TPU.' + assert not isinstance( + jax_view(base_linear.weight).sharding, jax.sharding. + SingleDeviceSharding), 'base_linear.weight should have been sharded.' + + lora_linear = lora_cls(base_linear) + return lora_linear + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) @pytest.mark.parametrize("repeats", [2]) @@ -219,37 +237,44 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: def create_column_parallel_packed_layer(): # We first create a base linear layer, then a lora layer to wrap it. if repeats == 2: + + def _create_merged_column_linear(): + return MergedColumnParallelLinear( + 64, # input_size + [64] * repeats, # output_size + bias=False, + params_dtype=torch.float16) + # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. - linear = MergedColumnParallelLinear( - 64, # input_size - [64] * repeats, # output_size - bias=False, - params_dtype=torch.float16) + linear = _create_merged_column_linear() linear.weight.data = torch.rand_like(linear.weight.data) - base_linear = MergedColumnParallelLinear( - 64, # input_size - [64] * repeats, # output_size - bias=False, - params_dtype=torch.float16) - base_linear.weight.data = linear.weight.data - vllm_config = dist_init - jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) - linear_method = VllmUnquantizedLinearMethod(jax_config) - base_linear.quant_method = linear_method - linear_method.process_weights_after_loading( - base_linear - ) # here base_linear.weight is moved to TPU and sharded. - assert jax_view(base_linear.weight).platform( - ) == 'tpu', 'base_linear.weight should have been moved to TPU.' - assert not isinstance( - jax_view(base_linear.weight).sharding, - jax.sharding.SingleDeviceSharding - ), 'base_linear.weight should have been sharded.' - - lora_linear = MergedColumnParallelLinearWithLoRA(base_linear) + base_linear = _create_merged_column_linear() + # base_linear.weight.data = linear.weight.data + # vllm_config = dist_init + # jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) + # linear_method = VllmUnquantizedLinearMethod(jax_config) + # base_linear.quant_method = linear_method + # linear_method.process_weights_after_loading( + # base_linear + # ) # here base_linear.weight is moved to TPU and sharded. + # assert jax_view(base_linear.weight).platform( + # ) == 'tpu', 'base_linear.weight should have been moved to TPU.' + # assert not isinstance( + # jax_view(base_linear.weight).sharding, + # jax.sharding.SingleDeviceSharding + # ), 'base_linear.weight should have been sharded.' + + # lora_linear = MergedColumnParallelLinearWithLoRA(base_linear) + lora_linear = _create_linear_and_lora_wrapper( + linear, + base_linear, + MergedColumnParallelLinearWithLoRA, + vllm_config=dist_init, + mesh=mesh) elif repeats == 3: raise NotImplementedError("NYI: for MergedQKVParallelLinear case") + else: raise NotImplementedError("NYI: for QKVParallelLinear case") From 2af4daef58bd764f4d4dc4b20b686b69de3c2da1 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 6 Nov 2025 00:43:34 +0000 Subject: [PATCH 13/18] added test for MergedQKVParallelLinearWithLoRA Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index a50ae9b21..6e8ec906d 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -12,11 +12,13 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, - MergedColumnParallelLinearWithLoRA) + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA) # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper -from vllm.model_executor.layers.linear import MergedColumnParallelLinear +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear) from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform @@ -217,7 +219,7 @@ def _create_linear_and_lora_wrapper(linear, base_linear, lora_cls, vllm_config, @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) -@pytest.mark.parametrize("repeats", [2]) +@pytest.mark.parametrize("repeats", [2, 3]) @pytest.mark.parametrize("stage", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: max_loras = 9 @@ -273,12 +275,30 @@ def _create_merged_column_linear(): vllm_config=dist_init, mesh=mesh) elif repeats == 3: - raise NotImplementedError("NYI: for MergedQKVParallelLinear case") + + def _create_qkv_column_linear(): + return QKVParallelLinear(64, + 64, + 32, + bias=False, + params_dtype=torch.float16) + + linear = _create_qkv_column_linear() + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = _create_qkv_column_linear() + lora_linear = _create_linear_and_lora_wrapper( + linear, + base_linear, + MergedQKVParallelLinearWithLoRA, + vllm_config=dist_init, + mesh=mesh) else: raise NotImplementedError("NYI: for QKVParallelLinear case") with torchax.default_env(): + # In the e2e, this is done when we create the lora model (`load_lora_model`). lora_linear.create_lora_weights(max_loras, lora_config) # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. _shard_module_to_tpu(lora_linear, mesh) From 540e9dbccbc4a1f3e4589c8a1a6f44bda6dd5f49 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 6 Nov 2025 00:59:10 +0000 Subject: [PATCH 14/18] Added the test for QKVParallelLinearWithLoRA Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 30 +++++++++++++++++++++------ tpu_inference/layers/vllm/sharding.py | 23 +++++++++++++++----- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6e8ec906d..0416538b6 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -13,7 +13,8 @@ # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA) + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA) # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper @@ -219,7 +220,7 @@ def _create_linear_and_lora_wrapper(linear, base_linear, lora_cls, vllm_config, @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) -@pytest.mark.parametrize("repeats", [2, 3]) +@pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("stage", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: max_loras = 9 @@ -276,17 +277,17 @@ def _create_merged_column_linear(): mesh=mesh) elif repeats == 3: - def _create_qkv_column_linear(): + def _create_qkv_linear(): return QKVParallelLinear(64, 64, 32, bias=False, params_dtype=torch.float16) - linear = _create_qkv_column_linear() + linear = _create_qkv_linear() linear.weight.data = torch.rand_like(linear.weight.data) - base_linear = _create_qkv_column_linear() + base_linear = _create_qkv_linear() lora_linear = _create_linear_and_lora_wrapper( linear, base_linear, @@ -295,7 +296,24 @@ def _create_qkv_column_linear(): mesh=mesh) else: - raise NotImplementedError("NYI: for QKVParallelLinear case") + + def _create_qkv_linear(): + return QKVParallelLinear(64, + 64, + 32, + bias=False, + params_dtype=torch.float16) + + linear = _create_qkv_linear() + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = _create_qkv_linear() + lora_linear = _create_linear_and_lora_wrapper( + linear, + base_linear, + QKVParallelLinearWithLoRA, + vllm_config=dist_init, + mesh=mesh) with torchax.default_env(): # In the e2e, this is done when we create the lora model (`load_lora_model`). diff --git a/tpu_inference/layers/vllm/sharding.py b/tpu_inference/layers/vllm/sharding.py index 8ab1247c9..1fd43b859 100644 --- a/tpu_inference/layers/vllm/sharding.py +++ b/tpu_inference/layers/vllm/sharding.py @@ -6,8 +6,10 @@ from torch.utils import _pytree as pytree from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j -from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA, +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -123,9 +125,8 @@ def _shard_base_linear_lora_replicated(layer: BaseLinearLayerWithLoRA, layer.lora_b_stacked = sharded_lora_b_tpu -# TODO: Add custom sharding logic for following lora layers -def _shard_merged_column_parallel_linear_lora( - layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None: +def _shard_column_linear_lora(layer: ColumnParallelLinearWithLoRA, + mesh: Mesh) -> None: assert layer.n_slices > 0, "layer.n_slices should be greater than 0" # lora_a_stacked[i] has shape [max_loras, 1, max_lora_rank, in_features] sharded_lora_a_tpu = torch.nn.ParameterList() @@ -146,9 +147,20 @@ def _shard_merged_column_parallel_linear_lora( layer.lora_b_stacked = sharded_lora_b_tpu +# TODO: Add custom sharding logic for following lora layers +def _shard_qkv_linear_lora(layer: ColumnParallelLinearWithLoRA, + mesh: Mesh) -> None: + _shard_column_linear_lora(layer, mesh) + + +def _shard_merged_column_parallel_linear_lora( + layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None: + _shard_column_linear_lora(layer, mesh) + + def _shard_merged_qkv_parallel_linear_lora( layer: MergedQKVParallelLinearWithLoRA, mesh: Mesh) -> None: - _shard_merged_column_parallel_linear_lora(layer, mesh) + _shard_column_linear_lora(layer, mesh) def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA, @@ -162,6 +174,7 @@ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA, (ParallelLMHead, _shard_lm_head), (VocabParallelEmbedding, _shard_vocab_parallel_embedding), # Shard LoRA layers + (QKVParallelLinearWithLoRA, _shard_qkv_linear_lora), (MergedColumnParallelLinearWithLoRA, _shard_merged_column_parallel_linear_lora), (MergedQKVParallelLinearWithLoRA, _shard_merged_qkv_parallel_linear_lora), From 26794f2081c5add824e31506350dd588cd0f505d Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 6 Nov 2025 18:01:10 +0000 Subject: [PATCH 15/18] The ColumnParallelLinear test passed Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 228 +++++++++++++++++++++++++- tpu_inference/layers/vllm/sharding.py | 1 + 2 files changed, 224 insertions(+), 5 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 0416538b6..238879cb0 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -11,15 +11,18 @@ from vllm.config import LoRAConfig # yapf conflicts with isort for this block # yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, - MergedColumnParallelLinearWithLoRA, +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LoRAMapping, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA) + QKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA) # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform @@ -482,3 +485,218 @@ def _create_qkv_linear(): print( f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' ) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) +@pytest.mark.parametrize("orientation", ["row", "column"]) +@pytest.mark.parametrize("stage", [True, False]) +def test_linear_parallel(dist_init, num_loras, orientation, stage) -> None: + max_loras = 9 + max_lora_rank = 8 + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=max_lora_rank, + fully_sharded_loras=False, + lora_dtype=torch.float16, + ) + + axis_names = ("data", "model") + devices = jax.devices() + mesh_shape = (1, len(devices)) + mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) + + def create_column_parallel_packed_layer(): + # We first create a base linear layer, then a lora layer to wrap it. + if orientation == "row": + + def _create_row_linear(): + return RowParallelLinear( + 64, # input_size + 64, # output_size + bias=False, + params_dtype=torch.float16) + + linear = _create_row_linear() + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = _create_row_linear() + lora_linear = _create_linear_and_lora_wrapper( + linear, + base_linear, + RowParallelLinearWithLoRA, + vllm_config=dist_init, + mesh=mesh) + else: + + def _create_column_linear(): + return ColumnParallelLinear(64, + 64, + bias=False, + params_dtype=torch.float16) + + linear = _create_column_linear() + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = _create_column_linear() + lora_linear = _create_linear_and_lora_wrapper( + linear, + base_linear, + ColumnParallelLinearWithLoRA, + vllm_config=dist_init, + mesh=mesh) + + with torchax.default_env(): + lora_linear.create_lora_weights(max_loras, lora_config) + _shard_module_to_tpu(lora_linear, mesh) + + assert jax_view(lora_linear.lora_a_stacked[0]).platform( + ) == 'tpu', 'lora_a_stacked should have been moved to TPU.' + assert not isinstance( + jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding. + SingleDeviceSharding), 'lora_a_stacked should have been sharded.' + assert jax_view(lora_linear.lora_b_stacked[0]).platform( + ) == 'tpu', 'lora_b_stacked should have been moved to TPU.' + assert not isinstance( + jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding. + SingleDeviceSharding), 'lora_b_stacked should have been sharded.' + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == 1) + + return linear, lora_linear + + set_random_seed(6) + + linear, lora_linear = create_column_parallel_packed_layer() + with torchax.default_env(): + if len(devices) == 1: + assert torch.equal(linear.weight.data, + lora_linear.weight.to('cpu')) + + max_num_batched_tokens = 8192 + max_batches = 256 + with torchax.default_env(): + punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches, + 'jax', + max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) + lora_linear.set_mapping(punica_wrapper) + + # Populate lora matrices (lora_a and lora_b) in the lora layer. + index_to_id = get_random_index_to_id(num_loras, max_loras) + # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights + lora_dict, sublora_dict = populate_loras( + index_to_id, + lora_layer=lora_linear, + baselayer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32, + input_size=(1, 64), + input_range=(0, 1), + input_type=torch.float16, + device='cpu') + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + + with torchax.default_env(): + # Here we move the metadata from cpu to tpu. + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + max_loras, + vocab_size=512, + extra_vocab_size=lora_config.lora_extra_vocab_size, + ) + assert jax_view(punica_wrapper._lora_indices_per_batch).platform( + ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + assert isinstance( + jax_view(punica_wrapper._lora_indices_per_batch).sharding, + jax.sharding.SingleDeviceSharding + ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + + jax_inputs = [] + with torchax.default_env(): + for input in inputs: + # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` + # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + jax_inputs.append(jax_input) + with torchax.default_env(): + lora_result = lora_linear(torch.cat(jax_inputs))[0] + + expected_results: list[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + # What does input_ and lora_id looks like? + # linear(input_) returns (output, output_bias) so we only need the first one. + result = linear(input_)[0] + lora = lora_dict[lora_id] + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + with torchax.default_env(): + lora_result_cpu = lora_result.to('cpu') + torch.testing.assert_close(lora_result_cpu, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' + ) + + # Check that resetting the lora weights succeeds + # Here we set all lora weight to be empty. + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], # different from the above create_random_inputs + num_inputs=32, + input_size=(1, 64), + input_range=(0, 1), + input_type=torch.float16, + device='cpu') + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + + with torchax.default_env(): + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + jax_inputs = [] + with torchax.default_env(): + for input in inputs: + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + jax_inputs.append(jax_input) + with torchax.default_env(): + lora_result = lora_linear(torch.cat(jax_inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + with torchax.default_env(): + lora_result_cpu = lora_result.to('cpu') + torch.testing.assert_close(lora_result_cpu, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' + ) diff --git a/tpu_inference/layers/vllm/sharding.py b/tpu_inference/layers/vllm/sharding.py index 1fd43b859..89dd1c8ba 100644 --- a/tpu_inference/layers/vllm/sharding.py +++ b/tpu_inference/layers/vllm/sharding.py @@ -174,6 +174,7 @@ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA, (ParallelLMHead, _shard_lm_head), (VocabParallelEmbedding, _shard_vocab_parallel_embedding), # Shard LoRA layers + (ColumnParallelLinearWithLoRA, _shard_column_linear_lora), (QKVParallelLinearWithLoRA, _shard_qkv_linear_lora), (MergedColumnParallelLinearWithLoRA, _shard_merged_column_parallel_linear_lora), From 2be10fd09d0cf4d6a59abb77cb99ca72261870bd Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 6 Nov 2025 21:05:27 +0000 Subject: [PATCH 16/18] Finally fix the test. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 23 ++++++++++--------- .../layers/vllm/quantization/unquantized.py | 3 ++- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 238879cb0..0bc85551b 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -413,7 +413,9 @@ def _create_qkv_linear(): lora_result = lora_linear(torch.cat(jax_inputs))[0] expected_results: list[torch.Tensor] = [] + # len(inputs_)=32, prompt_mapping: prompt->lora_id mapping. for input_, lora_id in zip(inputs, prompt_mapping): + # eg. input_: [1, 64]. lora_id: 1 # linear(input_) returns (output, output_bias) so we only need the first one. result = linear(input_)[0] subloras = sublora_dict[lora_id] @@ -506,7 +508,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, stage) -> None: mesh_shape = (1, len(devices)) mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) - def create_column_parallel_packed_layer(): + def create_random_linear_parallel_layer(): # We first create a base linear layer, then a lora layer to wrap it. if orientation == "row": @@ -567,7 +569,7 @@ def _create_column_linear(): set_random_seed(6) - linear, lora_linear = create_column_parallel_packed_layer() + linear, lora_linear = create_random_linear_parallel_layer() with torchax.default_env(): if len(devices) == 1: assert torch.equal(linear.weight.data, @@ -602,7 +604,6 @@ def _create_column_linear(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) with torchax.default_env(): - # Here we move the metadata from cpu to tpu. punica_wrapper.update_metadata( lora_mapping, index_to_id, @@ -620,22 +621,22 @@ def _create_column_linear(): jax_inputs = [] with torchax.default_env(): for input in inputs: - # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` - # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` jax_input = torch_view(t2j(input)) - jax_input.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) + sharding = NamedSharding(mesh, P( + None, 'model')) if orientation == "row" else NamedSharding( + mesh, P(None, None)) + jax_input.apply_jax_(jax.device_put, sharding) jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = lora_linear(torch.cat(jax_inputs))[0] + jax_inputs = torch.cat(jax_inputs) + lora_result = lora_linear(jax_inputs)[0] expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): - # What does input_ and lora_id looks like? - # linear(input_) returns (output, output_bias) so we only need the first one. result = linear(input_)[0] lora = lora_dict[lora_id] - result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling + temp_lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling + result += temp_lora_result expected_results.append(result) expected_result = torch.cat(expected_results) diff --git a/tpu_inference/layers/vllm/quantization/unquantized.py b/tpu_inference/layers/vllm/quantization/unquantized.py index 29ad4f64e..c84ed7eec 100644 --- a/tpu_inference/layers/vllm/quantization/unquantized.py +++ b/tpu_inference/layers/vllm/quantization/unquantized.py @@ -128,7 +128,8 @@ def _apply_fused(self, x_jax = jax_view(x) weight_jax = jax_view(layer.weight) - outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax) + outs = jnp.einsum("mn,pn->mp", x_jax.astype(jnp.float32), + weight_jax.astype(jnp.float32)).astype(x_jax.dtype) if bias is not None and not layer.skip_bias_add: outs += bias.jax() From 83dd99bb67322320aef7baadf6375f8b837b5e79 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 10 Nov 2025 22:09:33 +0000 Subject: [PATCH 17/18] fixed the test Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 309 +++++++++++++++++--------------------- 1 file changed, 135 insertions(+), 174 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index f068e8d5b..d14b59168 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -13,12 +13,15 @@ # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform @@ -202,7 +205,7 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 4, 9]) -@pytest.mark.parametrize("repeats", [2]) +@pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("stage", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: set_random_seed(6) @@ -322,10 +325,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) +@pytest.mark.parametrize("num_loras", [1, 4, 9]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("stage", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, stage) -> None: + set_random_seed(6) + max_loras = 9 max_lora_rank = 8 lora_config = LoRAConfig( @@ -334,78 +339,13 @@ def test_linear_parallel(dist_init, num_loras, orientation, stage) -> None: fully_sharded_loras=False, lora_dtype=torch.float16, ) + vllm_config = dist_init + vllm_config.lora_config = lora_config - axis_names = ("data", "model") - devices = jax.devices() - mesh_shape = (1, len(devices)) - mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) - - def create_random_linear_parallel_layer(): - # We first create a base linear layer, then a lora layer to wrap it. - if orientation == "row": - - def _create_row_linear(): - return RowParallelLinear( - 64, # input_size - 64, # output_size - bias=False, - params_dtype=torch.float16) - - linear = _create_row_linear() - linear.weight.data = torch.rand_like(linear.weight.data) - - base_linear = _create_row_linear() - lora_linear = _create_linear_and_lora_wrapper( - linear, - base_linear, - RowParallelLinearWithLoRA, - vllm_config=dist_init, - mesh=mesh) - else: - - def _create_column_linear(): - return ColumnParallelLinear(64, - 64, - bias=False, - params_dtype=torch.float16) - - linear = _create_column_linear() - linear.weight.data = torch.rand_like(linear.weight.data) - - base_linear = _create_column_linear() - lora_linear = _create_linear_and_lora_wrapper( - linear, - base_linear, - ColumnParallelLinearWithLoRA, - vllm_config=dist_init, - mesh=mesh) - - with torchax.default_env(): - lora_linear.create_lora_weights(max_loras, lora_config) - _shard_module_to_tpu(lora_linear, mesh) - - assert jax_view(lora_linear.lora_a_stacked[0]).platform( - ) == 'tpu', 'lora_a_stacked should have been moved to TPU.' - assert not isinstance( - jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding. - SingleDeviceSharding), 'lora_a_stacked should have been sharded.' - assert jax_view(lora_linear.lora_b_stacked[0]).platform( - ) == 'tpu', 'lora_b_stacked should have been moved to TPU.' - assert not isinstance( - jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding. - SingleDeviceSharding), 'lora_b_stacked should have been sharded.' - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == 1) - - return linear, lora_linear - - set_random_seed(6) - - linear, lora_linear = create_random_linear_parallel_layer() - with torchax.default_env(): - if len(devices) == 1: - assert torch.equal(linear.weight.data, - lora_linear.weight.to('cpu')) + mesh = _create_mesh() + linear, lora_linear = _create_random_linear_parallel_layer( + orientation, vllm_config, mesh) + _verify_lora_linear_layer(linear, lora_linear) max_num_batched_tokens = 8192 max_batches = 256 @@ -433,58 +373,31 @@ def _create_column_linear(): input_range=(0, 1), input_type=torch.float16, device='cpu') - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - with torchax.default_env(): - punica_wrapper.update_metadata( - lora_mapping, - index_to_id, - max_loras, - vocab_size=512, - extra_vocab_size=lora_config.lora_extra_vocab_size, - ) - assert jax_view(punica_wrapper._lora_indices_per_batch).platform( - ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' - assert isinstance( - jax_view(punica_wrapper._lora_indices_per_batch).sharding, - jax.sharding.SingleDeviceSharding - ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + _update_punica_wrapper_metadata(punica_wrapper, index_mapping, + prompt_mapping, stage, index_to_id, + lora_config) - jax_inputs = [] with torchax.default_env(): - for input in inputs: - jax_input = torch_view(t2j(input)) - sharding = NamedSharding(mesh, P( - None, 'model')) if orientation == "row" else NamedSharding( - mesh, P(None, None)) - jax_input.apply_jax_(jax.device_put, sharding) - jax_inputs.append(jax_input) - with torchax.default_env(): - jax_inputs = torch.cat(jax_inputs) - lora_result = lora_linear(jax_inputs)[0] + torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh) + actual_result = lora_linear(torchax_inputs)[0] expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): result = linear(input_)[0] lora = lora_dict[lora_id] - temp_lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling - result += temp_lora_result + lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling + result += lora_result expected_results.append(result) expected_result = torch.cat(expected_results) - rtol, atol = TOLERANCES[lora_result.dtype] + rtol, atol = TOLERANCES[actual_result.dtype] with torchax.default_env(): - lora_result_cpu = lora_result.to('cpu') - torch.testing.assert_close(lora_result_cpu, + actual_result_cpu = actual_result.to('cpu') + torch.testing.assert_close(actual_result_cpu, expected_result, rtol=rtol, atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' - ) # Check that resetting the lora weights succeeds # Here we set all lora weight to be empty. @@ -498,53 +411,63 @@ def _create_column_linear(): input_range=(0, 1), input_type=torch.float16, device='cpu') - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - - with torchax.default_env(): - punica_wrapper.update_metadata( - lora_mapping, - index_to_id, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) + _update_punica_wrapper_metadata(punica_wrapper, index_mapping, + prompt_mapping, stage, index_to_id, + lora_config) - jax_inputs = [] - with torchax.default_env(): - for input in inputs: - jax_input = torch_view(t2j(input)) - jax_input.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = lora_linear(torch.cat(jax_inputs))[0] + torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh) + actual_result = lora_linear(torchax_inputs)[0] expected_result = linear(torch.cat(inputs))[0] - rtol, atol = TOLERANCES[lora_result.dtype] + rtol, atol = TOLERANCES[actual_result.dtype] with torchax.default_env(): - lora_result_cpu = lora_result.to('cpu') - torch.testing.assert_close(lora_result_cpu, + actual_result_cpu = actual_result.to('cpu') + torch.testing.assert_close(actual_result_cpu, expected_result, rtol=rtol, atol=atol) -def _create_linear_and_lora_wrapper(linear, base_linear, lora_cls, vllm_config, - mesh): - base_linear.weight.data = linear.weight.data - jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) - linear_method = VllmUnquantizedLinearMethod(jax_config) - base_linear.quant_method = linear_method - linear_method.process_weights_after_loading( - base_linear) # here base_linear.weight is moved to TPU and sharded. - assert jax_view(base_linear.weight).platform( - ) == 'tpu', 'base_linear.weight should have been moved to TPU.' - assert not isinstance( - jax_view(base_linear.weight).sharding, jax.sharding. - SingleDeviceSharding), 'base_linear.weight should have been sharded.' +def _create_random_linear_parallel_layer(orientation, vllm_config, mesh): + # We first create a base linear layer, then a lora layer to wrap it. + if orientation == "row": - lora_linear = lora_cls(base_linear) - return lora_linear + def _create_row_linear(): + return RowParallelLinear( + 64, # input_size + 64, # output_size + bias=False, + params_dtype=torch.float16) + + linear = _create_row_linear() + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = _create_row_linear() + lora_linear = _create_lora_wrapper(linear, + base_linear, + RowParallelLinearWithLoRA, + vllm_config=vllm_config, + mesh=mesh) + else: + + def _create_column_linear(): + return ColumnParallelLinear(64, + 64, + bias=False, + params_dtype=torch.float16) + + linear = _create_column_linear() + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = _create_column_linear() + lora_linear = _create_lora_wrapper(linear, + base_linear, + ColumnParallelLinearWithLoRA, + vllm_config=vllm_config, + mesh=mesh) + + return linear, lora_linear def _create_mesh(): @@ -603,37 +526,75 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh): # We first create a base linear layer, then a lora layer to wrap it. if repeats == 2: # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. - linear = MergedColumnParallelLinear( - 64, # input_size - [64] * repeats, # output_size - bias=False, - params_dtype=torch.float16) + def _create_merged_column_linear(): + return MergedColumnParallelLinear( + 64, # input_size + [64] * repeats, # output_size + bias=False, + params_dtype=torch.float16) + + linear = _create_merged_column_linear() linear.weight.data = torch.rand_like(linear.weight.data) - base_linear = MergedColumnParallelLinear( - 64, # input_size - [64] * repeats, # output_size - bias=False, - params_dtype=torch.float16) - base_linear.weight.data = linear.weight.data - jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) - linear_method = VllmUnquantizedLinearMethod(jax_config) - base_linear.quant_method = linear_method - linear_method.process_weights_after_loading( - base_linear - ) # here base_linear.weight is moved to TPU and sharded. - assert jax_view(base_linear.weight).platform( - ) == 'tpu', 'base_linear.weight should have been moved to TPU.' - assert not isinstance( - jax_view( - base_linear.weight).sharding, jax.sharding.SingleDeviceSharding - ), 'base_linear.weight should have been sharded.' - - lora_linear = MergedColumnParallelLinearWithLoRA(base_linear) + base_linear = _create_merged_column_linear() + lora_linear = _create_lora_wrapper(linear, base_linear, + MergedColumnParallelLinearWithLoRA, + vllm_config, mesh, repeats) elif repeats == 3: - raise NotImplementedError("NYI: for MergedQKVParallelLinear case") + + def _create_qkv_linear(): + return QKVParallelLinear(64, + 64, + 32, + bias=False, + params_dtype=torch.float16) + + linear = _create_qkv_linear() + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = _create_qkv_linear() + lora_linear = _create_lora_wrapper(linear, base_linear, + MergedQKVParallelLinearWithLoRA, + vllm_config, mesh, repeats) else: - raise NotImplementedError("NYI: for QKVParallelLinear case") + + def _create_qkv_linear(): + return QKVParallelLinear(64, + 64, + 32, + bias=False, + params_dtype=torch.float16) + + linear = _create_qkv_linear() + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = _create_qkv_linear() + lora_linear = _create_lora_wrapper(linear, base_linear, + QKVParallelLinearWithLoRA, + vllm_config, mesh, repeats) + + return linear, lora_linear + + +def _create_lora_wrapper(linear, + base_linear, + lora_cls, + vllm_config, + mesh, + repeats=1): + base_linear.weight.data = linear.weight.data + jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) + linear_method = VllmUnquantizedLinearMethod(jax_config) + base_linear.quant_method = linear_method + linear_method.process_weights_after_loading( + base_linear) # here base_linear.weight is moved to TPU and sharded. + assert jax_view(base_linear.weight).platform( + ) == 'tpu', 'base_linear.weight should have been moved to TPU.' + assert not isinstance( + jax_view(base_linear.weight).sharding, jax.sharding. + SingleDeviceSharding), 'base_linear.weight should have been sharded.' + + lora_linear = lora_cls(base_linear) lora_config = vllm_config.lora_config max_loras = lora_config.max_loras @@ -656,4 +617,4 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh): assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - return linear, lora_linear + return lora_linear From 01ddba746485a4c7b981f6ae6054c44face922cd Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 10 Nov 2025 23:03:23 +0000 Subject: [PATCH 18/18] Start using bf16 Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 22 +++++++++---------- .../layers/vllm/quantization/unquantized.py | 3 +-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index d14b59168..3f45d08c1 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -216,7 +216,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: max_loras=max_loras, max_lora_rank=max_lora_rank, fully_sharded_loras=False, - lora_dtype=torch.float16, + lora_dtype=torch.bfloat16, ) vllm_config = dist_init vllm_config.lora_config = lora_config @@ -256,7 +256,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: num_inputs=32, input_size=(1, 64), input_range=(0, 1), - input_type=torch.float16, + input_type=torch.bfloat16, device='cpu') _update_punica_wrapper_metadata(punica_wrapper, index_mapping, @@ -303,7 +303,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: num_inputs=32, input_size=(1, 64), input_range=(0, 1), - input_type=torch.float16, + input_type=torch.bfloat16, device='cpu') _update_punica_wrapper_metadata(punica_wrapper, index_mapping, @@ -337,7 +337,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, stage) -> None: max_loras=max_loras, max_lora_rank=max_lora_rank, fully_sharded_loras=False, - lora_dtype=torch.float16, + lora_dtype=torch.bfloat16, ) vllm_config = dist_init vllm_config.lora_config = lora_config @@ -371,7 +371,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, stage) -> None: num_inputs=32, input_size=(1, 64), input_range=(0, 1), - input_type=torch.float16, + input_type=torch.bfloat16, device='cpu') _update_punica_wrapper_metadata(punica_wrapper, index_mapping, @@ -409,7 +409,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, stage) -> None: num_inputs=32, input_size=(1, 64), input_range=(0, 1), - input_type=torch.float16, + input_type=torch.bfloat16, device='cpu') _update_punica_wrapper_metadata(punica_wrapper, index_mapping, prompt_mapping, stage, index_to_id, @@ -438,7 +438,7 @@ def _create_row_linear(): 64, # input_size 64, # output_size bias=False, - params_dtype=torch.float16) + params_dtype=torch.bfloat16) linear = _create_row_linear() linear.weight.data = torch.rand_like(linear.weight.data) @@ -455,7 +455,7 @@ def _create_column_linear(): return ColumnParallelLinear(64, 64, bias=False, - params_dtype=torch.float16) + params_dtype=torch.bfloat16) linear = _create_column_linear() linear.weight.data = torch.rand_like(linear.weight.data) @@ -531,7 +531,7 @@ def _create_merged_column_linear(): 64, # input_size [64] * repeats, # output_size bias=False, - params_dtype=torch.float16) + params_dtype=torch.bfloat16) linear = _create_merged_column_linear() linear.weight.data = torch.rand_like(linear.weight.data) @@ -547,7 +547,7 @@ def _create_qkv_linear(): 64, 32, bias=False, - params_dtype=torch.float16) + params_dtype=torch.bfloat16) linear = _create_qkv_linear() linear.weight.data = torch.rand_like(linear.weight.data) @@ -563,7 +563,7 @@ def _create_qkv_linear(): 64, 32, bias=False, - params_dtype=torch.float16) + params_dtype=torch.bfloat16) linear = _create_qkv_linear() linear.weight.data = torch.rand_like(linear.weight.data) diff --git a/tpu_inference/layers/vllm/quantization/unquantized.py b/tpu_inference/layers/vllm/quantization/unquantized.py index 255838312..7881332f7 100644 --- a/tpu_inference/layers/vllm/quantization/unquantized.py +++ b/tpu_inference/layers/vllm/quantization/unquantized.py @@ -127,8 +127,7 @@ def _apply_fused(self, x_jax = jax_view(x) weight_jax = jax_view(layer.weight) - outs = jnp.einsum("mn,pn->mp", x_jax.astype(jnp.float32), - weight_jax.astype(jnp.float32)).astype(x_jax.dtype) + outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax) if bias is not None and not layer.skip_bias_add: outs += bias.jax()