Skip to content

Commit 58237e9

Browse files
authored
Don't cache source partitions
Differential Revision: D86149805 Pull Request resolved: #15541
1 parent 386c5fb commit 58237e9

File tree

3 files changed

+91
-3
lines changed

3 files changed

+91
-3
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,7 @@ def get_deps(
458458
a bool indicating if the deps are valid and a list of all the
459459
dep nodes. This handles the src partition for
460460
"""
461-
if self.src_partitions is None:
462-
# Cache src partitions so we don't have to recompute them every time
463-
self.src_partitions = get_source_partitions(ep.graph, self.linear_modules)
461+
self.src_partitions = get_source_partitions(ep.graph, self.linear_modules)
464462

465463
# src_partition is None if node is not in source partition,
466464
# otherwise gives us the linear source partition it belongs to

backends/xnnpack/test/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,14 @@ runtime.python_test(
113113
"//executorch/examples/xnnpack:models", # @manual
114114
],
115115
)
116+
117+
runtime.python_test(
118+
name = "test_xnnpack_partitioner",
119+
srcs = ["test_xnnpack_partitioner.py"],
120+
deps = [
121+
"//caffe2:torch",
122+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
123+
"//executorch/exir:lib",
124+
"//executorch/extension/pybindings:portable_lib",
125+
],
126+
)

backends/xnnpack/test/test_xnnpack_partitioner.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@
99
import unittest
1010

1111
import torch
12+
import torch.nn.functional as F
13+
1214
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
1315
from executorch.exir import to_edge, to_edge_transform_and_lower
16+
from executorch.extension.pybindings.portable_lib import (
17+
_load_for_executorch_from_buffer,
18+
)
1419
from torch.export import export
1520

1621

@@ -82,3 +87,77 @@ def test_no_warning_for_to_edge_transform_and_lower_workflow(self):
8287

8388
log_contents = log_capture_string.getvalue()
8489
self.assertNotIn("DEPRECATION WARNING", log_contents)
90+
91+
def test_multi_method_partitioning_with_shared_weights(self):
92+
"""
93+
Test that multi-method models with shared weights are correctly partitioned.
94+
Verify that:
95+
1. Both methods are fully lowered to XNNPACK.
96+
2. Constants are not duplicated between named data and constant buffers.
97+
3. Program executes correctly.
98+
"""
99+
100+
class MultiMethodModel(torch.nn.Module):
101+
def __init__(self):
102+
super().__init__()
103+
self.linear = torch.nn.Linear(8, 16)
104+
self.linear2 = torch.nn.Linear(16, 8)
105+
106+
def forward(self, x):
107+
return self.linear2(F.sigmoid(self.linear(x)))
108+
109+
def forward_2(self, x):
110+
return self.linear2(F.relu(self.linear(x)))
111+
112+
def example_inputs(self):
113+
return (torch.randn(1, 8),)
114+
115+
model = MultiMethodModel()
116+
117+
# Get eager reference output.
118+
example_inputs = model.example_inputs()
119+
with torch.no_grad():
120+
fwd1_eager = model.forward(*example_inputs)
121+
fwd2_eager = model.forward_2(*example_inputs)
122+
123+
# Export both methods
124+
ep_fwd = export(model, model.example_inputs(), strict=True)
125+
# Patch the forward, as export only traces the 'forward' method.
126+
model.forward = model.forward_2
127+
ep_fwd_2 = export(model, model.example_inputs(), strict=True)
128+
129+
# Convert to edge and lower to executorch
130+
edge = to_edge({"forward": ep_fwd, "forward_2": ep_fwd_2})
131+
lowered = edge.to_backend(XnnpackPartitioner(force_fp32_dynamic_linear=True))
132+
executorch = lowered.to_executorch()
133+
134+
# Check that graph is fully delegated.
135+
nodes_1 = list(lowered._edge_programs["forward"].graph.nodes)
136+
nodes_2 = list(lowered._edge_programs["forward_2"].graph.nodes)
137+
self.assertEqual(len(nodes_1), 5)
138+
self.assertEqual(len(nodes_2), 5)
139+
expected_node_names = [
140+
"x",
141+
"lowered_module_0",
142+
"executorch_call_delegate",
143+
"getitem",
144+
"output_1",
145+
]
146+
for n in expected_node_names:
147+
self.assertTrue(any(node.name == n for node in nodes_1))
148+
self.assertTrue(any(node.name == n for node in nodes_2))
149+
150+
# Check that weights are not duplicated.
151+
self.assertEqual(len(executorch._named_data.pte_data), 4)
152+
self.assertEqual(len(executorch._named_data.buffers), 4)
153+
self.assertEqual(len(executorch._named_data.external_data), 0)
154+
155+
# Check that there are no constant buffers (besides the placeholder).
156+
self.assertEqual(len(executorch._emitter_output.program.constant_buffer), 1)
157+
158+
# Check for model correctness.
159+
executorch_module = _load_for_executorch_from_buffer(executorch.buffer)
160+
fwd1_et = executorch_module.run_method("forward", example_inputs)
161+
fwd2_et = executorch_module.run_method("forward_2", example_inputs)
162+
self.assertTrue(torch.allclose(fwd1_eager, fwd1_et[0], 1e-3))
163+
self.assertTrue(torch.allclose(fwd2_eager, fwd2_et[0], 1e-3))

0 commit comments

Comments
 (0)