@@ -539,9 +539,10 @@ def from_vector(
539539 )
540540 return cls (metadata + data , subtype = VECTOR_SUBTYPE )
541541
542- def as_vector (self ) -> BinaryVector :
543- """From the Binary, create a list of numbers, along with dtype and padding.
542+ def as_vector (self , return_numpy : bool = False ) -> BinaryVector :
543+ """From the Binary, create a list or 1-d numpy array of numbers, along with dtype and padding.
544544
545+ :param return_numpy: If True, BinaryVector.data will be a one-dimensional numpy array. By default, it is a list.
545546 :return: BinaryVector
546547
547548 .. versionadded:: 4.10
@@ -550,108 +551,90 @@ def as_vector(self) -> BinaryVector:
550551 if self .subtype != VECTOR_SUBTYPE :
551552 raise ValueError (f"Cannot decode subtype { self .subtype } as a vector" )
552553
553- position = 0
554- dtype , padding = struct .unpack_from ("<sB" , self , position )
555- position += 2
554+ dtype , padding = struct .unpack_from ("<sB" , self )
556555 dtype = BinaryVectorDtype (dtype )
557- n_values = len (self ) - position
556+ offset = 2
557+ n_bytes = len (self ) - offset
558558
559559 if padding and dtype != BinaryVectorDtype .PACKED_BIT :
560560 raise ValueError (
561561 f"Corrupt data. Padding ({ padding } ) must be 0 for all but PACKED_BIT dtypes. ({ dtype = } )"
562562 )
563563
564- if dtype == BinaryVectorDtype .INT8 :
565- dtype_format = "b"
566- format_string = f"<{ n_values } { dtype_format } "
567- vector = list (struct .unpack_from (format_string , self , position ))
568- return BinaryVector (vector , dtype , padding )
564+ if not return_numpy :
565+ if dtype == BinaryVectorDtype .INT8 :
566+ dtype_format = "b"
567+ format_string = f"<{ n_bytes } { dtype_format } "
568+ vector = list (struct .unpack_from (format_string , self , offset ))
569+ return BinaryVector (vector , dtype , padding )
570+
571+ elif dtype == BinaryVectorDtype .FLOAT32 :
572+ n_values = n_bytes // 4
573+ if n_bytes % 4 :
574+ raise ValueError (
575+ "Corrupt data. N bytes for a float32 vector must be a multiple of 4."
576+ )
577+ dtype_format = "f"
578+ format_string = f"<{ n_values } { dtype_format } "
579+ vector = list (struct .unpack_from (format_string , self , offset ))
580+ return BinaryVector (vector , dtype , padding )
569581
570- elif dtype == BinaryVectorDtype .FLOAT32 :
571- n_bytes = len (self ) - position
572- n_values = n_bytes // 4
573- if n_bytes % 4 :
574- raise ValueError (
575- "Corrupt data. N bytes for a float32 vector must be a multiple of 4."
576- )
577- dtype_format = "f"
578- format_string = f"<{ n_values } { dtype_format } "
579- vector = list (struct .unpack_from (format_string , self , position ))
580- return BinaryVector (vector , dtype , padding )
581-
582- elif dtype == BinaryVectorDtype .PACKED_BIT :
583- # data packed as uint8
584- if padding and not n_values :
585- raise ValueError ("Corrupt data. Vector has a padding P, but no data." )
586- if padding > 7 or padding < 0 :
587- raise ValueError (f"Corrupt data. Padding ({ padding } ) must be between 0 and 7." )
588- dtype_format = "B"
589- format_string = f"<{ n_values } { dtype_format } "
590- unpacked_uint8s = list (struct .unpack_from (format_string , self , position ))
591- if padding and n_values and unpacked_uint8s [- 1 ] & (1 << padding ) - 1 != 0 :
592- warnings .warn (
593- "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero." ,
594- DeprecationWarning ,
595- stacklevel = 2 ,
596- )
597- return BinaryVector (unpacked_uint8s , dtype , padding )
582+ elif dtype == BinaryVectorDtype .PACKED_BIT :
583+ # data packed as uint8
584+ if padding and not n_bytes :
585+ raise ValueError ("Corrupt data. Vector has a padding P, but no data." )
586+ if padding > 7 or padding < 0 :
587+ raise ValueError (f"Corrupt data. Padding ({ padding } ) must be between 0 and 7." )
588+ dtype_format = "B"
589+ format_string = f"<{ n_bytes } { dtype_format } "
590+ unpacked_uint8s = list (struct .unpack_from (format_string , self , offset ))
591+ if padding and n_bytes and unpacked_uint8s [- 1 ] & (1 << padding ) - 1 != 0 :
592+ warnings .warn (
593+ "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero." ,
594+ DeprecationWarning ,
595+ stacklevel = 2 ,
596+ )
597+ return BinaryVector (unpacked_uint8s , dtype , padding )
598598
599- else :
600- raise NotImplementedError ("Binary Vector dtype %s not yet supported" % dtype .name )
599+ else :
600+ raise NotImplementedError ("Binary Vector dtype %s not yet supported" % dtype .name )
601+ else : # create a numpy array
602+ try :
603+ import numpy as np
604+ except ImportError as exc :
605+ raise ImportError (
606+ "Converting binary to numpy.ndarray requires numpy to be installed."
607+ ) from exc
608+ if dtype == BinaryVectorDtype .INT8 :
609+ data = np .frombuffer (self [offset :], dtype = "int8" )
610+ elif dtype == BinaryVectorDtype .FLOAT32 :
611+ if n_bytes % 4 :
612+ raise ValueError (
613+ "Corrupt data. N bytes for a float32 vector must be a multiple of 4."
614+ )
615+ data = np .frombuffer (self [offset :], dtype = "float32" )
616+ elif dtype == BinaryVectorDtype .PACKED_BIT :
617+ # data packed as uint8
618+ if padding and not n_bytes :
619+ raise ValueError ("Corrupt data. Vector has a padding P, but no data." )
620+ if padding > 7 or padding < 0 :
621+ raise ValueError (f"Corrupt data. Padding ({ padding } ) must be between 0 and 7." )
622+ data = np .frombuffer (self [offset :], dtype = "uint8" )
623+ if padding and np .unpackbits (data [- 1 ])[- padding :].sum () > 0 :
624+ warnings .warn (
625+ "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero." ,
626+ DeprecationWarning ,
627+ stacklevel = 2 ,
628+ )
629+ else :
630+ raise NotImplementedError ("Binary Vector dtype %s not yet supported" % dtype .name )
631+ return BinaryVector (data , dtype , padding )
601632
602633 @property
603634 def subtype (self ) -> int :
604635 """Subtype of this binary data."""
605636 return self .__subtype
606637
607- def as_numpy_vector (self ) -> BinaryVector :
608- """From the Binary, create a BinaryVector where data is a 1-dim numpy array.
609- dtype still follows our typing (BinaryVectorDtype),
610- and padding is as we define it, notably equivalent to a negative value of count
611- in `numpy.unpackbits <https://numpy.org/doc/stable/reference/generated/numpy.unpackbits.html>`_.
612-
613- :return: BinaryVector
614-
615- .. versionadded:: 4.16
616- """
617- if self .subtype != VECTOR_SUBTYPE :
618- raise ValueError (f"Cannot decode subtype { self .subtype } as a vector" )
619- try :
620- import numpy as np
621- except ImportError as exc :
622- raise ImportError (
623- "Converting binary to numpy.ndarray requires numpy to be installed."
624- ) from exc
625-
626- dtype , padding = struct .unpack_from ("<sB" , self , 0 )
627- dtype = BinaryVectorDtype (dtype )
628- n_bytes = len (self ) - 2
629-
630- if dtype == BinaryVectorDtype .INT8 :
631- data = np .frombuffer (self [2 :], dtype = "int8" )
632- elif dtype == BinaryVectorDtype .FLOAT32 :
633- if n_bytes % 4 :
634- raise ValueError (
635- "Corrupt data. N bytes for a float32 vector must be a multiple of 4."
636- )
637- data = np .frombuffer (self [2 :], dtype = "float32" )
638- elif dtype == BinaryVectorDtype .PACKED_BIT :
639- # data packed as uint8
640- if padding and not n_bytes :
641- raise ValueError ("Corrupt data. Vector has a padding P, but no data." )
642- if padding > 7 or padding < 0 :
643- raise ValueError (f"Corrupt data. Padding ({ padding } ) must be between 0 and 7." )
644- data = np .frombuffer (self [2 :], dtype = "uint8" )
645- if padding and np .unpackbits (data [- 1 ])[- padding :].sum () > 0 :
646- warnings .warn (
647- "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero." ,
648- DeprecationWarning ,
649- stacklevel = 2 ,
650- )
651- else :
652- raise ValueError (f"Unsupported dtype code: { dtype !r} " )
653- return BinaryVector (data , dtype , padding )
654-
655638 def __getnewargs__ (self ) -> Tuple [bytes , int ]: # type: ignore[override]
656639 # Work around http://bugs.python.org/issue7382
657640 data = super ().__getnewargs__ ()[0 ]
0 commit comments