diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index de68c8aa..88005ff8 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -51,7 +51,9 @@ ) from pymongo.collection import Collection from pymongo.common import MAX_WRITE_BATCH_SIZE +from pymongo.driver_info import DriverInfo +import pymongoarrow.version as pymongoarrow_version from pymongoarrow.context import PyMongoArrowContext from pymongoarrow.errors import ArrowWriteError from pymongoarrow.result import ArrowWriteResult @@ -90,6 +92,15 @@ _MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE) +def _add_driver_metadata(collection: Collection): + client = collection.database.client + + if callable(client.append_metadata): + client.append_metadata( + DriverInfo(name="PyMongoArrow", version=pymongoarrow_version.__version__) + ) + + def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of a find query as a :class:`pyarrow.Table` instance. @@ -110,6 +121,7 @@ def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwa :Returns: An instance of class:`pyarrow.Table`. """ + _add_driver_metadata(collection) context = PyMongoArrowContext( schema, codec_options=collection.codec_options, allow_invalid=allow_invalid ) @@ -152,6 +164,7 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, allow_invalid=Fals :Returns: An instance of class:`pyarrow.Table`. """ + _add_driver_metadata(collection) context = PyMongoArrowContext( schema, codec_options=collection.codec_options, allow_invalid=allow_invalid ) @@ -570,6 +583,7 @@ def write( else: type_registry = TypeRegistry([*base_codecs, _DecimalCodec()]) codec_options = codec_options.with_options(type_registry=type_registry) + _add_driver_metadata(collection) while cur_offset < tab_size: cur_size = 0 diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index 9e138b1b..a7af89af 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -53,6 +53,7 @@ from pymongo.collection import Collection from pytz import timezone +import pymongoarrow.version as pymongoarrow_version from pymongoarrow.api import Schema, aggregate_arrow_all, find_arrow_all, write from pymongoarrow.errors import ArrowWriteError from pymongoarrow.monkey import patch_all @@ -1256,6 +1257,14 @@ def run_test(): for future in futures: future.result() + def test_driver_metadata(self): + self.run_find({}, schema=self.schema) + + metadata = self.coll.database.client.options.pool_options.metadata + + self.assertIn("PyMongoArrow", metadata["driver"]["name"]) + self.assertIn(pymongoarrow_version.__version__, metadata["driver"]["version"]) + class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase): def run_find(self, *args, **kwargs):