11import random
22from typing import Optional
3+ import copy
34
45import jax
6+ import numpy as np
57import pytest
68import torch
79import torchax
2729 JaxCommonConfig , JaxCommonLinearConfig )
2830from tpu_inference .layers .vllm .quantization .unquantized import \
2931 VllmUnquantizedLinearMethod
30- from tpu_inference .layers .vllm .sharding import shard_model_to_tpu , _shard_merged_column_parallel_linear_lora
32+ from tpu_inference .layers .vllm .sharding import shard_model_to_tpu , _shard_merged_column_parallel_linear_lora , _shard_module_to_tpu
3133
3234# TODO(xiowei):
3335# - add test for multi-chip.
34- # - add equivalent test for ColumnParallelLinearWithShardedLoRA.
3536
3637P = PartitionSpec
3738
@@ -56,7 +57,7 @@ def check_punica_wrapper(punica_wrapper) -> bool:
5657def get_random_index_to_id (num_loras : int ,
5758 num_slots : int ,
5859 log : bool = True ) -> list [Optional [int ]]:
59- """Creates a random index_to_lora_id mapping.
60+ """Creates a random index_to_lora_id mapping: slot[index] = lora_id .
6061
6162 Args:
6263 num_loras: The number of active loras in the mapping.
@@ -76,7 +77,7 @@ def get_random_index_to_id(num_loras: int,
7677 slots : list [Optional [int ]] = [None ] * num_slots
7778 random_slot_selections = (torch .randperm (num_slots )[:num_loras ]).tolist ()
7879 for lora_id , slot_idx in enumerate (random_slot_selections , start = 1 ):
79- # xw32: It seems the slot_idx start at 1.
80+ # The slot_idx start at 1.
8081 slots [slot_idx ] = lora_id
8182
8283 if log :
@@ -92,7 +93,7 @@ def populate_loras(
9293 generate_embeddings_tensor : int = 0 ,
9394 repeats : int = 1 ,
9495) -> tuple [dict [int , LoRALayerWeights ], dict [int , list [LoRALayerWeights ]]]:
95- """This method populates the lora layers (BaseLayerWithLoRA) with lora weights .
96+ """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA) .
9697
9798 Args:
9899 index_to_id: a list of lora ids. The index of the lora id
@@ -140,8 +141,7 @@ def populate_loras(
140141 subloras ) if repeats > 1 else subloras [0 ]
141142
142143 # Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env.
143- with torchax .default_env (), jax .default_device (
144- jax .devices ("tpu" )[0 ]):
144+ with torchax .default_env ():
145145 lora_layer .set_lora (
146146 slot_idx ,
147147 lora_a = lora .lora_a ,
@@ -207,11 +207,10 @@ def create_random_inputs(
207207@torch .inference_mode ()
208208@pytest .mark .parametrize ("num_loras" , [1 , 2 , 4 , 9 ])
209209@pytest .mark .parametrize ("repeats" , [2 ])
210- @pytest .mark .parametrize ("fully_shard" , [False ]) # TODO(xiowei): add "True".
211- @pytest .mark .parametrize ("device" , ["cpu" ])
210+ @pytest .mark .parametrize ("fully_shard" , [False ])
212211@pytest .mark .parametrize ("stage" , [True , False ])
213212def test_column_parallel_packed (dist_init , num_loras , repeats , fully_shard ,
214- device , stage ) -> None :
213+ stage ) -> None :
215214 max_loras = 9
216215 max_lora_rank = 8
217216 lora_config = LoRAConfig (
@@ -228,9 +227,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
228227 mesh = jax .make_mesh (mesh_shape , axis_names , devices = jax .devices ())
229228
230229 def create_column_parallel_packed_layer ():
231- # Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper .
230+ # We first create a base linear layer, then a lora layer to wrap it .
232231 if repeats == 2 :
233- # In e2e, MergedColumnParallelLinear is created when we load th e model. The weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
232+ # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
234233 linear = MergedColumnParallelLinear (
235234 256 , # input_size
236235 [256 ] * repeats , # output_size
@@ -245,50 +244,42 @@ def create_column_parallel_packed_layer():
245244 params_dtype = torch .float16 )
246245 base_linear .weight .data = linear .weight .data
247246 vllm_config = dist_init
248- # self.jax_config.mesh.devices[0][0].platform
249247 jax_config = JaxCommonLinearConfig (vllm_config , mesh , base_linear )
250248 linear_method = VllmUnquantizedLinearMethod (jax_config )
251249 base_linear .quant_method = linear_method
252- linear_method .process_weights_after_loading (base_linear )
253- # here base_linear.weight is on TPU and sharded.
250+ linear_method .process_weights_after_loading (base_linear ) # here base_linear.weight is moved to TPU and sharded.
254251 assert jax_view (base_linear .weight ).platform () == 'tpu' , 'base_linear.weight should have been moved to TPU.'
255252 assert not isinstance (jax_view (base_linear .weight ).sharding , jax .sharding .SingleDeviceSharding ), 'base_linear.weight should have been sharded.'
256253
257- # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
258254 lora_linear = MergedColumnParallelLinearWithLoRA (
259255 base_linear
260256 )
261257 elif repeats == 3 :
262- # TODO(xiowei): add test for this case.
263258 raise NotImplementedError ("NYI: for MergedQKVParallelLinear case" )
264259 else :
265- # TODO(xiowei): add test for this case.
266260 raise NotImplementedError ("NYI: for QKVParallelLinear case" )
267261
268- n_slices = repeats
269- #TODO(xw): check if we can enable torchax globally.
270- # TODO(xw): check if we can calculate both actual and expected output using torchax.
271262 with torchax .default_env ():
272- # create_lora_weights creates global shape weight.
263+ # create_lora_weights creates global shape lora weight.
273264 lora_linear .create_lora_weights (max_loras , lora_config )
274- _shard_merged_column_parallel_linear_lora (lora_linear , mesh )
275- # TODO: assert the lora_a_stacked is on TPU and sharded.
265+ # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
266+ _shard_module_to_tpu (lora_linear , mesh )
267+
276268 assert jax_view (lora_linear .lora_a_stacked [0 ]).platform () == 'tpu' , 'lora_a_stacked should have been moved to TPU.'
277269 assert not isinstance (jax_view (lora_linear .lora_a_stacked [0 ]).sharding , jax .sharding .SingleDeviceSharding ), 'lora_a_stacked should have been sharded.'
278270 assert jax_view (lora_linear .lora_b_stacked [0 ]).platform () == 'tpu' , 'lora_b_stacked should have been moved to TPU.'
279271 assert not isinstance (jax_view (lora_linear .lora_b_stacked [0 ]).sharding , jax .sharding .SingleDeviceSharding ), 'lora_b_stacked should have been sharded.'
280-
281- # TODO: assert the lora_b_stacked is on TPU and sharded.
272+ n_slices = repeats
282273 assert (lora_linear .n_slices == len (lora_linear .lora_a_stacked ) == len (
283274 lora_linear .lora_b_stacked ) == n_slices )
284275
285- return linear , lora_linear , linear_method
276+ return linear , lora_linear
286277
287278 set_random_seed (6 )
288279
289- linear , lora_linear , linear_method = create_column_parallel_packed_layer ()
280+ linear , lora_linear = create_column_parallel_packed_layer ()
290281 with torchax .default_env ():
291- # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor
282+ # lora_linear.weight has type torchax.tensor.Tensor
292283 # BaseLinearLayerWithLoRA.weight property guarantees this.
293284 assert torch .equal (linear .weight , lora_linear .weight .to ('cpu' ))
294285
@@ -302,8 +293,7 @@ def create_column_parallel_packed_layer():
302293 assert check_punica_wrapper (punica_wrapper )
303294 lora_linear .set_mapping (punica_wrapper )
304295
305- # load the lora weight, shard it, and send it to TPU.
306- # create a lora slot index to lora id mapping.
296+ # Populate lora matrices (lora_a and lora_b) in the lora layer.
307297 index_to_id = get_random_index_to_id (num_loras , max_loras )
308298 # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
309299 lora_dict , sublora_dict = populate_loras (
@@ -322,10 +312,11 @@ def create_column_parallel_packed_layer():
322312 input_size = (1 , 256 ),
323313 input_range = (0 , 1 ),
324314 input_type = torch .float16 ,
325- device = device )
315+ device = 'cpu' )
326316 lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
327317
328- with torchax .default_env (), jax .default_device (jax .devices ("tpu" )[0 ]):
318+ with torchax .default_env ():
319+ # Here we move the metadata from cpu to tpu.
329320 punica_wrapper .update_metadata (
330321 lora_mapping ,
331322 index_to_id ,
@@ -337,17 +328,14 @@ def create_column_parallel_packed_layer():
337328 assert isinstance (jax_view (punica_wrapper ._lora_indices_per_batch ).sharding , jax .sharding .SingleDeviceSharding ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
338329
339330 jax_inputs = []
340- with torchax .default_env (), jax . default_device ( jax . devices ( "tpu" )[ 0 ]) :
331+ with torchax .default_env ():
341332 for input in inputs :
342333 # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
343334 # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
344335 jax_input = torch_view (t2j (input ))
345336 jax_input .apply_jax_ (jax .device_put , NamedSharding (mesh , P (None , None )))
346337 jax_inputs .append (jax_input )
347338 with torchax .default_env ():
348- # lora_result = lora_linear(torch.cat(jax_inputs))[0]
349- # lora_result = j2t(lora_result)
350- # lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs))
351339 lora_result = lora_linear (torch .cat (jax_inputs ))[0 ]
352340
353341 expected_results : list [torch .Tensor ] = []
@@ -387,10 +375,10 @@ def create_column_parallel_packed_layer():
387375 input_size = (1 , 256 ),
388376 input_range = (0 , 1 ),
389377 input_type = torch .float16 ,
390- device = device )
378+ device = 'cpu' )
391379 lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
392380
393- with torchax .default_env (), jax . default_device ( jax . devices ( "tpu" )[ 0 ]) :
381+ with torchax .default_env ():
394382 punica_wrapper .update_metadata (
395383 lora_mapping ,
396384 index_to_id ,
@@ -400,16 +388,14 @@ def create_column_parallel_packed_layer():
400388 )
401389
402390 jax_inputs = []
403- with torchax .default_env (), jax . default_device ( jax . devices ( "tpu" )[ 0 ]) :
391+ with torchax .default_env ():
404392 for input in inputs :
405393 jax_input = torch_view (t2j (input ))
406394 jax_input .apply_jax_ (jax .device_put ,
407395 NamedSharding (mesh , P (None , None )))
408396 jax_inputs .append (jax_input )
409397 with torchax .default_env ():
410- lora_result = linear_method .apply (lora_linear .base_layer , torch .cat (jax_inputs ))
411- # lora_result = lora_linear(torch.cat(jax_inputs))[0]
412- # lora_result = j2t(lora_result)
398+ lora_result = lora_linear (torch .cat (jax_inputs ))[0 ]
413399 expected_result = linear (torch .cat (inputs ))[0 ]
414400
415401 rtol , atol = TOLERANCES [lora_result .dtype ]
0 commit comments