Skip to content

Commit cb1bb91

Browse files
committed
pre-commit
1 parent 646762b commit cb1bb91

File tree

2 files changed

+35
-32
lines changed

2 files changed

+35
-32
lines changed

tpu_inference/models/jax/llama4.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import jax
55
import jax.numpy as jnp
6-
import torch
76
from flax import nnx
87
from flax.typing import PRNGKey
98
from jax.sharding import Mesh
@@ -20,11 +19,12 @@
2019
SharedExpertsTransformerBlock
2120
from tpu_inference.logger import init_logger
2221
from tpu_inference.models.jax.utils.weight_utils import (
23-
get_param, model_weights_generator, print_param_info, reshape_params,
24-
transpose_params, convert_torch_to_jax_with_view)
22+
convert_torch_to_jax_with_view, get_param, model_weights_generator,
23+
print_param_info, reshape_params, transpose_params)
2524

2625
logger = init_logger(__name__)
2726

27+
2828
class Llama4ForCausalLM(nnx.Module):
2929

3030
def __init__(self,
@@ -121,30 +121,31 @@ def __init__(self,
121121
ed_sharding=(None, 'expert'),
122122
random_init=force_random_weights)
123123

124-
moe_ffw = MoE(dtype=dtype,
125-
num_local_experts=self.num_local_experts,
126-
apply_expert_weight_before_computation=True,
127-
hidden_size=self.hidden_size,
128-
intermediate_size_moe=self.intermediate_size_moe,
129-
hidden_act=self.hidden_act,
130-
router=router,
131-
rngs=self.rng,
132-
activation_ffw_td=('data', None),
133-
activation_ffw_ted=('data', 'expert', None),
134-
edf_sharding=('expert', None, 'model'),
135-
efd_sharding=('expert', 'model', None),
136-
random_init=force_random_weights) if is_moe_layer else None
137-
138-
139-
dense_ffw = DenseFFW(dtype=dtype,
140-
hidden_act=self.hidden_act,
141-
hidden_size=self.hidden_size,
142-
intermediate_size=self.intermediate_size_mlp,
143-
random_init=force_random_weights,
144-
rngs=self.rng,
145-
df_sharding=(None, 'model'),
146-
fd_sharding=('model', None),
147-
activation_ffw_td=('data', None)) if not is_moe_layer else None
124+
moe_ffw = MoE(
125+
dtype=dtype,
126+
num_local_experts=self.num_local_experts,
127+
apply_expert_weight_before_computation=True,
128+
hidden_size=self.hidden_size,
129+
intermediate_size_moe=self.intermediate_size_moe,
130+
hidden_act=self.hidden_act,
131+
router=router,
132+
rngs=self.rng,
133+
activation_ffw_td=('data', None),
134+
activation_ffw_ted=('data', 'expert', None),
135+
edf_sharding=('expert', None, 'model'),
136+
efd_sharding=('expert', 'model', None),
137+
random_init=force_random_weights) if is_moe_layer else None
138+
139+
dense_ffw = DenseFFW(
140+
dtype=dtype,
141+
hidden_act=self.hidden_act,
142+
hidden_size=self.hidden_size,
143+
intermediate_size=self.intermediate_size_mlp,
144+
random_init=force_random_weights,
145+
rngs=self.rng,
146+
df_sharding=(None, 'model'),
147+
fd_sharding=('model', None),
148+
activation_ffw_td=('data', None)) if not is_moe_layer else None
148149

149150
attn = Llama4Attention(
150151
hidden_size=self.hidden_size,
@@ -524,8 +525,9 @@ def load_weights(self, model_for_loading: nnx.Module):
524525

525526
loaded_weight = convert_torch_to_jax_with_view(
526527
loaded_weight, cast_type)
527-
loaded_weight = transpose_params(
528-
loaded_name, loaded_weight, self._transpose_map)
528+
loaded_weight = transpose_params(loaded_name,
529+
loaded_weight,
530+
self._transpose_map)
529531

530532
buffer_key = f"{mapped_name}_{'scale' if is_scale else 'qvalue'}"
531533
if buffer_key not in self.expert_weights_buffer:

tpu_inference/models/jax/utils/weight_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
from dataclasses import dataclass, field
1111
from typing import Any, Optional
1212

13-
import torch
14-
1513
import jax
1614
import jax.numpy as jnp
15+
import torch
1716
from flax import nnx
1817
from jax.sharding import Mesh, NamedSharding
1918
from jax.sharding import PartitionSpec as P
@@ -89,7 +88,9 @@ def model_weights_generator(
8988
st_file, framework, filter_regex):
9089
yield name, weight_tensor
9190

92-
def convert_torch_to_jax_with_view(loaded_weight: torch.Tensor, cast_type: jnp.dtype) -> jax.Array:
91+
92+
def convert_torch_to_jax_with_view(loaded_weight: torch.Tensor,
93+
cast_type: jnp.dtype) -> jax.Array:
9394
"""
9495
Converts a PyTorch tensor to a JAX array by reinterpreting its
9596
bit representation using a dtype view map.

0 commit comments

Comments
 (0)