@@ -210,73 +210,18 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
210210 fully_sharded_loras = False ,
211211 lora_dtype = torch .float16 ,
212212 )
213+ vllm_config = dist_init
214+ vllm_config .lora_config = lora_config
213215
214216 axis_names = ("data" , "model" )
215217 devices = jax .devices ()
216218 mesh_shape = (1 , len (devices ))
217219 mesh = jax .make_mesh (mesh_shape , axis_names , devices = devices )
218220
219- def create_column_parallel_packed_layer ():
220- # We first create a base linear layer, then a lora layer to wrap it.
221- if repeats == 2 :
222- # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
223- linear = MergedColumnParallelLinear (
224- 64 , # input_size
225- [64 ] * repeats , # output_size
226- bias = False ,
227- params_dtype = torch .float16 )
228- linear .weight .data = torch .rand_like (linear .weight .data )
229-
230- base_linear = MergedColumnParallelLinear (
231- 64 , # input_size
232- [64 ] * repeats , # output_size
233- bias = False ,
234- params_dtype = torch .float16 )
235- base_linear .weight .data = linear .weight .data
236- vllm_config = dist_init
237- jax_config = JaxCommonLinearConfig (vllm_config , mesh , base_linear )
238- linear_method = VllmUnquantizedLinearMethod (jax_config )
239- base_linear .quant_method = linear_method
240- linear_method .process_weights_after_loading (
241- base_linear
242- ) # here base_linear.weight is moved to TPU and sharded.
243- assert jax_view (base_linear .weight ).platform (
244- ) == 'tpu' , 'base_linear.weight should have been moved to TPU.'
245- assert not isinstance (
246- jax_view (base_linear .weight ).sharding ,
247- jax .sharding .SingleDeviceSharding
248- ), 'base_linear.weight should have been sharded.'
249-
250- lora_linear = MergedColumnParallelLinearWithLoRA (base_linear )
251- elif repeats == 3 :
252- raise NotImplementedError ("NYI: for MergedQKVParallelLinear case" )
253- else :
254- raise NotImplementedError ("NYI: for QKVParallelLinear case" )
255-
256- with torchax .default_env ():
257- lora_linear .create_lora_weights (max_loras , lora_config )
258- # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
259- _shard_module_to_tpu (lora_linear , mesh )
260-
261- assert jax_view (lora_linear .lora_a_stacked [0 ]).platform (
262- ) == 'tpu' , 'lora_a_stacked should have been moved to TPU.'
263- assert not isinstance (
264- jax_view (lora_linear .lora_a_stacked [0 ]).sharding , jax .sharding .
265- SingleDeviceSharding ), 'lora_a_stacked should have been sharded.'
266- assert jax_view (lora_linear .lora_b_stacked [0 ]).platform (
267- ) == 'tpu' , 'lora_b_stacked should have been moved to TPU.'
268- assert not isinstance (
269- jax_view (lora_linear .lora_b_stacked [0 ]).sharding , jax .sharding .
270- SingleDeviceSharding ), 'lora_b_stacked should have been sharded.'
271- n_slices = repeats
272- assert (lora_linear .n_slices == len (lora_linear .lora_a_stacked ) == len (
273- lora_linear .lora_b_stacked ) == n_slices )
274-
275- return linear , lora_linear
276-
277221 set_random_seed (6 )
278222
279- linear , lora_linear = create_column_parallel_packed_layer ()
223+ linear , lora_linear = _create_column_parallel_packed_layer (
224+ repeats , vllm_config , mesh )
280225 with torchax .default_env ():
281226 # lora_linear.weight has type torchax.tensor.Tensor
282227 # BaseLinearLayerWithLoRA.weight property guarantees this.
@@ -419,3 +364,63 @@ def create_column_parallel_packed_layer():
419364 print (
420365 f'Output mean diff: { torch .mean (torch .abs (expected_result - lora_result_cpu ))} '
421366 )
367+
368+
369+ def _create_column_parallel_packed_layer (repeats , vllm_config , mesh ):
370+ # We first create a base linear layer, then a lora layer to wrap it.
371+ if repeats == 2 :
372+ # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
373+ linear = MergedColumnParallelLinear (
374+ 64 , # input_size
375+ [64 ] * repeats , # output_size
376+ bias = False ,
377+ params_dtype = torch .float16 )
378+ linear .weight .data = torch .rand_like (linear .weight .data )
379+
380+ base_linear = MergedColumnParallelLinear (
381+ 64 , # input_size
382+ [64 ] * repeats , # output_size
383+ bias = False ,
384+ params_dtype = torch .float16 )
385+ base_linear .weight .data = linear .weight .data
386+ jax_config = JaxCommonLinearConfig (vllm_config , mesh , base_linear )
387+ linear_method = VllmUnquantizedLinearMethod (jax_config )
388+ base_linear .quant_method = linear_method
389+ linear_method .process_weights_after_loading (
390+ base_linear
391+ ) # here base_linear.weight is moved to TPU and sharded.
392+ assert jax_view (base_linear .weight ).platform (
393+ ) == 'tpu' , 'base_linear.weight should have been moved to TPU.'
394+ assert not isinstance (
395+ jax_view (
396+ base_linear .weight ).sharding , jax .sharding .SingleDeviceSharding
397+ ), 'base_linear.weight should have been sharded.'
398+
399+ lora_linear = MergedColumnParallelLinearWithLoRA (base_linear )
400+ elif repeats == 3 :
401+ raise NotImplementedError ("NYI: for MergedQKVParallelLinear case" )
402+ else :
403+ raise NotImplementedError ("NYI: for QKVParallelLinear case" )
404+
405+ lora_config = vllm_config .lora_config
406+ max_loras = lora_config .max_loras
407+ with torchax .default_env ():
408+ lora_linear .create_lora_weights (max_loras , lora_config )
409+ # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
410+ _shard_module_to_tpu (lora_linear , mesh )
411+
412+ assert jax_view (lora_linear .lora_a_stacked [0 ]).platform (
413+ ) == 'tpu' , 'lora_a_stacked should have been moved to TPU.'
414+ assert not isinstance (
415+ jax_view (lora_linear .lora_a_stacked [0 ]).sharding , jax .sharding .
416+ SingleDeviceSharding ), 'lora_a_stacked should have been sharded.'
417+ assert jax_view (lora_linear .lora_b_stacked [0 ]).platform (
418+ ) == 'tpu' , 'lora_b_stacked should have been moved to TPU.'
419+ assert not isinstance (
420+ jax_view (lora_linear .lora_b_stacked [0 ]).sharding , jax .sharding .
421+ SingleDeviceSharding ), 'lora_b_stacked should have been sharded.'
422+ n_slices = repeats
423+ assert (lora_linear .n_slices == len (lora_linear .lora_a_stacked ) == len (
424+ lora_linear .lora_b_stacked ) == n_slices )
425+
426+ return linear , lora_linear
0 commit comments