22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Declare operator support for min/max along a dimension in TOSA.
6+
7+ Provide support checks ensuring that argmax/argmin indices are not consumed,
8+ restricting to float profiles until index quantization is supported.
9+
10+ """
511
612import torch .fx as fx
713from executorch .backends .arm .operator_support .tosa_supported_operators import (
1420
1521@register_tosa_support_check
1622class MinMaxSupported (SupportedTOSAOperatorCheck ):
23+ """Provide TOSA support check for ``aten.max.dim`` and ``aten.min.dim``."""
24+
1725 targets = [
1826 exir_ops .edge .aten .max .dim ,
1927 exir_ops .edge .aten .min .dim ,
@@ -24,7 +32,16 @@ class MinMaxSupported(SupportedTOSAOperatorCheck):
2432 TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
2533 ]
2634
27- def is_node_tosa_supported (self , node : fx .Node , tosa_spec : TosaSpecification ):
35+ def is_node_tosa_supported (
36+ self , node : fx .Node , tosa_spec : TosaSpecification
37+ ) -> bool :
38+ """Return True if the node is supported by TOSA.
39+
40+ Allow max/min when the argmax/argmin output is unused or dropped (i.e.,
41+ only the value is consumed). Disallow cases where arg indices are
42+ further used.
43+
44+ """
2845 if node .target in [exir_ops .edge .aten .max .dim , exir_ops .edge .aten .min .dim ]:
2946 no_argmax = len (node .users ) == 1
3047 no_argmax_users = (len (node .users ) == 2 ) and (
0 commit comments