Skip to content

Commit c9b725c

Browse files
committed
Currently seeing args error
1 parent 983644a commit c9b725c

File tree

2 files changed

+65
-11
lines changed

2 files changed

+65
-11
lines changed

tpu_inference/models/common/model_loader.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from typing import Any, Optional
33

44
import jax
5+
# NOTE: Force usage of the experimental layout API. User explicitly requested no fallbacks.
6+
# If this import fails (older JAX), it's acceptable for this hacky test phase.
7+
from jax.experimental.layout import Layout, Format # type: ignore
58
import torch
69
from flax import nnx
710
from jax.sharding import Mesh, NamedSharding, PartitionSpec
@@ -209,19 +212,68 @@ def get_flax_model(
209212
# https://flax.readthedocs.io/en/latest/guides/performance.html
210213
graphdef, state = nnx.split(jit_model)
211214

212-
@functools.partial(
213-
jax.jit,
214-
out_shardings=(
215-
kv_cache_sharding,
216-
hidden_states_sharding,
217-
hidden_states_sharding, # aux hidden states
218-
),
219-
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220-
static_argnums=6, #6 is layer_name_to_kvcache_index
215+
def get_state_shardings_with_auto_layout(state):
216+
def wrap_sharding(x):
217+
if hasattr(x, 'sharding'):
218+
return Format(Layout.AUTO, sharding=x.sharding)
219+
return None
220+
return jax.tree.map(wrap_sharding, state)
221+
222+
# Preserve the sharding but add Layout.AUTO
223+
state_shardings = get_state_shardings_with_auto_layout(state)
224+
state_shapes = jax.tree.map(
225+
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) if hasattr(x, 'shape') else x,
226+
state,
221227
)
228+
222229
def run_model(graphdef, state, *args):
223-
model = nnx.merge(graphdef, state)
224-
return model(*args)
230+
@functools.partial(
231+
jax.jit,
232+
# Args layout (positional indices for run_model_base):
233+
# 0: graphdef
234+
# 1: state
235+
# 2: kv_cache (donated)
236+
# 3: input_ids
237+
# 4: attention_metadata
238+
# 5: inputs_embeds
239+
# 6: layer_name_to_kvcache_index (static)
240+
# 7: lora_metadata
241+
# Note: For pjit, the static arg (index 6) is not part of the runtime
242+
# args tree. Hence in_shardings should only specify shardings for the
243+
# non-static positional args: indices [0,1,2,3,4,5,7] => length 7.
244+
in_shardings=(None, state_shardings, None, None, None, None, None),
245+
out_shardings=(
246+
kv_cache_sharding,
247+
hidden_states_sharding,
248+
hidden_states_sharding, # aux hidden states
249+
),
250+
donate_argnums=2, # 2 is kv_cache
251+
static_argnums=6, # 6 is layer_name_to_kvcache_index
252+
)
253+
def run_model_base(graphdef, state, *args):
254+
model = nnx.merge(graphdef, state)
255+
return model(*args)
256+
257+
# Compile with ShapeDtypeStruct for state to infer layouts
258+
compiled = run_model_base.lower(graphdef, state_shapes, *args).compile()
259+
260+
# Call with real state
261+
runtime_args = args[:4] + args[5:] # Exclude static arg at index 6
262+
return compiled(graphdef, state, *runtime_args)
263+
264+
# @functools.partial(
265+
# jax.jit,
266+
# out_shardings=(
267+
# kv_cache_sharding,
268+
# hidden_states_sharding,
269+
# hidden_states_sharding, # aux hidden states
270+
# ),
271+
# donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
272+
# static_argnums=6, #6 is layer_name_to_kvcache_index
273+
# )
274+
# def run_model(graphdef, state, *args):
275+
# model = nnx.merge(graphdef, state)
276+
# return model(*args)
225277

226278
logits_sharding = NamedSharding(
227279
mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))

tpu_inference/models/jax/utils/quantization/quantization_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from qwix.contrib.padded_qarray import PaddedPtqProvider
1818
from qwix.contrib import padded_qarray as ptq
1919

20+
qwix.contrib.padded_qarray.QARRAY_KEEP_PADDED_SHAPE = True
21+
2022
if TYPE_CHECKING:
2123
from vllm.config import VllmConfig
2224

0 commit comments

Comments
 (0)