Skip to content

Commit 2eae008

Browse files
committed
also check if the correct and the sharding is correct.
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent 11c7ea2 commit 2eae008

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

tests/lora/test_layers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ def create_column_parallel_packed_layer():
251251
base_linear.quant_method=linear_method
252252
linear_method.process_weights_after_loading(base_linear)
253253
# here base_linear.weight is on TPU and sharded.
254+
assert jax_view(base_linear.weight).platform() == 'tpu', 'base_linear.weight should have been moved to TPU.'
255+
assert not isinstance(jax_view(base_linear.weight).sharding, jax.sharding.SingleDeviceSharding), 'base_linear.weight should have been sharded.'
254256

255257
# In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
256258
lora_linear = MergedColumnParallelLinearWithLoRA(
@@ -270,6 +272,13 @@ def create_column_parallel_packed_layer():
270272
# create_lora_weights creates global shape weight.
271273
lora_linear.create_lora_weights(max_loras, lora_config)
272274
_shard_merged_column_parallel_linear_lora(lora_linear, mesh)
275+
# TODO: assert the lora_a_stacked is on TPU and sharded.
276+
assert jax_view(lora_linear.lora_a_stacked[0]).platform() == 'tpu', 'lora_a_stacked should have been moved to TPU.'
277+
assert not isinstance(jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
278+
assert jax_view(lora_linear.lora_b_stacked[0]).platform() == 'tpu', 'lora_b_stacked should have been moved to TPU.'
279+
assert not isinstance(jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
280+
281+
# TODO: assert the lora_b_stacked is on TPU and sharded.
273282
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
274283
lora_linear.lora_b_stacked) == n_slices)
275284

@@ -324,7 +333,8 @@ def create_column_parallel_packed_layer():
324333
vocab_size=512,
325334
extra_vocab_size=lora_config.lora_extra_vocab_size,
326335
)
327-
# punica_wrapper.move_to_device(mesh)
336+
assert jax_view(punica_wrapper._lora_indices_per_batch).platform() == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
337+
assert isinstance(jax_view(punica_wrapper._lora_indices_per_batch).sharding, jax.sharding.SingleDeviceSharding), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
328338

329339
jax_inputs = []
330340
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):

0 commit comments

Comments
 (0)