@@ -248,6 +248,7 @@ def create_column_parallel_packed_layer():
248248 # self.jax_config.mesh.devices[0][0].platform
249249 jax_config = JaxCommonLinearConfig (vllm_config , mesh , base_linear )
250250 linear_method = VllmUnquantizedLinearMethod (jax_config )
251+ base_linear .quant_method = linear_method
251252 linear_method .process_weights_after_loading (base_linear )
252253 # here base_linear.weight is on TPU and sharded.
253254
@@ -263,6 +264,8 @@ def create_column_parallel_packed_layer():
263264 raise NotImplementedError ("NYI: for QKVParallelLinear case" )
264265
265266 n_slices = repeats
267+ #TODO(xw): check if we can enable torchax globally.
268+ # TODO(xw): check if we can calculate both actual and expected output using torchax.
266269 with torchax .default_env ():
267270 # create_lora_weights creates global shape weight.
268271 lora_linear .create_lora_weights (max_loras , lora_config )
@@ -282,10 +285,11 @@ def create_column_parallel_packed_layer():
282285
283286 max_num_batched_tokens = 8192
284287 max_batches = 256
285- punica_wrapper = get_punica_wrapper (max_num_batched_tokens ,
286- max_batches ,
287- device ,
288- max_loras = max_loras )
288+ with torchax .default_env ():
289+ punica_wrapper = get_punica_wrapper (max_num_batched_tokens ,
290+ max_batches ,
291+ 'jax' ,
292+ max_loras = max_loras )
289293 assert check_punica_wrapper (punica_wrapper )
290294 lora_linear .set_mapping (punica_wrapper )
291295
@@ -333,7 +337,8 @@ def create_column_parallel_packed_layer():
333337 with torchax .default_env ():
334338 # lora_result = lora_linear(torch.cat(jax_inputs))[0]
335339 # lora_result = j2t(lora_result)
336- lora_result = linear_method .apply (lora_linear .base_layer , torch .cat (jax_inputs ))
340+ # lora_result = linear_method.apply(lora_linear.base_layer, torch.cat(jax_inputs))
341+ lora_result = lora_linear (torch .cat (jax_inputs ))[0 ]
337342
338343 expected_results : list [torch .Tensor ] = []
339344 for input_ , lora_id in zip (inputs , prompt_mapping ):
@@ -348,17 +353,18 @@ def create_column_parallel_packed_layer():
348353 expected_result = torch .cat (expected_results )
349354
350355 rtol , atol = TOLERANCES [lora_result .dtype ]
351- # with torchax.default_env():
352- # torch.testing.assert_close(lora_result.to('cpu'),
353- # expected_result,
354- # rtol=rtol,
355- # atol=atol)
356- # print(
357- # f'Output max diff: {torch.max(torch.abs(expected_result.to('cpu') - lora_result))}'
358- # )
359- # print(
360- # f'Output mean diff: {torch.mean(torch.abs(expected_result.to('cpu') - lora_result))}'
361- # )
356+ with torchax .default_env ():
357+ lora_result_cpu = lora_result .to ('cpu' )
358+ torch .testing .assert_close (lora_result_cpu ,
359+ expected_result ,
360+ rtol = rtol ,
361+ atol = atol )
362+ print (
363+ f'Output max diff: { torch .max (torch .abs (expected_result - lora_result_cpu ))} '
364+ )
365+ print (
366+ f'Output mean diff: { torch .mean (torch .abs (expected_result - lora_result_cpu ))} '
367+ )
362368
363369 # Check that resetting the lora weights succeeds
364370 # Here we set all lora weight to be empty.
0 commit comments