1111from vllm .config import LoRAConfig
1212# yapf conflicts with isort for this block
1313# yapf: disable
14- from vllm .lora .layers import (BaseLayerWithLoRA , LoRAMapping ,
15- MergedColumnParallelLinearWithLoRA )
14+ from vllm .lora .layers import (BaseLayerWithLoRA , ColumnParallelLinearWithLoRA ,
15+ LoRAMapping , MergedColumnParallelLinearWithLoRA ,
16+ MergedQKVParallelLinearWithLoRA ,
17+ QKVParallelLinearWithLoRA ,
18+ ReplicatedLinearWithLoRA ,
19+ RowParallelLinearWithLoRA )
1620# yapf: enable
1721from vllm .lora .models import LoRALayerWeights , PackedLoRALayerWeights
1822from vllm .lora .punica_wrapper import get_punica_wrapper
19- from vllm .model_executor .layers .linear import MergedColumnParallelLinear
23+ from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
24+ MergedColumnParallelLinear ,
25+ QKVParallelLinear ,
26+ ReplicatedLinear ,
27+ RowParallelLinear )
2028from vllm .model_executor .utils import set_random_seed
2129from vllm .platforms import current_platform
2230
@@ -199,7 +207,7 @@ def create_random_inputs(
199207
200208@torch .inference_mode ()
201209@pytest .mark .parametrize ("num_loras" , [1 , 4 , 9 ])
202- @pytest .mark .parametrize ("repeats" , [2 ])
210+ @pytest .mark .parametrize ("repeats" , [1 , 2 , 3 ])
203211@pytest .mark .parametrize ("stage" , [True , False ])
204212def test_column_parallel_packed (dist_init , num_loras , repeats , stage ) -> None :
205213 set_random_seed (6 )
@@ -210,7 +218,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
210218 max_loras = max_loras ,
211219 max_lora_rank = max_lora_rank ,
212220 fully_sharded_loras = False ,
213- lora_dtype = torch .float16 ,
221+ lora_dtype = torch .bfloat16 ,
214222 )
215223 vllm_config = dist_init
216224 vllm_config .lora_config = lora_config
@@ -220,6 +228,17 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
220228 repeats , vllm_config , mesh )
221229 _verify_lora_linear_layer (linear , lora_linear )
222230
231+ # After we create the lora_config, the linear layer and the lora layer,
232+ # here are the steps to do next:
233+ # - create a punica wrapper.
234+ # - associate the punica wrapper with the lora layer.
235+ # - populate the lora matrices in the lora layer: use non-zero values for testing lora and zero values for testing the case where the layer doesn't have lora.
236+ # - create inputs and lora_mapping.
237+ # - update the metadata of the punica wrapper.
238+ # - convert the inputs to be torchax tensors.
239+ # - then run a forward on the lora layer to get the actual output.
240+ # - then run a reference implementation as the expected output.
241+
223242 # Create a punica wrapper and associate it with the lora linear layer.
224243 max_num_batched_tokens = 8192
225244 max_batches = 256
@@ -250,7 +269,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
250269 num_inputs = 32 ,
251270 input_size = (1 , 64 ),
252271 input_range = (0 , 1 ),
253- input_type = torch .float16 ,
272+ input_type = torch .bfloat16 ,
254273 device = 'cpu' )
255274
256275 _update_punica_wrapper_metadata (punica_wrapper , index_mapping ,
@@ -297,7 +316,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
297316 num_inputs = 32 ,
298317 input_size = (1 , 64 ),
299318 input_range = (0 , 1 ),
300- input_type = torch .float16 ,
319+ input_type = torch .bfloat16 ,
301320 device = 'cpu' )
302321
303322 _update_punica_wrapper_metadata (punica_wrapper , index_mapping ,
@@ -318,6 +337,173 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
318337 atol = atol )
319338
320339
340+ @torch .inference_mode ()
341+ @pytest .mark .parametrize ("num_loras" , [1 , 4 , 9 ])
342+ @pytest .mark .parametrize ("layer_type" , ["row" , "column" , "replicated" ])
343+ @pytest .mark .parametrize ("stage" , [True , False ])
344+ def test_linear_parallel (dist_init , num_loras , layer_type , stage ) -> None :
345+ set_random_seed (6 )
346+
347+ max_loras = 9
348+ max_lora_rank = 8
349+ lora_config = LoRAConfig (
350+ max_loras = max_loras ,
351+ max_lora_rank = max_lora_rank ,
352+ fully_sharded_loras = False ,
353+ lora_dtype = torch .bfloat16 ,
354+ )
355+ vllm_config = dist_init
356+ vllm_config .lora_config = lora_config
357+
358+ mesh = _create_mesh ()
359+ linear , lora_linear = _create_random_linear_parallel_layer (
360+ layer_type , vllm_config , mesh )
361+ _verify_lora_linear_layer (linear , lora_linear )
362+
363+ max_num_batched_tokens = 8192
364+ max_batches = 256
365+ with torchax .default_env ():
366+ punica_wrapper = get_punica_wrapper (max_num_batched_tokens ,
367+ max_batches ,
368+ 'jax' ,
369+ max_loras = max_loras )
370+ assert check_punica_wrapper (punica_wrapper )
371+ lora_linear .set_mapping (punica_wrapper )
372+
373+ # Populate lora matrices (lora_a and lora_b) in the lora layer.
374+ index_to_id = get_random_index_to_id (num_loras , max_loras )
375+ # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
376+ lora_dict , sublora_dict = populate_loras (
377+ index_to_id ,
378+ lora_layer = lora_linear ,
379+ baselayer_weights = linear .weight ,
380+ )
381+
382+ inputs , index_mapping , prompt_mapping = create_random_inputs (
383+ active_lora_ids = list (lora_dict .keys ()),
384+ num_inputs = 32 ,
385+ input_size = (1 , 64 ),
386+ input_range = (0 , 1 ),
387+ input_type = torch .bfloat16 ,
388+ device = 'cpu' )
389+
390+ _update_punica_wrapper_metadata (punica_wrapper , index_mapping ,
391+ prompt_mapping , stage , index_to_id ,
392+ lora_config )
393+
394+ with torchax .default_env ():
395+ torchax_inputs = _shard_and_move_inputs_to_tpu (inputs , mesh )
396+ actual_result = lora_linear (torchax_inputs )[0 ]
397+
398+ expected_results : list [torch .Tensor ] = []
399+ for input_ , lora_id in zip (inputs , prompt_mapping ):
400+ result = linear (input_ )[0 ]
401+ lora = lora_dict [lora_id ]
402+ lora_result = input_ @ lora .lora_a .T @ lora .lora_b .T * lora .scaling
403+ result += lora_result
404+ expected_results .append (result )
405+ expected_result = torch .cat (expected_results )
406+
407+ rtol , atol = TOLERANCES [actual_result .dtype ]
408+ with torchax .default_env ():
409+ actual_result_cpu = actual_result .to ('cpu' )
410+ torch .testing .assert_close (actual_result_cpu ,
411+ expected_result ,
412+ rtol = rtol ,
413+ atol = atol )
414+
415+ # Check that resetting the lora weights succeeds
416+ # Here we set all lora weight to be empty.
417+ for slot_idx in range (max_loras ):
418+ lora_linear .reset_lora (slot_idx )
419+
420+ inputs , index_mapping , prompt_mapping = create_random_inputs (
421+ active_lora_ids = [0 ], # different from the above create_random_inputs
422+ num_inputs = 32 ,
423+ input_size = (1 , 64 ),
424+ input_range = (0 , 1 ),
425+ input_type = torch .bfloat16 ,
426+ device = 'cpu' )
427+ _update_punica_wrapper_metadata (punica_wrapper , index_mapping ,
428+ prompt_mapping , stage , index_to_id ,
429+ lora_config )
430+
431+ with torchax .default_env ():
432+ torchax_inputs = _shard_and_move_inputs_to_tpu (inputs , mesh )
433+ actual_result = lora_linear (torchax_inputs )[0 ]
434+ expected_result = linear (torch .cat (inputs ))[0 ]
435+
436+ rtol , atol = TOLERANCES [actual_result .dtype ]
437+ with torchax .default_env ():
438+ actual_result_cpu = actual_result .to ('cpu' )
439+ torch .testing .assert_close (actual_result_cpu ,
440+ expected_result ,
441+ rtol = rtol ,
442+ atol = atol )
443+
444+
445+ def _create_random_linear_parallel_layer (layer_type , vllm_config , mesh ):
446+ # We first create a base linear layer, then a lora layer to wrap it.
447+ if layer_type == "row" :
448+
449+ def _create_row_linear ():
450+ return RowParallelLinear (
451+ 64 , # input_size
452+ 64 , # output_size
453+ bias = False ,
454+ params_dtype = torch .bfloat16 )
455+
456+ linear = _create_row_linear ()
457+ linear .weight .data = torch .rand_like (linear .weight .data )
458+
459+ base_linear = _create_row_linear ()
460+ lora_linear = _create_lora_wrapper (linear ,
461+ base_linear ,
462+ RowParallelLinearWithLoRA ,
463+ vllm_config = vllm_config ,
464+ mesh = mesh )
465+ elif layer_type == "column" :
466+
467+ def _create_column_linear ():
468+ return ColumnParallelLinear (64 ,
469+ 64 ,
470+ bias = False ,
471+ params_dtype = torch .bfloat16 )
472+
473+ linear = _create_column_linear ()
474+ linear .weight .data = torch .rand_like (linear .weight .data )
475+
476+ base_linear = _create_column_linear ()
477+ lora_linear = _create_lora_wrapper (linear ,
478+ base_linear ,
479+ ColumnParallelLinearWithLoRA ,
480+ vllm_config = vllm_config ,
481+ mesh = mesh )
482+
483+ elif layer_type == "replicated" :
484+
485+ def _create_replicated_linear ():
486+ return ReplicatedLinear (64 ,
487+ 64 ,
488+ bias = False ,
489+ params_dtype = torch .bfloat16 )
490+
491+ linear = _create_replicated_linear ()
492+ linear .weight .data = torch .rand_like (linear .weight .data )
493+
494+ base_linear = _create_replicated_linear ()
495+ lora_linear = _create_lora_wrapper (linear ,
496+ base_linear ,
497+ ReplicatedLinearWithLoRA ,
498+ vllm_config = vllm_config ,
499+ mesh = mesh )
500+
501+ else :
502+ raise NotImplementedError ("Unknown layer type: {}" .format (layer_type ))
503+
504+ return linear , lora_linear
505+
506+
321507def _create_mesh ():
322508 axis_names = ("data" , "model" )
323509 devices = jax .devices ()
@@ -374,37 +560,75 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
374560 # We first create a base linear layer, then a lora layer to wrap it.
375561 if repeats == 2 :
376562 # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
377- linear = MergedColumnParallelLinear (
378- 64 , # input_size
379- [64 ] * repeats , # output_size
380- bias = False ,
381- params_dtype = torch .float16 )
563+ def _create_merged_column_linear ():
564+ return MergedColumnParallelLinear (
565+ 64 , # input_size
566+ [64 ] * repeats , # output_size
567+ bias = False ,
568+ params_dtype = torch .bfloat16 )
569+
570+ linear = _create_merged_column_linear ()
382571 linear .weight .data = torch .rand_like (linear .weight .data )
383572
384- base_linear = MergedColumnParallelLinear (
385- 64 , # input_size
386- [64 ] * repeats , # output_size
387- bias = False ,
388- params_dtype = torch .float16 )
389- base_linear .weight .data = linear .weight .data
390- jax_config = JaxCommonLinearConfig (vllm_config , mesh , base_linear )
391- linear_method = VllmUnquantizedLinearMethod (jax_config )
392- base_linear .quant_method = linear_method
393- linear_method .process_weights_after_loading (
394- base_linear
395- ) # here base_linear.weight is moved to TPU and sharded.
396- assert jax_view (base_linear .weight ).platform (
397- ) == 'tpu' , 'base_linear.weight should have been moved to TPU.'
398- assert not isinstance (
399- jax_view (
400- base_linear .weight ).sharding , jax .sharding .SingleDeviceSharding
401- ), 'base_linear.weight should have been sharded.'
402-
403- lora_linear = MergedColumnParallelLinearWithLoRA (base_linear )
573+ base_linear = _create_merged_column_linear ()
574+ lora_linear = _create_lora_wrapper (linear , base_linear ,
575+ MergedColumnParallelLinearWithLoRA ,
576+ vllm_config , mesh , repeats )
404577 elif repeats == 3 :
405- raise NotImplementedError ("NYI: for MergedQKVParallelLinear case" )
578+
579+ def _create_qkv_linear ():
580+ return QKVParallelLinear (64 ,
581+ 64 ,
582+ 32 ,
583+ bias = False ,
584+ params_dtype = torch .bfloat16 )
585+
586+ linear = _create_qkv_linear ()
587+ linear .weight .data = torch .rand_like (linear .weight .data )
588+
589+ base_linear = _create_qkv_linear ()
590+ lora_linear = _create_lora_wrapper (linear , base_linear ,
591+ MergedQKVParallelLinearWithLoRA ,
592+ vllm_config , mesh , repeats )
406593 else :
407- raise NotImplementedError ("NYI: for QKVParallelLinear case" )
594+
595+ def _create_qkv_linear ():
596+ return QKVParallelLinear (64 ,
597+ 64 ,
598+ 32 ,
599+ bias = False ,
600+ params_dtype = torch .bfloat16 )
601+
602+ linear = _create_qkv_linear ()
603+ linear .weight .data = torch .rand_like (linear .weight .data )
604+
605+ base_linear = _create_qkv_linear ()
606+ lora_linear = _create_lora_wrapper (linear , base_linear ,
607+ QKVParallelLinearWithLoRA ,
608+ vllm_config , mesh , repeats )
609+
610+ return linear , lora_linear
611+
612+
613+ def _create_lora_wrapper (linear ,
614+ base_linear ,
615+ lora_cls ,
616+ vllm_config ,
617+ mesh ,
618+ repeats = 1 ):
619+ base_linear .weight .data = linear .weight .data
620+ jax_config = JaxCommonLinearConfig (vllm_config , mesh , base_linear )
621+ linear_method = VllmUnquantizedLinearMethod (jax_config )
622+ base_linear .quant_method = linear_method
623+ linear_method .process_weights_after_loading (
624+ base_linear ) # here base_linear.weight is moved to TPU and sharded.
625+ assert jax_view (base_linear .weight ).platform (
626+ ) == 'tpu' , 'base_linear.weight should have been moved to TPU.'
627+ assert not isinstance (
628+ jax_view (base_linear .weight ).sharding , jax .sharding .
629+ SingleDeviceSharding ), 'base_linear.weight should have been sharded.'
630+
631+ lora_linear = lora_cls (base_linear )
408632
409633 lora_config = vllm_config .lora_config
410634 max_loras = lora_config .max_loras
@@ -427,4 +651,4 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
427651 assert (lora_linear .n_slices == len (lora_linear .lora_a_stacked ) == len (
428652 lora_linear .lora_b_stacked ) == n_slices )
429653
430- return linear , lora_linear
654+ return lora_linear
0 commit comments