Skip to content

Commit 3e3f039

Browse files
authored
Add more lora wrapper unit tests (#1036)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent 22db072 commit 3e3f039

File tree

2 files changed

+279
-40
lines changed

2 files changed

+279
-40
lines changed

tests/lora/test_layers.py

Lines changed: 259 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,20 @@
1111
from vllm.config import LoRAConfig
1212
# yapf conflicts with isort for this block
1313
# yapf: disable
14-
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping,
15-
MergedColumnParallelLinearWithLoRA)
14+
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
15+
LoRAMapping, MergedColumnParallelLinearWithLoRA,
16+
MergedQKVParallelLinearWithLoRA,
17+
QKVParallelLinearWithLoRA,
18+
ReplicatedLinearWithLoRA,
19+
RowParallelLinearWithLoRA)
1620
# yapf: enable
1721
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
1822
from vllm.lora.punica_wrapper import get_punica_wrapper
19-
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
23+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
24+
MergedColumnParallelLinear,
25+
QKVParallelLinear,
26+
ReplicatedLinear,
27+
RowParallelLinear)
2028
from vllm.model_executor.utils import set_random_seed
2129
from vllm.platforms import current_platform
2230

@@ -199,7 +207,7 @@ def create_random_inputs(
199207

200208
@torch.inference_mode()
201209
@pytest.mark.parametrize("num_loras", [1, 4, 9])
202-
@pytest.mark.parametrize("repeats", [2])
210+
@pytest.mark.parametrize("repeats", [1, 2, 3])
203211
@pytest.mark.parametrize("stage", [True, False])
204212
def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
205213
set_random_seed(6)
@@ -210,7 +218,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
210218
max_loras=max_loras,
211219
max_lora_rank=max_lora_rank,
212220
fully_sharded_loras=False,
213-
lora_dtype=torch.float16,
221+
lora_dtype=torch.bfloat16,
214222
)
215223
vllm_config = dist_init
216224
vllm_config.lora_config = lora_config
@@ -220,6 +228,17 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
220228
repeats, vllm_config, mesh)
221229
_verify_lora_linear_layer(linear, lora_linear)
222230

231+
# After we create the lora_config, the linear layer and the lora layer,
232+
# here are the steps to do next:
233+
# - create a punica wrapper.
234+
# - associate the punica wrapper with the lora layer.
235+
# - populate the lora matrices in the lora layer: use non-zero values for testing lora and zero values for testing the case where the layer doesn't have lora.
236+
# - create inputs and lora_mapping.
237+
# - update the metadata of the punica wrapper.
238+
# - convert the inputs to be torchax tensors.
239+
# - then run a forward on the lora layer to get the actual output.
240+
# - then run a reference implementation as the expected output.
241+
223242
# Create a punica wrapper and associate it with the lora linear layer.
224243
max_num_batched_tokens = 8192
225244
max_batches = 256
@@ -250,7 +269,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
250269
num_inputs=32,
251270
input_size=(1, 64),
252271
input_range=(0, 1),
253-
input_type=torch.float16,
272+
input_type=torch.bfloat16,
254273
device='cpu')
255274

256275
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
@@ -297,7 +316,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
297316
num_inputs=32,
298317
input_size=(1, 64),
299318
input_range=(0, 1),
300-
input_type=torch.float16,
319+
input_type=torch.bfloat16,
301320
device='cpu')
302321

303322
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
@@ -318,6 +337,173 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
318337
atol=atol)
319338

320339

340+
@torch.inference_mode()
341+
@pytest.mark.parametrize("num_loras", [1, 4, 9])
342+
@pytest.mark.parametrize("layer_type", ["row", "column", "replicated"])
343+
@pytest.mark.parametrize("stage", [True, False])
344+
def test_linear_parallel(dist_init, num_loras, layer_type, stage) -> None:
345+
set_random_seed(6)
346+
347+
max_loras = 9
348+
max_lora_rank = 8
349+
lora_config = LoRAConfig(
350+
max_loras=max_loras,
351+
max_lora_rank=max_lora_rank,
352+
fully_sharded_loras=False,
353+
lora_dtype=torch.bfloat16,
354+
)
355+
vllm_config = dist_init
356+
vllm_config.lora_config = lora_config
357+
358+
mesh = _create_mesh()
359+
linear, lora_linear = _create_random_linear_parallel_layer(
360+
layer_type, vllm_config, mesh)
361+
_verify_lora_linear_layer(linear, lora_linear)
362+
363+
max_num_batched_tokens = 8192
364+
max_batches = 256
365+
with torchax.default_env():
366+
punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
367+
max_batches,
368+
'jax',
369+
max_loras=max_loras)
370+
assert check_punica_wrapper(punica_wrapper)
371+
lora_linear.set_mapping(punica_wrapper)
372+
373+
# Populate lora matrices (lora_a and lora_b) in the lora layer.
374+
index_to_id = get_random_index_to_id(num_loras, max_loras)
375+
# lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
376+
lora_dict, sublora_dict = populate_loras(
377+
index_to_id,
378+
lora_layer=lora_linear,
379+
baselayer_weights=linear.weight,
380+
)
381+
382+
inputs, index_mapping, prompt_mapping = create_random_inputs(
383+
active_lora_ids=list(lora_dict.keys()),
384+
num_inputs=32,
385+
input_size=(1, 64),
386+
input_range=(0, 1),
387+
input_type=torch.bfloat16,
388+
device='cpu')
389+
390+
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
391+
prompt_mapping, stage, index_to_id,
392+
lora_config)
393+
394+
with torchax.default_env():
395+
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
396+
actual_result = lora_linear(torchax_inputs)[0]
397+
398+
expected_results: list[torch.Tensor] = []
399+
for input_, lora_id in zip(inputs, prompt_mapping):
400+
result = linear(input_)[0]
401+
lora = lora_dict[lora_id]
402+
lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
403+
result += lora_result
404+
expected_results.append(result)
405+
expected_result = torch.cat(expected_results)
406+
407+
rtol, atol = TOLERANCES[actual_result.dtype]
408+
with torchax.default_env():
409+
actual_result_cpu = actual_result.to('cpu')
410+
torch.testing.assert_close(actual_result_cpu,
411+
expected_result,
412+
rtol=rtol,
413+
atol=atol)
414+
415+
# Check that resetting the lora weights succeeds
416+
# Here we set all lora weight to be empty.
417+
for slot_idx in range(max_loras):
418+
lora_linear.reset_lora(slot_idx)
419+
420+
inputs, index_mapping, prompt_mapping = create_random_inputs(
421+
active_lora_ids=[0], # different from the above create_random_inputs
422+
num_inputs=32,
423+
input_size=(1, 64),
424+
input_range=(0, 1),
425+
input_type=torch.bfloat16,
426+
device='cpu')
427+
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
428+
prompt_mapping, stage, index_to_id,
429+
lora_config)
430+
431+
with torchax.default_env():
432+
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
433+
actual_result = lora_linear(torchax_inputs)[0]
434+
expected_result = linear(torch.cat(inputs))[0]
435+
436+
rtol, atol = TOLERANCES[actual_result.dtype]
437+
with torchax.default_env():
438+
actual_result_cpu = actual_result.to('cpu')
439+
torch.testing.assert_close(actual_result_cpu,
440+
expected_result,
441+
rtol=rtol,
442+
atol=atol)
443+
444+
445+
def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
446+
# We first create a base linear layer, then a lora layer to wrap it.
447+
if layer_type == "row":
448+
449+
def _create_row_linear():
450+
return RowParallelLinear(
451+
64, # input_size
452+
64, # output_size
453+
bias=False,
454+
params_dtype=torch.bfloat16)
455+
456+
linear = _create_row_linear()
457+
linear.weight.data = torch.rand_like(linear.weight.data)
458+
459+
base_linear = _create_row_linear()
460+
lora_linear = _create_lora_wrapper(linear,
461+
base_linear,
462+
RowParallelLinearWithLoRA,
463+
vllm_config=vllm_config,
464+
mesh=mesh)
465+
elif layer_type == "column":
466+
467+
def _create_column_linear():
468+
return ColumnParallelLinear(64,
469+
64,
470+
bias=False,
471+
params_dtype=torch.bfloat16)
472+
473+
linear = _create_column_linear()
474+
linear.weight.data = torch.rand_like(linear.weight.data)
475+
476+
base_linear = _create_column_linear()
477+
lora_linear = _create_lora_wrapper(linear,
478+
base_linear,
479+
ColumnParallelLinearWithLoRA,
480+
vllm_config=vllm_config,
481+
mesh=mesh)
482+
483+
elif layer_type == "replicated":
484+
485+
def _create_replicated_linear():
486+
return ReplicatedLinear(64,
487+
64,
488+
bias=False,
489+
params_dtype=torch.bfloat16)
490+
491+
linear = _create_replicated_linear()
492+
linear.weight.data = torch.rand_like(linear.weight.data)
493+
494+
base_linear = _create_replicated_linear()
495+
lora_linear = _create_lora_wrapper(linear,
496+
base_linear,
497+
ReplicatedLinearWithLoRA,
498+
vllm_config=vllm_config,
499+
mesh=mesh)
500+
501+
else:
502+
raise NotImplementedError("Unknown layer type: {}".format(layer_type))
503+
504+
return linear, lora_linear
505+
506+
321507
def _create_mesh():
322508
axis_names = ("data", "model")
323509
devices = jax.devices()
@@ -374,37 +560,75 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
374560
# We first create a base linear layer, then a lora layer to wrap it.
375561
if repeats == 2:
376562
# In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
377-
linear = MergedColumnParallelLinear(
378-
64, # input_size
379-
[64] * repeats, # output_size
380-
bias=False,
381-
params_dtype=torch.float16)
563+
def _create_merged_column_linear():
564+
return MergedColumnParallelLinear(
565+
64, # input_size
566+
[64] * repeats, # output_size
567+
bias=False,
568+
params_dtype=torch.bfloat16)
569+
570+
linear = _create_merged_column_linear()
382571
linear.weight.data = torch.rand_like(linear.weight.data)
383572

384-
base_linear = MergedColumnParallelLinear(
385-
64, # input_size
386-
[64] * repeats, # output_size
387-
bias=False,
388-
params_dtype=torch.float16)
389-
base_linear.weight.data = linear.weight.data
390-
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
391-
linear_method = VllmUnquantizedLinearMethod(jax_config)
392-
base_linear.quant_method = linear_method
393-
linear_method.process_weights_after_loading(
394-
base_linear
395-
) # here base_linear.weight is moved to TPU and sharded.
396-
assert jax_view(base_linear.weight).platform(
397-
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
398-
assert not isinstance(
399-
jax_view(
400-
base_linear.weight).sharding, jax.sharding.SingleDeviceSharding
401-
), 'base_linear.weight should have been sharded.'
402-
403-
lora_linear = MergedColumnParallelLinearWithLoRA(base_linear)
573+
base_linear = _create_merged_column_linear()
574+
lora_linear = _create_lora_wrapper(linear, base_linear,
575+
MergedColumnParallelLinearWithLoRA,
576+
vllm_config, mesh, repeats)
404577
elif repeats == 3:
405-
raise NotImplementedError("NYI: for MergedQKVParallelLinear case")
578+
579+
def _create_qkv_linear():
580+
return QKVParallelLinear(64,
581+
64,
582+
32,
583+
bias=False,
584+
params_dtype=torch.bfloat16)
585+
586+
linear = _create_qkv_linear()
587+
linear.weight.data = torch.rand_like(linear.weight.data)
588+
589+
base_linear = _create_qkv_linear()
590+
lora_linear = _create_lora_wrapper(linear, base_linear,
591+
MergedQKVParallelLinearWithLoRA,
592+
vllm_config, mesh, repeats)
406593
else:
407-
raise NotImplementedError("NYI: for QKVParallelLinear case")
594+
595+
def _create_qkv_linear():
596+
return QKVParallelLinear(64,
597+
64,
598+
32,
599+
bias=False,
600+
params_dtype=torch.bfloat16)
601+
602+
linear = _create_qkv_linear()
603+
linear.weight.data = torch.rand_like(linear.weight.data)
604+
605+
base_linear = _create_qkv_linear()
606+
lora_linear = _create_lora_wrapper(linear, base_linear,
607+
QKVParallelLinearWithLoRA,
608+
vllm_config, mesh, repeats)
609+
610+
return linear, lora_linear
611+
612+
613+
def _create_lora_wrapper(linear,
614+
base_linear,
615+
lora_cls,
616+
vllm_config,
617+
mesh,
618+
repeats=1):
619+
base_linear.weight.data = linear.weight.data
620+
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
621+
linear_method = VllmUnquantizedLinearMethod(jax_config)
622+
base_linear.quant_method = linear_method
623+
linear_method.process_weights_after_loading(
624+
base_linear) # here base_linear.weight is moved to TPU and sharded.
625+
assert jax_view(base_linear.weight).platform(
626+
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
627+
assert not isinstance(
628+
jax_view(base_linear.weight).sharding, jax.sharding.
629+
SingleDeviceSharding), 'base_linear.weight should have been sharded.'
630+
631+
lora_linear = lora_cls(base_linear)
408632

409633
lora_config = vllm_config.lora_config
410634
max_loras = lora_config.max_loras
@@ -427,4 +651,4 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
427651
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
428652
lora_linear.lora_b_stacked) == n_slices)
429653

430-
return linear, lora_linear
654+
return lora_linear

0 commit comments

Comments
 (0)