Skip to content

Commit 044234c

Browse files
authored
Revert "Arm backend: Propagate node info from quantizer to backend" (#15760)
Reverts #15300
1 parent b005f10 commit 044234c

File tree

11 files changed

+19
-251
lines changed

11 files changed

+19
-251
lines changed

backends/arm/common/annotation_meta.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 3 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
FuseQuantizedActivationPass,
2222
)
2323
from executorch.backends.arm._passes.insert_table_ops import TableOps
24-
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
2524
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
2625
from executorch.backends.arm.operator_support.ethos_u55_support import (
2726
EthosU55CastCheck,
@@ -141,7 +140,6 @@ def tosa_support_factory(
141140
]
142141

143142
if not tosa_spec.support_float():
144-
negative_checks.append(CheckArmQuantized(reporter))
145143
negative_checks.append(CheckProperQuantization(reporter))
146144
if tosa_spec.is_U55_subset:
147145
negative_checks.append(EthosU55NotSupported(reporter))
@@ -169,6 +167,7 @@ class TOSAProINTSupportList(OperatorSupportBase):
169167
def is_node_supported(
170168
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
171169
) -> bool:
170+
172171
return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList
173172

174173

@@ -181,78 +180,8 @@ class TOSAProFPSupportList(OperatorSupportBase):
181180
def is_node_supported(
182181
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
183182
) -> bool:
184-
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList
185-
186-
187-
class CheckArmQuantized(OperatorSupportBase):
188-
"""
189-
Check if the node was marked as quantized in the Arm backend.
190-
This is used to ensure that nodes that were quantized in the Arm backend
191-
are only partitioned if they are supported by the TOSA backend.
192-
"""
193-
194-
def __init__(self, reporter: WhyNoPartitionReporter):
195-
self.reporter = reporter
196-
197-
def _is_quantized(self, node: torch.fx.Node) -> bool:
198-
"""Checks if the node is quantized.
199183

200-
A node is considered quantized if at least one criteria is met:
201-
- Its dtype is not floating point or complex => integer
202-
- It is one of the special cases where the node has been created in to_edge, e.g.
203-
.Scalar operations that have been promoted .Tensor operations
204-
where the scalar is replaced by a full op.
205-
- It has been marked as quantized in the ArmAnnotationInfo custom meta.
206-
207-
Args:
208-
node (torch.fx.Node): The FX node to check.
209-
210-
Returns:
211-
bool: True if the node is quantized, False otherwise.
212-
"""
213-
node_dtype = get_first_fake_tensor(node).dtype
214-
if not node_dtype.is_complex and not node_dtype.is_floating_point:
215-
return True
216-
if node.target in (
217-
exir_ops.edge.aten.full_like.default,
218-
*ComputeConstantOpsAOT.targeted_ops,
219-
):
220-
# Special cases where nodes have been created in to_edge, e.g.
221-
# .Scalar operations that have been promoted .Tensor operations
222-
# where the scalar is replaced by a full op.
223-
if all(user.target in Q_OPS for user in node.users):
224-
return True
225-
for user in node.users:
226-
if (
227-
user.target
228-
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
229-
):
230-
dim_order_dtype = get_first_fake_tensor(user).dtype
231-
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
232-
return False
233-
else:
234-
return False
235-
return True
236-
return (
237-
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
238-
and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized
239-
)
240-
241-
def is_node_supported(
242-
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
243-
) -> bool:
244-
if node.op != "call_function":
245-
return False
246-
247-
if node.target in (*DQ_OPS, *Q_OPS):
248-
return True
249-
250-
if not self._is_quantized(node):
251-
self.reporter.report_reject(
252-
node, "Node was not marked as quantized in the Arm backend."
253-
)
254-
return False
255-
return True
184+
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList
256185

257186

258187
class CheckProperQuantization(OperatorSupportBase):
@@ -498,6 +427,7 @@ def is_node_supported(
498427

499428

500429
class CheckFloat64Inputs(OperatorSupportBase):
430+
501431
def __init__(
502432
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
503433
):

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
32
# Copyright 2024-2025 Arm Limited and/or its affiliates.
3+
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -14,8 +14,6 @@
1414

1515
from typing import cast
1616

17-
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
18-
1917
from torch.fx import Node
2018

2119
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
@@ -67,10 +65,4 @@ def mark_node_as_annotated(node: Node) -> None:
6765
"""
6866
if Q_ANNOTATION_KEY not in node.meta:
6967
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
70-
annotation_info = ArmAnnotationInfo(
71-
quantized=True,
72-
)
7368
node.meta[Q_ANNOTATION_KEY]._annotated = True
74-
meta_custom = node.meta.get("custom", {})
75-
meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info
76-
node.meta["custom"] = meta_custom

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,6 @@ def _match_pattern(
394394
torch.ops.aten.view.default,
395395
torch.ops.aten.view_as.default,
396396
torch.ops.aten.view_copy.default,
397-
torch.ops.aten._unsafe_view.default,
398397
torch.ops.aten.select.int,
399398
torch.ops.aten.select_copy.int,
400399
torch.ops.aten.slice.Tensor,
@@ -427,7 +426,6 @@ def _match_pattern(
427426
]
428427

429428
_one_to_one_shared_input_or_input_act_qspec = [
430-
torch.ops.aten.alias.default,
431429
torch.ops.aten.clone.default,
432430
torch.ops.aten.hardtanh.default,
433431
torch.ops.aten.hardtanh_.default,
@@ -695,10 +693,10 @@ def any_or_hardtanh_min_zero(n: Node):
695693
]
696694
quant_properties.quant_output = None
697695
elif node.target in [
696+
torch.ops.aten.scalar_tensor.default,
698697
torch.ops.aten.full.default,
699698
torch.ops.aten.full,
700699
torch.ops.aten.fill_.Scalar,
701-
torch.ops.aten.scalar_tensor.default,
702700
]:
703701
quant_properties.quant_inputs = []
704702
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)

backends/arm/test/misc/test_int64.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def forward(self, x: torch.Tensor):
6868
ConstAdd(torch.int64, 2**40),
6969
(torch.rand(10) - 0.5,),
7070
),
71+
"int64_in+float_const": (
72+
ConstAdd(torch.float32),
73+
(torch.randint(0, 10, (10,)),),
74+
),
7175
"fp32_in+int64_buffer_chain": (
7276
BufferChainAdd(torch.int64),
7377
(torch.rand(2, 5, 3) - 0.5,),
@@ -90,7 +94,7 @@ def test_int64_tosa_FP(test_data: Tuple):
9094
ArmTester(
9195
model,
9296
inputs,
93-
common.get_tosa_compile_spec("TOSA-1.0+FP"),
97+
common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"),
9498
)
9599
.export()
96100
.to_edge_transform_and_lower()

backends/arm/test/misc/test_quant_custom_meta.py

Lines changed: 0 additions & 100 deletions
This file was deleted.

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class TestSD3Transformer2DModel:
3939

4040
ops_after_partitioner_INT = {
4141
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
42-
"torch.ops.higher_order.executorch_call_delegate": 3,
43-
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
42+
"torch.ops.higher_order.executorch_call_delegate": 2,
4443
}
4544

4645
def _prepare_inputs(

backends/arm/test/models/test_nn_functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def test_nn_functional_FP(test_data):
102102
@parametrize(
103103
"test_data",
104104
module_tests,
105+
{"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"},
105106
)
106107
def test_nn_functional_INT(test_data):
107108
module, inputs = test_data

backends/arm/test/models/test_torch_functions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,10 @@ def test_torch_fns_FP(test_data):
126126
xfails={
127127
"nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). "
128128
"Requires dynamic output shape.",
129-
"eye": "ValueError: Failed processing buffer placeholder: aten_arange_start_step_1_pre_computed_common. "
130-
"Is the original torch function supported?",
131129
"topk": "NotImplementedError: No registered serialization name for <class 'torch.return_types.topk'> found",
132130
"sort": "NotImplementedError: No registered serialization name for <class 'torch.return_types.sort'> found",
133131
},
134-
strict=True,
132+
strict=False,
135133
)
136134
def test_torch_fns_INT(test_data):
137135
module, inputs = test_data

backends/arm/test/ops/test_eye.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t):
9595
input_data(),
9696
EyeAdd.aten_op,
9797
use_to_edge_transform_and_lower=True,
98-
)
98+
).dump_artifact("to_edge_transform_and_lower")
9999
pipeline.pop_stage("check.quant_nodes")
100100
pipeline.run()
101101

0 commit comments

Comments
 (0)