Skip to content

Commit 892348b

Browse files
authored
PYTHON-4336 - Add driver metadata to PyMongoArrow (#364)
1 parent 48e37c1 commit 892348b

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@
5151
)
5252
from pymongo.collection import Collection
5353
from pymongo.common import MAX_WRITE_BATCH_SIZE
54+
from pymongo.driver_info import DriverInfo
5455

56+
import pymongoarrow.version as pymongoarrow_version
5557
from pymongoarrow.context import PyMongoArrowContext
5658
from pymongoarrow.errors import ArrowWriteError
5759
from pymongoarrow.result import ArrowWriteResult
@@ -90,6 +92,15 @@
9092
_MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE)
9193

9294

95+
def _add_driver_metadata(collection: Collection):
96+
client = collection.database.client
97+
98+
if callable(client.append_metadata):
99+
client.append_metadata(
100+
DriverInfo(name="PyMongoArrow", version=pymongoarrow_version.__version__)
101+
)
102+
103+
93104
def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
94105
"""Method that returns the results of a find query as a
95106
:class:`pyarrow.Table` instance.
@@ -110,6 +121,7 @@ def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwa
110121
:Returns:
111122
An instance of class:`pyarrow.Table`.
112123
"""
124+
_add_driver_metadata(collection)
113125
context = PyMongoArrowContext(
114126
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
115127
)
@@ -152,6 +164,7 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, allow_invalid=Fals
152164
:Returns:
153165
An instance of class:`pyarrow.Table`.
154166
"""
167+
_add_driver_metadata(collection)
155168
context = PyMongoArrowContext(
156169
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
157170
)
@@ -570,6 +583,7 @@ def write(
570583
else:
571584
type_registry = TypeRegistry([*base_codecs, _DecimalCodec()])
572585
codec_options = codec_options.with_options(type_registry=type_registry)
586+
_add_driver_metadata(collection)
573587

574588
while cur_offset < tab_size:
575589
cur_size = 0

bindings/python/test/test_arrow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pymongo.collection import Collection
5454
from pytz import timezone
5555

56+
import pymongoarrow.version as pymongoarrow_version
5657
from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all, write
5758
from pymongoarrow.errors import ArrowWriteError
5859
from pymongoarrow.monkey import patch_all
@@ -1256,6 +1257,14 @@ def run_test():
12561257
for future in futures:
12571258
future.result()
12581259

1260+
def test_driver_metadata(self):
1261+
self.run_find({}, schema=self.schema)
1262+
1263+
metadata = self.coll.database.client.options.pool_options.metadata
1264+
1265+
self.assertIn("PyMongoArrow", metadata["driver"]["name"])
1266+
self.assertIn(pymongoarrow_version.__version__, metadata["driver"]["version"])
1267+
12591268

12601269
class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
12611270
def run_find(self, *args, **kwargs):

0 commit comments

Comments
 (0)