|
9 | 9 | import unittest |
10 | 10 |
|
11 | 11 | import torch |
| 12 | +import torch.nn.functional as F |
| 13 | + |
12 | 14 | from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
13 | 15 | from executorch.exir import to_edge, to_edge_transform_and_lower |
| 16 | +from executorch.extension.pybindings.portable_lib import ( |
| 17 | + _load_for_executorch_from_buffer, |
| 18 | +) |
14 | 19 | from torch.export import export |
15 | 20 |
|
16 | 21 |
|
@@ -82,3 +87,77 @@ def test_no_warning_for_to_edge_transform_and_lower_workflow(self): |
82 | 87 |
|
83 | 88 | log_contents = log_capture_string.getvalue() |
84 | 89 | self.assertNotIn("DEPRECATION WARNING", log_contents) |
| 90 | + |
| 91 | + def test_multi_method_partitioning_with_shared_weights(self): |
| 92 | + """ |
| 93 | + Test that multi-method models with shared weights are correctly partitioned. |
| 94 | + Verify that: |
| 95 | + 1. Both methods are fully lowered to XNNPACK. |
| 96 | + 2. Constants are not duplicated between named data and constant buffers. |
| 97 | + 3. Program executes correctly. |
| 98 | + """ |
| 99 | + |
| 100 | + class MultiMethodModel(torch.nn.Module): |
| 101 | + def __init__(self): |
| 102 | + super().__init__() |
| 103 | + self.linear = torch.nn.Linear(8, 16) |
| 104 | + self.linear2 = torch.nn.Linear(16, 8) |
| 105 | + |
| 106 | + def forward(self, x): |
| 107 | + return self.linear2(F.sigmoid(self.linear(x))) |
| 108 | + |
| 109 | + def forward_2(self, x): |
| 110 | + return self.linear2(F.relu(self.linear(x))) |
| 111 | + |
| 112 | + def example_inputs(self): |
| 113 | + return (torch.randn(1, 8),) |
| 114 | + |
| 115 | + model = MultiMethodModel() |
| 116 | + |
| 117 | + # Get eager reference output. |
| 118 | + example_inputs = model.example_inputs() |
| 119 | + with torch.no_grad(): |
| 120 | + fwd1_eager = model.forward(*example_inputs) |
| 121 | + fwd2_eager = model.forward_2(*example_inputs) |
| 122 | + |
| 123 | + # Export both methods |
| 124 | + ep_fwd = export(model, model.example_inputs(), strict=True) |
| 125 | + # Patch the forward, as export only traces the 'forward' method. |
| 126 | + model.forward = model.forward_2 |
| 127 | + ep_fwd_2 = export(model, model.example_inputs(), strict=True) |
| 128 | + |
| 129 | + # Convert to edge and lower to executorch |
| 130 | + edge = to_edge({"forward": ep_fwd, "forward_2": ep_fwd_2}) |
| 131 | + lowered = edge.to_backend(XnnpackPartitioner(force_fp32_dynamic_linear=True)) |
| 132 | + executorch = lowered.to_executorch() |
| 133 | + |
| 134 | + # Check that graph is fully delegated. |
| 135 | + nodes_1 = list(lowered._edge_programs["forward"].graph.nodes) |
| 136 | + nodes_2 = list(lowered._edge_programs["forward_2"].graph.nodes) |
| 137 | + self.assertEqual(len(nodes_1), 5) |
| 138 | + self.assertEqual(len(nodes_2), 5) |
| 139 | + expected_node_names = [ |
| 140 | + "x", |
| 141 | + "lowered_module_0", |
| 142 | + "executorch_call_delegate", |
| 143 | + "getitem", |
| 144 | + "output_1", |
| 145 | + ] |
| 146 | + for n in expected_node_names: |
| 147 | + self.assertTrue(any(node.name == n for node in nodes_1)) |
| 148 | + self.assertTrue(any(node.name == n for node in nodes_2)) |
| 149 | + |
| 150 | + # Check that weights are not duplicated. |
| 151 | + self.assertEqual(len(executorch._named_data.pte_data), 4) |
| 152 | + self.assertEqual(len(executorch._named_data.buffers), 4) |
| 153 | + self.assertEqual(len(executorch._named_data.external_data), 0) |
| 154 | + |
| 155 | + # Check that there are no constant buffers (besides the placeholder). |
| 156 | + self.assertEqual(len(executorch._emitter_output.program.constant_buffer), 1) |
| 157 | + |
| 158 | + # Check for model correctness. |
| 159 | + executorch_module = _load_for_executorch_from_buffer(executorch.buffer) |
| 160 | + fwd1_et = executorch_module.run_method("forward", example_inputs) |
| 161 | + fwd2_et = executorch_module.run_method("forward_2", example_inputs) |
| 162 | + self.assertTrue(torch.allclose(fwd1_eager, fwd1_et[0], 1e-3)) |
| 163 | + self.assertTrue(torch.allclose(fwd2_eager, fwd2_et[0], 1e-3)) |
0 commit comments