@@ -79,7 +79,7 @@ def b(x):
7979 return x .encode ('ascii' )
8080
8181# Windows version of Python does not provide it
82- # for compatibility with older versions of Windows.
82+ # for compatibility with older versions of Windows.
8383if not hasattr (socket , 'inet_pton' ):
8484 def inet_pton (t , addr ):
8585 import ctypes
@@ -95,14 +95,23 @@ def inet_pton(t, addr):
9595 return out_addr_p .raw
9696 socket .inet_pton = inet_pton
9797
98- def is_ipv4 (hostname ):
99- ip_parts = hostname .split ('.' )
100- for i in range (0 ,len (ip_parts )):
101- if int (ip_parts [i ]) > 255 :
98+ def is_ipv4 (ip ):
99+ if '.' in ip :
100+ ip_parts = ip .split ('.' )
101+ if len (ip_parts ) == 4 :
102+ for i in range (0 ,len (ip_parts )):
103+ if str (ip_parts [i ]).isdigit ():
104+ if int (ip_parts [i ]) > 255 :
105+ return False
106+ else :
107+ return False
108+ pattern = r'^([0-9]{1,3}[.]){3}[0-9]{1,3}$'
109+ if match (pattern , ip ) is not None :
110+ return 4
111+ else :
102112 return False
103- pattern = r'^([0-9]{1,3}[.]){3}[0-9]{1,3}$'
104- if match (pattern , hostname ) is not None :
105- return 4
113+ else :
114+ return False
106115 return False
107116
108117def is_ipv6 (hostname ):
@@ -187,20 +196,26 @@ def open(self, filename):
187196 self ._f = open (filename , 'rb' )
188197 else :
189198 raise ValueError ("Invalid mode. Please enter either FILE_IO or SHARED_MEMORY." )
190- self ._dbtype = struct .unpack ('B' , self ._f .read (1 ))[0 ]
191- self ._dbcolumn = struct .unpack ('B' , self ._f .read (1 ))[0 ]
192- self ._dbyear = struct .unpack ('B' , self ._f .read (1 ))[0 ]
193- self ._dbmonth = struct .unpack ('B' , self ._f .read (1 ))[0 ]
194- self ._dbday = struct .unpack ('B' , self ._f .read (1 ))[0 ]
195- self ._ipv4dbcount = struct .unpack ('<I' , self ._f .read (4 ))[0 ]
196- self ._ipv4dbaddr = struct .unpack ('<I' , self ._f .read (4 ))[0 ]
197- self ._ipv6dbcount = struct .unpack ('<I' , self ._f .read (4 ))[0 ]
198- self ._ipv6dbaddr = struct .unpack ('<I' , self ._f .read (4 ))[0 ]
199- self ._ipv4indexbaseaddr = struct .unpack ('<I' , self ._f .read (4 ))[0 ]
200- self ._ipv6indexbaseaddr = struct .unpack ('<I' , self ._f .read (4 ))[0 ]
201- self ._productcode = struct .unpack ('B' , self ._f .read (1 ))[0 ]
202- self ._licensecode = struct .unpack ('B' , self ._f .read (1 ))[0 ]
203- self ._databasesize = struct .unpack ('B' , self ._f .read (1 ))[0 ]
199+ if (self .mode == 'SHARED_MEMORY' ):
200+ # We can directly use slice notation to read content from mmap object. https://docs.python.org/3/library/mmap.html?highlight=mmap#module-mmap
201+ header_row = self ._f [0 :32 ]
202+ else :
203+ self ._f .seek (0 )
204+ header_row = self ._f .read (32 )
205+ self ._dbtype = struct .unpack ('B' , header_row [0 :1 ])[0 ]
206+ self ._dbcolumn = struct .unpack ('B' , header_row [1 :2 ])[0 ]
207+ self ._dbyear = struct .unpack ('B' , header_row [2 :3 ])[0 ]
208+ self ._dbmonth = struct .unpack ('B' , header_row [3 :4 ])[0 ]
209+ self ._dbday = struct .unpack ('B' , header_row [4 :5 ])[0 ]
210+ self ._ipv4dbcount = struct .unpack ('<I' , header_row [5 :9 ])[0 ]
211+ self ._ipv4dbaddr = struct .unpack ('<I' , header_row [9 :13 ])[0 ]
212+ self ._ipv6dbcount = struct .unpack ('<I' , header_row [13 :17 ])[0 ]
213+ self ._ipv6dbaddr = struct .unpack ('<I' , header_row [17 :21 ])[0 ]
214+ self ._ipv4indexbaseaddr = struct .unpack ('<I' , header_row [21 :25 ])[0 ]
215+ self ._ipv6indexbaseaddr = struct .unpack ('<I' , header_row [25 :29 ])[0 ]
216+ self ._productcode = struct .unpack ('B' , header_row [29 :30 ])[0 ]
217+ self ._licensecode = struct .unpack ('B' , header_row [30 :31 ])[0 ]
218+ self ._databasesize = struct .unpack ('B' , header_row [31 :32 ])[0 ]
204219 if (self ._productcode != 1 ) :
205220 if (self ._dbyear > 20 and self ._productcode != 0 ) :
206221 self ._f .close ()
@@ -329,12 +344,14 @@ def find(self, addr):
329344
330345 def _reads (self , offset ):
331346 self ._f .seek (offset - 1 )
332- n = struct .unpack ('B' , self ._f .read (1 ))[0 ]
333- # return u(self._f.read(n))
347+ ''''''
348+ data = self ._f .read (257 )
349+ char_count = struct .unpack ('B' , data [0 :1 ])[0 ]
350+ string = data [1 :char_count + 1 ]
334351 if sys .version < '3' :
335- return str (self . _f . read ( n ) .decode ('iso-8859-1' ).encode ('utf-8' ))
352+ return str (string .decode ('iso-8859-1' ).encode ('utf-8' ))
336353 else :
337- return u (self . _f . read ( n ) .decode ('iso-8859-1' ).encode ('utf-8' ))
354+ return u (string .decode ('iso-8859-1' ).encode ('utf-8' ))
338355
339356 def _readi (self , offset ):
340357 self ._f .seek (offset - 1 )
@@ -470,6 +487,29 @@ def _ip2no(self, addr):
470487 no = no + block [0 ] * 256 * 256 * 256
471488 return int (no )
472489
490+ def calc_off (self , off , baseaddr , what , mid ):
491+ # return baseaddr + mid * (self._dbcolumn * 4 + off) + off + 4 * (what[self._dbtype]-1)
492+ return baseaddr + mid * (self ._dbcolumn * 4 + off ) + off + 4 * (what - 1 )
493+
494+ def read32x2 (self , offset ):
495+ self ._f .seek (offset - 1 )
496+ data = self ._f .read (8 )
497+ return struct .unpack ('<L' , data [0 :4 ])[0 ], struct .unpack ('<L' , data [4 :8 ])[0 ]
498+
499+ def readRow32 (self , offset ):
500+ data_length = self ._dbcolumn * 4 + 4
501+ self ._f .seek (offset - 1 )
502+ raw_data = self ._f .read (data_length )
503+ ip_from = struct .unpack ('<L' , raw_data [0 :4 ])[0 ]
504+ ip_to = struct .unpack ('<L' , raw_data [data_length - 4 :data_length ])[0 ]
505+ return (ip_from , ip_to )
506+
507+ def readRow128 (self , offset ):
508+ data_length = self ._dbcolumn * 4 + 12 + 16
509+ self ._f .seek (offset - 1 )
510+ raw_data = self ._f .read (data_length )
511+ return ((struct .unpack ('<L' , raw_data [12 :16 ])[0 ] << 96 ) | (struct .unpack ('<L' , raw_data [8 :12 ])[0 ] << 64 ) | (struct .unpack ('<L' , raw_data [4 :8 ])[0 ] << 32 ) | struct .unpack ('<L' , raw_data [0 :4 ])[0 ], (struct .unpack ('<L' , raw_data [data_length - 4 :data_length ])[0 ] << 96 ) | (struct .unpack ('<L' , raw_data [data_length - 8 :data_length - 4 ])[0 ] << 64 ) | (struct .unpack ('<L' , raw_data [data_length - 12 :data_length - 8 ])[0 ] << 32 ) | struct .unpack ('<L' , raw_data [data_length - 16 :data_length - 12 ])[0 ])
512+
473513 def _parse_addr (self , addr ):
474514 ''' Parses address and returns IP version. Raises exception on invalid argument '''
475515 ipv = 0
@@ -537,12 +577,8 @@ def _parse_addr(self, addr):
537577 return ipv , ipnum
538578
539579 def _get_record (self , ip ):
540- # global original_ip
541580 self .original_ip = ip
542581 low = 0
543- # ipv = self._parse_addr(ip)
544- # ipv = self._parse_addr(ip)[0]
545- # ipnum = self._parse_addr(ip)[1]
546582 ipv , ipnum = self ._parse_addr (ip )
547583 if ipv == 0 :
548584 rec = IP2LocationRecord ()
@@ -581,8 +617,9 @@ def _get_record(self, ip):
581617 high = self ._ipv4dbcount
582618 if self ._ipv4indexbaseaddr > 0 :
583619 indexpos = ((ipno >> 16 ) << 3 ) + self ._ipv4indexbaseaddr
584- low = self ._readi (indexpos )
585- high = self ._readi (indexpos + 4 )
620+ # low = self._readi(indexpos)
621+ # high = self._readi(indexpos + 4)
622+ low ,high = self .read32x2 (indexpos )
586623
587624 elif ipv == 6 :
588625 if self ._ipv6dbcount == 0 :
@@ -627,8 +664,10 @@ def _get_record(self, ip):
627664
628665 while low <= high :
629666 mid = int ((low + high ) / 2 )
630- ipfrom = self ._readip (baseaddr + (mid ) * (self ._dbcolumn * 4 + off ), ipv )
631- ipto = self ._readip (baseaddr + (mid + 1 ) * (self ._dbcolumn * 4 + off ), ipv )
667+ if ipv == 4 :
668+ ipfrom , ipto = self .readRow32 (baseaddr + mid * self ._dbcolumn * 4 )
669+ elif ipv == 6 :
670+ ipfrom , ipto = self .readRow128 (baseaddr + mid * ((self ._dbcolumn * 4 ) + 12 ) )
632671
633672 if ipfrom <= ipno < ipto :
634673 return self ._read_record (mid , ipv )
0 commit comments