|
3 | 3 |
|
4 | 4 | import jax |
5 | 5 | import jax.numpy as jnp |
6 | | -import torch |
7 | 6 | from flax import nnx |
8 | 7 | from flax.typing import PRNGKey |
9 | 8 | from jax.sharding import Mesh |
|
20 | 19 | SharedExpertsTransformerBlock |
21 | 20 | from tpu_inference.logger import init_logger |
22 | 21 | from tpu_inference.models.jax.utils.weight_utils import ( |
23 | | - get_param, model_weights_generator, print_param_info, reshape_params, |
24 | | - transpose_params, convert_torch_to_jax_with_view) |
| 22 | + convert_torch_to_jax_with_view, get_param, model_weights_generator, |
| 23 | + print_param_info, reshape_params, transpose_params) |
25 | 24 |
|
26 | 25 | logger = init_logger(__name__) |
27 | 26 |
|
| 27 | + |
28 | 28 | class Llama4ForCausalLM(nnx.Module): |
29 | 29 |
|
30 | 30 | def __init__(self, |
@@ -121,30 +121,31 @@ def __init__(self, |
121 | 121 | ed_sharding=(None, 'expert'), |
122 | 122 | random_init=force_random_weights) |
123 | 123 |
|
124 | | - moe_ffw = MoE(dtype=dtype, |
125 | | - num_local_experts=self.num_local_experts, |
126 | | - apply_expert_weight_before_computation=True, |
127 | | - hidden_size=self.hidden_size, |
128 | | - intermediate_size_moe=self.intermediate_size_moe, |
129 | | - hidden_act=self.hidden_act, |
130 | | - router=router, |
131 | | - rngs=self.rng, |
132 | | - activation_ffw_td=('data', None), |
133 | | - activation_ffw_ted=('data', 'expert', None), |
134 | | - edf_sharding=('expert', None, 'model'), |
135 | | - efd_sharding=('expert', 'model', None), |
136 | | - random_init=force_random_weights) if is_moe_layer else None |
137 | | - |
138 | | - |
139 | | - dense_ffw = DenseFFW(dtype=dtype, |
140 | | - hidden_act=self.hidden_act, |
141 | | - hidden_size=self.hidden_size, |
142 | | - intermediate_size=self.intermediate_size_mlp, |
143 | | - random_init=force_random_weights, |
144 | | - rngs=self.rng, |
145 | | - df_sharding=(None, 'model'), |
146 | | - fd_sharding=('model', None), |
147 | | - activation_ffw_td=('data', None)) if not is_moe_layer else None |
| 124 | + moe_ffw = MoE( |
| 125 | + dtype=dtype, |
| 126 | + num_local_experts=self.num_local_experts, |
| 127 | + apply_expert_weight_before_computation=True, |
| 128 | + hidden_size=self.hidden_size, |
| 129 | + intermediate_size_moe=self.intermediate_size_moe, |
| 130 | + hidden_act=self.hidden_act, |
| 131 | + router=router, |
| 132 | + rngs=self.rng, |
| 133 | + activation_ffw_td=('data', None), |
| 134 | + activation_ffw_ted=('data', 'expert', None), |
| 135 | + edf_sharding=('expert', None, 'model'), |
| 136 | + efd_sharding=('expert', 'model', None), |
| 137 | + random_init=force_random_weights) if is_moe_layer else None |
| 138 | + |
| 139 | + dense_ffw = DenseFFW( |
| 140 | + dtype=dtype, |
| 141 | + hidden_act=self.hidden_act, |
| 142 | + hidden_size=self.hidden_size, |
| 143 | + intermediate_size=self.intermediate_size_mlp, |
| 144 | + random_init=force_random_weights, |
| 145 | + rngs=self.rng, |
| 146 | + df_sharding=(None, 'model'), |
| 147 | + fd_sharding=('model', None), |
| 148 | + activation_ffw_td=('data', None)) if not is_moe_layer else None |
148 | 149 |
|
149 | 150 | attn = Llama4Attention( |
150 | 151 | hidden_size=self.hidden_size, |
@@ -524,8 +525,9 @@ def load_weights(self, model_for_loading: nnx.Module): |
524 | 525 |
|
525 | 526 | loaded_weight = convert_torch_to_jax_with_view( |
526 | 527 | loaded_weight, cast_type) |
527 | | - loaded_weight = transpose_params( |
528 | | - loaded_name, loaded_weight, self._transpose_map) |
| 528 | + loaded_weight = transpose_params(loaded_name, |
| 529 | + loaded_weight, |
| 530 | + self._transpose_map) |
529 | 531 |
|
530 | 532 | buffer_key = f"{mapped_name}_{'scale' if is_scale else 'qvalue'}" |
531 | 533 | if buffer_key not in self.expert_weights_buffer: |
|
0 commit comments