Skip to content

Commit 11c7ea2

Browse files
committed
ok, the test passed. Need to make it simpler next.
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent 8044fb7 commit 11c7ea2

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

tests/lora/test_layers.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def create_column_parallel_packed_layer():
248248
# self.jax_config.mesh.devices[0][0].platform
249249
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
250250
linear_method = VllmUnquantizedLinearMethod(jax_config)
251+
base_linear.quant_method=linear_method
251252
linear_method.process_weights_after_loading(base_linear)
252253
# here base_linear.weight is on TPU and sharded.
253254

@@ -263,6 +264,8 @@ def create_column_parallel_packed_layer():
263264
raise NotImplementedError("NYI: for QKVParallelLinear case")
264265

265266
n_slices = repeats
267+
#TODO(xw): check if we can enable torchax globally.
268+
# TODO(xw): check if we can calculate both actual and expected output using torchax.
266269
with torchax.default_env():
267270
# create_lora_weights creates global shape weight.
268271
lora_linear.create_lora_weights(max_loras, lora_config)
@@ -282,10 +285,11 @@ def create_column_parallel_packed_layer():
282285

283286
max_num_batched_tokens = 8192
284287
max_batches = 256
285-
punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
286-
max_batches,
287-
device,
288-
max_loras=max_loras)
288+
with torchax.default_env():
289+
punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
290+
max_batches,
291+
'jax',
292+
max_loras=max_loras)
289293
assert check_punica_wrapper(punica_wrapper)
290294
lora_linear.set_mapping(punica_wrapper)
291295

@@ -333,7 +337,8 @@ def create_column_parallel_packed_layer():
333337
with torchax.default_env():
334338
# lora_result = lora_linear(torch.cat(jax_inputs))[0]
335339
# lora_result = j2t(lora_result)
336-
lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs))
340+
# lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs))
341+
lora_result = lora_linear(torch.cat(jax_inputs))[0]
337342

338343
expected_results: list[torch.Tensor] = []
339344
for input_, lora_id in zip(inputs, prompt_mapping):
@@ -348,17 +353,18 @@ def create_column_parallel_packed_layer():
348353
expected_result = torch.cat(expected_results)
349354

350355
rtol, atol = TOLERANCES[lora_result.dtype]
351-
# with torchax.default_env():
352-
# torch.testing.assert_close(lora_result.to('cpu'),
353-
# expected_result,
354-
# rtol=rtol,
355-
# atol=atol)
356-
# print(
357-
# f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}'
358-
# )
359-
# print(
360-
# f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}'
361-
# )
356+
with torchax.default_env():
357+
lora_result_cpu = lora_result.to('cpu')
358+
torch.testing.assert_close(lora_result_cpu,
359+
expected_result,
360+
rtol=rtol,
361+
atol=atol)
362+
print(
363+
f'Output max diff: {torch.max(torch.abs(expected_result - lora_result_cpu))}'
364+
)
365+
print(
366+
f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}'
367+
)
362368

363369
# Check that resetting the lora weights succeeds
364370
# Here we set all lora weight to be empty.

tpu_inference/lora/torch_punica_tpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
2323
PunicaWrapperTPU is designed to manage and provide metadata for the punica
2424
kernel. The main function is to maintain the state information for
2525
Multi-LoRA, and to provide the interface for the pytorch punica ops.
26+
27+
It is created by get_punica_wrapper when we load_lora_model->create_lora_manager. Device is TPU.
2628
"""
2729

2830
def __init__(self, max_num_batched_tokens: int, max_batches: int,

0 commit comments

Comments
 (0)