From e73674a8019bd6efb1c37c81fb40045af3a918e8 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 27 Oct 2025 18:24:05 +0000 Subject: [PATCH 01/14] Save an initial version. Not capable of running at all. Signed-off-by: Xiongfei Wei --- tests/lora/conftest.py | 32 ++++ tests/lora/test_layers.py | 385 ++++++++++++++++++++++++++++++++++++++ tests/lora/utils.py | 96 ++++++++++ 3 files changed, 513 insertions(+) create mode 100644 tests/lora/conftest.py create mode 100644 tests/lora/test_layers.py create mode 100644 tests/lora/utils.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py new file mode 100644 index 000000000..d573070de --- /dev/null +++ b/tests/lora/conftest.py @@ -0,0 +1,32 @@ +import tempfile + +import pytest +from vllm.config import set_current_vllm_config +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.engine.arg_utils import EngineArgs + + +@pytest.fixture +def dist_init(): + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + + vllm_config = engine_args.create_engine_config() + + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + ensure_model_parallel_initialized(1, 1) + yield vllm_config + cleanup_dist_env_and_memory(shutdown_ray=True) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py new file mode 100644 index 000000000..a40afb532 --- /dev/null +++ b/tests/lora/test_layers.py @@ -0,0 +1,385 @@ +import random +from typing import Optional + +import jax +import pytest +import torch +import torchax +from jax.sharding import NamedSharding, PartitionSpec +from torchax.interop import torch_view +from torchax.ops.mappings import j2t, t2j +# from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu +from vllm.config import LoRAConfig +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, + MergedColumnParallelLinearWithLoRA) +# yapf: enable +from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.model_executor.layers.linear import MergedColumnParallelLinear +from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform + +from .utils import DummyLoRAManager + +# TODO(xiowei): +# - add test for multi-chip. +# - add equivalent test for ColumnParallelLinearWithShardedLoRA. + +P = PartitionSpec + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + +pytestmark = pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test is only for TPU platform.") + +# prefill stage(True) or decode stage(False) +STAGES = [True, False] + + +def check_punica_wrapper(punica_wrapper) -> bool: + from tpu_inference.lora.torch_punica_tpu import PunicaWrapperTPU + return type(punica_wrapper) is PunicaWrapperTPU + + +def get_random_index_to_id(num_loras: int, + num_slots: int, + log: bool = True) -> list[Optional[int]]: + """Creates a random index_to_lora_id mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + + returns: + index_to_lora_id: a random index_to_lora_id mapping. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: list[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + # xw32: It seems the slot_idx start at 1. + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + index_to_id: list[Optional[int]], + lora_layer: BaseLayerWithLoRA, + baselayer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: + """This method populates the lora layers (BaseLayerWithLoRA) with lora weights. + + Args: + index_to_id: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + lora_layer: the LoRAlayer to populate. + baselayer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + + returns: + lora_dict: a dictionary dict[int, LoRALayerWeights] that maps the lora ID to the corresponding lora weights. + sublora_dict: a dictionary dict[int, list[LoRALayerWeights]] that maps the lora ID to the corresponding lora weights. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. + sublora_dict: dict[int, list[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(index_to_id): + if lora_id is not None: + subloras: list[LoRALayerWeights] = [] + sublora_len = baselayer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager( + baselayer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=baselayer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[(sublora_len * + i):(sublora_len * (i + 1)), :] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + # Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env. + with torchax.default_env(), jax.default_device( + jax.devices("tpu")[0]): + lora_layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: list[int], + num_inputs: int, + input_size: tuple[int, ...], + input_range: tuple[float, float], + input_type: torch.dtype = torch.int, + device: torch.device = "cpu", +) -> tuple[list[torch.Tensor], list[int], list[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. Or the number of requests. + input_size: the size of each individual input. Or the number of tokens. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + + returns: + inputs: a list of torch tensors of size num_inputs. Each input has shape `input_size`. + index_mapping: maps each input token to a lora ID. + prompt_mapping: maps each request to a lora ID. + """ + + low, high = input_range + + inputs: list[torch.Tensor] = [] + index_mapping: list[int] = [] + prompt_mapping: list[int] = [] + + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device=device)) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device=device) * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) +@pytest.mark.parametrize("repeats", [2]) +@pytest.mark.parametrize("fully_shard", [False]) # TODO(xiowei): add "True". +@pytest.mark.parametrize("device", ["cpu"]) +@pytest.mark.parametrize("stage", [True, False]) +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device, stage) -> None: + max_loras = 9 + max_lora_rank = 8 + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=max_lora_rank, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) + + axis_names = ("data", "model") + mesh_shape = ( + 1, 1 + ) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices())) + mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices()) + + def create_column_parallel_packed_layer(): + # Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper. + if repeats == 2: + linear = MergedColumnParallelLinear( + 256, # input_size + [256] * repeats, # output_size + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = MergedColumnParallelLinearWithLoRA( + linear + ) # TODO(xiowei): add test for MergedColumnParallelLinearWithShardedLoRA (fully_shard == True) + elif repeats == 3: + # TODO(xiowei): add test for this case. + raise NotImplementedError("NYI: for MergedQKVParallelLinear case") + else: + # TODO(xiowei): add test for this case. + raise NotImplementedError("NYI: for QKVParallelLinear case") + + n_slices = repeats + # create_lora_weights creates global shape weight. + lora_linear.create_lora_weights(max_loras, lora_config) + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == n_slices) + + return linear, lora_linear + + set_random_seed(6) + + linear, lora_linear = create_column_parallel_packed_layer() + # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor + # BaseLinearLayerWithLoRA.weight property guarantees this. + assert torch.equal(linear.weight, lora_linear.weight) + with torchax.default_env(): + assert torch.equal(linear.weight.data, j2t(lora_linear.weight)) + + max_num_batched_tokens = 8192 + max_batches = 256 + punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches, + device, + max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) + lora_linear.set_mapping(punica_wrapper) + + # load the lora weight, shard it, and send it to TPU. + # create a lora slot index to lora id mapping. + index_to_id = get_random_index_to_id(num_loras, max_loras) + # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights + lora_dict, sublora_dict = populate_loras( + index_to_id, + lora_layer=lora_linear, + baselayer_weights=linear.weight, + repeats=repeats, + ) + + # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 256]. + # index_mapping: list[int] + # prompt_mapping: list[int] + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32, + input_size=(1, 256), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + max_loras, + vocab_size=512, + extra_vocab_size=lora_config.lora_extra_vocab_size, + ) + punica_wrapper.move_to_device(mesh) + + jax_inputs = [] + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + for input in inputs: + # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` + # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + jax_inputs.append(jax_input) + with torchax.default_env(): + lora_result = lora_linear(torch.cat(jax_inputs))[0] + lora_result = j2t(lora_result) + + expected_results: list[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + # linear(input_) returns (output, output_bias) so we only need the first one. + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * + (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * + sublora.scaling) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' + ) + + # Check that resetting the lora weights succeeds + # Here we set all lora weight to be empty. + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], # different from the above create_random_inputs + num_inputs=32, + input_size=(1, 256), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + punica_wrapper.move_to_device(mesh) + + jax_inputs = [] + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + for input in inputs: + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + jax_inputs.append(jax_input) + with torchax.default_env(): + lora_result = lora_linear(torch.cat(jax_inputs))[0] + lora_result = j2t(lora_result) + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py new file mode 100644 index 000000000..41c5cf38d --- /dev/null +++ b/tests/lora/utils.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights + + +# https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py +class DummyLoRAManager: + + def __init__(self, device: torch.device = "cuda:0"): + super().__init__() + self._loras: dict[str, LoRALayerWeights] = {} + self._device = device + + def set_module_lora(self, module_name: str, lora: LoRALayerWeights): + self._loras[module_name] = lora + + def get_module_lora(self, module_name: str) -> LoRALayerWeights: + return self._loras[module_name] + + def init_random_lora( + self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0, + ): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([rank, weight.shape[1]], + dtype=weight.dtype, + device=self._device), + lora_b=torch.rand([weight.shape[0], rank], + dtype=weight.dtype, + device=self._device), + ) + if generate_embeddings_tensor: + lora.embeddings_tensor = torch.rand( + 5, + generate_embeddings_tensor, + dtype=weight.dtype, + device=self._device, + ) + self.set_module_lora(module_name, lora) + + return lora + + def init_lora( + self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None, + ): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([rank, input_dim], device="cuda"), + lora_b=torch.rand([output_dim, input_dim], device="cuda"), + embeddings_tensor=embeddings_tensor, + ) + self.set_module_lora(module_name, lora) + return lora + + def reset_lora(self): + self._loras = {} + + def init_packed_lora( + self, + module_name: str, + input_dim: int, + output_dims: list[int], + noop_lora_index: list[int] | None = None, + rank: int = 8, + ): + base_loras: list[LoRALayerWeights] = [] + noop_lora_index_set = set(noop_lora_index or []) + + for i, out_dim in enumerate(output_dims): + base_lora = self.init_lora( + module_name + "_000_" + str(i), + input_dim, + out_dim, + rank=rank, + noop=i in noop_lora_index_set, + ) + base_loras.append(base_lora) + packed_lora = PackedLoRALayerWeights.pack(base_loras) + self.set_module_lora(module_name, packed_lora) + return packed_lora From da64f463e1d0912d39640db42d38ef512f1f99a3 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 17:00:40 +0000 Subject: [PATCH 02/14] now the test can run to completion. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 83 ++++++++++++++------- tpu_inference/runner/compilation_manager.py | 1 + 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index a40afb532..970666dd4 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,6 +8,7 @@ from jax.sharding import NamedSharding, PartitionSpec from torchax.interop import torch_view from torchax.ops.mappings import j2t, t2j +from torchax.interop import jax_view, torch_view # from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu from vllm.config import LoRAConfig # yapf conflicts with isort for this block @@ -22,6 +23,11 @@ from vllm.platforms import current_platform from .utils import DummyLoRAManager +from tpu_inference.layers.vllm.quantization.common import ( + JaxCommonConfig, JaxCommonLinearConfig) +from tpu_inference.layers.vllm.quantization.unquantized import \ + VllmUnquantizedLinearMethod +from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora # TODO(xiowei): # - add test for multi-chip. @@ -224,15 +230,31 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, def create_column_parallel_packed_layer(): # Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper. if repeats == 2: + # In e2e, MergedColumnParallelLinear is created when we load th e model. The weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. linear = MergedColumnParallelLinear( 256, # input_size [256] * repeats, # output_size bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = MergedColumnParallelLinear( + 256, # input_size + [256] * repeats, # output_size + bias=False, + params_dtype=torch.float16) + base_linear.weight.data = linear.weight.data + vllm_config = dist_init + # self.jax_config.mesh.devices[0][0].platform + jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) + linear_method = VllmUnquantizedLinearMethod(jax_config) + linear_method.process_weights_after_loading(base_linear) + # here base_linear.weight is on TPU and sharded. + + # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. lora_linear = MergedColumnParallelLinearWithLoRA( - linear - ) # TODO(xiowei): add test for MergedColumnParallelLinearWithShardedLoRA (fully_shard == True) + base_linear + ) elif repeats == 3: # TODO(xiowei): add test for this case. raise NotImplementedError("NYI: for MergedQKVParallelLinear case") @@ -241,21 +263,22 @@ def create_column_parallel_packed_layer(): raise NotImplementedError("NYI: for QKVParallelLinear case") n_slices = repeats - # create_lora_weights creates global shape weight. - lora_linear.create_lora_weights(max_loras, lora_config) + with torchax.default_env(): + # create_lora_weights creates global shape weight. + lora_linear.create_lora_weights(max_loras, lora_config) + _shard_merged_column_parallel_linear_lora(lora_linear, mesh) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - return linear, lora_linear + return linear, lora_linear, linear_method set_random_seed(6) - linear, lora_linear = create_column_parallel_packed_layer() - # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor - # BaseLinearLayerWithLoRA.weight property guarantees this. - assert torch.equal(linear.weight, lora_linear.weight) + linear, lora_linear, linear_method = create_column_parallel_packed_layer() with torchax.default_env(): - assert torch.equal(linear.weight.data, j2t(lora_linear.weight)) + # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor + # BaseLinearLayerWithLoRA.weight property guarantees this. + assert torch.equal(linear.weight, lora_linear.weight.to('cpu')) max_num_batched_tokens = 8192 max_batches = 256 @@ -297,7 +320,7 @@ def create_column_parallel_packed_layer(): vocab_size=512, extra_vocab_size=lora_config.lora_extra_vocab_size, ) - punica_wrapper.move_to_device(mesh) + # punica_wrapper.move_to_device(mesh) jax_inputs = [] with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): @@ -305,12 +328,12 @@ def create_column_parallel_packed_layer(): # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` jax_input = torch_view(t2j(input)) - jax_input.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) + jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = lora_linear(torch.cat(jax_inputs))[0] - lora_result = j2t(lora_result) + # lora_result = lora_linear(torch.cat(jax_inputs))[0] + # lora_result = j2t(lora_result) + lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -318,23 +341,24 @@ def create_column_parallel_packed_layer(): result = linear(input_)[0] subloras = sublora_dict[lora_id] for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * - (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * + result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] * + (i + 1)] += (input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling) expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' - ) + with torchax.default_env(): + torch.testing.assert_close(lora_result.to('cpu'), + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}' + ) # Check that resetting the lora weights succeeds # Here we set all lora weight to be empty. @@ -368,8 +392,9 @@ def create_column_parallel_packed_layer(): NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = lora_linear(torch.cat(jax_inputs))[0] - lora_result = j2t(lora_result) + lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) + # lora_result = lora_linear(torch.cat(jax_inputs))[0] + # lora_result = j2t(lora_result) expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index c50c4bc86..6225cb4d6 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -7,6 +7,7 @@ import numpy as np import vllm.envs as envs from jax.sharding import NamedSharding, PartitionSpec +from vllm.utils.math_utils import cdiv from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata From 823fbc1545123010304d696a28e89df3db967f83 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 17:11:08 +0000 Subject: [PATCH 03/14] ok. The case without lora passed. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 45 ++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 970666dd4..192a5464a 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -348,17 +348,17 @@ def create_column_parallel_packed_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - with torchax.default_env(): - torch.testing.assert_close(lora_result.to('cpu'), - expected_result, - rtol=rtol, - atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}' - ) + # with torchax.default_env(): + # torch.testing.assert_close(lora_result.to('cpu'), + # expected_result, + # rtol=rtol, + # atol=atol) + # print( + # f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}' + # ) + # print( + # f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}' + # ) # Check that resetting the lora weights succeeds # Here we set all lora weight to be empty. @@ -382,7 +382,6 @@ def create_column_parallel_packed_layer(): 512, lora_config.lora_extra_vocab_size, ) - punica_wrapper.move_to_device(mesh) jax_inputs = [] with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): @@ -398,13 +397,15 @@ def create_column_parallel_packed_layer(): expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' - ) + with torchax.default_env(): + lora_result_cpu = lora_result.to('cpu') + torch.testing.assert_close(lora_result_cpu, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' + ) From e8718d807e3a46de2b3fa9a53aa1ca3f5c7dc5e1 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 17:46:35 +0000 Subject: [PATCH 04/14] ok, the test passed. Need to make it simpler next. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 38 +++++++++++++++----------- tpu_inference/lora/torch_punica_tpu.py | 2 ++ 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 192a5464a..7967fb336 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -248,6 +248,7 @@ def create_column_parallel_packed_layer(): # self.jax_config.mesh.devices[0][0].platform jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) linear_method = VllmUnquantizedLinearMethod(jax_config) + base_linear.quant_method=linear_method linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is on TPU and sharded. @@ -263,6 +264,8 @@ def create_column_parallel_packed_layer(): raise NotImplementedError("NYI: for QKVParallelLinear case") n_slices = repeats + #TODO(xw): check if we can enable torchax globally. + # TODO(xw): check if we can calculate both actual and expected output using torchax. with torchax.default_env(): # create_lora_weights creates global shape weight. lora_linear.create_lora_weights(max_loras, lora_config) @@ -282,10 +285,11 @@ def create_column_parallel_packed_layer(): max_num_batched_tokens = 8192 max_batches = 256 - punica_wrapper = get_punica_wrapper(max_num_batched_tokens, - max_batches, - device, - max_loras=max_loras) + with torchax.default_env(): + punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches, + 'jax', + max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_linear.set_mapping(punica_wrapper) @@ -333,7 +337,8 @@ def create_column_parallel_packed_layer(): with torchax.default_env(): # lora_result = lora_linear(torch.cat(jax_inputs))[0] # lora_result = j2t(lora_result) - lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) + # lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) + lora_result = lora_linear(torch.cat(jax_inputs))[0] expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -348,17 +353,18 @@ def create_column_parallel_packed_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - # with torchax.default_env(): - # torch.testing.assert_close(lora_result.to('cpu'), - # expected_result, - # rtol=rtol, - # atol=atol) - # print( - # f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}' - # ) - # print( - # f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}' - # ) + with torchax.default_env(): + lora_result_cpu = lora_result.to('cpu') + torch.testing.assert_close(lora_result_cpu, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' + ) # Check that resetting the lora weights succeeds # Here we set all lora weight to be empty. diff --git a/tpu_inference/lora/torch_punica_tpu.py b/tpu_inference/lora/torch_punica_tpu.py index 7c2da5361..658e0f521 100644 --- a/tpu_inference/lora/torch_punica_tpu.py +++ b/tpu_inference/lora/torch_punica_tpu.py @@ -23,6 +23,8 @@ class PunicaWrapperTPU(PunicaWrapperBase): PunicaWrapperTPU is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the pytorch punica ops. + + It is created by get_punica_wrapper when we load_lora_model->create_lora_manager. Device is TPU. """ def __init__(self, max_num_batched_tokens: int, max_batches: int, From 539c639268e6b0986acd7737d722ef296fd2079d Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 20:37:48 +0000 Subject: [PATCH 05/14] also check if the correct and the sharding is correct. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 7967fb336..9ecadc22d 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -251,6 +251,8 @@ def create_column_parallel_packed_layer(): base_linear.quant_method=linear_method linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is on TPU and sharded. + assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.' + assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.' # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. lora_linear = MergedColumnParallelLinearWithLoRA( @@ -270,6 +272,13 @@ def create_column_parallel_packed_layer(): # create_lora_weights creates global shape weight. lora_linear.create_lora_weights(max_loras, lora_config) _shard_merged_column_parallel_linear_lora(lora_linear, mesh) + # TODO: assert the lora_a_stacked is on TPU and sharded. + assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.' + assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.' + assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.' + assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.' + + # TODO: assert the lora_b_stacked is on TPU and sharded. assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) @@ -324,7 +333,8 @@ def create_column_parallel_packed_layer(): vocab_size=512, extra_vocab_size=lora_config.lora_extra_vocab_size, ) - # punica_wrapper.move_to_device(mesh) + assert jax_view(punica_wrapper._lora_indices_per_batch).platform() == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' jax_inputs = [] with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): From 90888cb5640bacdc02ae495dc57ef980faef3e38 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 29 Oct 2025 20:58:09 +0000 Subject: [PATCH 06/14] cleaned up Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 72 ++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 9ecadc22d..62f3ce202 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -1,7 +1,9 @@ import random from typing import Optional +import copy import jax +import numpy as np import pytest import torch import torchax @@ -27,11 +29,10 @@ JaxCommonConfig, JaxCommonLinearConfig) from tpu_inference.layers.vllm.quantization.unquantized import \ VllmUnquantizedLinearMethod -from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora +from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora, _shard_module_to_tpu # TODO(xiowei): # - add test for multi-chip. -# - add equivalent test for ColumnParallelLinearWithShardedLoRA. P = PartitionSpec @@ -56,7 +57,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: def get_random_index_to_id(num_loras: int, num_slots: int, log: bool = True) -> list[Optional[int]]: - """Creates a random index_to_lora_id mapping. + """Creates a random index_to_lora_id mapping: slot[index] = lora_id. Args: num_loras: The number of active loras in the mapping. @@ -76,7 +77,7 @@ def get_random_index_to_id(num_loras: int, slots: list[Optional[int]] = [None] * num_slots random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() for lora_id, slot_idx in enumerate(random_slot_selections, start=1): - # xw32: It seems the slot_idx start at 1. + # The slot_idx start at 1. slots[slot_idx] = lora_id if log: @@ -92,7 +93,7 @@ def populate_loras( generate_embeddings_tensor: int = 0, repeats: int = 1, ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: - """This method populates the lora layers (BaseLayerWithLoRA) with lora weights. + """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA). Args: index_to_id: a list of lora ids. The index of the lora id @@ -140,8 +141,7 @@ def populate_loras( subloras) if repeats > 1 else subloras[0] # Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env. - with torchax.default_env(), jax.default_device( - jax.devices("tpu")[0]): + with torchax.default_env(): lora_layer.set_lora( slot_idx, lora_a=lora.lora_a, @@ -207,11 +207,10 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) @pytest.mark.parametrize("repeats", [2]) -@pytest.mark.parametrize("fully_shard", [False]) # TODO(xiowei): add "True". -@pytest.mark.parametrize("device", ["cpu"]) +@pytest.mark.parametrize("fully_shard", [False]) @pytest.mark.parametrize("stage", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage) -> None: + stage) -> None: max_loras = 9 max_lora_rank = 8 lora_config = LoRAConfig( @@ -228,9 +227,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices()) def create_column_parallel_packed_layer(): - # Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper. + # We first create a base linear layer, then a lora layer to wrap it. if repeats == 2: - # In e2e, MergedColumnParallelLinear is created when we load th e model. The weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. + # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. linear = MergedColumnParallelLinear( 256, # input_size [256] * repeats, # output_size @@ -245,50 +244,42 @@ def create_column_parallel_packed_layer(): params_dtype=torch.float16) base_linear.weight.data = linear.weight.data vllm_config = dist_init - # self.jax_config.mesh.devices[0][0].platform jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) linear_method = VllmUnquantizedLinearMethod(jax_config) base_linear.quant_method=linear_method - linear_method.process_weights_after_loading(base_linear) - # here base_linear.weight is on TPU and sharded. + linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is moved to TPU and sharded. assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.' assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.' - # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. lora_linear = MergedColumnParallelLinearWithLoRA( base_linear ) elif repeats == 3: - # TODO(xiowei): add test for this case. raise NotImplementedError("NYI: for MergedQKVParallelLinear case") else: - # TODO(xiowei): add test for this case. raise NotImplementedError("NYI: for QKVParallelLinear case") - n_slices = repeats - #TODO(xw): check if we can enable torchax globally. - # TODO(xw): check if we can calculate both actual and expected output using torchax. with torchax.default_env(): - # create_lora_weights creates global shape weight. + # create_lora_weights creates global shape lora weight. lora_linear.create_lora_weights(max_loras, lora_config) - _shard_merged_column_parallel_linear_lora(lora_linear, mesh) - # TODO: assert the lora_a_stacked is on TPU and sharded. + # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. + _shard_module_to_tpu(lora_linear, mesh) + assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.' assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.' assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.' assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.' - - # TODO: assert the lora_b_stacked is on TPU and sharded. + n_slices = repeats assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - return linear, lora_linear, linear_method + return linear, lora_linear set_random_seed(6) - linear, lora_linear, linear_method = create_column_parallel_packed_layer() + linear, lora_linear = create_column_parallel_packed_layer() with torchax.default_env(): - # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor + # lora_linear.weight has type torchax.tensor.Tensor # BaseLinearLayerWithLoRA.weight property guarantees this. assert torch.equal(linear.weight, lora_linear.weight.to('cpu')) @@ -302,8 +293,7 @@ def create_column_parallel_packed_layer(): assert check_punica_wrapper(punica_wrapper) lora_linear.set_mapping(punica_wrapper) - # load the lora weight, shard it, and send it to TPU. - # create a lora slot index to lora id mapping. + # Populate lora matrices (lora_a and lora_b) in the lora layer. index_to_id = get_random_index_to_id(num_loras, max_loras) # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights lora_dict, sublora_dict = populate_loras( @@ -322,10 +312,11 @@ def create_column_parallel_packed_layer(): input_size=(1, 256), input_range=(0, 1), input_type=torch.float16, - device=device) + device='cpu') lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + with torchax.default_env(): + # Here we move the metadata from cpu to tpu. punica_wrapper.update_metadata( lora_mapping, index_to_id, @@ -337,7 +328,7 @@ def create_column_parallel_packed_layer(): assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' jax_inputs = [] - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + with torchax.default_env(): for input in inputs: # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` @@ -345,9 +336,6 @@ def create_column_parallel_packed_layer(): jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): - # lora_result = lora_linear(torch.cat(jax_inputs))[0] - # lora_result = j2t(lora_result) - # lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) lora_result = lora_linear(torch.cat(jax_inputs))[0] expected_results: list[torch.Tensor] = [] @@ -387,10 +375,10 @@ def create_column_parallel_packed_layer(): input_size=(1, 256), input_range=(0, 1), input_type=torch.float16, - device=device) + device='cpu') lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + with torchax.default_env(): punica_wrapper.update_metadata( lora_mapping, index_to_id, @@ -400,16 +388,14 @@ def create_column_parallel_packed_layer(): ) jax_inputs = [] - with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + with torchax.default_env(): for input in inputs: jax_input = torch_view(t2j(input)) jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs)) - # lora_result = lora_linear(torch.cat(jax_inputs))[0] - # lora_result = j2t(lora_result) + lora_result = lora_linear(torch.cat(jax_inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] From 223cf150adc3aa194ed8b9864e1ad07ba3aefe88 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Oct 2025 22:19:47 +0000 Subject: [PATCH 07/14] ok, fixed the torchax.view.item() issue. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 14 +++++++------- tpu_inference/lora/torch_punica_tpu.py | 4 +++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 62f3ce202..eb0c0ce75 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -231,15 +231,15 @@ def create_column_parallel_packed_layer(): if repeats == 2: # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. linear = MergedColumnParallelLinear( - 256, # input_size - [256] * repeats, # output_size + 64, # input_size + [64] * repeats, # output_size bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) base_linear = MergedColumnParallelLinear( - 256, # input_size - [256] * repeats, # output_size + 64, # input_size + [64] * repeats, # output_size bias=False, params_dtype=torch.float16) base_linear.weight.data = linear.weight.data @@ -303,13 +303,13 @@ def create_column_parallel_packed_layer(): repeats=repeats, ) - # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 256]. + # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64]. # index_mapping: list[int] # prompt_mapping: list[int] inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=32, - input_size=(1, 256), + input_size=(1, 64), input_range=(0, 1), input_type=torch.float16, device='cpu') @@ -372,7 +372,7 @@ def create_column_parallel_packed_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], # different from the above create_random_inputs num_inputs=32, - input_size=(1, 256), + input_size=(1, 64), input_range=(0, 1), input_type=torch.float16, device='cpu') diff --git a/tpu_inference/lora/torch_punica_tpu.py b/tpu_inference/lora/torch_punica_tpu.py index 658e0f521..ca8554b43 100644 --- a/tpu_inference/lora/torch_punica_tpu.py +++ b/tpu_inference/lora/torch_punica_tpu.py @@ -8,6 +8,8 @@ import torch.nn.functional as F import torchax from vllm.lora.punica_wrapper.utils import convert_mapping +from torchax.interop import jax_view, torch_view + if TYPE_CHECKING: # avoid circuit import @@ -283,7 +285,7 @@ def _update_prefill_metadata(self, self.batch_size = 1 self._lora_indices_per_batch[:self. batch_size] = token_lora_tensor[:self. - batch_size] + batch_size].torch() def _pad_prompt_mapping( self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: From c1f5841a466972f44f15e3a58f6f8d7faf8e0d0b Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Oct 2025 23:18:36 +0000 Subject: [PATCH 08/14] add multi-chip test case Signed-off-by: Xiongfei Wei --- .buildkite/pipeline_jax.yml | 13 +++++++++++++ tests/lora/test_layers.py | 12 +++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index ea648b173..2bdecf1be 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -158,6 +158,7 @@ steps: queue: tpu_v6e_queue commands: - | +<<<<<<< HEAD if [[ "$$NIGHTLY" == "1" ]]; then .buildkite/scripts/run_in_docker.sh \ bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ @@ -166,6 +167,12 @@ steps: echo "Skipping: NIGHTLY environment variable not set" exit 0 fi +======= + .buildkite/scripts/run_in_docker.sh \ + bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' +>>>>>>> c17bacea (add multi-chip test case) - label: "E2E MLPerf tests for JAX + vLLM models on multiple chips" key: test_11 @@ -213,6 +220,7 @@ steps: queue: tpu_v6e_8_queue commands: - | +<<<<<<< HEAD if [[ "$$NIGHTLY" == "1" ]]; then .buildkite/scripts/run_in_docker.sh \ bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py' @@ -220,6 +228,11 @@ steps: echo "Skipping: NIGHTLY environment variable not set" exit 0 fi +======= + .buildkite/scripts/run_in_docker.sh \ + bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' +>>>>>>> c17bacea (add multi-chip test case) - label: "E2E data parallelism test" key: test_14 diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index eb0c0ce75..a6c65331d 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -221,10 +221,13 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ) axis_names = ("data", "model") + devices = jax.devices() mesh_shape = ( - 1, 1 + 1, len(devices) + # 1, 1 ) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices())) - mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices()) + print(f'xw32 mesh_shape: {mesh_shape}') + mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) def create_column_parallel_packed_layer(): # We first create a base linear layer, then a lora layer to wrap it. @@ -281,7 +284,10 @@ def create_column_parallel_packed_layer(): with torchax.default_env(): # lora_linear.weight has type torchax.tensor.Tensor # BaseLinearLayerWithLoRA.weight property guarantees this. - assert torch.equal(linear.weight, lora_linear.weight.to('cpu')) + # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix. + # So the below check will fail. + if len(devices) == 1: + assert torch.equal(linear.weight.data, lora_linear.weight.to('cpu')) max_num_batched_tokens = 8192 max_batches = 256 From 70e1ac2365a5d364fc4ce59206f9841a627c0165 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 30 Oct 2025 23:24:07 +0000 Subject: [PATCH 09/14] fix the format Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 61 +++++++++++++-------- tpu_inference/lora/torch_punica_tpu.py | 5 +- tpu_inference/runner/compilation_manager.py | 1 - 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index a6c65331d..e6a33775b 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -1,16 +1,13 @@ import random from typing import Optional -import copy import jax -import numpy as np import pytest import torch import torchax from jax.sharding import NamedSharding, PartitionSpec -from torchax.interop import torch_view -from torchax.ops.mappings import j2t, t2j from torchax.interop import jax_view, torch_view +from torchax.ops.mappings import t2j # from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu from vllm.config import LoRAConfig # yapf conflicts with isort for this block @@ -24,12 +21,12 @@ from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform -from .utils import DummyLoRAManager -from tpu_inference.layers.vllm.quantization.common import ( - JaxCommonConfig, JaxCommonLinearConfig) +from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig from tpu_inference.layers.vllm.quantization.unquantized import \ VllmUnquantizedLinearMethod -from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora, _shard_module_to_tpu +from tpu_inference.layers.vllm.sharding import _shard_module_to_tpu + +from .utils import DummyLoRAManager # TODO(xiowei): # - add test for multi-chip. @@ -239,7 +236,7 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - + base_linear = MergedColumnParallelLinear( 64, # input_size [64] * repeats, # output_size @@ -249,14 +246,18 @@ def create_column_parallel_packed_layer(): vllm_config = dist_init jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) linear_method = VllmUnquantizedLinearMethod(jax_config) - base_linear.quant_method=linear_method - linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is moved to TPU and sharded. - assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.' - assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.' - - lora_linear = MergedColumnParallelLinearWithLoRA( + base_linear.quant_method = linear_method + linear_method.process_weights_after_loading( base_linear - ) + ) # here base_linear.weight is moved to TPU and sharded. + assert jax_view(base_linear.weight).platform( + ) == 'tpu', 'base_linear.weight should have been moved to TPU.' + assert not isinstance( + jax_view(base_linear.weight).sharding, + jax.sharding.SingleDeviceSharding + ), 'base_linear.weight should have been sharded.' + + lora_linear = MergedColumnParallelLinearWithLoRA(base_linear) elif repeats == 3: raise NotImplementedError("NYI: for MergedQKVParallelLinear case") else: @@ -268,10 +269,16 @@ def create_column_parallel_packed_layer(): # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. _shard_module_to_tpu(lora_linear, mesh) - assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.' - assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.' - assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.' - assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.' + assert jax_view(lora_linear.lora_a_stacked[0]).platform( + ) == 'tpu', 'lora_a_stacked should have been moved to TPU.' + assert not isinstance( + jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding. + SingleDeviceSharding), 'lora_a_stacked should have been sharded.' + assert jax_view(lora_linear.lora_b_stacked[0]).platform( + ) == 'tpu', 'lora_b_stacked should have been moved to TPU.' + assert not isinstance( + jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding. + SingleDeviceSharding), 'lora_b_stacked should have been sharded.' n_slices = repeats assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) @@ -287,7 +294,8 @@ def create_column_parallel_packed_layer(): # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix. # So the below check will fail. if len(devices) == 1: - assert torch.equal(linear.weight.data, lora_linear.weight.to('cpu')) + assert torch.equal(linear.weight.data, + lora_linear.weight.to('cpu')) max_num_batched_tokens = 8192 max_batches = 256 @@ -330,8 +338,12 @@ def create_column_parallel_packed_layer(): vocab_size=512, extra_vocab_size=lora_config.lora_extra_vocab_size, ) - assert jax_view(punica_wrapper._lora_indices_per_batch).platform() == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' - assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + assert jax_view(punica_wrapper._lora_indices_per_batch).platform( + ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + assert isinstance( + jax_view(punica_wrapper._lora_indices_per_batch).sharding, + jax.sharding.SingleDeviceSharding + ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' jax_inputs = [] with torchax.default_env(): @@ -339,7 +351,8 @@ def create_column_parallel_packed_layer(): # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` jax_input = torch_view(t2j(input)) - jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) jax_inputs.append(jax_input) with torchax.default_env(): lora_result = lora_linear(torch.cat(jax_inputs))[0] diff --git a/tpu_inference/lora/torch_punica_tpu.py b/tpu_inference/lora/torch_punica_tpu.py index ca8554b43..a3cc5bf3d 100644 --- a/tpu_inference/lora/torch_punica_tpu.py +++ b/tpu_inference/lora/torch_punica_tpu.py @@ -8,8 +8,6 @@ import torch.nn.functional as F import torchax from vllm.lora.punica_wrapper.utils import convert_mapping -from torchax.interop import jax_view, torch_view - if TYPE_CHECKING: # avoid circuit import @@ -285,7 +283,8 @@ def _update_prefill_metadata(self, self.batch_size = 1 self._lora_indices_per_batch[:self. batch_size] = token_lora_tensor[:self. - batch_size].torch() + batch_size].torch( + ) def _pad_prompt_mapping( self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 6225cb4d6..c50c4bc86 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -7,7 +7,6 @@ import numpy as np import vllm.envs as envs from jax.sharding import NamedSharding, PartitionSpec -from vllm.utils.math_utils import cdiv from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata From bdb47e95ec6c22c5d440be8af697f2f7b22465fa Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 31 Oct 2025 18:36:21 +0000 Subject: [PATCH 10/14] clean up Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index e6a33775b..abd072b3f 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,7 +8,6 @@ from jax.sharding import NamedSharding, PartitionSpec from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j -# from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu from vllm.config import LoRAConfig # yapf conflicts with isort for this block # yapf: disable @@ -28,9 +27,6 @@ from .utils import DummyLoRAManager -# TODO(xiowei): -# - add test for multi-chip. - P = PartitionSpec TOLERANCES = { @@ -204,26 +200,20 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) @pytest.mark.parametrize("repeats", [2]) -@pytest.mark.parametrize("fully_shard", [False]) @pytest.mark.parametrize("stage", [True, False]) -def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - stage) -> None: +def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: max_loras = 9 max_lora_rank = 8 lora_config = LoRAConfig( max_loras=max_loras, max_lora_rank=max_lora_rank, - fully_sharded_loras=fully_shard, + fully_sharded_loras=False, lora_dtype=torch.float16, ) axis_names = ("data", "model") devices = jax.devices() - mesh_shape = ( - 1, len(devices) - # 1, 1 - ) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices())) - print(f'xw32 mesh_shape: {mesh_shape}') + mesh_shape = (1, len(devices)) mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) def create_column_parallel_packed_layer(): @@ -264,7 +254,6 @@ def create_column_parallel_packed_layer(): raise NotImplementedError("NYI: for QKVParallelLinear case") with torchax.default_env(): - # create_lora_weights creates global shape lora weight. lora_linear.create_lora_weights(max_loras, lora_config) # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. _shard_module_to_tpu(lora_linear, mesh) From 3022e6a44912d0ec1f3fd3124e5d1e0cbe79c00d Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 5 Nov 2025 21:18:51 +0000 Subject: [PATCH 11/14] Add lora unit tests to the CI Signed-off-by: Xiongfei Wei --- .buildkite/pipeline_jax.yml | 47 +++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index 2bdecf1be..79787592f 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -151,28 +151,20 @@ steps: exit 0 fi - - label: "lora tests for JAX + vLLM models single chip" + - label: "lora e2e tests for JAX + vLLM models single chip" key: test_10 soft_fail: true agents: queue: tpu_v6e_queue commands: - | -<<<<<<< HEAD if [[ "$$NIGHTLY" == "1" ]]; then .buildkite/scripts/run_in_docker.sh \ - bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py' + bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py' else echo "Skipping: NIGHTLY environment variable not set" exit 0 fi -======= - .buildkite/scripts/run_in_docker.sh \ - bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' ->>>>>>> c17bacea (add multi-chip test case) - label: "E2E MLPerf tests for JAX + vLLM models on multiple chips" key: test_11 @@ -210,7 +202,7 @@ steps: exit 0 fi - - label: "lora tests for JAX + vLLM models multi chips" + - label: "lora e2e tests for JAX + vLLM models multi chips" key: test_13 soft_fail: true env: @@ -220,7 +212,6 @@ steps: queue: tpu_v6e_8_queue commands: - | -<<<<<<< HEAD if [[ "$$NIGHTLY" == "1" ]]; then .buildkite/scripts/run_in_docker.sh \ bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py' @@ -228,11 +219,6 @@ steps: echo "Skipping: NIGHTLY environment variable not set" exit 0 fi -======= - .buildkite/scripts/run_in_docker.sh \ - bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' ->>>>>>> c17bacea (add multi-chip test case) - label: "E2E data parallelism test" key: test_14 @@ -246,6 +232,29 @@ steps: .buildkite/scripts/run_in_docker.sh \ bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/e2e/test_data_parallel.py' + - label: "lora unit tests on single chip" + key: test_15 + soft_fail: true + agents: + queue: tpu_v6e_queue + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + bash -c ' python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' + + - label: "lora unit tests on multi chips" + key: test_16 + soft_fail: true + env: + USE_V6E8_QUEUE: "True" + VLLM_LOG_LEVEL: "INFO" + agents: + queue: tpu_v6e_8_queue + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' # ----------------------------------------------------------------- # NOTIFICATION STEP # ----------------------------------------------------------------- @@ -266,9 +275,11 @@ steps: - test_12 - test_13 - test_14 + - test_15 + - test_16 agents: queue: cpu commands: - | .buildkite/scripts/check_results.sh \ - "TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 + "TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_14 test_15 test_16 From 3f3355bec4c3e6c4c9f25718acca6936c7f36074 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 7 Nov 2025 17:47:17 +0000 Subject: [PATCH 12/14] extract the function _create_column_parallel_packed_layer out of the test. Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 123 ++++++++++++++++++++------------------ 1 file changed, 64 insertions(+), 59 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index abd072b3f..d1c55f3d1 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -210,73 +210,18 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: fully_sharded_loras=False, lora_dtype=torch.float16, ) + vllm_config = dist_init + vllm_config.lora_config = lora_config axis_names = ("data", "model") devices = jax.devices() mesh_shape = (1, len(devices)) mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) - def create_column_parallel_packed_layer(): - # We first create a base linear layer, then a lora layer to wrap it. - if repeats == 2: - # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. - linear = MergedColumnParallelLinear( - 64, # input_size - [64] * repeats, # output_size - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - - base_linear = MergedColumnParallelLinear( - 64, # input_size - [64] * repeats, # output_size - bias=False, - params_dtype=torch.float16) - base_linear.weight.data = linear.weight.data - vllm_config = dist_init - jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) - linear_method = VllmUnquantizedLinearMethod(jax_config) - base_linear.quant_method = linear_method - linear_method.process_weights_after_loading( - base_linear - ) # here base_linear.weight is moved to TPU and sharded. - assert jax_view(base_linear.weight).platform( - ) == 'tpu', 'base_linear.weight should have been moved to TPU.' - assert not isinstance( - jax_view(base_linear.weight).sharding, - jax.sharding.SingleDeviceSharding - ), 'base_linear.weight should have been sharded.' - - lora_linear = MergedColumnParallelLinearWithLoRA(base_linear) - elif repeats == 3: - raise NotImplementedError("NYI: for MergedQKVParallelLinear case") - else: - raise NotImplementedError("NYI: for QKVParallelLinear case") - - with torchax.default_env(): - lora_linear.create_lora_weights(max_loras, lora_config) - # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. - _shard_module_to_tpu(lora_linear, mesh) - - assert jax_view(lora_linear.lora_a_stacked[0]).platform( - ) == 'tpu', 'lora_a_stacked should have been moved to TPU.' - assert not isinstance( - jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding. - SingleDeviceSharding), 'lora_a_stacked should have been sharded.' - assert jax_view(lora_linear.lora_b_stacked[0]).platform( - ) == 'tpu', 'lora_b_stacked should have been moved to TPU.' - assert not isinstance( - jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding. - SingleDeviceSharding), 'lora_b_stacked should have been sharded.' - n_slices = repeats - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == n_slices) - - return linear, lora_linear - set_random_seed(6) - linear, lora_linear = create_column_parallel_packed_layer() + linear, lora_linear = _create_column_parallel_packed_layer( + repeats, vllm_config, mesh) with torchax.default_env(): # lora_linear.weight has type torchax.tensor.Tensor # BaseLinearLayerWithLoRA.weight property guarantees this. @@ -419,3 +364,63 @@ def create_column_parallel_packed_layer(): print( f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' ) + + +def _create_column_parallel_packed_layer(repeats, vllm_config, mesh): + # We first create a base linear layer, then a lora layer to wrap it. + if repeats == 2: + # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading. + linear = MergedColumnParallelLinear( + 64, # input_size + [64] * repeats, # output_size + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + + base_linear = MergedColumnParallelLinear( + 64, # input_size + [64] * repeats, # output_size + bias=False, + params_dtype=torch.float16) + base_linear.weight.data = linear.weight.data + jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear) + linear_method = VllmUnquantizedLinearMethod(jax_config) + base_linear.quant_method = linear_method + linear_method.process_weights_after_loading( + base_linear + ) # here base_linear.weight is moved to TPU and sharded. + assert jax_view(base_linear.weight).platform( + ) == 'tpu', 'base_linear.weight should have been moved to TPU.' + assert not isinstance( + jax_view( + base_linear.weight).sharding, jax.sharding.SingleDeviceSharding + ), 'base_linear.weight should have been sharded.' + + lora_linear = MergedColumnParallelLinearWithLoRA(base_linear) + elif repeats == 3: + raise NotImplementedError("NYI: for MergedQKVParallelLinear case") + else: + raise NotImplementedError("NYI: for QKVParallelLinear case") + + lora_config = vllm_config.lora_config + max_loras = lora_config.max_loras + with torchax.default_env(): + lora_linear.create_lora_weights(max_loras, lora_config) + # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. + _shard_module_to_tpu(lora_linear, mesh) + + assert jax_view(lora_linear.lora_a_stacked[0]).platform( + ) == 'tpu', 'lora_a_stacked should have been moved to TPU.' + assert not isinstance( + jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding. + SingleDeviceSharding), 'lora_a_stacked should have been sharded.' + assert jax_view(lora_linear.lora_b_stacked[0]).platform( + ) == 'tpu', 'lora_b_stacked should have been moved to TPU.' + assert not isinstance( + jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding. + SingleDeviceSharding), 'lora_b_stacked should have been sharded.' + n_slices = repeats + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == n_slices) + + return linear, lora_linear From fde81d37cd8620330c56311d81b2546e679e64f7 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 7 Nov 2025 18:38:51 +0000 Subject: [PATCH 13/14] reduce the size of the test Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index d1c55f3d1..74c054afd 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -198,7 +198,7 @@ def create_random_inputs( @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) +@pytest.mark.parametrize("num_loras", [1, 4, 9]) @pytest.mark.parametrize("repeats", [2]) @pytest.mark.parametrize("stage", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: From e9a705b917a3e6daac8f9bb726458478b3cdcee4 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 7 Nov 2025 18:53:45 +0000 Subject: [PATCH 14/14] in the middle of refactoring Signed-off-by: Xiongfei Wei --- tests/lora/test_layers.py | 154 +++++++++++++++++++------------------- 1 file changed, 79 insertions(+), 75 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 74c054afd..26e376ef2 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -202,6 +202,8 @@ def create_random_inputs( @pytest.mark.parametrize("repeats", [2]) @pytest.mark.parametrize("stage", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: + set_random_seed(6) + max_loras = 9 max_lora_rank = 8 lora_config = LoRAConfig( @@ -213,24 +215,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: vllm_config = dist_init vllm_config.lora_config = lora_config - axis_names = ("data", "model") - devices = jax.devices() - mesh_shape = (1, len(devices)) - mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) - - set_random_seed(6) - + mesh = _create_mesh() linear, lora_linear = _create_column_parallel_packed_layer( repeats, vllm_config, mesh) - with torchax.default_env(): - # lora_linear.weight has type torchax.tensor.Tensor - # BaseLinearLayerWithLoRA.weight property guarantees this. - # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix. - # So the below check will fail. - if len(devices) == 1: - assert torch.equal(linear.weight.data, - lora_linear.weight.to('cpu')) + _verify_lora_linear_layer(linear, lora_linear) + # Create a punica wrapper and associate it with the lora linear layer. max_num_batched_tokens = 8192 max_batches = 256 with torchax.default_env(): @@ -251,6 +241,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: repeats=repeats, ) + # Create inputs and lora mappings. # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64]. # index_mapping: list[int] # prompt_mapping: list[int] @@ -261,35 +252,14 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: input_range=(0, 1), input_type=torch.float16, device='cpu') - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - with torchax.default_env(): - # Here we move the metadata from cpu to tpu. - punica_wrapper.update_metadata( - lora_mapping, - index_to_id, - max_loras, - vocab_size=512, - extra_vocab_size=lora_config.lora_extra_vocab_size, - ) - assert jax_view(punica_wrapper._lora_indices_per_batch).platform( - ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' - assert isinstance( - jax_view(punica_wrapper._lora_indices_per_batch).sharding, - jax.sharding.SingleDeviceSharding - ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + _update_punica_wrapper_metadata(punica_wrapper, index_mapping, + prompt_mapping, stage, index_to_id, + lora_config) - jax_inputs = [] - with torchax.default_env(): - for input in inputs: - # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` - # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` - jax_input = torch_view(t2j(input)) - jax_input.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = lora_linear(torch.cat(jax_inputs))[0] + torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh) + actual_result = lora_linear(torchax_inputs)[0] expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -303,19 +273,19 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: expected_results.append(result) expected_result = torch.cat(expected_results) - rtol, atol = TOLERANCES[lora_result.dtype] + rtol, atol = TOLERANCES[actual_result.dtype] with torchax.default_env(): - lora_result_cpu = lora_result.to('cpu') - torch.testing.assert_close(lora_result_cpu, + actual_result_cpu = actual_result.to('cpu') + torch.testing.assert_close(actual_result_cpu, expected_result, rtol=rtol, atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' - ) + # print( + # f'Output max diff: {torch.max(torch.abs(expected_result - actual_result_cpu))}' + # ) + # print( + # f'Output mean diff: {torch.mean(torch.abs(expected_result - actual_result_cpu))}' + # ) # Check that resetting the lora weights succeeds # Here we set all lora weight to be empty. @@ -329,41 +299,75 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: input_range=(0, 1), input_type=torch.float16, device='cpu') - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - with torchax.default_env(): - punica_wrapper.update_metadata( - lora_mapping, - index_to_id, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) + _update_punica_wrapper_metadata(punica_wrapper, index_mapping, + prompt_mapping, stage, index_to_id, + lora_config) - jax_inputs = [] - with torchax.default_env(): - for input in inputs: - jax_input = torch_view(t2j(input)) - jax_input.apply_jax_(jax.device_put, - NamedSharding(mesh, P(None, None))) - jax_inputs.append(jax_input) with torchax.default_env(): - lora_result = lora_linear(torch.cat(jax_inputs))[0] + torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh) + actual_result = lora_linear(torchax_inputs)[0] expected_result = linear(torch.cat(inputs))[0] - rtol, atol = TOLERANCES[lora_result.dtype] + rtol, atol = TOLERANCES[actual_result.dtype] with torchax.default_env(): - lora_result_cpu = lora_result.to('cpu') - torch.testing.assert_close(lora_result_cpu, + actual_result_cpu = actual_result.to('cpu') + torch.testing.assert_close(actual_result_cpu, expected_result, rtol=rtol, atol=atol) - print( - f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}' - ) - print( - f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}' + + +def _create_mesh(): + axis_names = ("data", "model") + devices = jax.devices() + mesh_shape = (1, len(devices)) + mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) + return mesh + + +def _verify_lora_linear_layer(linear, lora_linear): + with torchax.default_env(): + # lora_linear.weight has type torchax.tensor.Tensor + # BaseLinearLayerWithLoRA.weight property guarantees this. + # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix. + # So the below check will fail. + if len(jax.devices()) == 1: + assert torch.equal(linear.weight.data, + lora_linear.weight.to('cpu')) + + +def _shard_and_move_inputs_to_tpu(inputs, mesh): + processed_inputs = [] + for input in inputs: + # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` + # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + processed_inputs.append(jax_input) + return torch.cat(processed_inputs) + + +def _update_punica_wrapper_metadata(punica_wrapper, index_mapping, + prompt_mapping, stage, index_to_id, + lora_config): + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + with torchax.default_env(): + # Here we move the metadata from cpu to tpu. + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + lora_config.max_loras, + vocab_size=512, + extra_vocab_size=lora_config.lora_extra_vocab_size, ) + assert jax_view(punica_wrapper._lora_indices_per_batch).platform( + ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' + assert isinstance( + jax_view(punica_wrapper._lora_indices_per_batch).sharding, + jax.sharding.SingleDeviceSharding + ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.' def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):