Skip to content

Commit 74aa212

Browse files
authored
fix conv_folding cannot capture Conv -> Mul -> Add etc (#602)
* fix conv_folding cannot capture Conv -> Mul -> Add etc
1 parent 2b63257 commit 74aa212

File tree

4 files changed

+88
-12
lines changed

4 files changed

+88
-12
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_conv_folding.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,11 @@ bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) {
119119
return true;
120120
}
121121

122-
void FoldFrozenConvAddOrSub(Block* b) {
122+
bool FoldFrozenConvAddOrSub(Block* b) {
123+
bool graph_modified = false;
123124
for (Node* n : b->nodes()) {
124125
for (Block* block : n->blocks()) {
125-
FoldFrozenConvAddOrSub(block);
126+
graph_modified |= FoldFrozenConvAddOrSub(block);
126127
}
127128

128129
if (supportedAddOrSub(n) && supportedConvNode(n->inputs().at(0)->node())) {
@@ -174,15 +175,18 @@ void FoldFrozenConvAddOrSub(Block* b) {
174175
add_or_sub->kind().toUnqualString());
175176
conv->replaceInputWith(conv_b_value, fused_conv_b);
176177
add_or_sub->output()->replaceAllUsesWith(conv->output());
178+
graph_modified = true;
177179
// DCE run after cleans up nodes
178180
}
179181
}
182+
return graph_modified;
180183
}
181184

182-
void FoldFrozenConvMulOrDiv(Block* b) {
185+
bool FoldFrozenConvMulOrDiv(Block* b) {
186+
bool graph_modified = false;
183187
for (Node* n : b->nodes()) {
184188
for (Block* block : n->blocks()) {
185-
FoldFrozenConvMulOrDiv(block);
189+
graph_modified |= FoldFrozenConvMulOrDiv(block);
186190
}
187191

188192
if (supportedMulOrDiv(n) && supportedConvNode(n->inputs().at(0)->node())) {
@@ -287,21 +291,35 @@ void FoldFrozenConvMulOrDiv(Block* b) {
287291
mul_or_div->kind().toUnqualString());
288292
conv->replaceInputWith(conv_b_value, fused_conv_bias);
289293
}
294+
graph_modified = true;
290295
// DCE run after cleans up nodes
291296
}
292297
}
298+
return graph_modified;
293299
}
294300

295301
} // namespace
296302

297-
void FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph) {
298-
FoldFrozenConvAddOrSub(graph->block());
303+
bool FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph) {
304+
bool graph_modified = FoldFrozenConvAddOrSub(graph->block());
299305
EliminateDeadCode(graph);
306+
return graph_modified;
300307
}
301308

302-
void FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph) {
303-
FoldFrozenConvMulOrDiv(graph->block());
309+
bool FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph) {
310+
bool graph_modified = FoldFrozenConvMulOrDiv(graph->block());
304311
EliminateDeadCode(graph);
312+
return graph_modified;
313+
}
314+
315+
void FrozenConvFolding(std::shared_ptr<Graph>& graph) {
316+
// run a couple times to capture Conv -> Mul -> Add etc
317+
bool changed;
318+
do {
319+
changed = false;
320+
changed |= FoldFrozenConvAddOrSub(graph);
321+
changed |= FoldFrozenConvMulOrDiv(graph);
322+
} while (changed);
305323
}
306324

307325
} // namespace jit

intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_conv_folding.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ namespace jit {
88
// Fuses Convolution -> Add/Sub into a single Convolution by
99
// folding add constant tensor into conv weights.
1010
// This pass only works on Frozen Graphs; otherwise it is a No-Op.
11-
TORCH_API void FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph);
11+
TORCH_API bool FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph);
1212

1313
// Fuses Convolution -> Mul/Div into a single Convolution by
1414
// folding add constant tensor into conv weights.
1515
// This pass only works on Frozen Graphs; otherwise it is a No-Op.
16-
TORCH_API void FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph);
16+
TORCH_API bool FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph);
17+
18+
// Call FoldFrozenConvAddOrSub and FoldFrozenConvMulOrDiv a couple times
19+
TORCH_API void FrozenConvFolding(std::shared_ptr<Graph>& graph);
1720

1821
} // namespace jit
1922
} // namespace torch

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,7 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
359359
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
360360

361361
// convolution folding
362-
FoldFrozenConvAddOrSub(graph);
363-
FoldFrozenConvMulOrDiv(graph);
362+
FrozenConvFolding(graph);
364363

365364
// convolution fusion
366365
graph_rewrite::insertPrePackedConvOp(graph);

tests/cpu/test_jit.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@
7474
from torch._six import inf, nan
7575
from torch.testing._internal.common_utils import TestCase
7676

77+
try:
78+
import torchvision
79+
HAS_TORCHVISION = True
80+
except ImportError:
81+
HAS_TORCHVISION = False
82+
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
83+
7784
device = 'cpu:0'
7885
SIZE = 100
7986

@@ -2508,6 +2515,55 @@ def test_remove_bailout(self):
25082515
kind_not_in_graph="prim::BailOut",
25092516
prec=0.05)
25102517

2518+
@skipIfNoTorchVision
2519+
def test_conv_torchvision_bn_folding(self):
2520+
from torchvision.ops import misc as misc_nn_ops
2521+
class M(nn.Module):
2522+
def __init__(self):
2523+
super(M, self).__init__()
2524+
norm_layer = misc_nn_ops.FrozenBatchNorm2d
2525+
self.inplanes = 64
2526+
self.dilation = 1
2527+
self.groups = 1
2528+
self.base_width = 64
2529+
self.conv1 = torch.nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
2530+
self.bn1 = norm_layer(self.inplanes)
2531+
self.relu = torch.nn.ReLU(inplace=True)
2532+
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
2533+
2534+
def forward(self, x):
2535+
x = self.conv1(x)
2536+
x = self.bn1(x)
2537+
x = self.relu(x)
2538+
x = self.maxpool(x)
2539+
return x
2540+
2541+
model = M().eval()
2542+
self._test_output(
2543+
model,
2544+
torch.randn(1, 3, 1200, 1200),
2545+
kind_in_graph="ipex_prepack::convolution_relu_run",
2546+
kind_not_in_graph="aten::add")
2547+
2548+
self._test_output(
2549+
model,
2550+
torch.randn(1, 3, 1200, 1200),
2551+
kind_in_graph="ipex_prepack::convolution_relu_run",
2552+
kind_not_in_graph="aten::mul")
2553+
2554+
self._test_output_bf16(
2555+
model,
2556+
torch.randn(1, 3, 1200, 1200),
2557+
kind_in_graph="ipex_prepack::convolution_relu_run",
2558+
kind_not_in_graph="aten::add",
2559+
prec=0.1)
2560+
2561+
self._test_output_bf16(
2562+
model,
2563+
torch.randn(1, 3, 1200, 1200),
2564+
kind_in_graph="ipex_prepack::convolution_relu_run",
2565+
kind_not_in_graph="aten::mul",
2566+
prec=0.1)
25112567

25122568
if __name__ == '__main__':
25132569
torch.manual_seed(2020)

0 commit comments

Comments
 (0)