Skip to content

Commit d8cd254

Browse files
chunyuan-wEikanWang
authored andcommitted
add TE fuser pass before removing type specializations (#387)
1 parent a91f1be commit d8cd254

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

torch_ipex/csrc/jit/fusion_pass.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
#include <torch/csrc/jit/frontend/error_report.h>
1313
#include <torch/csrc/jit/ir/alias_analysis.h>
1414
#include <torch/csrc/jit/jit_log.h>
15+
#include <torch/csrc/jit/passes/batch_mm.h>
1516
#include <torch/csrc/jit/passes/constant_propagation.h>
1617
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
18+
#include <torch/csrc/jit/passes/lower_tuples.h>
1719
#include <torch/csrc/jit/passes/remove_dropout.h>
1820
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
1921
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
@@ -384,6 +386,13 @@ void FusionPass(std::shared_ptr<Graph>& graph) {
384386
GRAPH_DUMP(
385387
"After IPEXFusionPass. Before RemoveTensorTypeSpecializations", graph);
386388

389+
// TODO: workaround here to go throughput the TE fuser pass before
390+
// RemoveTensorTypeSpecializations since TE fuser needs the type
391+
// specializations
392+
LowerSimpleTuples(graph);
393+
BatchMM(graph);
394+
FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1);
395+
387396
RemoveTensorTypeSpecializations(graph);
388397
GRAPH_DUMP(
389398
"After RemoveTensorTypeSpecializations. End of optimization pass", graph);

0 commit comments

Comments
 (0)