@@ -251,6 +251,8 @@ def create_column_parallel_packed_layer():
251251 base_linear .quant_method = linear_method
252252 linear_method .process_weights_after_loading (base_linear )
253253 # here base_linear.weight is on TPU and sharded.
254+ assert jax_view (base_linear .weight ).platform () == 'tpu' , 'base_linear.weight should have been moved to TPU.'
255+ assert not isinstance (jax_view (base_linear .weight ).sharding , jax .sharding .SingleDeviceSharding ), 'base_linear.weight should have been sharded.'
254256
255257 # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
256258 lora_linear = MergedColumnParallelLinearWithLoRA (
@@ -270,6 +272,13 @@ def create_column_parallel_packed_layer():
270272 # create_lora_weights creates global shape weight.
271273 lora_linear .create_lora_weights (max_loras , lora_config )
272274 _shard_merged_column_parallel_linear_lora (lora_linear , mesh )
275+ # TODO: assert the lora_a_stacked is on TPU and sharded.
276+ assert jax_view (lora_linear .lora_a_stacked [0 ]).platform () == 'tpu' , 'lora_a_stacked should have been moved to TPU.'
277+ assert not isinstance (jax_view (lora_linear .lora_a_stacked [0 ]).sharding , jax .sharding .SingleDeviceSharding ), 'lora_a_stacked should have been sharded.'
278+ assert jax_view (lora_linear .lora_b_stacked [0 ]).platform () == 'tpu' , 'lora_b_stacked should have been moved to TPU.'
279+ assert not isinstance (jax_view (lora_linear .lora_b_stacked [0 ]).sharding , jax .sharding .SingleDeviceSharding ), 'lora_b_stacked should have been sharded.'
280+
281+ # TODO: assert the lora_b_stacked is on TPU and sharded.
273282 assert (lora_linear .n_slices == len (lora_linear .lora_a_stacked ) == len (
274283 lora_linear .lora_b_stacked ) == n_slices )
275284
@@ -324,7 +333,8 @@ def create_column_parallel_packed_layer():
324333 vocab_size = 512 ,
325334 extra_vocab_size = lora_config .lora_extra_vocab_size ,
326335 )
327- # punica_wrapper.move_to_device(mesh)
336+ assert jax_view (punica_wrapper ._lora_indices_per_batch ).platform () == 'tpu' , 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
337+ assert isinstance (jax_view (punica_wrapper ._lora_indices_per_batch ).sharding , jax .sharding .SingleDeviceSharding ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
328338
329339 jax_inputs = []
330340 with torchax .default_env (), jax .default_device (jax .devices ("tpu" )[0 ]):
0 commit comments