@@ -202,6 +202,8 @@ def create_random_inputs(
202202@pytest .mark .parametrize ("repeats" , [2 ])
203203@pytest .mark .parametrize ("stage" , [True , False ])
204204def test_column_parallel_packed (dist_init , num_loras , repeats , stage ) -> None :
205+ set_random_seed (6 )
206+
205207 max_loras = 9
206208 max_lora_rank = 8
207209 lora_config = LoRAConfig (
@@ -213,24 +215,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
213215 vllm_config = dist_init
214216 vllm_config .lora_config = lora_config
215217
216- axis_names = ("data" , "model" )
217- devices = jax .devices ()
218- mesh_shape = (1 , len (devices ))
219- mesh = jax .make_mesh (mesh_shape , axis_names , devices = devices )
220-
221- set_random_seed (6 )
222-
218+ mesh = _create_mesh ()
223219 linear , lora_linear = _create_column_parallel_packed_layer (
224220 repeats , vllm_config , mesh )
225- with torchax .default_env ():
226- # lora_linear.weight has type torchax.tensor.Tensor
227- # BaseLinearLayerWithLoRA.weight property guarantees this.
228- # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
229- # So the below check will fail.
230- if len (devices ) == 1 :
231- assert torch .equal (linear .weight .data ,
232- lora_linear .weight .to ('cpu' ))
221+ _verify_lora_linear_layer (linear , lora_linear )
233222
223+ # Create a punica wrapper and associate it with the lora linear layer.
234224 max_num_batched_tokens = 8192
235225 max_batches = 256
236226 with torchax .default_env ():
@@ -251,6 +241,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
251241 repeats = repeats ,
252242 )
253243
244+ # Create inputs and lora mappings.
254245 # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64].
255246 # index_mapping: list[int]
256247 # prompt_mapping: list[int]
@@ -261,35 +252,14 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
261252 input_range = (0 , 1 ),
262253 input_type = torch .float16 ,
263254 device = 'cpu' )
264- lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
265255
266- with torchax .default_env ():
267- # Here we move the metadata from cpu to tpu.
268- punica_wrapper .update_metadata (
269- lora_mapping ,
270- index_to_id ,
271- max_loras ,
272- vocab_size = 512 ,
273- extra_vocab_size = lora_config .lora_extra_vocab_size ,
274- )
275- assert jax_view (punica_wrapper ._lora_indices_per_batch ).platform (
276- ) == 'tpu' , 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
277- assert isinstance (
278- jax_view (punica_wrapper ._lora_indices_per_batch ).sharding ,
279- jax .sharding .SingleDeviceSharding
280- ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
256+ _update_punica_wrapper_metadata (punica_wrapper , index_mapping ,
257+ prompt_mapping , stage , index_to_id ,
258+ lora_config )
281259
282- jax_inputs = []
283- with torchax .default_env ():
284- for input in inputs :
285- # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
286- # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
287- jax_input = torch_view (t2j (input ))
288- jax_input .apply_jax_ (jax .device_put ,
289- NamedSharding (mesh , P (None , None )))
290- jax_inputs .append (jax_input )
291260 with torchax .default_env ():
292- lora_result = lora_linear (torch .cat (jax_inputs ))[0 ]
261+ torchax_inputs = _shard_and_move_inputs_to_tpu (inputs , mesh )
262+ actual_result = lora_linear (torchax_inputs )[0 ]
293263
294264 expected_results : list [torch .Tensor ] = []
295265 for input_ , lora_id in zip (inputs , prompt_mapping ):
@@ -303,19 +273,19 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
303273 expected_results .append (result )
304274 expected_result = torch .cat (expected_results )
305275
306- rtol , atol = TOLERANCES [lora_result .dtype ]
276+ rtol , atol = TOLERANCES [actual_result .dtype ]
307277 with torchax .default_env ():
308- lora_result_cpu = lora_result .to ('cpu' )
309- torch .testing .assert_close (lora_result_cpu ,
278+ actual_result_cpu = actual_result .to ('cpu' )
279+ torch .testing .assert_close (actual_result_cpu ,
310280 expected_result ,
311281 rtol = rtol ,
312282 atol = atol )
313- print (
314- f'Output max diff: { torch .max (torch .abs (expected_result - lora_result_cpu ))} '
315- )
316- print (
317- f'Output mean diff: { torch .mean (torch .abs (expected_result - lora_result_cpu ))} '
318- )
283+ # print(
284+ # f'Output max diff: {torch.max(torch.abs(expected_result - actual_result_cpu ))}'
285+ # )
286+ # print(
287+ # f'Output mean diff: {torch.mean(torch.abs(expected_result - actual_result_cpu ))}'
288+ # )
319289
320290 # Check that resetting the lora weights succeeds
321291 # Here we set all lora weight to be empty.
@@ -329,41 +299,75 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
329299 input_range = (0 , 1 ),
330300 input_type = torch .float16 ,
331301 device = 'cpu' )
332- lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
333302
334- with torchax .default_env ():
335- punica_wrapper .update_metadata (
336- lora_mapping ,
337- index_to_id ,
338- max_loras ,
339- 512 ,
340- lora_config .lora_extra_vocab_size ,
341- )
303+ _update_punica_wrapper_metadata (punica_wrapper , index_mapping ,
304+ prompt_mapping , stage , index_to_id ,
305+ lora_config )
342306
343- jax_inputs = []
344- with torchax .default_env ():
345- for input in inputs :
346- jax_input = torch_view (t2j (input ))
347- jax_input .apply_jax_ (jax .device_put ,
348- NamedSharding (mesh , P (None , None )))
349- jax_inputs .append (jax_input )
350307 with torchax .default_env ():
351- lora_result = lora_linear (torch .cat (jax_inputs ))[0 ]
308+ torchax_inputs = _shard_and_move_inputs_to_tpu (inputs , mesh )
309+ actual_result = lora_linear (torchax_inputs )[0 ]
352310 expected_result = linear (torch .cat (inputs ))[0 ]
353311
354- rtol , atol = TOLERANCES [lora_result .dtype ]
312+ rtol , atol = TOLERANCES [actual_result .dtype ]
355313 with torchax .default_env ():
356- lora_result_cpu = lora_result .to ('cpu' )
357- torch .testing .assert_close (lora_result_cpu ,
314+ actual_result_cpu = actual_result .to ('cpu' )
315+ torch .testing .assert_close (actual_result_cpu ,
358316 expected_result ,
359317 rtol = rtol ,
360318 atol = atol )
361- print (
362- f'Output max diff: { torch .max (torch .abs (expected_result - lora_result_cpu ))} '
363- )
364- print (
365- f'Output mean diff: { torch .mean (torch .abs (expected_result - lora_result_cpu ))} '
319+
320+
321+ def _create_mesh ():
322+ axis_names = ("data" , "model" )
323+ devices = jax .devices ()
324+ mesh_shape = (1 , len (devices ))
325+ mesh = jax .make_mesh (mesh_shape , axis_names , devices = devices )
326+ return mesh
327+
328+
329+ def _verify_lora_linear_layer (linear , lora_linear ):
330+ with torchax .default_env ():
331+ # lora_linear.weight has type torchax.tensor.Tensor
332+ # BaseLinearLayerWithLoRA.weight property guarantees this.
333+ # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
334+ # So the below check will fail.
335+ if len (jax .devices ()) == 1 :
336+ assert torch .equal (linear .weight .data ,
337+ lora_linear .weight .to ('cpu' ))
338+
339+
340+ def _shard_and_move_inputs_to_tpu (inputs , mesh ):
341+ processed_inputs = []
342+ for input in inputs :
343+ # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
344+ # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
345+ jax_input = torch_view (t2j (input ))
346+ jax_input .apply_jax_ (jax .device_put ,
347+ NamedSharding (mesh , P (None , None )))
348+ processed_inputs .append (jax_input )
349+ return torch .cat (processed_inputs )
350+
351+
352+ def _update_punica_wrapper_metadata (punica_wrapper , index_mapping ,
353+ prompt_mapping , stage , index_to_id ,
354+ lora_config ):
355+ lora_mapping = LoRAMapping (index_mapping , prompt_mapping , is_prefill = stage )
356+ with torchax .default_env ():
357+ # Here we move the metadata from cpu to tpu.
358+ punica_wrapper .update_metadata (
359+ lora_mapping ,
360+ index_to_id ,
361+ lora_config .max_loras ,
362+ vocab_size = 512 ,
363+ extra_vocab_size = lora_config .lora_extra_vocab_size ,
366364 )
365+ assert jax_view (punica_wrapper ._lora_indices_per_batch ).platform (
366+ ) == 'tpu' , 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
367+ assert isinstance (
368+ jax_view (punica_wrapper ._lora_indices_per_batch ).sharding ,
369+ jax .sharding .SingleDeviceSharding
370+ ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
367371
368372
369373def _create_column_parallel_packed_layer (repeats , vllm_config , mesh ):
0 commit comments