Skip to content

Commit ab09b7c

Browse files
committed
extract the function _create_column_parallel_packed_layer out of the test.
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent 7ea2939 commit ab09b7c

File tree

1 file changed

+64
-59
lines changed

1 file changed

+64
-59
lines changed

tests/lora/test_layers.py

Lines changed: 64 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -210,73 +210,18 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
210210
fully_sharded_loras=False,
211211
lora_dtype=torch.float16,
212212
)
213+
vllm_config = dist_init
214+
vllm_config.lora_config = lora_config
213215

214216
axis_names = ("data", "model")
215217
devices = jax.devices()
216218
mesh_shape = (1, len(devices))
217219
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
218220

219-
def create_column_parallel_packed_layer():
220-
# We first create a base linear layer, then a lora layer to wrap it.
221-
if repeats == 2:
222-
# In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
223-
linear = MergedColumnParallelLinear(
224-
64, # input_size
225-
[64] * repeats, # output_size
226-
bias=False,
227-
params_dtype=torch.float16)
228-
linear.weight.data = torch.rand_like(linear.weight.data)
229-
230-
base_linear = MergedColumnParallelLinear(
231-
64, # input_size
232-
[64] * repeats, # output_size
233-
bias=False,
234-
params_dtype=torch.float16)
235-
base_linear.weight.data = linear.weight.data
236-
vllm_config = dist_init
237-
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
238-
linear_method = VllmUnquantizedLinearMethod(jax_config)
239-
base_linear.quant_method = linear_method
240-
linear_method.process_weights_after_loading(
241-
base_linear
242-
) # here base_linear.weight is moved to TPU and sharded.
243-
assert jax_view(base_linear.weight).platform(
244-
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
245-
assert not isinstance(
246-
jax_view(base_linear.weight).sharding,
247-
jax.sharding.SingleDeviceSharding
248-
), 'base_linear.weight should have been sharded.'
249-
250-
lora_linear = MergedColumnParallelLinearWithLoRA(base_linear)
251-
elif repeats == 3:
252-
raise NotImplementedError("NYI: for MergedQKVParallelLinear case")
253-
else:
254-
raise NotImplementedError("NYI: for QKVParallelLinear case")
255-
256-
with torchax.default_env():
257-
lora_linear.create_lora_weights(max_loras, lora_config)
258-
# In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
259-
_shard_module_to_tpu(lora_linear, mesh)
260-
261-
assert jax_view(lora_linear.lora_a_stacked[0]).platform(
262-
) == 'tpu', 'lora_a_stacked should have been moved to TPU.'
263-
assert not isinstance(
264-
jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.
265-
SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
266-
assert jax_view(lora_linear.lora_b_stacked[0]).platform(
267-
) == 'tpu', 'lora_b_stacked should have been moved to TPU.'
268-
assert not isinstance(
269-
jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.
270-
SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
271-
n_slices = repeats
272-
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
273-
lora_linear.lora_b_stacked) == n_slices)
274-
275-
return linear, lora_linear
276-
277221
set_random_seed(6)
278222

279-
linear, lora_linear = create_column_parallel_packed_layer()
223+
linear, lora_linear = _create_column_parallel_packed_layer(
224+
repeats, vllm_config, mesh)
280225
with torchax.default_env():
281226
# lora_linear.weight has type torchax.tensor.Tensor
282227
# BaseLinearLayerWithLoRA.weight property guarantees this.
@@ -419,3 +364,63 @@ def create_column_parallel_packed_layer():
419364
print(
420365
f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result_cpu))}'
421366
)
367+
368+
369+
def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
370+
# We first create a base linear layer, then a lora layer to wrap it.
371+
if repeats == 2:
372+
# In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
373+
linear = MergedColumnParallelLinear(
374+
64, # input_size
375+
[64] * repeats, # output_size
376+
bias=False,
377+
params_dtype=torch.float16)
378+
linear.weight.data = torch.rand_like(linear.weight.data)
379+
380+
base_linear = MergedColumnParallelLinear(
381+
64, # input_size
382+
[64] * repeats, # output_size
383+
bias=False,
384+
params_dtype=torch.float16)
385+
base_linear.weight.data = linear.weight.data
386+
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
387+
linear_method = VllmUnquantizedLinearMethod(jax_config)
388+
base_linear.quant_method = linear_method
389+
linear_method.process_weights_after_loading(
390+
base_linear
391+
) # here base_linear.weight is moved to TPU and sharded.
392+
assert jax_view(base_linear.weight).platform(
393+
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
394+
assert not isinstance(
395+
jax_view(
396+
base_linear.weight).sharding, jax.sharding.SingleDeviceSharding
397+
), 'base_linear.weight should have been sharded.'
398+
399+
lora_linear = MergedColumnParallelLinearWithLoRA(base_linear)
400+
elif repeats == 3:
401+
raise NotImplementedError("NYI: for MergedQKVParallelLinear case")
402+
else:
403+
raise NotImplementedError("NYI: for QKVParallelLinear case")
404+
405+
lora_config = vllm_config.lora_config
406+
max_loras = lora_config.max_loras
407+
with torchax.default_env():
408+
lora_linear.create_lora_weights(max_loras, lora_config)
409+
# In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
410+
_shard_module_to_tpu(lora_linear, mesh)
411+
412+
assert jax_view(lora_linear.lora_a_stacked[0]).platform(
413+
) == 'tpu', 'lora_a_stacked should have been moved to TPU.'
414+
assert not isinstance(
415+
jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.
416+
SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
417+
assert jax_view(lora_linear.lora_b_stacked[0]).platform(
418+
) == 'tpu', 'lora_b_stacked should have been moved to TPU.'
419+
assert not isinstance(
420+
jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.
421+
SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
422+
n_slices = repeats
423+
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
424+
lora_linear.lora_b_stacked) == n_slices)
425+
426+
return linear, lora_linear

0 commit comments

Comments
 (0)