Skip to content

Commit 59029c9

Browse files
committed
[GPT-OSS] Add sharding configs to support Qwix quantization
Signed-off-by: Jordan Dotzel <amishacorns@users.noreply.github.com>
1 parent e9f570c commit 59029c9

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

tpu_inference/models/jax/gpt_oss.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self,
8282
hidden_size=hidden_size,
8383
dtype=dtype,
8484
rngs=self.rng,
85-
vd_sharding=(('data', 'model'), None),
85+
vd_sharding=P(('data', 'model'), None),
8686
random_init=self.random_init,
8787
)
8888

@@ -105,9 +105,9 @@ def __init__(self,
105105
query_tnh=P(None, 'model', None),
106106
keyvalue_skh=P(None, 'model', None),
107107
attn_o_tnh=P(None, 'model', None),
108-
dnh_sharding=(None, 'model', None),
109-
dkh_sharding=(None, 'model', None),
110-
nhd_sharding=('model', None, None),
108+
dnh_sharding=P(None, 'model', None),
109+
dkh_sharding=P(None, 'model', None),
110+
nhd_sharding=P('model', None, None),
111111
mesh=self.mesh,
112112
)
113113

@@ -120,9 +120,9 @@ def __init__(self,
120120
dtype=dtype,
121121
router_act='softmax',
122122
random_init=self.random_init,
123-
activation_ffw_td=('data', None),
124-
ed_sharding=('model', None),
125-
e_sharding=('model', ),
123+
activation_ffw_td=P('data', None),
124+
ed_sharding=P('model', None),
125+
e_sharding=P('model'),
126126
)
127127

128128
moe_mlp = GptOssMoE(
@@ -135,10 +135,10 @@ def __init__(self,
135135
router=router,
136136
swiglu_limit=swiglu_limit,
137137
# Sharding configuration
138-
activation_ffw_td=('data', None),
139-
edf_sharding=('model', None, None),
140-
efd_sharding=('model', None, None),
141-
ed_sharding=('model', None),
138+
activation_ffw_td=P('data', None),
139+
edf_sharding=P('model', None, None),
140+
efd_sharding=P('model', None, None),
141+
ed_sharding=P('model', None),
142142
)
143143

144144
block = TransformerBlock(
@@ -148,13 +148,15 @@ def __init__(self,
148148
epsilon=rms_norm_eps,
149149
dtype=dtype,
150150
rngs=self.rng,
151+
activation_ffw_td=P('data', None),
151152
),
152153
pre_mlp_norm=RMSNorm(
153154
dims=hidden_size,
154155
random_init=self.random_init,
155156
epsilon=rms_norm_eps,
156157
dtype=dtype,
157158
rngs=self.rng,
159+
activation_ffw_td=P('data', None),
158160
),
159161
attn=attn,
160162
custom_module=moe_mlp,
@@ -167,15 +169,16 @@ def __init__(self,
167169
random_init=self.random_init,
168170
epsilon=rms_norm_eps,
169171
dtype=dtype,
172+
activation_ffw_td=P('data', None),
170173
)
171174

172175
self.lm_head = LMhead(
173176
vocab_size=vocab_size,
174177
hidden_size=hidden_size,
175178
dtype=dtype,
176179
rngs=self.rng,
177-
vd_sharding=(('data', 'model'), None),
178-
dv_sharding=(None, ('data', 'model')),
180+
vd_sharding=P(('data', 'model'), None),
181+
dv_sharding=P(None, ('data', 'model')),
179182
random_init=self.random_init,
180183
)
181184

0 commit comments

Comments
 (0)