11from __future__ import annotations
22
33import base64
4+ import dataclasses
5+ import io
46import logging
57import os
68import platform
79import shlex
10+ import shutil
811import stat
912import subprocess
1013import sys
14+ import tarfile
1115from pathlib import Path
1216from typing import Any
17+ from urllib import request
1318
1419HERE = Path (__file__ ).absolute ().parent
1520ROOT = HERE .parent .parent
1621ENV_FILE = HERE / "test-env.sh"
1722DRIVERS_TOOLS = os .environ .get ("DRIVERS_TOOLS" , "" ).replace (os .sep , "/" )
23+ PLATFORM = "windows" if os .name == "nt" else sys .platform
1824
1925logging .basicConfig (level = logging .INFO )
2026LOGGER = logging .getLogger (__name__ )
7480)
7581
7682
83+ @dataclasses .dataclass
84+ class Distro :
85+ name : str
86+ version_id : str
87+ arch : str
88+
89+
7790def write_env (name : str , value : Any ) -> None :
7891 with ENV_FILE .open ("a" , newline = "\n " ) as fid :
7992 # Remove any existing quote chars.
@@ -92,6 +105,69 @@ def run_command(cmd: str) -> None:
92105 LOGGER .info ("Running command %s... done." , cmd )
93106
94107
108+ def get_distro () -> Distro :
109+ name = ""
110+ version_id = ""
111+ arch = platform .machine ()
112+ with open ("/etc/os-release" ) as fid :
113+ for line in fid .readlines ():
114+ line = line .replace ('"' , "" ) # noqa: PLW2901
115+ if line .startswith ("NAME=" ):
116+ _ , _ , name = line .strip ().partition ("=" )
117+ if line .startswith ("VERSION_ID=" ):
118+ _ , _ , version_id = line .strip ().partition ("=" )
119+ return Distro (name = name , version_id = version_id , arch = arch )
120+
121+
122+ def setup_libmongocrypt ():
123+ target = ""
124+ if PLATFORM == "windows" :
125+ # PYTHON-2808 Ensure this machine has the CA cert for google KMS.
126+ if is_set ("TEST_FLE_GCP_AUTO" ):
127+ run_command ('powershell.exe "Invoke-WebRequest -URI https://oauth2.googleapis.com/"' )
128+ target = "windows-test"
129+
130+ elif PLATFORM == "darwin" :
131+ target = "macos"
132+
133+ else :
134+ distro = get_distro ()
135+ if distro .name .startswith ("Debian" ):
136+ target = f"debian{ distro .version_id } "
137+ elif distro .name .startswith ("Red Hat" ):
138+ if distro .version_id .startswith ("7" ):
139+ target = "rhel-70-64-bit"
140+ elif distro .version_id .startswith ("8" ):
141+ if distro .arch == "aarch64" :
142+ target = "rhel-82-arm64"
143+ else :
144+ target = "rhel-80-64-bit"
145+
146+ if not is_set ("LIBMONGOCRYPT_URL" ):
147+ if not target :
148+ raise ValueError ("Cannot find libmongocrypt target for current platform!" )
149+ url = f"https://s3.amazonaws.com/mciuploads/libmongocrypt/{ target } /master/latest/libmongocrypt.tar.gz"
150+ else :
151+ url = os .environ ["LIBMONGOCRYPT_URL" ]
152+
153+ shutil .rmtree (HERE / "libmongocrypt" , ignore_errors = True )
154+
155+ LOGGER .info (f"Fetching { url } ..." )
156+ with request .urlopen (request .Request (url ), timeout = 15.0 ) as response : # noqa: S310
157+ if response .status == 200 :
158+ fileobj = io .BytesIO (response .read ())
159+ with tarfile .open ("libmongocrypt.tar.gz" , fileobj = fileobj ) as fid :
160+ fid .extractall (Path .cwd () / "libmongocrypt" )
161+ LOGGER .info (f"Fetching { url } ... done." )
162+
163+ run_command ("ls -la libmongocrypt" )
164+ run_command ("ls -la libmongocrypt/nocrypto" )
165+
166+ if PLATFORM == "windows" :
167+ # libmongocrypt's windows dll is not marked executable.
168+ run_command ("chmod +x libmongocrypt/nocrypto/bin/mongocrypt.dll" )
169+
170+
95171def handle_test_env () -> None :
96172 AUTH = os .environ .get ("AUTH" , "noauth" )
97173 SSL = os .environ .get ("SSL" , "nossl" )
@@ -156,7 +232,7 @@ def handle_test_env() -> None:
156232 write_env ("PYMONGO_DISABLE_TEST_COMMANDS" , "1" )
157233
158234 if is_set ("TEST_ENTERPRISE_AUTH" ):
159- if os . name == "nt " :
235+ if PLATFORM == "windows " :
160236 LOGGER .info ("Setting GSSAPI_PASS" )
161237 write_env ("GSSAPI_PASS" , os .environ ["SASL_PASS" ])
162238 write_env ("GSSAPI_CANONICALIZE" , "true" )
@@ -214,19 +290,19 @@ def handle_test_env() -> None:
214290 if is_set ("TEST_ENCRYPTION" ) or is_set ("TEST_FLE_AZURE_AUTO" ) or is_set ("TEST_FLE_GCP_AUTO" ):
215291 # Check for libmongocrypt download.
216292 if not (ROOT / "libmongocrypt" ).exists ():
217- run_command ( f"bash { HERE . as_posix () } /setup-libmongocrypt.sh" )
293+ setup_libmongocrypt ( )
218294
219295 # TODO: Test with 'pip install pymongocrypt'
220296 UV_ARGS .append ("--group pymongocrypt_source" )
221297
222298 # Use the nocrypto build to avoid dependency issues with older windows/python versions.
223299 BASE = ROOT / "libmongocrypt/nocrypto"
224- if sys . platform == "linux" :
300+ if PLATFORM == "linux" :
225301 if (BASE / "lib/libmongocrypt.so" ).exists ():
226302 PYMONGOCRYPT_LIB = BASE / "lib/libmongocrypt.so"
227303 else :
228304 PYMONGOCRYPT_LIB = BASE / "lib64/libmongocrypt.so"
229- elif sys . platform == "darwin" :
305+ elif PLATFORM == "darwin" :
230306 PYMONGOCRYPT_LIB = BASE / "lib/libmongocrypt.dylib"
231307 else :
232308 PYMONGOCRYPT_LIB = BASE / "bin/mongocrypt.dll"
@@ -244,7 +320,7 @@ def handle_test_env() -> None:
244320 if is_set ("TEST_CRYPT_SHARED" ):
245321 CRYPT_SHARED_DIR = Path (os .environ ["CRYPT_SHARED_LIB_PATH" ]).parent .as_posix ()
246322 LOGGER .info ("Using crypt_shared_dir %s" , CRYPT_SHARED_DIR )
247- if os . name == "nt " :
323+ if PLATFORM == "windows " :
248324 write_env ("PATH" , f"{ CRYPT_SHARED_DIR } :$PATH" )
249325 else :
250326 write_env (
0 commit comments