33"""
44import lzma
55import struct
6+ import sys
67
78from osi3 .osi_sensorview_pb2 import SensorView
89from osi3 .osi_groundtruth_pb2 import GroundTruth
@@ -24,17 +25,18 @@ def map_message_type(type_name):
2425 """Map the type name to the protobuf message type."""
2526 return MESSAGES_TYPE [type_name ]
2627
27- def __init__ (self , path = None , type_name = "SensorView" ):
28+ def __init__ (self , path = None , type_name = "SensorView" , cache_messages = False ):
2829 self .type = self .map_message_type (type_name )
2930 self .file = None
31+ self .current_index = None
3032 self .message_offsets = None
3133 self .read_complete = False
32- self .read_limit = None
34+ self .message_cache = {} if cache_messages else None
3335 self ._header_length = 4
3436 if path :
35- self .from_file (path , type_name )
37+ self .from_file (path , type_name , cache_messages )
3638
37- def from_file (self , path , type_name = "SensorView" ):
39+ def from_file (self , path , type_name = "SensorView" , cache_messages = False ):
3840 """Import a trace from a file"""
3941 self .type = self .map_message_type (type_name )
4042
@@ -44,48 +46,78 @@ def from_file(self, path, type_name="SensorView"):
4446 self .file = open (path , "rb" )
4547
4648 self .read_complete = False
47- self .read_limit = 0
49+ self .current_index = 0
4850 self .message_offsets = [0 ]
51+ self .message_cache = {} if cache_messages else None
4952
5053 def retrieve_offsets (self , limit = None ):
5154 """Retrieve the offsets of the messages from the file."""
5255 if not self .read_complete :
53- self .file .seek (self .read_limit , 0 )
54- while not self .read_complete and not limit or len (self .message_offsets ) < limit :
56+ self .current_index = len (self .message_offsets ) - 1
57+ self .file .seek (self .message_offsets [- 1 ], 0 )
58+ while (
59+ not self .read_complete and not limit or len (self .message_offsets ) <= limit
60+ ):
5561 self .retrieve_message (skip = True )
5662 return self .message_offsets
5763
58- def read_message (self , offset = None , skip = False ):
59- """Read a message from the file at the given offset."""
60- if offset :
61- self .file .seek (offset , 0 )
62- message = self .type ()
64+ def retrieve_message (self , index = None , skip = False ):
65+ """Retrieve the next message from the file at the current position or given index, or skip it if skip is true."""
66+ if index is not None :
67+ self .current_index = index
68+ self .file .seek (self .message_offsets [index ], 0 )
69+ if self .message_cache is not None and self .current_index in self .message_cache :
70+ message = self .message_cache [self .current_index ]
71+ self .current_index += 1
72+ if self .current_index == len (self .message_offsets ):
73+ self .file .seek (0 , 2 )
74+ else :
75+ self .file .seek (self .message_offsets [self .current_index ], 0 )
76+ if skip :
77+ return self .message_offsets [self .current_index ]
78+ else :
79+ return message
80+ start = self .file .tell ()
6381 header = self .file .read (self ._header_length )
6482 if len (header ) < self ._header_length :
83+ if start == self .message_offsets [- 1 ]:
84+ self .message_offsets .pop ()
85+ self .read_complete = True
86+ self .file .seek (start , 0 )
6587 return None
6688 message_length = struct .unpack ("<L" , header )[0 ]
6789 if skip :
68- self .file .seek (message_length , 1 )
69- return self .file .tell ()
90+ new_pos = self .file .seek (message_length , 1 )
91+ if new_pos - start < message_length + self ._header_length :
92+ if start == self .message_offsets [- 1 ]:
93+ self .message_offsets .pop ()
94+ self .read_complete = True
95+ self .file .seek (start , 0 )
96+ return None
97+ self .current_index += 1
98+ if start == self .message_offsets [- 1 ]:
99+ self .message_offsets .append (new_pos )
100+ return new_pos
70101 message_data = self .file .read (message_length )
71102 if len (message_data ) < message_length :
103+ if start == self .message_offsets [- 1 ]:
104+ self .message_offsets .pop ()
105+ self .read_complete = True
106+ self .file .seek (start , 0 )
72107 return None
108+ self .current_index += 1
109+ message = self .type ()
73110 message .ParseFromString (message_data )
111+ if start == self .message_offsets [- 1 ]:
112+ if self .message_cache is not None :
113+ self .message_cache [len (self .message_offsets )- 1 ] = message
114+ self .message_offsets .append (self .file .tell ())
74115 return message
75116
76- def retrieve_message (self , skip = False ):
77- """Retrieve the next message from the file, or skip it if skip is true."""
78- result = self .read_message (skip = skip )
79- if result is None :
80- self .message_offsets .pop ()
81- self .read_complete = True
82- if skip :
83- self .read_limit = result
84- self .message_offsets .append (result )
85- else :
86- self .read_limit = self .file .tell ()
87- self .message_offsets .append (self .read_limit )
88- return result
117+ def restart (self , index = None ):
118+ """Restart the reading of the file from the beginning or from a given index."""
119+ self .current_index = index if index else 0
120+ self .file .seek (self .message_offsets [self .current_index ], 0 )
89121
90122 def __iter__ (self ):
91123 while message := self .retrieve_message ():
@@ -95,36 +127,42 @@ def get_message_by_index(self, index):
95127 """
96128 Get a message by its index.
97129 """
98- if index > len (self .message_offsets ):
130+ if index >= len (self .message_offsets ):
99131 self .retrieve_offsets (index )
100- return self .read_message (self .message_offsets [index ])
132+ if self .message_cache is not None and index in self .message_cache :
133+ return self .message_cache [index ]
134+ return self .retrieve_message (index = index )
101135
102136 def get_messages (self ):
137+ """
138+ Yield an iterator over all messages in the file.
139+ """
103140 return self .get_messages_in_index_range (0 , None )
104141
105142 def get_messages_in_index_range (self , begin , end ):
106143 """
107144 Yield an iterator over messages of indexes between begin and end included.
108145 """
109- if begin > len (self .message_offsets ):
146+ if begin >= len (self .message_offsets ):
110147 self .retrieve_offsets (begin )
111- self .file . seek ( self . message_offsets [ begin ], 0 )
148+ self .restart ( begin )
112149 current = begin
113150 while end is None or current < end :
114- message = (
115- self .retrieve_message ()
116- if current >= len (self .message_offsets )
117- else self .read_message ()
118- )
119- if message is None :
120- break
121- yield message
151+ if self .message_cache is not None and current in self .message_cache :
152+ yield self .message_cache [current ]
153+ else :
154+ message = self .retrieve_message ()
155+ if message is None :
156+ break
157+ yield message
122158 current += 1
123159
124160 def close (self ):
125161 if self .file :
126162 self .file .close ()
127163 self .file = None
164+ self .current_index = None
165+ self .message_cache = None
128166 self .message_offsets = None
129167 self .read_complete = False
130168 self .read_limit = None
0 commit comments