Skip to content

Commit cc47bab

Browse files
committed
add multi-chip test case
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent c480658 commit cc47bab

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ steps:
157157
queue: tpu_v6e_queue
158158
commands:
159159
- |
160+
<<<<<<< HEAD
160161
if [[ "$$NIGHTLY" == "1" ]]; then
161162
.buildkite/scripts/run_in_docker.sh \
162163
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \
@@ -165,6 +166,12 @@ steps:
165166
echo "Skipping: NIGHTLY environment variable not set"
166167
exit 0
167168
fi
169+
=======
170+
.buildkite/scripts/run_in_docker.sh \
171+
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \
172+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \
173+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py'
174+
>>>>>>> c17bacea (add multi-chip test case)
168175

169176
- label: "E2E MLPerf tests for JAX + vLLM models on multiple chips"
170177
key: test_11
@@ -212,13 +219,19 @@ steps:
212219
queue: tpu_v6e_8_queue
213220
commands:
214221
- |
222+
<<<<<<< HEAD
215223
if [[ "$$NIGHTLY" == "1" ]]; then
216224
.buildkite/scripts/run_in_docker.sh \
217225
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py'
218226
else
219227
echo "Skipping: NIGHTLY environment variable not set"
220228
exit 0
221229
fi
230+
=======
231+
.buildkite/scripts/run_in_docker.sh \
232+
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \
233+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py'
234+
>>>>>>> c17bacea (add multi-chip test case)
222235

223236

224237
# -----------------------------------------------------------------

tests/lora/test_layers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,13 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
221221
)
222222

223223
axis_names = ("data", "model")
224+
devices = jax.devices()
224225
mesh_shape = (
225-
1, 1
226+
1, len(devices)
227+
# 1, 1
226228
) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices()))
227-
mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices())
229+
print(f'xw32 mesh_shape: {mesh_shape}')
230+
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
228231

229232
def create_column_parallel_packed_layer():
230233
# We first create a base linear layer, then a lora layer to wrap it.
@@ -281,7 +284,10 @@ def create_column_parallel_packed_layer():
281284
with torchax.default_env():
282285
# lora_linear.weight has type torchax.tensor.Tensor
283286
# BaseLinearLayerWithLoRA.weight property guarantees this.
284-
assert torch.equal(linear.weight, lora_linear.weight.to('cpu'))
287+
# if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
288+
# So the below check will fail.
289+
if len(devices) == 1:
290+
assert torch.equal(linear.weight.data, lora_linear.weight.to('cpu'))
285291

286292
max_num_batched_tokens = 8192
287293
max_batches = 256

0 commit comments

Comments
 (0)