@@ -82,7 +82,7 @@ def __init__(self,
8282 hidden_size = hidden_size ,
8383 dtype = dtype ,
8484 rngs = self .rng ,
85- vd_sharding = (('data' , 'model' ), None ),
85+ vd_sharding = P (('data' , 'model' ), None ),
8686 random_init = self .random_init ,
8787 )
8888
@@ -105,9 +105,9 @@ def __init__(self,
105105 query_tnh = P (None , 'model' , None ),
106106 keyvalue_skh = P (None , 'model' , None ),
107107 attn_o_tnh = P (None , 'model' , None ),
108- dnh_sharding = (None , 'model' , None ),
109- dkh_sharding = (None , 'model' , None ),
110- nhd_sharding = ('model' , None , None ),
108+ dnh_sharding = P (None , 'model' , None ),
109+ dkh_sharding = P (None , 'model' , None ),
110+ nhd_sharding = P ('model' , None , None ),
111111 mesh = self .mesh ,
112112 )
113113
@@ -120,9 +120,9 @@ def __init__(self,
120120 dtype = dtype ,
121121 router_act = 'softmax' ,
122122 random_init = self .random_init ,
123- activation_ffw_td = ('data' , None ),
124- ed_sharding = ('model' , None ),
125- e_sharding = ('model' , ),
123+ activation_ffw_td = P ('data' , None ),
124+ ed_sharding = P ('model' , None ),
125+ e_sharding = P ('model' ),
126126 )
127127
128128 moe_mlp = GptOssMoE (
@@ -135,10 +135,10 @@ def __init__(self,
135135 router = router ,
136136 swiglu_limit = swiglu_limit ,
137137 # Sharding configuration
138- activation_ffw_td = ('data' , None ),
139- edf_sharding = ('model' , None , None ),
140- efd_sharding = ('model' , None , None ),
141- ed_sharding = ('model' , None ),
138+ activation_ffw_td = P ('data' , None ),
139+ edf_sharding = P ('model' , None , None ),
140+ efd_sharding = P ('model' , None , None ),
141+ ed_sharding = P ('model' , None ),
142142 )
143143
144144 block = TransformerBlock (
@@ -148,13 +148,15 @@ def __init__(self,
148148 epsilon = rms_norm_eps ,
149149 dtype = dtype ,
150150 rngs = self .rng ,
151+ activation_ffw_td = P ('data' , None ),
151152 ),
152153 pre_mlp_norm = RMSNorm (
153154 dims = hidden_size ,
154155 random_init = self .random_init ,
155156 epsilon = rms_norm_eps ,
156157 dtype = dtype ,
157158 rngs = self .rng ,
159+ activation_ffw_td = P ('data' , None ),
158160 ),
159161 attn = attn ,
160162 custom_module = moe_mlp ,
@@ -167,15 +169,16 @@ def __init__(self,
167169 random_init = self .random_init ,
168170 epsilon = rms_norm_eps ,
169171 dtype = dtype ,
172+ activation_ffw_td = P ('data' , None ),
170173 )
171174
172175 self .lm_head = LMhead (
173176 vocab_size = vocab_size ,
174177 hidden_size = hidden_size ,
175178 dtype = dtype ,
176179 rngs = self .rng ,
177- vd_sharding = (('data' , 'model' ), None ),
178- dv_sharding = (None , ('data' , 'model' )),
180+ vd_sharding = P (('data' , 'model' ), None ),
181+ dv_sharding = P (None , ('data' , 'model' )),
179182 random_init = self .random_init ,
180183 )
181184
0 commit comments