Skip to content

Commit 3cc5041

Browse files
committed
Numpy now lazily imported. For typing, removed numpy extra. justfile now uses with
1 parent f03b943 commit 3cc5041

File tree

6 files changed

+57
-271
lines changed

6 files changed

+57
-271
lines changed

bson/binary.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,9 @@
6565
from array import array as _array
6666
from mmap import mmap as _mmap
6767

68-
69-
_NUMPY_AVAILABLE = False
70-
try:
7168
import numpy as np
7269
import numpy.typing as npt
7370

74-
_NUMPY_AVAILABLE = True
75-
except ImportError:
76-
np = None # type: ignore
77-
7871

7972
class UuidRepresentation:
8073
UNSPECIFIED = 0
@@ -492,9 +485,7 @@ def from_vector(
492485
)
493486
metadata = struct.pack("<sB", dtype.value, padding)
494487

495-
if _NUMPY_AVAILABLE and isinstance(vector, np.ndarray):
496-
data = _numpy_vector_to_bytes(vector, dtype)
497-
else:
488+
if isinstance(vector, list):
498489
if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8
499490
format_str = "b"
500491
if padding:
@@ -511,7 +502,36 @@ def from_vector(
511502
raise ValueError(f"padding does not apply to {dtype=}")
512503
else:
513504
raise NotImplementedError("%s not yet supported" % dtype)
514-
data = struct.pack(f"<{len(vector)}{format_str}", *vector) # type: ignore
505+
data = struct.pack(f"<{len(vector)}{format_str}", *vector)
506+
else: # vector is numpy array or incorrect type.
507+
try:
508+
import numpy as np
509+
except ImportError as exc:
510+
raise ImportError(
511+
"Failed to create binary from vector. Check type. If numpy array, numpy must be installed."
512+
) from exc
513+
if not isinstance(vector, np.ndarray):
514+
raise TypeError("Vector must be a numpy array.")
515+
if vector.ndim != 1:
516+
raise ValueError(
517+
"from_numpy_vector only supports 1D arrays as it creates a single vector."
518+
)
519+
520+
if dtype == BinaryVectorDtype.FLOAT32:
521+
vector = vector.astype(np.dtype("float32"), copy=False)
522+
elif dtype == BinaryVectorDtype.INT8:
523+
if vector.min() >= -128 and vector.max() <= 127:
524+
vector = vector.astype(np.dtype("int8"), copy=False)
525+
else:
526+
raise ValueError("Values found outside INT8 range.")
527+
elif dtype == BinaryVectorDtype.PACKED_BIT:
528+
if vector.min() >= 0 and vector.max() <= 127:
529+
vector = vector.astype(np.dtype("uint8"), copy=False)
530+
else:
531+
raise ValueError("Values found outside UINT8 range.")
532+
else:
533+
raise NotImplementedError("%s not yet supported" % dtype)
534+
data = vector.tobytes()
515535

516536
if padding and len(vector) and not (data[-1] & ((1 << padding) - 1)) == 0:
517537
raise ValueError(
@@ -596,8 +616,13 @@ def as_numpy_vector(self) -> BinaryVector:
596616
"""
597617
if self.subtype != VECTOR_SUBTYPE:
598618
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector")
599-
if not _NUMPY_AVAILABLE:
600-
raise ImportError("Converting binary to numpy.ndarray requires numpy to be installed.")
619+
try:
620+
import numpy as np
621+
except ImportError as exc:
622+
raise ImportError(
623+
"Converting binary to numpy.ndarray requires numpy to be installed."
624+
) from exc
625+
601626
dtype, padding = struct.unpack_from("<sB", self, 0)
602627
dtype = BinaryVectorDtype(dtype)
603628

@@ -637,32 +662,3 @@ def __repr__(self) -> str:
637662
return f"<Binary(REDACTED, {self.__subtype})>"
638663
else:
639664
return f"Binary({bytes.__repr__(self)}, {self.__subtype})"
640-
641-
642-
def _numpy_vector_to_bytes(
643-
vector: npt.NDArray[np.number],
644-
dtype: BinaryVectorDtype,
645-
) -> bytes:
646-
if not _NUMPY_AVAILABLE:
647-
raise ImportError("Converting numpy.ndarray to binary requires numpy to be installed.")
648-
649-
if not isinstance(vector, np.ndarray):
650-
raise TypeError("Vector must be a numpy array.")
651-
if vector.ndim != 1:
652-
raise ValueError("from_numpy_vector only supports 1D arrays as it creates a single vector.")
653-
654-
if dtype == BinaryVectorDtype.FLOAT32:
655-
vector = vector.astype(np.dtype("float32"), copy=False)
656-
elif dtype == BinaryVectorDtype.INT8:
657-
if vector.min() >= -128 and vector.max() <= 127:
658-
vector = vector.astype(np.dtype("int8"), copy=False)
659-
else:
660-
raise ValueError("Values found outside INT8 range.")
661-
elif dtype == BinaryVectorDtype.PACKED_BIT:
662-
if vector.min() >= 0 and vector.max() <= 127:
663-
vector = vector.astype(np.dtype("uint8"), copy=False)
664-
else:
665-
raise ValueError("Values found outside UINT8 range.")
666-
else:
667-
raise NotImplementedError("%s not yet supported" % dtype)
668-
return vector.tobytes()

justfile

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
set shell := ["bash", "-c"]
33

44
# Commonly used command segments.
5-
typing_run := "uv run --group typing --extra aws --extra encryption --extra numpy --extra ocsp --extra snappy --extra test --extra zstd"
5+
typing_run := "uv run --group typing --extra aws --extra encryption --with numpy --extra ocsp --extra snappy --extra test --extra zstd"
66
docs_run := "uv run --extra docs"
77
doc_build := "./doc/_build"
88
mypy_args := "--install-types --non-interactive"
@@ -39,14 +39,14 @@ typing: && resync
3939

4040
[group('typing')]
4141
typing-mypy: && resync
42-
{{typing_run}} mypy {{mypy_args}} bson gridfs tools pymongo
43-
{{typing_run}} mypy {{mypy_args}} --config-file mypy_test.ini test
44-
{{typing_run}} mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py
42+
{{typing_run}} python -m mypy {{mypy_args}} bson gridfs tools pymongo
43+
{{typing_run}} python -m mypy {{mypy_args}} --config-file mypy_test.ini test
44+
{{typing_run}} python -m mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py
4545

4646
[group('typing')]
4747
typing-pyright: && resync
48-
{{typing_run}} pyright test/test_typing.py test/test_typing_strict.py
49-
{{typing_run}} pyright -p strict_pyrightconfig.json test/test_typing_strict.py
48+
{{typing_run}} python -m pyright test/test_typing.py test/test_typing_strict.py
49+
{{typing_run}} python -m pyright -p strict_pyrightconfig.json test/test_typing_strict.py
5050

5151
[group('lint')]
5252
lint *args="": && resync
@@ -58,7 +58,13 @@ lint-manual *args="": && resync
5858

5959
[group('test')]
6060
test *args="-v --durations=5 --maxfail=10": && resync
61-
uv run --extra test pytest {{args}}
61+
uv run --extra test python -m pytest {{args}}
62+
63+
[group('test')]
64+
test-bson *args="-v --durations=5 --maxfail=10": && resync
65+
uv run --extra test --with numpy python -m pytest test/test_bson.py
66+
67+
6268

6369
[group('test')]
6470
run-tests *args: && resync

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ ocsp = ["requirements/ocsp.txt"]
8787
snappy = ["requirements/snappy.txt"]
8888
test = ["requirements/test.txt"]
8989
zstd = ["requirements/zstd.txt"]
90-
numpy = ["requirements/numpy.txt"]
9190

9291
[tool.pytest.ini_options]
9392
minversion = "7"

requirements/numpy.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

test/test_bson.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import array
2020
import collections
2121
import datetime
22+
import importlib.util
2223
import mmap
2324
import os
2425
import pickle
@@ -71,13 +72,7 @@
7172
from bson.timestamp import Timestamp
7273
from bson.tz_util import FixedOffset, utc
7374

74-
_NUMPY_AVAILABLE = False
75-
try:
76-
import numpy as np
77-
78-
_NUMPY_AVAILABLE = True
79-
except ImportError:
80-
np = None # type: ignore
75+
_NUMPY_AVAILABLE = importlib.util.find_spec("numpy") is not None
8176

8277

8378
class NotADict(abc.MutableMapping):
@@ -883,6 +878,8 @@ def test_binaryvector_equality(self):
883878
def test_vector_from_numpy(self):
884879
"""Follows test_vector except for input type numpy.ndarray"""
885880
# Simple data values could be treated as any of our BinaryVectorDtypes
881+
import numpy as np
882+
886883
arr = np.array([2, 3])
887884
# INT8
888885
binary_vector_int8 = Binary.from_vector(arr, BinaryVectorDtype.INT8)

0 commit comments

Comments
 (0)