Skip to content

Commit 7771799

Browse files
authored
Qualcomm AI Engine Direct - enable operators adaptive_max_pool2d and grid_sampler 2D and 3D (#15371)
### Summary Enable operators adaptive_max_pool2d and grid_sampler 2D and 3D ### Test plan ```bash python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_adaptive_max_pool2d -b build-android -H $HOST -s $SN -m $CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_adaptive_max_pool2d -b build-android -H $HOST -s $SN -m $CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_grid_sampler -b build-android -H $HOST -s $SN -m $CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_grid_sampler -b build-android -H $HOST -s $SN -m $CHIPID ```
1 parent d361573 commit 7771799

File tree

10 files changed

+506
-12
lines changed

10 files changed

+506
-12
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,12 @@ class LayoutTransform(ExportPass):
4343
layout_sensitive_ops = {
4444
exir_ops.edge.aten.adaptive_avg_pool2d.default,
4545
exir_ops.edge.aten._adaptive_avg_pool3d.default,
46+
exir_ops.edge.aten.adaptive_max_pool2d.default,
4647
exir_ops.edge.aten.avg_pool2d.default,
4748
exir_ops.edge.aten.avg_pool3d.default,
4849
exir_ops.edge.aten.convolution.default,
50+
exir_ops.edge.aten.grid_sampler_2d.default,
51+
exir_ops.edge.aten.grid_sampler_3d.default,
4952
exir_ops.edge.aten.instance_norm.default,
5053
exir_ops.edge.aten.max_pool2d_with_indices.default,
5154
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,

backends/qualcomm/builders/README.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of implementing operator builder to unblock yourself and land pull requests more efficiently.
33

44
## Sections
5-
* [References](#references)
6-
* [Getting Started](#getting-started)
7-
* [Identify Unsupported Operator](#identify-unsupported-operator)
8-
* [Check Operator Spec](#check-operator-spec)
9-
* [Implementation](#implementation)
10-
* [Quantizer Annotation](#quantizer-annotation)
11-
* [Operator Support Status](#operator-support-status)
12-
* [Issues](#issues)
13-
* [Pull Requests](#pull-requests)
5+
- [Contribution for More Operators](#contribution-for-more-operators)
6+
- [Sections](#sections)
7+
- [References](#references)
8+
- [Qualcomm AI Engine Direct](#qualcomm-ai-engine-direct)
9+
- [PyTorch](#pytorch)
10+
- [Getting Started](#getting-started)
11+
- [Identify Unsupported Operator](#identify-unsupported-operator)
12+
- [Check Operator Spec](#check-operator-spec)
13+
- [Implementation](#implementation)
14+
- [Quantizer Annotation](#quantizer-annotation)
15+
- [Operator Support Status](#operator-support-status)
16+
- [Issues](#issues)
17+
- [Pull Requests](#pull-requests)
1418

1519
## References
1620
### Qualcomm AI Engine Direct
@@ -365,7 +369,7 @@ Please help update following table if you are contributing new operators:
365369
+ 🚫 = Deprecated, supported with other QNN Ops
366370

367371

368-
| Operators | HTP - 92/116 Enabled |
372+
| Operators | HTP - 94/116 Enabled |
369373
|-----------|---------|
370374
| Argmax | ✓ |
371375
| Argmin | ✓ |
@@ -431,7 +435,7 @@ Please help update following table if you are contributing new operators:
431435
| Gelu | ✓ |
432436
| GetSparseIndices | ✗ |
433437
| GetSparseValues | ✗ |
434-
| GridSample | ✗ |
438+
| GridSample | ✓ |
435439
| GroupNorm | ✓ |
436440
| HardSwish | ✓ |
437441
| InstanceNorm | ✓ |

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
node_visitor,
99
op_abs,
1010
op_adaptive_avg_pool2d,
11+
op_adaptive_max_pool2d,
1112
op_add,
1213
op_amax,
1314
op_amin,
@@ -44,6 +45,7 @@
4445
op_gather,
4546
op_ge,
4647
op_gelu,
48+
op_grid_sampler_2d,
4749
op_group_norm,
4850
op_gt,
4951
op_hardsigmoid,
@@ -114,6 +116,7 @@
114116
node_visitor,
115117
op_abs,
116118
op_adaptive_avg_pool2d,
119+
op_adaptive_max_pool2d,
117120
op_add,
118121
op_amax,
119122
op_amin,
@@ -150,6 +153,7 @@
150153
op_gather,
151154
op_ge,
152155
op_gelu,
156+
op_grid_sampler_2d,
153157
op_group_norm,
154158
op_gt,
155159
op_hardswish,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import cast, Dict, List
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
import numpy as np
11+
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
17+
from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW
18+
19+
20+
@register_node_visitor
21+
class AdaptiveMaxPool2D(NodeVisitor):
22+
target = ["aten.adaptive_max_pool2d.default"]
23+
24+
def __init__(self, *args) -> None:
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
31+
) -> PyQnnWrapper.PyQnnOpWrapper:
32+
input_node = self.get_node(node.args[0])
33+
input_tensor = self.get_tensor(input_node, node)
34+
input_tensor_wrapper = self.define_tensor(
35+
input_node,
36+
node,
37+
input_tensor,
38+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
39+
nodes_to_wrappers,
40+
)
41+
users = list(node.users.keys())
42+
for user in users:
43+
if user.target.__name__ == "getitem":
44+
getitem_index = user.args[1]
45+
if getitem_index != 0:
46+
warnings.warn(
47+
f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}",
48+
stacklevel=1,
49+
)
50+
return
51+
52+
if len(node.args) > 2:
53+
warnings.warn(
54+
"[QNN Delegate Op Builder]: The return_indices is not supported, fallback op",
55+
stacklevel=1,
56+
)
57+
return
58+
59+
input_height = input_tensor.shape[1]
60+
input_width = input_tensor.shape[2]
61+
# output cases
62+
out_wh = cast(List[int], node.args[1])
63+
if len(out_wh) == 1:
64+
output_height = node.args[1][0]
65+
output_width = node.args[1][0]
66+
else:
67+
output_height = node.args[1][0]
68+
output_width = node.args[1][1]
69+
if output_height is None:
70+
output_height = input_height
71+
if output_width is None:
72+
output_width = input_width
73+
# NOTE: Here we need not to emphasize on mode, cuz the output shape is decided by user.
74+
mode = OpPoolMax2d.RoundingMode.FLOOR
75+
76+
# floor division
77+
stride_height = input_height // output_height
78+
filter_height = input_height - (output_height - 1) * stride_height
79+
stride_width = input_width // output_width
80+
filter_width = input_width - (output_width - 1) * stride_width
81+
82+
filter = [filter_height, filter_width]
83+
filter_shape = [len(filter)]
84+
85+
stride = [stride_height, stride_width]
86+
stride_shape = [len(stride)]
87+
88+
padding = [0, 0]
89+
padding_shape = [len(padding), len(padding)]
90+
91+
out_tensor = self.get_tensor(node, node, 0)
92+
output_tensor_wrapper = self.define_tensor(
93+
node,
94+
node,
95+
out_tensor,
96+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
97+
nodes_to_wrappers,
98+
)
99+
100+
adaptive_max_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
101+
node.name,
102+
QNN_OP_PACKAGE_NAME_QTI_AISW,
103+
OpPoolMax2d.op_name,
104+
)
105+
106+
adaptive_max_pool2d_op.AddInputTensors([input_tensor_wrapper])
107+
adaptive_max_pool2d_op.AddOutputTensors([output_tensor_wrapper])
108+
109+
adaptive_max_pool2d_op.AddTensorParam(
110+
OpPoolMax2d.param_filter_size,
111+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
112+
len(filter_shape),
113+
filter_shape,
114+
np.array(
115+
filter,
116+
dtype=np.uint32,
117+
),
118+
True,
119+
)
120+
121+
adaptive_max_pool2d_op.AddTensorParam(
122+
OpPoolMax2d.param_stride,
123+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
124+
len(stride_shape),
125+
stride_shape,
126+
np.array(
127+
stride,
128+
dtype=np.uint32,
129+
),
130+
True,
131+
)
132+
133+
adaptive_max_pool2d_op.AddTensorParam(
134+
OpPoolMax2d.param_pad_amount,
135+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
136+
len(padding_shape),
137+
padding_shape,
138+
np.array(
139+
[[padding[0], padding[0]], [padding[1], padding[1]]],
140+
dtype=np.uint32,
141+
),
142+
True,
143+
)
144+
145+
adaptive_max_pool2d_op.AddScalarParam(
146+
OpPoolMax2d.param_rounding_mode,
147+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
148+
{QCOM_DATA: np.uint32(mode)},
149+
)
150+
151+
return adaptive_max_pool2d_op

0 commit comments

Comments
 (0)