Skip to content

Commit 9420ec1

Browse files
committed
as_numpy_vector refactored to as_vector(return_numpy=True)
1 parent 8dec0d3 commit 9420ec1

File tree

2 files changed

+76
-93
lines changed

2 files changed

+76
-93
lines changed

bson/binary.py

Lines changed: 72 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -539,9 +539,10 @@ def from_vector(
539539
)
540540
return cls(metadata + data, subtype=VECTOR_SUBTYPE)
541541

542-
def as_vector(self) -> BinaryVector:
543-
"""From the Binary, create a list of numbers, along with dtype and padding.
542+
def as_vector(self, return_numpy: bool = False) -> BinaryVector:
543+
"""From the Binary, create a list or 1-d numpy array of numbers, along with dtype and padding.
544544
545+
:param return_numpy: If True, BinaryVector.data will be a one-dimensional numpy array. By default, it is a list.
545546
:return: BinaryVector
546547
547548
.. versionadded:: 4.10
@@ -550,108 +551,90 @@ def as_vector(self) -> BinaryVector:
550551
if self.subtype != VECTOR_SUBTYPE:
551552
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector")
552553

553-
position = 0
554-
dtype, padding = struct.unpack_from("<sB", self, position)
555-
position += 2
554+
dtype, padding = struct.unpack_from("<sB", self)
556555
dtype = BinaryVectorDtype(dtype)
557-
n_values = len(self) - position
556+
offset = 2
557+
n_bytes = len(self) - offset
558558

559559
if padding and dtype != BinaryVectorDtype.PACKED_BIT:
560560
raise ValueError(
561561
f"Corrupt data. Padding ({padding}) must be 0 for all but PACKED_BIT dtypes. ({dtype=})"
562562
)
563563

564-
if dtype == BinaryVectorDtype.INT8:
565-
dtype_format = "b"
566-
format_string = f"<{n_values}{dtype_format}"
567-
vector = list(struct.unpack_from(format_string, self, position))
568-
return BinaryVector(vector, dtype, padding)
564+
if not return_numpy:
565+
if dtype == BinaryVectorDtype.INT8:
566+
dtype_format = "b"
567+
format_string = f"<{n_bytes}{dtype_format}"
568+
vector = list(struct.unpack_from(format_string, self, offset))
569+
return BinaryVector(vector, dtype, padding)
570+
571+
elif dtype == BinaryVectorDtype.FLOAT32:
572+
n_values = n_bytes // 4
573+
if n_bytes % 4:
574+
raise ValueError(
575+
"Corrupt data. N bytes for a float32 vector must be a multiple of 4."
576+
)
577+
dtype_format = "f"
578+
format_string = f"<{n_values}{dtype_format}"
579+
vector = list(struct.unpack_from(format_string, self, offset))
580+
return BinaryVector(vector, dtype, padding)
569581

570-
elif dtype == BinaryVectorDtype.FLOAT32:
571-
n_bytes = len(self) - position
572-
n_values = n_bytes // 4
573-
if n_bytes % 4:
574-
raise ValueError(
575-
"Corrupt data. N bytes for a float32 vector must be a multiple of 4."
576-
)
577-
dtype_format = "f"
578-
format_string = f"<{n_values}{dtype_format}"
579-
vector = list(struct.unpack_from(format_string, self, position))
580-
return BinaryVector(vector, dtype, padding)
581-
582-
elif dtype == BinaryVectorDtype.PACKED_BIT:
583-
# data packed as uint8
584-
if padding and not n_values:
585-
raise ValueError("Corrupt data. Vector has a padding P, but no data.")
586-
if padding > 7 or padding < 0:
587-
raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.")
588-
dtype_format = "B"
589-
format_string = f"<{n_values}{dtype_format}"
590-
unpacked_uint8s = list(struct.unpack_from(format_string, self, position))
591-
if padding and n_values and unpacked_uint8s[-1] & (1 << padding) - 1 != 0:
592-
warnings.warn(
593-
"Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.",
594-
DeprecationWarning,
595-
stacklevel=2,
596-
)
597-
return BinaryVector(unpacked_uint8s, dtype, padding)
582+
elif dtype == BinaryVectorDtype.PACKED_BIT:
583+
# data packed as uint8
584+
if padding and not n_bytes:
585+
raise ValueError("Corrupt data. Vector has a padding P, but no data.")
586+
if padding > 7 or padding < 0:
587+
raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.")
588+
dtype_format = "B"
589+
format_string = f"<{n_bytes}{dtype_format}"
590+
unpacked_uint8s = list(struct.unpack_from(format_string, self, offset))
591+
if padding and n_bytes and unpacked_uint8s[-1] & (1 << padding) - 1 != 0:
592+
warnings.warn(
593+
"Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.",
594+
DeprecationWarning,
595+
stacklevel=2,
596+
)
597+
return BinaryVector(unpacked_uint8s, dtype, padding)
598598

599-
else:
600-
raise NotImplementedError("Binary Vector dtype %s not yet supported" % dtype.name)
599+
else:
600+
raise NotImplementedError("Binary Vector dtype %s not yet supported" % dtype.name)
601+
else: # create a numpy array
602+
try:
603+
import numpy as np
604+
except ImportError as exc:
605+
raise ImportError(
606+
"Converting binary to numpy.ndarray requires numpy to be installed."
607+
) from exc
608+
if dtype == BinaryVectorDtype.INT8:
609+
data = np.frombuffer(self[offset:], dtype="int8")
610+
elif dtype == BinaryVectorDtype.FLOAT32:
611+
if n_bytes % 4:
612+
raise ValueError(
613+
"Corrupt data. N bytes for a float32 vector must be a multiple of 4."
614+
)
615+
data = np.frombuffer(self[offset:], dtype="float32")
616+
elif dtype == BinaryVectorDtype.PACKED_BIT:
617+
# data packed as uint8
618+
if padding and not n_bytes:
619+
raise ValueError("Corrupt data. Vector has a padding P, but no data.")
620+
if padding > 7 or padding < 0:
621+
raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.")
622+
data = np.frombuffer(self[offset:], dtype="uint8")
623+
if padding and np.unpackbits(data[-1])[-padding:].sum() > 0:
624+
warnings.warn(
625+
"Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.",
626+
DeprecationWarning,
627+
stacklevel=2,
628+
)
629+
else:
630+
raise NotImplementedError("Binary Vector dtype %s not yet supported" % dtype.name)
631+
return BinaryVector(data, dtype, padding)
601632

602633
@property
603634
def subtype(self) -> int:
604635
"""Subtype of this binary data."""
605636
return self.__subtype
606637

607-
def as_numpy_vector(self) -> BinaryVector:
608-
"""From the Binary, create a BinaryVector where data is a 1-dim numpy array.
609-
dtype still follows our typing (BinaryVectorDtype),
610-
and padding is as we define it, notably equivalent to a negative value of count
611-
in `numpy.unpackbits <https://numpy.org/doc/stable/reference/generated/numpy.unpackbits.html>`_.
612-
613-
:return: BinaryVector
614-
615-
.. versionadded:: 4.16
616-
"""
617-
if self.subtype != VECTOR_SUBTYPE:
618-
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector")
619-
try:
620-
import numpy as np
621-
except ImportError as exc:
622-
raise ImportError(
623-
"Converting binary to numpy.ndarray requires numpy to be installed."
624-
) from exc
625-
626-
dtype, padding = struct.unpack_from("<sB", self, 0)
627-
dtype = BinaryVectorDtype(dtype)
628-
n_bytes = len(self) - 2
629-
630-
if dtype == BinaryVectorDtype.INT8:
631-
data = np.frombuffer(self[2:], dtype="int8")
632-
elif dtype == BinaryVectorDtype.FLOAT32:
633-
if n_bytes % 4:
634-
raise ValueError(
635-
"Corrupt data. N bytes for a float32 vector must be a multiple of 4."
636-
)
637-
data = np.frombuffer(self[2:], dtype="float32")
638-
elif dtype == BinaryVectorDtype.PACKED_BIT:
639-
# data packed as uint8
640-
if padding and not n_bytes:
641-
raise ValueError("Corrupt data. Vector has a padding P, but no data.")
642-
if padding > 7 or padding < 0:
643-
raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.")
644-
data = np.frombuffer(self[2:], dtype="uint8")
645-
if padding and np.unpackbits(data[-1])[-padding:].sum() > 0:
646-
warnings.warn(
647-
"Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.",
648-
DeprecationWarning,
649-
stacklevel=2,
650-
)
651-
else:
652-
raise ValueError(f"Unsupported dtype code: {dtype!r}")
653-
return BinaryVector(data, dtype, padding)
654-
655638
def __getnewargs__(self) -> Tuple[bytes, int]: # type: ignore[override]
656639
# Work around http://bugs.python.org/issue7382
657640
data = super().__getnewargs__()[0]

test/test_bson.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ def test_vector_from_numpy(self):
888888
assert isinstance(vector, BinaryVector)
889889
assert vector.data == arr.tolist()
890890
# as_numpy_vector
891-
vector_np = binary_vector_int8.as_numpy_vector()
891+
vector_np = binary_vector_int8.as_vector(return_numpy=True)
892892
assert isinstance(vector_np, BinaryVector)
893893
assert np.all(vector.data == arr)
894894
# PACKED_BIT
@@ -898,7 +898,7 @@ def test_vector_from_numpy(self):
898898
assert isinstance(vector, BinaryVector)
899899
assert vector.data == arr.tolist()
900900
# as_numpy_vector
901-
vector_np = binary_vector_uint8.as_numpy_vector()
901+
vector_np = binary_vector_uint8.as_vector(return_numpy=True)
902902
assert isinstance(vector_np, BinaryVector)
903903
assert np.all(vector_np.data == arr)
904904
# FLOAT32
@@ -908,7 +908,7 @@ def test_vector_from_numpy(self):
908908
assert isinstance(vector, BinaryVector)
909909
assert vector.data == arr.tolist()
910910
# as_numpy_vector
911-
vector_np = binary_vector_float32.as_numpy_vector()
911+
vector_np = binary_vector_float32.as_vector(return_numpy=True)
912912
assert isinstance(vector_np, BinaryVector)
913913
assert np.all(vector_np.data == arr)
914914

@@ -926,7 +926,7 @@ def test_vector_from_numpy(self):
926926
list_floats = [-1.1, 1.1]
927927
cast_bin = Binary.from_vector(np.array(list_floats), BinaryVectorDtype.INT8)
928928
vector = cast_bin.as_vector()
929-
vector_np = cast_bin.as_numpy_vector()
929+
vector_np = cast_bin.as_vector(return_numpy=True)
930930
assert vector.data != list_floats
931931
assert vector.data == vector_np.data.tolist() == [-1, 1]
932932

0 commit comments

Comments
 (0)