Skip to content

Conversation

@richardsliu
Copy link

Description

This fixes running the Torchax backend on Pathways. In Pathways, the following line would cause the weights to be loaded onto CPU devices on the controller: https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/models/vllm/vllm_model_wrapper.py#L108C9-L109C74

This would later raise a failure when the weights are transferred to the TPU device with jax.device_put(), since the target device is non-PJRT.

The fix is to detach the original pytorch/jax tensor wrapper and just use a numpy array for jax.device_put().

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/444030476

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: Richard Liu <ricliu@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants