|
2 | 2 | from typing import Any, Optional |
3 | 3 |
|
4 | 4 | import jax |
| 5 | +# NOTE: Force usage of the experimental layout API. User explicitly requested no fallbacks. |
| 6 | +# If this import fails (older JAX), it's acceptable for this hacky test phase. |
| 7 | +from jax.experimental.layout import Layout, Format # type: ignore |
5 | 8 | import torch |
6 | 9 | from flax import nnx |
7 | 10 | from jax.sharding import Mesh, NamedSharding, PartitionSpec |
@@ -209,19 +212,68 @@ def get_flax_model( |
209 | 212 | # https://flax.readthedocs.io/en/latest/guides/performance.html |
210 | 213 | graphdef, state = nnx.split(jit_model) |
211 | 214 |
|
212 | | - @functools.partial( |
213 | | - jax.jit, |
214 | | - out_shardings=( |
215 | | - kv_cache_sharding, |
216 | | - hidden_states_sharding, |
217 | | - hidden_states_sharding, # aux hidden states |
218 | | - ), |
219 | | - donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache |
220 | | - static_argnums=6, #6 is layer_name_to_kvcache_index |
| 215 | + def get_state_shardings_with_auto_layout(state): |
| 216 | + def wrap_sharding(x): |
| 217 | + if hasattr(x, 'sharding'): |
| 218 | + return Format(Layout.AUTO, sharding=x.sharding) |
| 219 | + return None |
| 220 | + return jax.tree.map(wrap_sharding, state) |
| 221 | + |
| 222 | + # Preserve the sharding but add Layout.AUTO |
| 223 | + state_shardings = get_state_shardings_with_auto_layout(state) |
| 224 | + state_shapes = jax.tree.map( |
| 225 | + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) if hasattr(x, 'shape') else x, |
| 226 | + state, |
221 | 227 | ) |
| 228 | + |
222 | 229 | def run_model(graphdef, state, *args): |
223 | | - model = nnx.merge(graphdef, state) |
224 | | - return model(*args) |
| 230 | + @functools.partial( |
| 231 | + jax.jit, |
| 232 | + # Args layout (positional indices for run_model_base): |
| 233 | + # 0: graphdef |
| 234 | + # 1: state |
| 235 | + # 2: kv_cache (donated) |
| 236 | + # 3: input_ids |
| 237 | + # 4: attention_metadata |
| 238 | + # 5: inputs_embeds |
| 239 | + # 6: layer_name_to_kvcache_index (static) |
| 240 | + # 7: lora_metadata |
| 241 | + # Note: For pjit, the static arg (index 6) is not part of the runtime |
| 242 | + # args tree. Hence in_shardings should only specify shardings for the |
| 243 | + # non-static positional args: indices [0,1,2,3,4,5,7] => length 7. |
| 244 | + in_shardings=(None, state_shardings, None, None, None, None, None), |
| 245 | + out_shardings=( |
| 246 | + kv_cache_sharding, |
| 247 | + hidden_states_sharding, |
| 248 | + hidden_states_sharding, # aux hidden states |
| 249 | + ), |
| 250 | + donate_argnums=2, # 2 is kv_cache |
| 251 | + static_argnums=6, # 6 is layer_name_to_kvcache_index |
| 252 | + ) |
| 253 | + def run_model_base(graphdef, state, *args): |
| 254 | + model = nnx.merge(graphdef, state) |
| 255 | + return model(*args) |
| 256 | + |
| 257 | + # Compile with ShapeDtypeStruct for state to infer layouts |
| 258 | + compiled = run_model_base.lower(graphdef, state_shapes, *args).compile() |
| 259 | + |
| 260 | + # Call with real state |
| 261 | + runtime_args = args[:4] + args[5:] # Exclude static arg at index 6 |
| 262 | + return compiled(graphdef, state, *runtime_args) |
| 263 | + |
| 264 | + # @functools.partial( |
| 265 | + # jax.jit, |
| 266 | + # out_shardings=( |
| 267 | + # kv_cache_sharding, |
| 268 | + # hidden_states_sharding, |
| 269 | + # hidden_states_sharding, # aux hidden states |
| 270 | + # ), |
| 271 | + # donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache |
| 272 | + # static_argnums=6, #6 is layer_name_to_kvcache_index |
| 273 | + # ) |
| 274 | + # def run_model(graphdef, state, *args): |
| 275 | + # model = nnx.merge(graphdef, state) |
| 276 | + # return model(*args) |
225 | 277 |
|
226 | 278 | logits_sharding = NamedSharding( |
227 | 279 | mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model")) |
|
0 commit comments