Skip to content

Commit 58cc12f

Browse files
committed
fix(compression): insert DECODE per consumer for alt decompression memory
Insert a separate DECODE immediately before each consumer of a compressed tensor, rather than sharing one DECODE output among all consumers. The interpreter's alternate decompression memory resets its allocation offset for each DECODE's Prepare, causing all DECODE outputs to be allocated at the same address. If two consumers share one DECODE and another DECODE runs between them, the intervening DECODE overwrites the shared output, corrupting data for the second consumer. Update test expectations to reflect the new DECODE-per-consumer behavior and change the integration test from expected-failure to expected-pass.
1 parent b05656b commit 58cc12f

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

tensorflow/lite/micro/compression/compression_integration_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ class AltDecompressionMemoryTest(tf.test.TestCase):
289289
between multiple operators and alternate decompression memory is enabled.
290290
"""
291291

292-
@unittest.expectedFailure
293292
def test_shared_compressed_tensor_with_alt_memory(self):
294293
"""Verify correct results when a shared compressed tensor is used with alt
295294
decompression memory.
@@ -303,12 +302,9 @@ def test_shared_compressed_tensor_with_alt_memory(self):
303302
DECODE outputs are allocated at the same address, so they overwrite each
304303
other. A DECODE output can only be used until the next DECODE runs.
305304
306-
To work around this limitation, the DECODE insertion code must insert a
305+
To work around this limitation, the DECODE insertion code inserts a
307306
separate DECODE immediately before each consumer of a compressed tensor,
308307
rather than sharing one DECODE output among all consumers.
309-
310-
This test is expected to fail because the current insertion code does not
311-
yet implement this workaround.
312308
"""
313309
flatbuffer = _build_shared_weights_model()
314310

tensorflow/lite/micro/compression/decode_insert.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,19 @@ def insert_decode_operators(
143143
This function modifies the model in-place, inserting DECODE operators
144144
before any operator that uses a compressed tensor as input.
145145
146-
For each compressed tensor:
146+
A separate DECODE is inserted before each consumer, rather than sharing one
147+
DECODE output among all consumers. This is required because the interpreter's
148+
alternate decompression memory resets its allocation offset for each DECODE's
149+
Prepare, causing all DECODE outputs to be allocated at the same address. If
150+
two consumers share one DECODE and another DECODE runs between them, the
151+
intervening DECODE overwrites the shared output, corrupting data for the
152+
second consumer.
153+
154+
For each consumer of a compressed tensor:
147155
1. Create an ancillary data tensor containing DCM + type-specific data
148156
2. Create an output tensor with the same shape/dtype as the decoded tensor
149-
3. Insert a DECODE operator before the first consumer
150-
4. Rewire all consumers to use the DECODE output instead of the encoded tensor
157+
3. Insert a DECODE operator immediately before the consumer
158+
4. Rewire the consumer to use the DECODE output
151159
152160
Args:
153161
model: The model to modify in-place.
@@ -179,23 +187,27 @@ def insert_decode_operators(
179187
for sg_idx, tensor_infos in by_subgraph.items():
180188
subgraph = model.subgraphs[sg_idx]
181189

182-
# Sort by earliest consumer position (process in reverse order to maintain
183-
# valid positions as we insert)
184-
tensor_infos.sort(
185-
key=lambda info: _find_earliest_consumer_position(
186-
subgraph, info.consumers),
190+
# Collect all (consumer, tensor_info) pairs and sort by consumer position
191+
# in reverse order so insertions don't invalidate positions
192+
consumer_pairs = []
193+
for info in tensor_infos:
194+
for consumer in info.consumers:
195+
consumer_pairs.append((consumer, info))
196+
197+
consumer_pairs.sort(
198+
key=lambda pair: subgraph.operators.index(pair[0]),
187199
reverse=True,
188200
)
189201

190-
for info in tensor_infos:
191-
# Create ancillary data tensor
202+
for consumer, info in consumer_pairs:
203+
# Create ancillary data tensor (one per DECODE)
192204
ancillary_tensor = _create_ancillary_tensor(
193205
info.ancillary_data,
194206
info.tensor,
195207
)
196208
subgraph.tensors.append(ancillary_tensor)
197209

198-
# Create output tensor
210+
# Create output tensor (one per DECODE)
199211
output_tensor = _create_output_tensor(info.tensor)
200212
subgraph.tensors.append(output_tensor)
201213

@@ -207,11 +219,9 @@ def insert_decode_operators(
207219
outputs=[output_tensor],
208220
)
209221

210-
# Find insertion position (before first consumer)
211-
insert_pos = _find_earliest_consumer_position(subgraph, info.consumers)
212-
213-
# Insert DECODE operator
222+
# Insert DECODE immediately before this consumer
223+
insert_pos = subgraph.operators.index(consumer)
214224
subgraph.operators.insert(insert_pos, decode_op)
215225

216-
# Rewire all consumers to use the decoded output
217-
_rewire_consumers(info.consumers, info.tensor, output_tensor)
226+
# Rewire only this consumer to use the decoded output
227+
_rewire_consumers([consumer], info.tensor, output_tensor)

tensorflow/lite/micro/compression/decode_insert_test.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ def test_consumer_rewired_to_decode_output(self):
222222
# Original weights tensor should NOT be in FC inputs
223223
self.assertNotIn(weights_tensor, fc_op.inputs)
224224

225-
def test_shared_tensor_single_decode(self):
226-
"""Tensor used by multiple ops gets single DECODE, both rewired."""
225+
def test_shared_tensor_decode_per_consumer(self):
226+
"""Tensor used by multiple ops gets separate DECODE for each consumer."""
227227
model = _build_shared_weights_model()
228228
weights_tensor = model.subgraphs[0].tensors[0]
229229

@@ -239,17 +239,25 @@ def test_shared_tensor_single_decode(self):
239239

240240
sg = model.subgraphs[0]
241241

242-
# Should have 3 operators: 1 DECODE + 2 FC
243-
self.assertEqual(len(sg.operators), 3)
242+
# Should have 4 operators: 2 DECODEs + 2 FCs (DECODE before each FC)
243+
self.assertEqual(len(sg.operators), 4)
244244
self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM)
245+
self.assertEqual(sg.operators[1].opcode,
246+
tflite.BuiltinOperator.FULLY_CONNECTED)
247+
self.assertEqual(sg.operators[2].opcode, tflite.BuiltinOperator.CUSTOM)
248+
self.assertEqual(sg.operators[3].opcode,
249+
tflite.BuiltinOperator.FULLY_CONNECTED)
245250

246-
decode_op = sg.operators[0]
251+
decode_op1 = sg.operators[0]
247252
fc_op1 = sg.operators[1]
248-
fc_op2 = sg.operators[2]
249-
250-
# Both FCs should use DECODE's output
251-
self.assertIs(fc_op1.inputs[1], decode_op.outputs[0])
252-
self.assertIs(fc_op2.inputs[1], decode_op.outputs[0])
253+
decode_op2 = sg.operators[2]
254+
fc_op2 = sg.operators[3]
255+
256+
# Each FC should use its own DECODE's output
257+
self.assertIs(fc_op1.inputs[1], decode_op1.outputs[0])
258+
self.assertIs(fc_op2.inputs[1], decode_op2.outputs[0])
259+
# The two DECODEs should have different outputs
260+
self.assertIsNot(decode_op1.outputs[0], decode_op2.outputs[0])
253261

254262
def test_ancillary_tensor_contains_dcm(self):
255263
"""Ancillary tensor data contains valid DCM header."""

0 commit comments

Comments
 (0)