Skip to content

Commit c2082ff

Browse files
committed
fix the format
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent cc47bab commit c2082ff

File tree

3 files changed

+39
-28
lines changed

3 files changed

+39
-28
lines changed

tests/lora/test_layers.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import random
22
from typing import Optional
3-
import copy
43

54
import jax
6-
import numpy as np
75
import pytest
86
import torch
97
import torchax
108
from jax.sharding import NamedSharding, PartitionSpec
11-
from torchax.interop import torch_view
12-
from torchax.ops.mappings import j2t, t2j
139
from torchax.interop import jax_view, torch_view
10+
from torchax.ops.mappings import t2j
1411
# from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu
1512
from vllm.config import LoRAConfig
1613
# yapf conflicts with isort for this block
@@ -24,12 +21,12 @@
2421
from vllm.model_executor.utils import set_random_seed
2522
from vllm.platforms import current_platform
2623

27-
from .utils import DummyLoRAManager
28-
from tpu_inference.layers.vllm.quantization.common import (
29-
JaxCommonConfig, JaxCommonLinearConfig)
24+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
3025
from tpu_inference.layers.vllm.quantization.unquantized import \
3126
VllmUnquantizedLinearMethod
32-
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu, _shard_merged_column_parallel_linear_lora, _shard_module_to_tpu
27+
from tpu_inference.layers.vllm.sharding import _shard_module_to_tpu
28+
29+
from .utils import DummyLoRAManager
3330

3431
# TODO(xiowei):
3532
# - add test for multi-chip.
@@ -239,7 +236,7 @@ def create_column_parallel_packed_layer():
239236
bias=False,
240237
params_dtype=torch.float16)
241238
linear.weight.data = torch.rand_like(linear.weight.data)
242-
239+
243240
base_linear = MergedColumnParallelLinear(
244241
64, # input_size
245242
[64] * repeats, # output_size
@@ -249,14 +246,18 @@ def create_column_parallel_packed_layer():
249246
vllm_config = dist_init
250247
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
251248
linear_method = VllmUnquantizedLinearMethod(jax_config)
252-
base_linear.quant_method=linear_method
253-
linear_method.process_weights_after_loading(base_linear) # here base_linear.weight is moved to TPU and sharded.
254-
assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.'
255-
assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.'
256-
257-
lora_linear = MergedColumnParallelLinearWithLoRA(
249+
base_linear.quant_method = linear_method
250+
linear_method.process_weights_after_loading(
258251
base_linear
259-
)
252+
) # here base_linear.weight is moved to TPU and sharded.
253+
assert jax_view(base_linear.weight).platform(
254+
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
255+
assert not isinstance(
256+
jax_view(base_linear.weight).sharding,
257+
jax.sharding.SingleDeviceSharding
258+
), 'base_linear.weight should have been sharded.'
259+
260+
lora_linear = MergedColumnParallelLinearWithLoRA(base_linear)
260261
elif repeats == 3:
261262
raise NotImplementedError("NYI: for MergedQKVParallelLinear case")
262263
else:
@@ -268,10 +269,16 @@ def create_column_parallel_packed_layer():
268269
# In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
269270
_shard_module_to_tpu(lora_linear, mesh)
270271

271-
assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.'
272-
assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
273-
assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.'
274-
assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
272+
assert jax_view(lora_linear.lora_a_stacked[0]).platform(
273+
) == 'tpu', 'lora_a_stacked should have been moved to TPU.'
274+
assert not isinstance(
275+
jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.
276+
SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
277+
assert jax_view(lora_linear.lora_b_stacked[0]).platform(
278+
) == 'tpu', 'lora_b_stacked should have been moved to TPU.'
279+
assert not isinstance(
280+
jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.
281+
SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
275282
n_slices = repeats
276283
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
277284
lora_linear.lora_b_stacked) == n_slices)
@@ -287,7 +294,8 @@ def create_column_parallel_packed_layer():
287294
# if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
288295
# So the below check will fail.
289296
if len(devices) == 1:
290-
assert torch.equal(linear.weight.data, lora_linear.weight.to('cpu'))
297+
assert torch.equal(linear.weight.data,
298+
lora_linear.weight.to('cpu'))
291299

292300
max_num_batched_tokens = 8192
293301
max_batches = 256
@@ -330,16 +338,21 @@ def create_column_parallel_packed_layer():
330338
vocab_size=512,
331339
extra_vocab_size=lora_config.lora_extra_vocab_size,
332340
)
333-
assert jax_view(punica_wrapper._lora_indices_per_batch).platform() == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
334-
assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
341+
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
342+
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
343+
assert isinstance(
344+
jax_view(punica_wrapper._lora_indices_per_batch).sharding,
345+
jax.sharding.SingleDeviceSharding
346+
), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
335347

336348
jax_inputs = []
337349
with torchax.default_env():
338350
for input in inputs:
339351
# without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
340352
# without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
341353
jax_input = torch_view(t2j(input))
342-
jax_input.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
354+
jax_input.apply_jax_(jax.device_put,
355+
NamedSharding(mesh, P(None, None)))
343356
jax_inputs.append(jax_input)
344357
with torchax.default_env():
345358
lora_result = lora_linear(torch.cat(jax_inputs))[0]

tpu_inference/lora/torch_punica_tpu.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import torch.nn.functional as F
99
import torchax
1010
from vllm.lora.punica_wrapper.utils import convert_mapping
11-
from torchax.interop import jax_view, torch_view
12-
1311

1412
if TYPE_CHECKING:
1513
# avoid circuit import
@@ -285,7 +283,8 @@ def _update_prefill_metadata(self,
285283
self.batch_size = 1
286284
self._lora_indices_per_batch[:self.
287285
batch_size] = token_lora_tensor[:self.
288-
batch_size].torch()
286+
batch_size].torch(
287+
)
289288

290289
def _pad_prompt_mapping(
291290
self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:

tpu_inference/runner/compilation_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import vllm.envs as envs
99
from jax.sharding import NamedSharding, PartitionSpec
10-
from vllm.utils.math_utils import cdiv
1110

1211
from tpu_inference.core.disagg_utils import is_disagg_enabled
1312
from tpu_inference.layers.common.attention_metadata import AttentionMetadata

0 commit comments

Comments
 (0)