|
8 | 8 | from jax.sharding import NamedSharding, PartitionSpec |
9 | 9 | from torchax.interop import jax_view, torch_view |
10 | 10 | from torchax.ops.mappings import t2j |
11 | | -# from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu |
12 | 11 | from vllm.config import LoRAConfig |
13 | 12 | # yapf conflicts with isort for this block |
14 | 13 | # yapf: disable |
|
28 | 27 |
|
29 | 28 | from .utils import DummyLoRAManager |
30 | 29 |
|
31 | | -# TODO(xiowei): |
32 | | -# - add test for multi-chip. |
33 | | - |
34 | 30 | P = PartitionSpec |
35 | 31 |
|
36 | 32 | TOLERANCES = { |
@@ -204,26 +200,20 @@ def create_random_inputs( |
204 | 200 | @torch.inference_mode() |
205 | 201 | @pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) |
206 | 202 | @pytest.mark.parametrize("repeats", [2]) |
207 | | -@pytest.mark.parametrize("fully_shard", [False]) |
208 | 203 | @pytest.mark.parametrize("stage", [True, False]) |
209 | | -def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, |
210 | | - stage) -> None: |
| 204 | +def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None: |
211 | 205 | max_loras = 9 |
212 | 206 | max_lora_rank = 8 |
213 | 207 | lora_config = LoRAConfig( |
214 | 208 | max_loras=max_loras, |
215 | 209 | max_lora_rank=max_lora_rank, |
216 | | - fully_sharded_loras=fully_shard, |
| 210 | + fully_sharded_loras=False, |
217 | 211 | lora_dtype=torch.float16, |
218 | 212 | ) |
219 | 213 |
|
220 | 214 | axis_names = ("data", "model") |
221 | 215 | devices = jax.devices() |
222 | | - mesh_shape = ( |
223 | | - 1, len(devices) |
224 | | - # 1, 1 |
225 | | - ) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices())) |
226 | | - print(f'xw32 mesh_shape: {mesh_shape}') |
| 216 | + mesh_shape = (1, len(devices)) |
227 | 217 | mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices) |
228 | 218 |
|
229 | 219 | def create_column_parallel_packed_layer(): |
@@ -264,7 +254,6 @@ def create_column_parallel_packed_layer(): |
264 | 254 | raise NotImplementedError("NYI: for QKVParallelLinear case") |
265 | 255 |
|
266 | 256 | with torchax.default_env(): |
267 | | - # create_lora_weights creates global shape lora weight. |
268 | 257 | lora_linear.create_lora_weights(max_loras, lora_config) |
269 | 258 | # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu. |
270 | 259 | _shard_module_to_tpu(lora_linear, mesh) |
|
0 commit comments