Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 225 additions & 35 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
from vllm.config import LoRAConfig
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA)
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LoRAMapping, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
# yapf: enable
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform

Expand Down Expand Up @@ -199,7 +205,7 @@ def create_random_inputs(

@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 4, 9])
@pytest.mark.parametrize("repeats", [2])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("stage", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
set_random_seed(6)
Expand All @@ -210,7 +216,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
max_loras=max_loras,
max_lora_rank=max_lora_rank,
fully_sharded_loras=False,
lora_dtype=torch.float16,
lora_dtype=torch.bfloat16,
)
vllm_config = dist_init
vllm_config.lora_config = lora_config
Expand Down Expand Up @@ -250,7 +256,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
num_inputs=32,
input_size=(1, 64),
input_range=(0, 1),
input_type=torch.float16,
input_type=torch.bfloat16,
device='cpu')

_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
Expand Down Expand Up @@ -297,7 +303,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
num_inputs=32,
input_size=(1, 64),
input_range=(0, 1),
input_type=torch.float16,
input_type=torch.bfloat16,
device='cpu')

_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
Expand All @@ -318,6 +324,152 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 4, 9])
@pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("stage", [True, False])
def test_linear_parallel(dist_init, num_loras, orientation, stage) -> None:
set_random_seed(6)

max_loras = 9
max_lora_rank = 8
lora_config = LoRAConfig(
max_loras=max_loras,
max_lora_rank=max_lora_rank,
fully_sharded_loras=False,
lora_dtype=torch.bfloat16,
)
vllm_config = dist_init
vllm_config.lora_config = lora_config

mesh = _create_mesh()
linear, lora_linear = _create_random_linear_parallel_layer(
orientation, vllm_config, mesh)
_verify_lora_linear_layer(linear, lora_linear)

max_num_batched_tokens = 8192
max_batches = 256
with torchax.default_env():
punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
max_batches,
'jax',
max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_linear.set_mapping(punica_wrapper)

# Populate lora matrices (lora_a and lora_b) in the lora layer.
index_to_id = get_random_index_to_id(num_loras, max_loras)
# lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
lora_dict, sublora_dict = populate_loras(
index_to_id,
lora_layer=lora_linear,
baselayer_weights=linear.weight,
)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32,
input_size=(1, 64),
input_range=(0, 1),
input_type=torch.bfloat16,
device='cpu')

_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
prompt_mapping, stage, index_to_id,
lora_config)

with torchax.default_env():
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
actual_result = lora_linear(torchax_inputs)[0]

expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
result = linear(input_)[0]
lora = lora_dict[lora_id]
lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
result += lora_result
expected_results.append(result)
expected_result = torch.cat(expected_results)

rtol, atol = TOLERANCES[actual_result.dtype]
with torchax.default_env():
actual_result_cpu = actual_result.to('cpu')
torch.testing.assert_close(actual_result_cpu,
expected_result,
rtol=rtol,
atol=atol)

# Check that resetting the lora weights succeeds
# Here we set all lora weight to be empty.
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0], # different from the above create_random_inputs
num_inputs=32,
input_size=(1, 64),
input_range=(0, 1),
input_type=torch.bfloat16,
device='cpu')
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
prompt_mapping, stage, index_to_id,
lora_config)

with torchax.default_env():
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
actual_result = lora_linear(torchax_inputs)[0]
expected_result = linear(torch.cat(inputs))[0]

rtol, atol = TOLERANCES[actual_result.dtype]
with torchax.default_env():
actual_result_cpu = actual_result.to('cpu')
torch.testing.assert_close(actual_result_cpu,
expected_result,
rtol=rtol,
atol=atol)


def _create_random_linear_parallel_layer(orientation, vllm_config, mesh):
# We first create a base linear layer, then a lora layer to wrap it.
if orientation == "row":

def _create_row_linear():
return RowParallelLinear(
64, # input_size
64, # output_size
bias=False,
params_dtype=torch.bfloat16)

linear = _create_row_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_row_linear()
lora_linear = _create_lora_wrapper(linear,
base_linear,
RowParallelLinearWithLoRA,
vllm_config=vllm_config,
mesh=mesh)
else:

def _create_column_linear():
return ColumnParallelLinear(64,
64,
bias=False,
params_dtype=torch.bfloat16)

linear = _create_column_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_column_linear()
lora_linear = _create_lora_wrapper(linear,
base_linear,
ColumnParallelLinearWithLoRA,
vllm_config=vllm_config,
mesh=mesh)

return linear, lora_linear


def _create_mesh():
axis_names = ("data", "model")
devices = jax.devices()
Expand Down Expand Up @@ -374,37 +526,75 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
# We first create a base linear layer, then a lora layer to wrap it.
if repeats == 2:
# In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
linear = MergedColumnParallelLinear(
64, # input_size
[64] * repeats, # output_size
bias=False,
params_dtype=torch.float16)
def _create_merged_column_linear():
return MergedColumnParallelLinear(
64, # input_size
[64] * repeats, # output_size
bias=False,
params_dtype=torch.bfloat16)

linear = _create_merged_column_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = MergedColumnParallelLinear(
64, # input_size
[64] * repeats, # output_size
bias=False,
params_dtype=torch.float16)
base_linear.weight.data = linear.weight.data
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
linear_method = VllmUnquantizedLinearMethod(jax_config)
base_linear.quant_method = linear_method
linear_method.process_weights_after_loading(
base_linear
) # here base_linear.weight is moved to TPU and sharded.
assert jax_view(base_linear.weight).platform(
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
assert not isinstance(
jax_view(
base_linear.weight).sharding, jax.sharding.SingleDeviceSharding
), 'base_linear.weight should have been sharded.'

lora_linear = MergedColumnParallelLinearWithLoRA(base_linear)
base_linear = _create_merged_column_linear()
lora_linear = _create_lora_wrapper(linear, base_linear,
MergedColumnParallelLinearWithLoRA,
vllm_config, mesh, repeats)
elif repeats == 3:
raise NotImplementedError("NYI: for MergedQKVParallelLinear case")

def _create_qkv_linear():
return QKVParallelLinear(64,
64,
32,
bias=False,
params_dtype=torch.bfloat16)

linear = _create_qkv_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_qkv_linear()
lora_linear = _create_lora_wrapper(linear, base_linear,
MergedQKVParallelLinearWithLoRA,
vllm_config, mesh, repeats)
else:
raise NotImplementedError("NYI: for QKVParallelLinear case")

def _create_qkv_linear():
return QKVParallelLinear(64,
64,
32,
bias=False,
params_dtype=torch.bfloat16)

linear = _create_qkv_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_qkv_linear()
lora_linear = _create_lora_wrapper(linear, base_linear,
QKVParallelLinearWithLoRA,
vllm_config, mesh, repeats)

return linear, lora_linear


def _create_lora_wrapper(linear,
base_linear,
lora_cls,
vllm_config,
mesh,
repeats=1):
base_linear.weight.data = linear.weight.data
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
linear_method = VllmUnquantizedLinearMethod(jax_config)
base_linear.quant_method = linear_method
linear_method.process_weights_after_loading(
base_linear) # here base_linear.weight is moved to TPU and sharded.
assert jax_view(base_linear.weight).platform(
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
assert not isinstance(
jax_view(base_linear.weight).sharding, jax.sharding.
SingleDeviceSharding), 'base_linear.weight should have been sharded.'

lora_linear = lora_cls(base_linear)

lora_config = vllm_config.lora_config
max_loras = lora_config.max_loras
Expand All @@ -427,4 +617,4 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == n_slices)

return linear, lora_linear
return lora_linear
24 changes: 19 additions & 5 deletions tpu_inference/layers/vllm/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from torch.nn import Parameter
from torch.utils import _pytree as pytree
from torchax.interop import torch_view
from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA,
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -127,9 +129,8 @@ def _shard_base_linear_lora_replicated(layer: BaseLinearLayerWithLoRA,
layer.lora_b_stacked = sharded_lora_b_tpu


# TODO: Add custom sharding logic for following lora layers
def _shard_merged_column_parallel_linear_lora(
layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None:
def _shard_column_linear_lora(layer: ColumnParallelLinearWithLoRA,
mesh: Mesh) -> None:
assert layer.n_slices > 0, "layer.n_slices should be greater than 0"
# lora_a_stacked[i] has shape [max_loras, 1, max_lora_rank, in_features]
sharded_lora_a_tpu = torch.nn.ParameterList()
Expand All @@ -150,9 +151,20 @@ def _shard_merged_column_parallel_linear_lora(
layer.lora_b_stacked = sharded_lora_b_tpu


# TODO: Add custom sharding logic for following lora layers
def _shard_qkv_linear_lora(layer: ColumnParallelLinearWithLoRA,
mesh: Mesh) -> None:
_shard_column_linear_lora(layer, mesh)


def _shard_merged_column_parallel_linear_lora(
layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None:
_shard_column_linear_lora(layer, mesh)


def _shard_merged_qkv_parallel_linear_lora(
layer: MergedQKVParallelLinearWithLoRA, mesh: Mesh) -> None:
_shard_merged_column_parallel_linear_lora(layer, mesh)
_shard_column_linear_lora(layer, mesh)


def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
Expand All @@ -166,6 +178,8 @@ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
(ParallelLMHead, _shard_lm_head),
(VocabParallelEmbedding, _shard_vocab_parallel_embedding),
# Shard LoRA layers
(ColumnParallelLinearWithLoRA, _shard_column_linear_lora),
(QKVParallelLinearWithLoRA, _shard_qkv_linear_lora),
(MergedColumnParallelLinearWithLoRA,
_shard_merged_column_parallel_linear_lora),
(MergedQKVParallelLinearWithLoRA, _shard_merged_qkv_parallel_linear_lora),
Expand Down