Skip to content

Commit d409a35

Browse files
committed
clean up
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent c2082ff commit d409a35

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

tests/lora/test_layers.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from jax.sharding import NamedSharding, PartitionSpec
99
from torchax.interop import jax_view, torch_view
1010
from torchax.ops.mappings import t2j
11-
# from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu
1211
from vllm.config import LoRAConfig
1312
# yapf conflicts with isort for this block
1413
# yapf: disable
@@ -28,9 +27,6 @@
2827

2928
from .utils import DummyLoRAManager
3029

31-
# TODO(xiowei):
32-
# - add test for multi-chip.
33-
3430
P = PartitionSpec
3531

3632
TOLERANCES = {
@@ -204,26 +200,20 @@ def create_random_inputs(
204200
@torch.inference_mode()
205201
@pytest.mark.parametrize("num_loras", [1, 2, 4, 9])
206202
@pytest.mark.parametrize("repeats", [2])
207-
@pytest.mark.parametrize("fully_shard", [False])
208203
@pytest.mark.parametrize("stage", [True, False])
209-
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
210-
stage) -> None:
204+
def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
211205
max_loras = 9
212206
max_lora_rank = 8
213207
lora_config = LoRAConfig(
214208
max_loras=max_loras,
215209
max_lora_rank=max_lora_rank,
216-
fully_sharded_loras=fully_shard,
210+
fully_sharded_loras=False,
217211
lora_dtype=torch.float16,
218212
)
219213

220214
axis_names = ("data", "model")
221215
devices = jax.devices()
222-
mesh_shape = (
223-
1, len(devices)
224-
# 1, 1
225-
) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices()))
226-
print(f'xw32 mesh_shape: {mesh_shape}')
216+
mesh_shape = (1, len(devices))
227217
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
228218

229219
def create_column_parallel_packed_layer():
@@ -264,7 +254,6 @@ def create_column_parallel_packed_layer():
264254
raise NotImplementedError("NYI: for QKVParallelLinear case")
265255

266256
with torchax.default_env():
267-
# create_lora_weights creates global shape lora weight.
268257
lora_linear.create_lora_weights(max_loras, lora_config)
269258
# In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
270259
_shard_module_to_tpu(lora_linear, mesh)

0 commit comments

Comments
 (0)