2626from netboxlabs .diode .sdk .ingester import Entity
2727from netboxlabs .diode .sdk .version import version_semver
2828
29- _MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
30- _DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
31- _DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
3229_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
3330_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
31+ _DEFAULT_STREAM = "latest"
32+ _DIODE_CERT_FILE_ENVVAR_NAME = "DIODE_CERT_FILE"
33+ _DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
34+ _DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
35+ _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME = "DIODE_SKIP_TLS_VERIFY"
3436_DRY_RUN_OUTPUT_DIR_ENVVAR_NAME = "DIODE_DRY_RUN_OUTPUT_DIR"
3537_INGEST_SCOPE = "diode:ingest"
36- _DEFAULT_STREAM = "latest"
3738_LOGGER = logging .getLogger (__name__ )
38-
39+ _MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
3940
4041def load_dryrun_entities (file_path : str | Path ) -> Iterable [Entity ]:
4142 """Yield entities from a file with concatenated JSON messages."""
@@ -53,20 +54,35 @@ class DiodeClientInterface:
5354 pass
5455
5556
56- def _load_certs () -> bytes :
57- """Loads cacert.pem."""
58- with open (certifi .where (), "rb" ) as f :
57+ def _load_certs (cert_file : str | None = None ) -> bytes :
58+ """Loads cacert.pem or custom certificate file."""
59+ cert_path = cert_file or certifi .where ()
60+ with open (cert_path , "rb" ) as f :
5961 return f .read ()
6062
6163
64+ def _should_verify_tls (scheme : str ) -> bool :
65+ """Determine if TLS verification should be enabled based on scheme and environment variable."""
66+ # Check if scheme is insecure
67+ insecure_scheme = scheme in ["grpc" , "http" ]
68+
69+ # Check environment variable
70+ skip_tls_env = os .getenv (_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME , "" ).lower ()
71+ skip_tls_from_env = skip_tls_env in ["true" , "1" , "yes" , "on" ]
72+
73+ # TLS verification is enabled by default, disabled only for insecure schemes or env var
74+ return not (insecure_scheme or skip_tls_from_env )
75+
76+
6277def parse_target (target : str ) -> tuple [str , str , bool ]:
6378 """Parse the target into authority, path and tls_verify."""
6479 parsed_target = urlparse (target )
6580
6681 if parsed_target .scheme not in ["grpc" , "grpcs" , "http" , "https" ]:
6782 raise ValueError ("target should start with grpc://, grpcs://, http:// or https://" )
6883
69- tls_verify = parsed_target .scheme in ["grpcs" , "https" ]
84+ # Determine if TLS verification should be enabled
85+ tls_verify = _should_verify_tls (parsed_target .scheme )
7086
7187 authority = parsed_target .netloc
7288
@@ -127,15 +143,22 @@ def __init__(
127143 sentry_traces_sample_rate : float = 1.0 ,
128144 sentry_profiles_sample_rate : float = 1.0 ,
129145 max_auth_retries : int = 3 ,
146+ cert_file : str | None = None ,
130147 ):
131148 """Initiate a new client."""
132149 log_level = os .getenv (_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME , "INFO" ).upper ()
133150 logging .basicConfig (level = log_level )
134151
135- self ._max_auth_retries = _get_optional_config_value (
136- _MAX_RETRIES_ENVVAR_NAME , max_auth_retries
152+ self ._max_auth_retries = int (_get_optional_config_value (
153+ _MAX_RETRIES_ENVVAR_NAME , str (max_auth_retries )
154+ ) or max_auth_retries )
155+ self ._cert_file = _get_optional_config_value (
156+ _DIODE_CERT_FILE_ENVVAR_NAME , cert_file
137157 )
138158 self ._target , self ._path , self ._tls_verify = parse_target (target )
159+
160+ # Load certificates once if needed
161+ self ._certificates = _load_certs (self ._cert_file ) if (self ._tls_verify or self ._cert_file ) else None
139162 self ._app_name = app_name
140163 self ._app_version = app_version
141164 self ._platform = platform .platform ()
@@ -161,12 +184,12 @@ def __init__(
161184 ),
162185 )
163186
164- if self ._tls_verify :
187+ if self ._tls_verify and self . _certificates :
165188 _LOGGER .debug ("Setting up gRPC secure channel" )
166189 self ._channel = grpc .secure_channel (
167190 self ._target ,
168191 grpc .ssl_channel_credentials (
169- root_certificates = _load_certs () ,
192+ root_certificates = self . _certificates ,
170193 ),
171194 options = channel_opts ,
172195 )
@@ -304,6 +327,7 @@ def _authenticate(self, scope: str):
304327 self ._client_id ,
305328 self ._client_secret ,
306329 scope ,
330+ self ._certificates ,
307331 )
308332 access_token = authentication_client .authenticate ()
309333 self ._metadata = list (
@@ -391,20 +415,24 @@ def __init__(
391415 client_id : str ,
392416 client_secret : str ,
393417 scope : str ,
418+ certificates : bytes | None = None ,
394419 ):
395420 self ._target = target
396421 self ._tls_verify = tls_verify
397422 self ._client_id = client_id
398423 self ._client_secret = client_secret
399424 self ._path = path
400425 self ._scope = scope
426+ self ._certificates = certificates
401427
402428 def authenticate (self ) -> str :
403429 """Request an OAuth2 token using client credentials and return it."""
404- if self ._tls_verify :
430+ if self ._tls_verify and self ._certificates :
431+ context = ssl .create_default_context ()
432+ context .load_verify_locations (cadata = self ._certificates .decode ('utf-8' ))
405433 conn = http .client .HTTPSConnection (
406434 self ._target ,
407- context = None if self . _tls_verify else ssl . _create_unverified_context () ,
435+ context = context ,
408436 )
409437 else :
410438 conn = http .client .HTTPConnection (
0 commit comments