Skip to content

Commit 2d87418

Browse files
committed
cleaned up
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent 2eae008 commit 2d87418

File tree

1 file changed

+29
-43
lines changed

1 file changed

+29
-43
lines changed

tests/lora/test_layers.py

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import random
22
from typing import Optional
3+
import copy
34

45
import jax
6+
import numpy as np
57
import pytest
68
import torch
79
import torchax
@@ -27,11 +29,10 @@
2729
JaxCommonConfig, JaxCommonLinearConfig)
2830
from tpu_inference.layers.vllm.quantization.unquantized import \
2931
VllmUnquantizedLinearMethod
30-
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora
32+
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora, _shard_module_to_tpu
3133

3234
# TODO(xiowei):
3335
# - add test for multi-chip.
34-
# - add equivalent test for ColumnParallelLinearWithShardedLoRA.
3536

3637
P = PartitionSpec
3738

@@ -56,7 +57,7 @@ def check_punica_wrapper(punica_wrapper) -> bool:
5657
def get_random_index_to_id(num_loras: int,
5758
num_slots: int,
5859
log: bool = True) -> list[Optional[int]]:
59-
"""Creates a random index_to_lora_id mapping.
60+
"""Creates a random index_to_lora_id mapping: slot[index] = lora_id.
6061
6162
Args:
6263
num_loras: The number of active loras in the mapping.
@@ -76,7 +77,7 @@ def get_random_index_to_id(num_loras: int,
7677
slots: list[Optional[int]] = [None] * num_slots
7778
random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
7879
for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
79-
# xw32: It seems the slot_idx start at 1.
80+
# The slot_idx start at 1.
8081
slots[slot_idx] = lora_id
8182

8283
if log:
@@ -92,7 +93,7 @@ def populate_loras(
9293
generate_embeddings_tensor: int = 0,
9394
repeats: int = 1,
9495
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
95-
"""This method populates the lora layers (BaseLayerWithLoRA) with lora weights.
96+
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
9697
9798
Args:
9899
index_to_id: a list of lora ids. The index of the lora id
@@ -140,8 +141,7 @@ def populate_loras(
140141
subloras) if repeats > 1 else subloras[0]
141142

142143
# Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env.
143-
with torchax.default_env(), jax.default_device(
144-
jax.devices("tpu")[0]):
144+
with torchax.default_env():
145145
lora_layer.set_lora(
146146
slot_idx,
147147
lora_a=lora.lora_a,
@@ -207,11 +207,10 @@ def create_random_inputs(
207207
@torch.inference_mode()
208208
@pytest.mark.parametrize("num_loras", [1, 2, 4, 9])
209209
@pytest.mark.parametrize("repeats", [2])
210-
@pytest.mark.parametrize("fully_shard", [False]) # TODO(xiowei): add "True".
211-
@pytest.mark.parametrize("device", ["cpu"])
210+
@pytest.mark.parametrize("fully_shard", [False])
212211
@pytest.mark.parametrize("stage", [True, False])
213212
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
214-
device, stage) -> None:
213+
stage) -> None:
215214
max_loras = 9
216215
max_lora_rank = 8
217216
lora_config = LoRAConfig(
@@ -228,9 +227,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
228227
mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices())
229228

230229
def create_column_parallel_packed_layer():
231-
# Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper.
230+
# We first create a base linear layer, then a lora layer to wrap it.
232231
if repeats == 2:
233-
# In e2e, MergedColumnParallelLinear is created when we load th e model. The weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
232+
# In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
234233
linear = MergedColumnParallelLinear(
235234
256, # input_size
236235
[256] * repeats, # output_size
@@ -245,50 +244,42 @@ def create_column_parallel_packed_layer():
245244
params_dtype=torch.float16)
246245
base_linear.weight.data = linear.weight.data
247246
vllm_config = dist_init
248-
# self.jax_config.mesh.devices[0][0].platform
249247
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
250248
linear_method = VllmUnquantizedLinearMethod(jax_config)
251249
base_linear.quant_method=linear_method
252-
linear_method.process_weights_after_loading(base_linear)
253-
# here base_linear.weight is on TPU and sharded.
250+
linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is moved to TPU and sharded.
254251
assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.'
255252
assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.'
256253

257-
# In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
258254
lora_linear = MergedColumnParallelLinearWithLoRA(
259255
base_linear
260256
)
261257
elif repeats == 3:
262-
# TODO(xiowei): add test for this case.
263258
raise NotImplementedError("NYI: for MergedQKVParallelLinear case")
264259
else:
265-
# TODO(xiowei): add test for this case.
266260
raise NotImplementedError("NYI: for QKVParallelLinear case")
267261

268-
n_slices = repeats
269-
#TODO(xw): check if we can enable torchax globally.
270-
# TODO(xw): check if we can calculate both actual and expected output using torchax.
271262
with torchax.default_env():
272-
# create_lora_weights creates global shape weight.
263+
# create_lora_weights creates global shape lora weight.
273264
lora_linear.create_lora_weights(max_loras, lora_config)
274-
_shard_merged_column_parallel_linear_lora(lora_linear, mesh)
275-
# TODO: assert the lora_a_stacked is on TPU and sharded.
265+
# In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
266+
_shard_module_to_tpu(lora_linear, mesh)
267+
276268
assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.'
277269
assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
278270
assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.'
279271
assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
280-
281-
# TODO: assert the lora_b_stacked is on TPU and sharded.
272+
n_slices = repeats
282273
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
283274
lora_linear.lora_b_stacked) == n_slices)
284275

285-
return linear, lora_linear, linear_method
276+
return linear, lora_linear
286277

287278
set_random_seed(6)
288279

289-
linear, lora_linear, linear_method = create_column_parallel_packed_layer()
280+
linear, lora_linear = create_column_parallel_packed_layer()
290281
with torchax.default_env():
291-
# linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor
282+
# lora_linear.weight has type torchax.tensor.Tensor
292283
# BaseLinearLayerWithLoRA.weight property guarantees this.
293284
assert torch.equal(linear.weight, lora_linear.weight.to('cpu'))
294285

@@ -302,8 +293,7 @@ def create_column_parallel_packed_layer():
302293
assert check_punica_wrapper(punica_wrapper)
303294
lora_linear.set_mapping(punica_wrapper)
304295

305-
# load the lora weight, shard it, and send it to TPU.
306-
# create a lora slot index to lora id mapping.
296+
# Populate lora matrices (lora_a and lora_b) in the lora layer.
307297
index_to_id = get_random_index_to_id(num_loras, max_loras)
308298
# lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
309299
lora_dict, sublora_dict = populate_loras(
@@ -322,10 +312,11 @@ def create_column_parallel_packed_layer():
322312
input_size=(1, 256),
323313
input_range=(0, 1),
324314
input_type=torch.float16,
325-
device=device)
315+
device='cpu')
326316
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
327317

328-
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
318+
with torchax.default_env():
319+
# Here we move the metadata from cpu to tpu.
329320
punica_wrapper.update_metadata(
330321
lora_mapping,
331322
index_to_id,
@@ -337,17 +328,14 @@ def create_column_parallel_packed_layer():
337328
assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
338329

339330
jax_inputs = []
340-
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
331+
with torchax.default_env():
341332
for input in inputs:
342333
# without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
343334
# without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
344335
jax_input = torch_view(t2j(input))
345336
jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
346337
jax_inputs.append(jax_input)
347338
with torchax.default_env():
348-
# lora_result = lora_linear(torch.cat(jax_inputs))[0]
349-
# lora_result = j2t(lora_result)
350-
# lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs))
351339
lora_result = lora_linear(torch.cat(jax_inputs))[0]
352340

353341
expected_results: list[torch.Tensor] = []
@@ -387,10 +375,10 @@ def create_column_parallel_packed_layer():
387375
input_size=(1, 256),
388376
input_range=(0, 1),
389377
input_type=torch.float16,
390-
device=device)
378+
device='cpu')
391379
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
392380

393-
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
381+
with torchax.default_env():
394382
punica_wrapper.update_metadata(
395383
lora_mapping,
396384
index_to_id,
@@ -400,16 +388,14 @@ def create_column_parallel_packed_layer():
400388
)
401389

402390
jax_inputs = []
403-
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
391+
with torchax.default_env():
404392
for input in inputs:
405393
jax_input = torch_view(t2j(input))
406394
jax_input.apply_jax_(jax.device_put,
407395
NamedSharding(mesh, P(None, None)))
408396
jax_inputs.append(jax_input)
409397
with torchax.default_env():
410-
lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs))
411-
# lora_result = lora_linear(torch.cat(jax_inputs))[0]
412-
# lora_result = j2t(lora_result)
398+
lora_result = lora_linear(torch.cat(jax_inputs))[0]
413399
expected_result = linear(torch.cat(inputs))[0]
414400

415401
rtol, atol = TOLERANCES[lora_result.dtype]

0 commit comments

Comments
 (0)