File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change 1212#include < torch/csrc/jit/frontend/error_report.h>
1313#include < torch/csrc/jit/ir/alias_analysis.h>
1414#include < torch/csrc/jit/jit_log.h>
15+ #include < torch/csrc/jit/passes/batch_mm.h>
1516#include < torch/csrc/jit/passes/constant_propagation.h>
1617#include < torch/csrc/jit/passes/graph_rewrite_helper.h>
18+ #include < torch/csrc/jit/passes/lower_tuples.h>
1719#include < torch/csrc/jit/passes/remove_dropout.h>
1820#include < torch/csrc/jit/passes/subgraph_rewrite.h>
1921#include < torch/csrc/jit/passes/tensorexpr_fuser.h>
@@ -384,6 +386,13 @@ void FusionPass(std::shared_ptr<Graph>& graph) {
384386 GRAPH_DUMP (
385387 " After IPEXFusionPass. Before RemoveTensorTypeSpecializations" , graph);
386388
389+ // TODO: workaround here to go throughput the TE fuser pass before
390+ // RemoveTensorTypeSpecializations since TE fuser needs the type
391+ // specializations
392+ LowerSimpleTuples (graph);
393+ BatchMM (graph);
394+ FuseTensorExprs (graph, getFusionGroupInlining () ? 2 : 1 );
395+
387396 RemoveTensorTypeSpecializations (graph);
388397 GRAPH_DUMP (
389398 " After RemoveTensorTypeSpecializations. End of optimization pass" , graph);
You can’t perform that action at this time.
0 commit comments