Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
)
from pymongo.collection import Collection
from pymongo.common import MAX_WRITE_BATCH_SIZE
from pymongo.driver_info import DriverInfo

import pymongoarrow.version as pymongoarrow_version
from pymongoarrow.context import PyMongoArrowContext
from pymongoarrow.errors import ArrowWriteError
from pymongoarrow.result import ArrowWriteResult
Expand Down Expand Up @@ -90,6 +92,15 @@
_MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE)


def _add_driver_metadata(collection: Collection):
client = collection.database.client

if callable(client.append_metadata):
client.append_metadata(
DriverInfo(name="PyMongoArrow", version=pymongoarrow_version.__version__)
)


def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`pyarrow.Table` instance.
Expand All @@ -110,6 +121,7 @@ def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwa
:Returns:
An instance of class:`pyarrow.Table`.
"""
_add_driver_metadata(collection)
context = PyMongoArrowContext(
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
)
Expand Down Expand Up @@ -152,6 +164,7 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, allow_invalid=Fals
:Returns:
An instance of class:`pyarrow.Table`.
"""
_add_driver_metadata(collection)
context = PyMongoArrowContext(
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
)
Expand Down Expand Up @@ -570,6 +583,7 @@ def write(
else:
type_registry = TypeRegistry([*base_codecs, _DecimalCodec()])
codec_options = codec_options.with_options(type_registry=type_registry)
_add_driver_metadata(collection)

while cur_offset < tab_size:
cur_size = 0
Expand Down
9 changes: 9 additions & 0 deletions bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pymongo.collection import Collection
from pytz import timezone

import pymongoarrow.version as pymongoarrow_version
from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all, write
from pymongoarrow.errors import ArrowWriteError
from pymongoarrow.monkey import patch_all
Expand Down Expand Up @@ -1256,6 +1257,14 @@ def run_test():
for future in futures:
future.result()

def test_driver_metadata(self):
self.run_find({}, schema=self.schema)

metadata = self.coll.database.client.options.pool_options.metadata

self.assertIn("PyMongoArrow", metadata["driver"]["name"])
self.assertIn(pymongoarrow_version.__version__, metadata["driver"]["version"])


class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
def run_find(self, *args, **kwargs):
Expand Down
Loading