@@ -159,7 +159,12 @@ def _transform(self, graph_module: GraphModule):
159159 def _tosa_pipeline (
160160 self , exported_program : ExportedProgram , graph_module : GraphModule
161161 ) -> GraphModule :
162+ # Preprocessing passes
163+
162164 self .add_pass (AnnotateOutputDimOrderPass ())
165+
166+ # Node transformation passes (pre q/dq folding)
167+
163168 self .add_pass (FuseQuantizedActivationPass ())
164169 self .add_pass (RemoveGetItemPass ())
165170 self .add_pass (ConvertToClampPass ())
@@ -174,8 +179,19 @@ def _tosa_pipeline(
174179 self .add_pass (ConvertELUParamsPass ())
175180 self .add_pass (ConvertSplitToSlicePass ())
176181 self .add_pass (QuantizeOperatorArguments ())
182+
183+ # Fold Q/DQ nodes, insert INT8/INT32 rescales.
184+
177185 self .add_pass (FoldAndAnnotateQParamsPass (exported_program )) # type: ignore[call-arg]
178186 self .add_pass (FuseDuplicateUsersPass ())
187+ # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
188+ # before FoldAndAnnotateQParamsPass but is unable to at the moment.
189+ # Ticket: MLETORCH-1539
190+ self .add_pass (DecomposeLinearPass ())
191+ self .add_pass (InsertRescaleInt32Pass ())
192+
193+ # Node transformation passes (post q/dq folding)
194+
179195 self .add_pass (DecomposeExpm1Pass ())
180196 self .add_pass (DecomposeLogitPass ())
181197 self .add_pass (DecomposeMaskedFill ())
@@ -196,57 +212,67 @@ def _tosa_pipeline(
196212 self .add_pass (DecomposeSignPass ())
197213 self .add_pass (DecomposeFloorDividePass ())
198214 self .add_pass (DecomposeDivTensorModePass ())
215+ self .add_pass (DecomposeGeluPass ())
216+ self .add_pass (DecomposeAddSubAlphaPass ())
217+ self .add_pass (DecomposeGroupedConv ())
218+ self .add_pass (Conv1dUnsqueezePass ())
219+
220+ # Scalars -> tensors, match tensor dtypes and ranks.
221+
199222 self .add_pass (ReplaceScalarWithTensorByProfilePass ())
223+ self .add_pass (ConvertFullLikeToFullPass ())
224+ self .add_pass (MatchArgDtypePass ())
225+ self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
226+ # TODO: Move DecomposeNotEqualPass to before or after this block of
227+ # passes. Ticket: MLETORCH-1540
228+ self .add_pass (DecomposeNotEqualPass ())
229+ self .add_pass (MatchArgRanksPass (exported_program ))
230+ self .add_pass (FuseConstantArgsPass (exported_program ))
231+
232+ # Node transformation passes (post scalar-removal)
233+
200234 self .add_pass (DecomposeRemainderPass ())
201235 self .add_pass (DecomposeDivTensorModePass ())
202236 self .add_pass (DecomposeEmbeddingPass ())
203237 self .add_pass (FuseBatchnorm2DPass (exported_program ))
204238 self .add_pass (ConvertMmToBmmPass ())
205239 self .add_pass (DecomposeGluPass ())
206- self .add_pass (DecomposeLinearPass ())
207240 self .add_pass (DecomposeLeakyReLUPass ())
208- self .add_pass (DecomposeNotEqualPass ())
209241 self .add_pass (DecomposeDivPass ())
210- self .add_pass (DecomposeAddSubAlphaPass ())
211242 self .add_pass (DecomposeSoftmaxPass ())
212- self .add_pass (DecomposeGeluPass ())
213- self .add_pass (ConvertFullLikeToFullPass ())
214243 self .add_pass (ConvertMinMaxPass ())
215244 self .add_pass (ConvertAnyDefaultDimDimsPass ())
216- self .add_pass (MatchArgDtypePass ())
217- self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
218- self .add_pass (MatchArgRanksPass (exported_program ))
219245 self .add_pass (DecomposeAdaptiveAvgPool2dPass ())
220246 self .add_pass (DecomposeAvgPool2d ())
221247 self .add_pass (
222248 DecorateFp32toInt32CastingPass ()
223249 ) # Require that no new fp32->int32 is introduced after this pass
224250 self .add_pass (ComputeConstantOpsAOT (exported_program ))
225-
226- self .add_pass (DecomposeGroupedConv ())
227251 self .add_pass (ConvertExpandCopyToRepeatPass ())
228252 self .add_pass (UnsqueezeBeforeRepeatPass ())
229253 self .add_pass (DecomposeCumsumPass (exported_program ))
230- self .add_pass (Conv1dUnsqueezePass ())
231254 self .add_pass (DecomposeMaxPool2DPass ())
232255 self .add_pass (SizeAdjustInputPass ())
233256 self .add_pass (DecomposeSelectPass ())
234257 self .add_pass (ConvertSqueezesToViewPass ())
235258 self .add_pass (CastToInt32Pass ())
236259 self .add_pass (BroadcastArgsPass ())
237-
238260 self .add_pass (ConvertPermuteSingletonToViewPass ())
239261 self .add_pass (FuseViewCopyTransform ())
240- self .add_pass (FuseConstantArgsPass (exported_program ))
241262 self .add_pass (DecomposeConv2dWithInt16ActivationPass ())
242- self .add_pass (CastInt64BuffersToInt32Pass ( exported_program ))
263+ self .add_pass (DecomposeSumPass ( ))
243264 self .add_pass (InsertTableOpsPass (exported_program ))
265+
266+ # Aten -> TOSA transformation passes
267+
244268 self .add_pass (RewriteUpsamplePass ())
245269 self .add_pass (RewriteConv2dPass (exported_program ))
246270 self .add_pass (RewriteMatmulPass ())
271+
272+ # Postprocessing/cleanup passes
273+
274+ self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
247275 self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
248- self .add_pass (InsertRescaleInt32Pass ())
249- self .add_pass (DecomposeSumPass ())
250276 self .add_pass (ToTosaMemoryFormatPass (exported_program ))
251277 self .add_pass (RemoveNoopPass ())
252278 self .add_pass (InsertRescalePass ())
0 commit comments