Skip to content

Commit 4ea9ddf

Browse files
Arm backend: Add docstrings for operator_support/minmax_support.py (#15555)
Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent ac57fde commit 4ea9ddf

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

backends/arm/operator_support/minmax_support.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for min/max along a dimension in TOSA.
6+
7+
Provide support checks ensuring that argmax/argmin indices are not consumed,
8+
restricting to float profiles until index quantization is supported.
9+
10+
"""
511

612
import torch.fx as fx
713
from executorch.backends.arm.operator_support.tosa_supported_operators import (
@@ -14,6 +20,8 @@
1420

1521
@register_tosa_support_check
1622
class MinMaxSupported(SupportedTOSAOperatorCheck):
23+
"""Provide TOSA support check for ``aten.max.dim`` and ``aten.min.dim``."""
24+
1725
targets = [
1826
exir_ops.edge.aten.max.dim,
1927
exir_ops.edge.aten.min.dim,
@@ -24,7 +32,16 @@ class MinMaxSupported(SupportedTOSAOperatorCheck):
2432
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2533
]
2634

27-
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
35+
def is_node_tosa_supported(
36+
self, node: fx.Node, tosa_spec: TosaSpecification
37+
) -> bool:
38+
"""Return True if the node is supported by TOSA.
39+
40+
Allow max/min when the argmax/argmin output is unused or dropped (i.e.,
41+
only the value is consumed). Disallow cases where arg indices are
42+
further used.
43+
44+
"""
2845
if node.target in [exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim]:
2946
no_argmax = len(node.users) == 1
3047
no_argmax_users = (len(node.users) == 2) and (

0 commit comments

Comments
 (0)