11import random
22from typing import Optional
3- import copy
43
54import jax
6- import numpy as np
75import pytest
86import torch
97import torchax
108from jax .sharding import NamedSharding , PartitionSpec
11- from torchax .interop import torch_view
12- from torchax .ops .mappings import j2t , t2j
139from torchax .interop import jax_view , torch_view
10+ from torchax .ops .mappings import t2j
1411# from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu
1512from vllm .config import LoRAConfig
1613# yapf conflicts with isort for this block
2421from vllm .model_executor .utils import set_random_seed
2522from vllm .platforms import current_platform
2623
27- from .utils import DummyLoRAManager
28- from tpu_inference .layers .vllm .quantization .common import (
29- JaxCommonConfig , JaxCommonLinearConfig )
24+ from tpu_inference .layers .vllm .quantization .common import JaxCommonLinearConfig
3025from tpu_inference .layers .vllm .quantization .unquantized import \
3126 VllmUnquantizedLinearMethod
32- from tpu_inference .layers .vllm .sharding import shard_model_to_tpu , _shard_merged_column_parallel_linear_lora , _shard_module_to_tpu
27+ from tpu_inference .layers .vllm .sharding import _shard_module_to_tpu
28+
29+ from .utils import DummyLoRAManager
3330
3431# TODO(xiowei):
3532# - add test for multi-chip.
@@ -239,7 +236,7 @@ def create_column_parallel_packed_layer():
239236 bias = False ,
240237 params_dtype = torch .float16 )
241238 linear .weight .data = torch .rand_like (linear .weight .data )
242-
239+
243240 base_linear = MergedColumnParallelLinear (
244241 64 , # input_size
245242 [64 ] * repeats , # output_size
@@ -249,14 +246,18 @@ def create_column_parallel_packed_layer():
249246 vllm_config = dist_init
250247 jax_config = JaxCommonLinearConfig (vllm_config , mesh , base_linear )
251248 linear_method = VllmUnquantizedLinearMethod (jax_config )
252- base_linear .quant_method = linear_method
253- linear_method .process_weights_after_loading (base_linear ) # here base_linear.weight is moved to TPU and sharded.
254- assert jax_view (base_linear .weight ).platform () == 'tpu' , 'base_linear.weight should have been moved to TPU.'
255- assert not isinstance (jax_view (base_linear .weight ).sharding , jax .sharding .SingleDeviceSharding ), 'base_linear.weight should have been sharded.'
256-
257- lora_linear = MergedColumnParallelLinearWithLoRA (
249+ base_linear .quant_method = linear_method
250+ linear_method .process_weights_after_loading (
258251 base_linear
259- )
252+ ) # here base_linear.weight is moved to TPU and sharded.
253+ assert jax_view (base_linear .weight ).platform (
254+ ) == 'tpu' , 'base_linear.weight should have been moved to TPU.'
255+ assert not isinstance (
256+ jax_view (base_linear .weight ).sharding ,
257+ jax .sharding .SingleDeviceSharding
258+ ), 'base_linear.weight should have been sharded.'
259+
260+ lora_linear = MergedColumnParallelLinearWithLoRA (base_linear )
260261 elif repeats == 3 :
261262 raise NotImplementedError ("NYI: for MergedQKVParallelLinear case" )
262263 else :
@@ -268,10 +269,16 @@ def create_column_parallel_packed_layer():
268269 # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
269270 _shard_module_to_tpu (lora_linear , mesh )
270271
271- assert jax_view (lora_linear .lora_a_stacked [0 ]).platform () == 'tpu' , 'lora_a_stacked should have been moved to TPU.'
272- assert not isinstance (jax_view (lora_linear .lora_a_stacked [0 ]).sharding , jax .sharding .SingleDeviceSharding ), 'lora_a_stacked should have been sharded.'
273- assert jax_view (lora_linear .lora_b_stacked [0 ]).platform () == 'tpu' , 'lora_b_stacked should have been moved to TPU.'
274- assert not isinstance (jax_view (lora_linear .lora_b_stacked [0 ]).sharding , jax .sharding .SingleDeviceSharding ), 'lora_b_stacked should have been sharded.'
272+ assert jax_view (lora_linear .lora_a_stacked [0 ]).platform (
273+ ) == 'tpu' , 'lora_a_stacked should have been moved to TPU.'
274+ assert not isinstance (
275+ jax_view (lora_linear .lora_a_stacked [0 ]).sharding , jax .sharding .
276+ SingleDeviceSharding ), 'lora_a_stacked should have been sharded.'
277+ assert jax_view (lora_linear .lora_b_stacked [0 ]).platform (
278+ ) == 'tpu' , 'lora_b_stacked should have been moved to TPU.'
279+ assert not isinstance (
280+ jax_view (lora_linear .lora_b_stacked [0 ]).sharding , jax .sharding .
281+ SingleDeviceSharding ), 'lora_b_stacked should have been sharded.'
275282 n_slices = repeats
276283 assert (lora_linear .n_slices == len (lora_linear .lora_a_stacked ) == len (
277284 lora_linear .lora_b_stacked ) == n_slices )
@@ -287,7 +294,8 @@ def create_column_parallel_packed_layer():
287294 # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
288295 # So the below check will fail.
289296 if len (devices ) == 1 :
290- assert torch .equal (linear .weight .data , lora_linear .weight .to ('cpu' ))
297+ assert torch .equal (linear .weight .data ,
298+ lora_linear .weight .to ('cpu' ))
291299
292300 max_num_batched_tokens = 8192
293301 max_batches = 256
@@ -330,16 +338,21 @@ def create_column_parallel_packed_layer():
330338 vocab_size = 512 ,
331339 extra_vocab_size = lora_config .lora_extra_vocab_size ,
332340 )
333- assert jax_view (punica_wrapper ._lora_indices_per_batch ).platform () == 'tpu' , 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
334- assert isinstance (jax_view (punica_wrapper ._lora_indices_per_batch ).sharding , jax .sharding .SingleDeviceSharding ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
341+ assert jax_view (punica_wrapper ._lora_indices_per_batch ).platform (
342+ ) == 'tpu' , 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
343+ assert isinstance (
344+ jax_view (punica_wrapper ._lora_indices_per_batch ).sharding ,
345+ jax .sharding .SingleDeviceSharding
346+ ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
335347
336348 jax_inputs = []
337349 with torchax .default_env ():
338350 for input in inputs :
339351 # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
340352 # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
341353 jax_input = torch_view (t2j (input ))
342- jax_input .apply_jax_ (jax .device_put , NamedSharding (mesh , P (None , None )))
354+ jax_input .apply_jax_ (jax .device_put ,
355+ NamedSharding (mesh , P (None , None )))
343356 jax_inputs .append (jax_input )
344357 with torchax .default_env ():
345358 lora_result = lora_linear (torch .cat (jax_inputs ))[0 ]
0 commit comments