11#!/usr/bin/env python
22# Copyright 2024 NetBox Labs Inc
33"""NetBox Labs, Diode - SDK - Client."""
4+
45import collections
6+ import http .client
7+ import json
58import logging
69import os
710import platform
11+ import ssl
812import uuid
913from collections .abc import Iterable
10- from urllib .parse import urlparse
14+ from urllib .parse import urlencode , urlparse
1115
1216import certifi
1317import grpc
1822from netboxlabs .diode .sdk .ingester import Entity
1923from netboxlabs .diode .sdk .version import version_semver
2024
21- _DIODE_API_KEY_ENVVAR_NAME = "DIODE_API_KEY "
25+ _MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES "
2226_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
2327_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
28+ _CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
29+ _CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
2430_DEFAULT_STREAM = "latest"
2531_LOGGER = logging .getLogger (__name__ )
2632
@@ -31,17 +37,6 @@ def _load_certs() -> bytes:
3137 return f .read ()
3238
3339
34- def _get_api_key (api_key : str | None = None ) -> str :
35- """Get API Key either from provided value or environment variable."""
36- if api_key is None :
37- api_key = os .getenv (_DIODE_API_KEY_ENVVAR_NAME )
38- if api_key is None :
39- raise DiodeConfigError (
40- f"api_key param or { _DIODE_API_KEY_ENVVAR_NAME } environment variable required"
41- )
42- return api_key
43-
44-
4540def parse_target (target : str ) -> tuple [str , str , bool ]:
4641 """Parse the target into authority, path and tls_verify."""
4742 parsed_target = urlparse (target )
@@ -66,6 +61,21 @@ def _get_sentry_dsn(sentry_dsn: str | None = None) -> str | None:
6661 return sentry_dsn
6762
6863
64+ def _get_required_config_value (env_var_name : str , value : str | None = None ) -> str :
65+ """Get required config value either from provided value or environment variable."""
66+ if value is None :
67+ value = os .getenv (env_var_name )
68+ if value is None :
69+ raise DiodeConfigError (f"parameter or { env_var_name } environment variable required" )
70+ return value
71+
72+ def _get_optional_config_value (env_var_name : str , value : str | None = None ) -> str | None :
73+ """Get optional config value either from provided value or environment variable."""
74+ if value is None :
75+ value = os .getenv (env_var_name )
76+ return value
77+
78+
6979class DiodeClient :
7080 """Diode Client."""
7181
@@ -81,30 +91,40 @@ def __init__(
8191 target : str ,
8292 app_name : str ,
8393 app_version : str ,
84- api_key : str | None = None ,
94+ client_id : str | None = None ,
95+ client_secret : str | None = None ,
8596 sentry_dsn : str = None ,
8697 sentry_traces_sample_rate : float = 1.0 ,
8798 sentry_profiles_sample_rate : float = 1.0 ,
99+ max_auth_retries : int = 3 ,
88100 ):
89101 """Initiate a new client."""
90102 log_level = os .getenv (_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME , "INFO" ).upper ()
91103 logging .basicConfig (level = log_level )
92104
105+ self ._max_auth_retries = _get_optional_config_value (_MAX_RETRIES_ENVVAR_NAME , max_auth_retries )
93106 self ._target , self ._path , self ._tls_verify = parse_target (target )
94107 self ._app_name = app_name
95108 self ._app_version = app_version
96109 self ._platform = platform .platform ()
97110 self ._python_version = platform .python_version ()
98111
99- api_key = _get_api_key (api_key )
112+ # Read client credentials from environment variables
113+ self ._client_id = _get_required_config_value (_CLIENT_ID_ENVVAR_NAME , client_id )
114+ self ._client_secret = _get_required_config_value (_CLIENT_SECRET_ENVVAR_NAME , client_secret )
115+
100116 self ._metadata = (
101- ("diode-api-key" , api_key ),
102117 ("platform" , self ._platform ),
103118 ("python-version" , self ._python_version ),
104119 )
105120
121+ self ._authenticate ()
122+
106123 channel_opts = (
107- ("grpc.primary_user_agent" , f"{ self ._name } /{ self ._version } { self ._app_name } /{ self ._app_version } " ),
124+ (
125+ "grpc.primary_user_agent" ,
126+ f"{ self ._name } /{ self ._version } { self ._app_name } /{ self ._app_version } " ,
127+ ),
108128 )
109129
110130 if self ._tls_verify :
@@ -129,9 +149,7 @@ def __init__(
129149 _LOGGER .debug (f"Setting up gRPC interceptor for path: { self ._path } " )
130150 rpc_method_interceptor = DiodeMethodClientInterceptor (subpath = self ._path )
131151
132- intercept_channel = grpc .intercept_channel (
133- self ._channel , rpc_method_interceptor
134- )
152+ intercept_channel = grpc .intercept_channel (self ._channel , rpc_method_interceptor )
135153 channel = intercept_channel
136154
137155 self ._stub = ingester_pb2_grpc .IngesterServiceStub (channel )
@@ -140,9 +158,7 @@ def __init__(
140158
141159 if self ._sentry_dsn is not None :
142160 _LOGGER .debug ("Setting up Sentry" )
143- self ._setup_sentry (
144- self ._sentry_dsn , sentry_traces_sample_rate , sentry_profiles_sample_rate
145- )
161+ self ._setup_sentry (self ._sentry_dsn , sentry_traces_sample_rate , sentry_profiles_sample_rate )
146162
147163 @property
148164 def name (self ) -> str :
@@ -202,24 +218,28 @@ def ingest(
202218 stream : str | None = _DEFAULT_STREAM ,
203219 ) -> ingester_pb2 .IngestResponse :
204220 """Ingest entities."""
205- try :
206- request = ingester_pb2 .IngestRequest (
207- stream = stream ,
208- id = str (uuid .uuid4 ()),
209- entities = entities ,
210- sdk_name = self .name ,
211- sdk_version = self .version ,
212- producer_app_name = self .app_name ,
213- producer_app_version = self .app_version ,
214- )
215-
216- return self ._stub .Ingest (request , metadata = self ._metadata )
217- except grpc .RpcError as err :
218- raise DiodeClientError (err ) from err
219-
220- def _setup_sentry (
221- self , dsn : str , traces_sample_rate : float , profiles_sample_rate : float
222- ):
221+ for attempt in range (self ._max_auth_retries ):
222+ try :
223+ request = ingester_pb2 .IngestRequest (
224+ stream = stream ,
225+ id = str (uuid .uuid4 ()),
226+ entities = entities ,
227+ sdk_name = self .name ,
228+ sdk_version = self .version ,
229+ producer_app_name = self .app_name ,
230+ producer_app_version = self .app_version ,
231+ )
232+ return self ._stub .Ingest (request , metadata = self ._metadata )
233+ except grpc .RpcError as err :
234+ if err .code () == grpc .StatusCode .UNAUTHENTICATED :
235+ if attempt < self ._max_auth_retries - 1 :
236+ _LOGGER .info (f"Retrying ingestion due to UNAUTHENTICATED error, attempt { attempt + 1 } " )
237+ self ._authenticate ()
238+ continue
239+ raise DiodeClientError (err ) from err
240+ return RuntimeError ("Max retries exceeded" )
241+
242+ def _setup_sentry (self , dsn : str , traces_sample_rate : float , profiles_sample_rate : float ):
223243 sentry_sdk .init (
224244 dsn = dsn ,
225245 release = self .version ,
@@ -234,6 +254,59 @@ def _setup_sentry(
234254 sentry_sdk .set_tag ("platform" , self ._platform )
235255 sentry_sdk .set_tag ("python_version" , self ._python_version )
236256
257+ def _authenticate (self ):
258+ authentication_client = _DiodeAuthentication (self ._target , self ._path , self ._tls_verify , self ._client_id , self ._client_secret )
259+ access_token = authentication_client .authenticate ()
260+ self ._metadata = list (filter (lambda x : x [0 ] != "authorization" , self ._metadata )) + \
261+ [("authorization" , f"Bearer { access_token } " )]
262+
263+
264+ class _DiodeAuthentication :
265+ def __init__ (self , target : str , path : str , tls_verify : bool , client_id : str , client_secret : str ):
266+ self ._target = target
267+ self ._tls_verify = tls_verify
268+ self ._client_id = client_id
269+ self ._client_secret = client_secret
270+ self ._path = path
271+
272+ def authenticate (self ) -> str :
273+ """Request an OAuth2 token using client credentials and return it."""
274+ if self ._tls_verify :
275+ conn = http .client .HTTPSConnection (
276+ self ._target ,
277+ context = None if self ._tls_verify else ssl ._create_unverified_context (),
278+ )
279+ else :
280+ conn = http .client .HTTPConnection (
281+ self ._target ,
282+ )
283+ headers = {"Content-type" : "application/x-www-form-urlencoded" }
284+ data = urlencode (
285+ {
286+ "grant_type" : "client_credentials" ,
287+ "client_id" : self ._client_id ,
288+ "client_secret" : self ._client_secret ,
289+ }
290+ )
291+ url = self ._get_auth_url ()
292+ conn .request ("POST" , url , data , headers )
293+ response = conn .getresponse ()
294+ if response .status != 200 :
295+ raise DiodeConfigError (f"Failed to obtain access token: { response .reason } " )
296+ token_info = json .loads (response .read ().decode ())
297+ access_token = token_info .get ("access_token" )
298+ if not access_token :
299+ raise DiodeConfigError (f"Failed to obtain access token for client { self ._client_id } " )
300+
301+ _LOGGER .debug (f"Access token obtained for client { self ._client_id } " )
302+ return access_token
303+
304+ def _get_auth_url (self ) -> str :
305+ """Construct the authentication URL, handling trailing slashes in the path."""
306+ # Ensure the path does not have trailing slashes
307+ path = self ._path .rstrip ('/' ) if self ._path else ''
308+ return f"{ path } /auth/token"
309+
237310
238311class _ClientCallDetails (
239312 collections .namedtuple (
@@ -259,9 +332,7 @@ class _ClientCallDetails(
259332 pass
260333
261334
262- class DiodeMethodClientInterceptor (
263- grpc .UnaryUnaryClientInterceptor , grpc .StreamUnaryClientInterceptor
264- ):
335+ class DiodeMethodClientInterceptor (grpc .UnaryUnaryClientInterceptor , grpc .StreamUnaryClientInterceptor ):
265336 """
266337 Diode Method Client Interceptor class.
267338
@@ -300,8 +371,6 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
300371 """Intercept unary unary."""
301372 return self ._intercept_call (continuation , client_call_details , request )
302373
303- def intercept_stream_unary (
304- self , continuation , client_call_details , request_iterator
305- ):
374+ def intercept_stream_unary (self , continuation , client_call_details , request_iterator ):
306375 """Intercept stream unary."""
307376 return self ._intercept_call (continuation , client_call_details , request_iterator )
0 commit comments