From 209b496ac69fe3ec5dc6b5f302d3538d5fc1779f Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 7 Nov 2025 16:48:39 -0800 Subject: [PATCH] Add 'SerializationArtifacts' to hold program+segments --- exir/_serialize/_program.py | 76 ++++++++++++++++++++++++---- exir/_serialize/test/test_program.py | 27 +++++----- 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index bee5b3438b0..b992ed1f97d 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -45,6 +45,15 @@ # endian. _HEADER_BYTEORDER: Literal["little"] = "little" +@dataclass +class SerializationArtifacts: + """ + Holds data required to serialize into a PTE. + """ + program: Program + mutable_data: Optional[List[Buffer]] + named_data: Optional[NamedDataStoreOutput] + # TODO: add constants here and remove constant buffer. @dataclass class AlignedData: @@ -575,7 +584,7 @@ def serialize_pte_binary( return pte_data -def _restore_segments(program: Program, segment_data: bytes) -> Program: +def _restore_segments(program: Program, segment_data: bytes) -> SerializationArtifacts: """Moves segments from `segment_data` into `program`. This should recreate the original Program that the segments were extracted @@ -589,7 +598,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: the preceding data has been stripped off so that the first segment begins at offset zero. Returns: - The Program with segments restored. + SerializationArtifacts, containing the Program with delegate and constant segments restored, as well as mutable and named data segments. """ # Extract the list of segment data blobs, which parallel program.segments. segments: List[bytes] = [] @@ -624,7 +633,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: # Replace constants from constant_segment into constant_buffer. if program.constant_segment and len(program.constant_segment.offsets) > 0: - buffers: List[Buffer] = [] + constant_buffers: List[Buffer] = [] constant_segment = segments[program.constant_segment.segment_index] for i in range(len(program.constant_segment.offsets)): start_offset = program.constant_segment.offsets[i] @@ -635,17 +644,60 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: if i < len(program.constant_segment.offsets) - 1 else len(constant_segment) ) - buffers.append(Buffer(storage=constant_segment[start_offset:end_offset])) - program.constant_buffer = buffers + constant_buffers.append(Buffer(storage=constant_segment[start_offset:end_offset])) + program.constant_buffer = constant_buffers program.constant_segment.segment_index = 0 program.constant_segment.offsets = [] + # Extract mutable segments. + mutable_data = None + if program.mutable_data_segment and len(program.mutable_data_segments.offsets) > 0: + mutable_buffers: List[Buffer] = [] + mutable_segment = segments[program.mutable_segment.segment_index] + for i in range(len(program.mutable_segments.offsets)): + start_offset = program.mutable_segment.offsets[i] + # Note: this is the original end offset plus any padding between + # it and the next start offset. + end_offset = ( + program.mutable_segment.offsets[i + 1] + if i < len(program.mutable_segment.offsets) - 1 + else len(mutable_segment) + ) + mutable_buffers.append(Buffer(storage=mutable_segment[start_offset:end_offset])) + mutable_data = mutable_buffers + # Is this correct? + program.mutable_segment.segment_index = 0 + program.mutable_segment.offsets = [] + + # Extract named data. + named_data = None + if program.named_data: + named_data_buffers: List[bytes] = [] + pte_data: Dict[str, DataEntry] = {} + + for entry in program.named_data: + if (entry.segment_index >= len(segments)): + raise ValueError( + "Named data segment index " + f"{entry.segment_index} >= num segments {len(segments)}" + ) + named_data_buffers.append(segments[entry.segment_index]) + pte_data[entry.key] = DataEntry( + buffer_index = len(named_data_buffers) - 1, + alignment = 1, # Deserialization does not preserve alignment. + tensor_layout = None + ) + named_data = NamedDataStoreOutput(buffers=named_data_buffers, pte_data=pte_data, external_data=None) + # Clear out the segments list since the original Program didn't have one. program.segments = [] - return program - + return SerializationArtifacts( + program=program, + mutable_data=mutable_data, + named_data=named_data + ) -def deserialize_pte_binary(program_data: bytes) -> Program: +def deserialize_pte_binary(program_data: bytes) -> SerializationArtifacts: """Returns a Program deserialized from the given runtime binary data.""" program_size = len(program_data) segment_base_offset = 0 @@ -664,8 +716,12 @@ def deserialize_pte_binary(program_data: bytes) -> Program: if segment_base_offset != 0: # Move segment data back into the Program. - program = _restore_segments( + return _restore_segments( program=program, segment_data=program_data[segment_base_offset:] ) - return program + return SerializationArtifacts( + program=program, + mutable_data=None, + named_data=None, + ) diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 80f4b8ca49f..ced2524656d 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -281,13 +281,13 @@ def constant_segment_with_tensor_alignment( ) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization # does not preserve constant segment; padding may be added # during serialization. - self.assertEqual(program2.execution_plan, program.execution_plan) + self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + self.assertEqual(len(deserialized.program.constant_buffer), len(program.constant_buffer)) def test_canonicalize_delegate_indices(self) -> None: def make_execution_plan( @@ -426,10 +426,9 @@ def test_round_trip_no_header_no_segments(self) -> None: self.assertIsNone(eh) # Convert back. - program2 = deserialize_pte_binary(pte_data) - + deserialized = deserialize_pte_binary(pte_data) # Programs should be the same. - self.assert_programs_equal(program, program2) + self.assert_programs_equal(program, deserialized.program) def test_round_trip_large_buffer_sizes(self) -> None: """Tests that when the non_const_buffer_sizes contains integers @@ -439,7 +438,7 @@ def test_round_trip_large_buffer_sizes(self) -> None: program = get_test_program() program.execution_plan[0].non_const_buffer_sizes = [0, 2**48] flatbuffer_from_py = bytes(serialize_pte_binary(program)) - self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py)) + self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py).program) def test_round_trip_no_segments_and_no_header(self) -> None: """Tests that a Program serialized with extract_delegate_segments=True @@ -463,10 +462,10 @@ def test_round_trip_no_segments_and_no_header(self) -> None: self.assertEqual(program_with_segments.segments, []) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs should be the same. - self.assert_programs_equal(program, program2) + self.assert_programs_equal(program, deserialized.program) @staticmethod def gen_blob_data(size: int, pattern: bytes) -> bytes: @@ -598,8 +597,8 @@ def test_round_trip_with_segments(self) -> None: # meaning that the segments were moved back to inline. This also # demonstrates that the contents of all segments survived, and weren't # truncated or corrupted. - program2 = deserialize_pte_binary(pte_data) - self.assert_programs_equal(program, program2) + deserialized = deserialize_pte_binary(pte_data) + self.assert_programs_equal(program, deserialized.program) def test_no_constants(self) -> None: program = get_test_program() @@ -884,13 +883,13 @@ def test_constant_delegate_and_named_data_segments(self) -> None: ) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization # does not preserve constant segment; padding may be added # during serialization. - self.assertEqual(program2.execution_plan, program.execution_plan) + self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + self.assertEqual(len(deserialized.program.constant_buffer), len(program.constant_buffer)) def test_named_data_segments(self) -> None: # Set segment alignment to 12 to test the padding.