From d8baa104247f1c9f702091a2055376f6b763ad92 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Tue, 15 Apr 2025 14:13:51 +0100 Subject: [PATCH 01/16] adds oauth2 authentication --- .gitignore | 3 +- README.md | 8 +- netboxlabs/diode/sdk/client.py | 103 ++++++++++++------ tests/test_client.py | 193 +++++++++++++++++++-------------- 4 files changed, 189 insertions(+), 118 deletions(-) diff --git a/.gitignore b/.gitignore index 257791d..65f0a15 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ __pycache__/ build/ dist/ .eggs/ -*.egg-info \ No newline at end of file +*.egg-info +poetry.lock diff --git a/README.md b/README.md index 53cac60..c181a55 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,10 @@ pip install netboxlabs-diode-sdk ### Environment variables -* `DIODE_API_KEY` - API key for the Diode service * `DIODE_SDK_LOG_LEVEL` - Log level for the SDK (default: `INFO`) * `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting +* `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication +* `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication ### Example @@ -94,8 +95,7 @@ if __name__ == "__main__": ## Development notes -Code in `netboxlabs/diode/sdk/diode/*` is generated from Protocol Buffers definitions (will be published and referred -here soon). +Code in `netboxlabs/diode/sdk/diode/*` is generated from Protocol Buffers definitions (will be published and referenced here soon). #### Linting @@ -107,7 +107,7 @@ black netboxlabs/ #### Testing ```shell -pytest tests/ +PYTHONPATH=$(pwd) pytest ``` ## License diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 44049c8..d528d04 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -1,13 +1,17 @@ #!/usr/bin/env python # Copyright 2024 NetBox Labs Inc """NetBox Labs, Diode - SDK - Client.""" + import collections +from collections.abc import Iterable +import http.client +import json import logging import os import platform +import ssl +from urllib.parse import urlparse, urlencode import uuid -from collections.abc import Iterable -from urllib.parse import urlparse import certifi import grpc @@ -18,9 +22,10 @@ from netboxlabs.diode.sdk.ingester import Entity from netboxlabs.diode.sdk.version import version_semver -_DIODE_API_KEY_ENVVAR_NAME = "DIODE_API_KEY" _DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL" _DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN" +_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID" +_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET" _DEFAULT_STREAM = "latest" _LOGGER = logging.getLogger(__name__) @@ -31,17 +36,6 @@ def _load_certs() -> bytes: return f.read() -def _get_api_key(api_key: str | None = None) -> str: - """Get API Key either from provided value or environment variable.""" - if api_key is None: - api_key = os.getenv(_DIODE_API_KEY_ENVVAR_NAME) - if api_key is None: - raise DiodeConfigError( - f"api_key param or {_DIODE_API_KEY_ENVVAR_NAME} environment variable required" - ) - return api_key - - def parse_target(target: str) -> tuple[str, str, bool]: """Parse the target into authority, path and tls_verify.""" parsed_target = urlparse(target) @@ -66,6 +60,15 @@ def _get_sentry_dsn(sentry_dsn: str | None = None) -> str | None: return sentry_dsn +def _get_required_config_value(env_var_name: str, value: str | None = None) -> str: + """Get required config value either from provided value or environment variable.""" + if value is None: + value = os.getenv(env_var_name) + if value is None: + raise DiodeConfigError(f"parameter or {env_var_name} environment variable required") + return value + + class DiodeClient: """Diode Client.""" @@ -81,7 +84,8 @@ def __init__( target: str, app_name: str, app_version: str, - api_key: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, sentry_dsn: str = None, sentry_traces_sample_rate: float = 1.0, sentry_profiles_sample_rate: float = 1.0, @@ -96,15 +100,23 @@ def __init__( self._platform = platform.platform() self._python_version = platform.python_version() - api_key = _get_api_key(api_key) + # Read client credentials from environment variables + client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id) + client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret) + + authentication_client = _DiodeAuthentication(self._target, self._tls_verify, client_id, client_secret) + access_token = authentication_client.authenticate() self._metadata = ( - ("diode-api-key", api_key), ("platform", self._platform), ("python-version", self._python_version), + ("authorization", f"Bearer {access_token}"), ) channel_opts = ( - ("grpc.primary_user_agent", f"{self._name}/{self._version} {self._app_name}/{self._app_version}"), + ( + "grpc.primary_user_agent", + f"{self._name}/{self._version} {self._app_name}/{self._app_version}", + ), ) if self._tls_verify: @@ -129,9 +141,7 @@ def __init__( _LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}") rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path) - intercept_channel = grpc.intercept_channel( - self._channel, rpc_method_interceptor - ) + intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor) channel = intercept_channel self._stub = ingester_pb2_grpc.IngesterServiceStub(channel) @@ -140,9 +150,7 @@ def __init__( if self._sentry_dsn is not None: _LOGGER.debug("Setting up Sentry") - self._setup_sentry( - self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate - ) + self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate) @property def name(self) -> str: @@ -212,14 +220,11 @@ def ingest( producer_app_name=self.app_name, producer_app_version=self.app_version, ) - return self._stub.Ingest(request, metadata=self._metadata) except grpc.RpcError as err: raise DiodeClientError(err) from err - def _setup_sentry( - self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float - ): + def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float): sentry_sdk.init( dsn=dsn, release=self.version, @@ -235,6 +240,40 @@ def _setup_sentry( sentry_sdk.set_tag("python_version", self._python_version) +class _DiodeAuthentication: + def __init__(self, target: str, tls_verify: bool, client_id: str, client_secret: str): + self._target = target + self._tls_verify = tls_verify + self._client_id = client_id + self._client_secret = client_secret + + def authenticate(self) -> str: + """Request an OAuth2 token using client credentials and return it.""" + conn = http.client.HTTPSConnection( + self._target, + context=None if self._tls_verify else ssl._create_unverified_context(), + ) + headers = {"Content-type": "application/x-www-form-urlencoded"} + data = urlencode( + { + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + } + ) + conn.request("POST", "/token", data, headers) + response = conn.getresponse() + if response.status != 200: + raise DiodeConfigError(f"Failed to obtain access token: {response.reason}") + token_info = json.loads(response.read().decode()) + access_token = token_info.get("access_token") + if not access_token: + raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}") + + _LOGGER.debug(f"Access token obtained for client {self._client_id}") + return access_token + + class _ClientCallDetails( collections.namedtuple( "_ClientCallDetails", @@ -259,9 +298,7 @@ class _ClientCallDetails( pass -class DiodeMethodClientInterceptor( - grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor -): +class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor): """ Diode Method Client Interceptor class. @@ -300,8 +337,6 @@ def intercept_unary_unary(self, continuation, client_call_details, request): """Intercept unary unary.""" return self._intercept_call(continuation, client_call_details, request) - def intercept_stream_unary( - self, continuation, client_call_details, request_iterator - ): + def intercept_stream_unary(self, continuation, client_call_details, request_iterator): """Intercept stream unary.""" return self._intercept_call(continuation, client_call_details, request_iterator) diff --git a/tests/test_client.py b/tests/test_client.py index 02a319c..45be386 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,19 +1,19 @@ #!/usr/bin/env python # Copyright 2024 NetBox Labs Inc """NetBox Labs - Tests.""" + import os from unittest import mock +from unittest.mock import patch, MagicMock import grpc import pytest from netboxlabs.diode.sdk.client import ( - _DIODE_API_KEY_ENVVAR_NAME, _DIODE_SENTRY_DSN_ENVVAR_NAME, DiodeClient, DiodeMethodClientInterceptor, _ClientCallDetails, - _get_api_key, _get_sentry_dsn, _load_certs, parse_target, @@ -22,13 +22,14 @@ from netboxlabs.diode.sdk.version import version_semver -def test_init(): +def test_init(mock_diode_authentication): """Check we can initiate a client configuration.""" config = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) assert config.target == "localhost:8081" assert config.name == "diode-sdk-python" @@ -39,25 +40,36 @@ def test_init(): assert config.path == "" -def test_config_error(): +@pytest.mark.parametrize( + "client_id,client_secret,env_var_name", + [ + (None, "123", "DIODE_CLIENT_ID"), + ("123", None, "DIODE_CLIENT_SECRET"), + (None, None, "DIODE_CLIENT_ID"), + ], +) +def test_config_errors(client_id, client_secret, env_var_name): """Check we can raise a config error.""" with pytest.raises(DiodeConfigError) as err: DiodeClient( - target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1" + target="grpc://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id=client_id, + client_secret=client_secret, ) - assert ( - str(err.value) == "api_key param or DIODE_API_KEY environment variable required" - ) + assert str(err.value) == f"parameter or {env_var_name} environment variable required" -def test_client_error(): +def test_client_error(mock_diode_authentication): """Check we can raise a client error.""" with pytest.raises(DiodeClientError) as err: client = DiodeClient( target="grpc://invalid:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) client.ingest(entities=[]) assert err.value.status_code == grpc.StatusCode.UNAVAILABLE @@ -72,10 +84,7 @@ def test_diode_client_error_repr_returns_correct_string(): error = DiodeClientError(grpc_error) error._status_code = grpc.StatusCode.UNAVAILABLE error._details = "Some details about the error" - assert ( - repr(error) - == "" - ) + assert repr(error) == "" def test_load_certs_returns_bytes(): @@ -83,26 +92,6 @@ def test_load_certs_returns_bytes(): assert isinstance(_load_certs(), bytes) -def test_get_api_key_returns_env_var_when_no_input(): - """Check that _get_api_key returns the env var when no input is provided.""" - os.environ[_DIODE_API_KEY_ENVVAR_NAME] = "env_var_key" - assert _get_api_key() == "env_var_key" - - -def test_get_api_key_returns_input_when_provided(): - """Check that _get_api_key returns the input when provided.""" - os.environ[_DIODE_API_KEY_ENVVAR_NAME] = "env_var_key" - assert _get_api_key("input_key") == "input_key" - - -def test_get_api_key_raises_error_when_no_input_or_env_var(): - """Check that _get_api_key raises an error when no input or env var is provided.""" - if _DIODE_API_KEY_ENVVAR_NAME in os.environ: - del os.environ[_DIODE_API_KEY_ENVVAR_NAME] - with pytest.raises(DiodeConfigError): - _get_api_key() - - def test_parse_target_handles_http_prefix(): """Check that parse_target raises an error when the target contains http://.""" with pytest.raises(ValueError): @@ -166,13 +155,14 @@ def test_get_sentry_dsn_returns_none_when_no_input_or_env_var(): assert _get_sentry_dsn() is None -def test_setup_sentry_initializes_with_correct_parameters(): +def test_setup_sentry_initializes_with_correct_parameters(mock_diode_authentication): """Check that DiodeClient._setup_sentry() initializes with the correct parameters.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch("sentry_sdk.init") as mock_init: client._setup_sentry("https://user@password.mock.dsn/123456", 0.5, 0.5) @@ -184,13 +174,14 @@ def test_setup_sentry_initializes_with_correct_parameters(): ) -def test_client_sets_up_secure_channel_when_grpcs_scheme_is_found_in_target(): +def test_client_sets_up_secure_channel_when_grpcs_scheme_is_found_in_target(mock_diode_authentication): """Check that DiodeClient.__init__() sets up the gRPC secure channel when grpcs:// scheme is found in the target.""" client = DiodeClient( target="grpcs://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with ( mock.patch("grpc.secure_channel") as mock_secure_channel, @@ -200,20 +191,22 @@ def test_client_sets_up_secure_channel_when_grpcs_scheme_is_found_in_target(): target="grpcs://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_debug.assert_called_once_with("Setting up gRPC secure channel") mock_secure_channel.assert_called_once() -def test_client_sets_up_insecure_channel_when_grpc_scheme_is_found_in_target(): +def test_client_sets_up_insecure_channel_when_grpc_scheme_is_found_in_target(mock_diode_authentication): """Check that DiodeClient.__init__() sets up the gRPC insecure channel when grpc:// scheme is found in the target.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with ( mock.patch("grpc.insecure_channel") as mock_insecure_channel, @@ -223,7 +216,8 @@ def test_client_sets_up_insecure_channel_when_grpc_scheme_is_found_in_target(): target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_debug.assert_called_with( @@ -232,14 +226,15 @@ def test_client_sets_up_insecure_channel_when_grpc_scheme_is_found_in_target(): mock_insecure_channel.assert_called_once() -def test_insecure_channel_options_with_primary_user_agent(): +def test_insecure_channel_options_with_primary_user_agent(mock_diode_authentication): """Check that DiodeClient.__init__() sets the gRPC primary_user_agent option for insecure channel.""" with mock.patch("grpc.insecure_channel") as mock_insecure_channel: client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_insecure_channel.assert_called_once() @@ -252,14 +247,15 @@ def test_insecure_channel_options_with_primary_user_agent(): ) -def test_secure_channel_options_with_primary_user_agent(): +def test_secure_channel_options_with_primary_user_agent(mock_diode_authentication): """Check that DiodeClient.__init__() sets the gRPC primary_user_agent option for secure channel.""" with mock.patch("grpc.secure_channel") as mock_secure_channel: client = DiodeClient( target="grpcs://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_secure_channel.assert_called_once() @@ -272,13 +268,14 @@ def test_secure_channel_options_with_primary_user_agent(): ) -def test_client_interceptor_setup_with_path(): +def test_client_interceptor_setup_with_path(mock_diode_authentication): """Check that DiodeClient.__init__() sets up the gRPC interceptor when a path is provided.""" client = DiodeClient( target="grpc://localhost:8081/my-path", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with ( mock.patch("grpc.intercept_channel") as mock_intercept_channel, @@ -288,7 +285,8 @@ def test_client_interceptor_setup_with_path(): target="grpc://localhost:8081/my-path", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_debug.assert_called_with( @@ -297,13 +295,14 @@ def test_client_interceptor_setup_with_path(): mock_intercept_channel.assert_called_once() -def test_client_interceptor_not_setup_without_path(): +def test_client_interceptor_not_setup_without_path(mock_diode_authentication): """Check that DiodeClient.__init__() does not set up the gRPC interceptor when no path is provided.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with ( mock.patch("grpc.intercept_channel") as mock_intercept_channel, @@ -313,7 +312,8 @@ def test_client_interceptor_not_setup_without_path(): target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_debug.assert_called_with( @@ -322,13 +322,14 @@ def test_client_interceptor_not_setup_without_path(): mock_intercept_channel.assert_not_called() -def test_client_setup_sentry_called_when_sentry_dsn_exists(): +def test_client_setup_sentry_called_when_sentry_dsn_exists(mock_diode_authentication): """Check that DiodeClient._setup_sentry() is called when sentry_dsn exists.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", sentry_dsn="https://user@password.mock.dsn/123456", ) with mock.patch.object(client, "_setup_sentry") as mock_setup_sentry: @@ -336,39 +337,41 @@ def test_client_setup_sentry_called_when_sentry_dsn_exists(): target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", sentry_dsn="https://user@password.mock.dsn/123456", ) - mock_setup_sentry.assert_called_once_with( - "https://user@password.mock.dsn/123456", 1.0, 1.0 - ) + mock_setup_sentry.assert_called_once_with("https://user@password.mock.dsn/123456", 1.0, 1.0) -def test_client_setup_sentry_not_called_when_sentry_dsn_not_exists(): +def test_client_setup_sentry_not_called_when_sentry_dsn_not_exists(mock_diode_authentication): """Check that DiodeClient._setup_sentry() is not called when sentry_dsn does not exist.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch.object(client, "_setup_sentry") as mock_setup_sentry: client.__init__( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_setup_sentry.assert_not_called() -def test_client_properties_return_expected_values(): +def test_client_properties_return_expected_values(mock_diode_authentication): """Check that DiodeClient properties return the expected values.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) assert client.name == "diode-sdk-python" assert client.version == version_semver() @@ -380,50 +383,54 @@ def test_client_properties_return_expected_values(): assert isinstance(client.channel, grpc.Channel) -def test_client_enter_returns_self(): +def test_client_enter_returns_self(mock_diode_authentication): """Check that DiodeClient.__enter__() returns self.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) assert client.__enter__() is client -def test_client_exit_closes_channel(): +def test_client_exit_closes_channel(mock_diode_authentication): """Check that DiodeClient.__exit__() closes the channel.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch.object(client._channel, "close") as mock_close: client.__exit__(None, None, None) mock_close.assert_called_once() -def test_client_close_closes_channel(): +def test_client_close_closes_channel(mock_diode_authentication): """Check that DiodeClient.close() closes the channel.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch.object(client._channel, "close") as mock_close: client.close() mock_close.assert_called_once() -def test_setup_sentry_sets_correct_tags(): +def test_setup_sentry_sets_correct_tags(mock_diode_authentication): """Check that DiodeClient._setup_sentry() sets the correct tags.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch("sentry_sdk.set_tag") as mock_set_tag: client._setup_sentry("https://user@password.mock.dsn/123456", 0.5, 0.5) @@ -458,10 +465,7 @@ def continuation(x, _): None, ) request = None - assert ( - interceptor.intercept_unary_unary(continuation, client_call_details, request) - == "/my/path/diode.v1.IngesterService/Ingest" - ) + assert interceptor.intercept_unary_unary(continuation, client_call_details, request) == "/my/path/diode.v1.IngesterService/Ingest" def test_interceptor_intercepts_stream_unary_calls(): @@ -481,8 +485,39 @@ def continuation(x, _): ) request_iterator = None assert ( - interceptor.intercept_stream_unary( - continuation, client_call_details, request_iterator - ) + interceptor.intercept_stream_unary(continuation, client_call_details, request_iterator) == "/my/path/diode.v1.IngesterService/Ingest" ) + + +@pytest.fixture +def mock_diode_authentication(): + """ + Fixture to mock the Diode authentication process. + + This mock replaces the _DiodeAuthentication class with a mock object + that returns a mocked token for authentication. + """ + with patch("netboxlabs.diode.sdk.client._DiodeAuthentication") as MockAuth: + mock_instance = MockAuth.return_value + mock_instance.authenticate.return_value = "mocked_token" + yield MockAuth + + +def test_diode_client_with_mocked_authentication(mock_diode_authentication): + """ + Test the DiodeClient initialization with mocked authentication. + + This test verifies that the client is initialized correctly with the mocked + authentication token and that the metadata includes the expected platform + and authorization headers. + """ + client = DiodeClient( + target="grpc://localhost:8080/diode", + app_name="my-test-app", + app_version="0.0.1", + client_id="test_client_id", + client_secret="test_client_secret", + ) + assert client._metadata[0] == ("platform", client._platform) + assert client._metadata[-1] == ("authorization", "Bearer mocked_token") \ No newline at end of file From 6b1df2542633ae987217017c59216976275c7952 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Tue, 15 Apr 2025 14:17:50 +0100 Subject: [PATCH 02/16] fixes PR labeller --- .github/workflows/labeler.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/labeler.yaml b/.github/workflows/labeler.yaml index f4687be..8c1fa95 100644 --- a/.github/workflows/labeler.yaml +++ b/.github/workflows/labeler.yaml @@ -11,6 +11,7 @@ jobs: permissions: contents: read pull-requests: write + issues: write runs-on: ubuntu-latest timeout-minutes: 5 steps: From 9907b73037212d31c9e902231155188f99cd6f3f Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Tue, 15 Apr 2025 14:31:27 +0100 Subject: [PATCH 03/16] linting --- netboxlabs/diode/sdk/client.py | 4 ++-- tests/test_client.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index d528d04..40d255f 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -3,15 +3,15 @@ """NetBox Labs, Diode - SDK - Client.""" import collections -from collections.abc import Iterable import http.client import json import logging import os import platform import ssl -from urllib.parse import urlparse, urlencode import uuid +from collections.abc import Iterable +from urllib.parse import urlencode, urlparse import certifi import grpc diff --git a/tests/test_client.py b/tests/test_client.py index 45be386..8acd9e9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,7 @@ import os from unittest import mock -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import grpc import pytest @@ -520,4 +520,4 @@ def test_diode_client_with_mocked_authentication(mock_diode_authentication): client_secret="test_client_secret", ) assert client._metadata[0] == ("platform", client._platform) - assert client._metadata[-1] == ("authorization", "Bearer mocked_token") \ No newline at end of file + assert client._metadata[-1] == ("authorization", "Bearer mocked_token") From 1e0d9c54ef8e276e6543e176a3504b8404f10e7b Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Tue, 15 Apr 2025 15:36:38 +0100 Subject: [PATCH 04/16] adds retries when 401 response --- netboxlabs/diode/sdk/client.py | 65 ++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 40d255f..425ea6d 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -101,17 +101,17 @@ def __init__( self._python_version = platform.python_version() # Read client credentials from environment variables - client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id) - client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret) + self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id) + self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret) + - authentication_client = _DiodeAuthentication(self._target, self._tls_verify, client_id, client_secret) - access_token = authentication_client.authenticate() self._metadata = ( ("platform", self._platform), ("python-version", self._python_version), - ("authorization", f"Bearer {access_token}"), ) + self._authenticate() + channel_opts = ( ( "grpc.primary_user_agent", @@ -210,19 +210,27 @@ def ingest( stream: str | None = _DEFAULT_STREAM, ) -> ingester_pb2.IngestResponse: """Ingest entities.""" - try: - request = ingester_pb2.IngestRequest( - stream=stream, - id=str(uuid.uuid4()), - entities=entities, - sdk_name=self.name, - sdk_version=self.version, - producer_app_name=self.app_name, - producer_app_version=self.app_version, - ) - return self._stub.Ingest(request, metadata=self._metadata) - except grpc.RpcError as err: - raise DiodeClientError(err) from err + max_retries = 3 + for attempt in range(max_retries): + try: + request = ingester_pb2.IngestRequest( + stream=stream, + id=str(uuid.uuid4()), + entities=entities, + sdk_name=self.name, + sdk_version=self.version, + producer_app_name=self.app_name, + producer_app_version=self.app_version, + ) + return self._stub.Ingest(request, metadata=self._metadata) + except grpc.RpcError as err: + if err.code() == grpc.StatusCode.UNAUTHENTICATED: + self._authenticate() + if attempt < max_retries - 1: + _LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}") + continue + raise DiodeClientError(err) from err + return None # should never hit this but it makes the linter happy def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float): sentry_sdk.init( @@ -239,6 +247,12 @@ def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rat sentry_sdk.set_tag("platform", self._platform) sentry_sdk.set_tag("python_version", self._python_version) + def _authenticate(self): + authentication_client = _DiodeAuthentication(self._target, self._tls_verify, self._client_id, self._client_secret) + access_token = authentication_client.authenticate() + self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \ + [("authorization", f"Bearer {access_token}")] + class _DiodeAuthentication: def __init__(self, target: str, tls_verify: bool, client_id: str, client_secret: str): @@ -249,10 +263,15 @@ def __init__(self, target: str, tls_verify: bool, client_id: str, client_secret: def authenticate(self) -> str: """Request an OAuth2 token using client credentials and return it.""" - conn = http.client.HTTPSConnection( - self._target, - context=None if self._tls_verify else ssl._create_unverified_context(), - ) + if self._tls_verify: + conn = http.client.HTTPSConnection( + self._target, + context=None if self._tls_verify else ssl._create_unverified_context(), + ) + else: + conn = http.client.HTTPConnection( + self._target, + ) headers = {"Content-type": "application/x-www-form-urlencoded"} data = urlencode( { @@ -261,7 +280,7 @@ def authenticate(self) -> str: "client_secret": self._client_secret, } ) - conn.request("POST", "/token", data, headers) + conn.request("POST", "/diode/auth/token", data, headers) response = conn.getresponse() if response.status != 200: raise DiodeConfigError(f"Failed to obtain access token: {response.reason}") From 6c1a76ea3df5837e9da7234d87d3be0d904549fc Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Tue, 15 Apr 2025 15:39:28 +0100 Subject: [PATCH 05/16] adds config for auth retries --- netboxlabs/diode/sdk/client.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 425ea6d..ab267f2 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -22,6 +22,7 @@ from netboxlabs.diode.sdk.ingester import Entity from netboxlabs.diode.sdk.version import version_semver +_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES" _DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL" _DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN" _CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID" @@ -94,6 +95,7 @@ def __init__( log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper() logging.basicConfig(level=log_level) + self._max_auth_retries = os.getenv(_MAX_RETRIES_ENVVAR_NAME, 3) self._target, self._path, self._tls_verify = parse_target(target) self._app_name = app_name self._app_version = app_version @@ -210,8 +212,7 @@ def ingest( stream: str | None = _DEFAULT_STREAM, ) -> ingester_pb2.IngestResponse: """Ingest entities.""" - max_retries = 3 - for attempt in range(max_retries): + for attempt in range(self._max_auth_retries): try: request = ingester_pb2.IngestRequest( stream=stream, @@ -226,7 +227,7 @@ def ingest( except grpc.RpcError as err: if err.code() == grpc.StatusCode.UNAUTHENTICATED: self._authenticate() - if attempt < max_retries - 1: + if attempt < self._max_auth_retries - 1: _LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}") continue raise DiodeClientError(err) from err From 6b4ceea3aa4435699bfa68c74b0d9f0e844a0a95 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Tue, 15 Apr 2025 15:52:24 +0100 Subject: [PATCH 06/16] add write access to contents for labeller action --- .github/workflows/labeler.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/labeler.yaml b/.github/workflows/labeler.yaml index 8c1fa95..0a05de4 100644 --- a/.github/workflows/labeler.yaml +++ b/.github/workflows/labeler.yaml @@ -9,7 +9,7 @@ concurrency: jobs: triage: permissions: - contents: read + contents: write pull-requests: write issues: write runs-on: ubuntu-latest From 24e5c2c98ce6263b3d29e10e75e4afc9c5ded5d2 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Tue, 15 Apr 2025 20:36:48 +0100 Subject: [PATCH 07/16] reverts action permissions --- .github/workflows/labeler.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/labeler.yaml b/.github/workflows/labeler.yaml index 0a05de4..f4687be 100644 --- a/.github/workflows/labeler.yaml +++ b/.github/workflows/labeler.yaml @@ -9,9 +9,8 @@ concurrency: jobs: triage: permissions: - contents: write + contents: read pull-requests: write - issues: write runs-on: ubuntu-latest timeout-minutes: 5 steps: From 0f87d1f343d63129a0ba0af8e8f3ef9f585e2c4d Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Tue, 15 Apr 2025 20:39:39 +0100 Subject: [PATCH 08/16] lock pytest-cov to 6.0.0 to avoid output parsing issue --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3388bec..ae49496 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ [project.optional-dependencies] # Optional dev = ["black", "check-manifest", "ruff"] -test = ["coverage", "pytest", "pytest-cov"] +test = ["coverage", "pytest", "pytest-cov==6.0.0"] [tool.coverage.run] omit = [ From 0197fc2c7b5140e35061e0090e0b4dd133c9ab51 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Wed, 16 Apr 2025 10:18:17 +0100 Subject: [PATCH 09/16] refactors error handling --- netboxlabs/diode/sdk/client.py | 10 +++++----- tests/test_client.py | 28 +++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index ab267f2..178fb65 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -226,12 +226,11 @@ def ingest( return self._stub.Ingest(request, metadata=self._metadata) except grpc.RpcError as err: if err.code() == grpc.StatusCode.UNAUTHENTICATED: - self._authenticate() if attempt < self._max_auth_retries - 1: _LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}") + self._authenticate() continue raise DiodeClientError(err) from err - return None # should never hit this but it makes the linter happy def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float): sentry_sdk.init( @@ -249,18 +248,19 @@ def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rat sentry_sdk.set_tag("python_version", self._python_version) def _authenticate(self): - authentication_client = _DiodeAuthentication(self._target, self._tls_verify, self._client_id, self._client_secret) + authentication_client = _DiodeAuthentication(self._target, self._tls_verify, self._client_id, self._client_secret, self._path) access_token = authentication_client.authenticate() self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \ [("authorization", f"Bearer {access_token}")] class _DiodeAuthentication: - def __init__(self, target: str, tls_verify: bool, client_id: str, client_secret: str): + def __init__(self, target: str, path: str, tls_verify: bool, client_id: str, client_secret: str): self._target = target self._tls_verify = tls_verify self._client_id = client_id self._client_secret = client_secret + self._path = path def authenticate(self) -> str: """Request an OAuth2 token using client credentials and return it.""" @@ -281,7 +281,7 @@ def authenticate(self) -> str: "client_secret": self._client_secret, } ) - conn.request("POST", "/diode/auth/token", data, headers) + conn.request("POST", f"{self._path}/auth/token", data, headers) response = conn.getresponse() if response.status != 200: raise DiodeConfigError(f"Failed to obtain access token: {response.reason}") diff --git a/tests/test_client.py b/tests/test_client.py index 8acd9e9..4329c62 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,7 @@ import os from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import patch, MagicMock import grpc import pytest @@ -521,3 +521,29 @@ def test_diode_client_with_mocked_authentication(mock_diode_authentication): ) assert client._metadata[0] == ("platform", client._platform) assert client._metadata[-1] == ("authorization", "Bearer mocked_token") + +def test_ingest_retries_on_unauthenticated_error(mock_diode_authentication): + """Test that the ingest method retries on UNAUTHENTICATED error.""" + # Create a mock stub that raises UNAUTHENTICATED error + mock_stub = MagicMock() + mock_stub.Ingest.side_effect = grpc.RpcError() + mock_stub.Ingest.side_effect.code = lambda: grpc.StatusCode.UNAUTHENTICATED + mock_stub.Ingest.side_effect.details = lambda: "Something went wrong" + + client = DiodeClient( + target="grpc://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + # Patch the DiodeClient to use the mock stub + client._stub = mock_stub + + # Attempt to ingest entities and expect a DiodeClientError after retries + with pytest.raises(DiodeClientError): + client.ingest(entities=[]) + + # Verify that the Ingest method was called the expected number of times + assert mock_stub.Ingest.call_count == client._max_auth_retries \ No newline at end of file From 5f3e93b5601a287bd282524ba0ba1a56e995ff70 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Wed, 16 Apr 2025 16:09:17 +0100 Subject: [PATCH 10/16] expose max auth retries on client constructor --- netboxlabs/diode/sdk/client.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 178fb65..79ec5e5 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -69,6 +69,12 @@ def _get_required_config_value(env_var_name: str, value: str | None = None) -> s raise DiodeConfigError(f"parameter or {env_var_name} environment variable required") return value +def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None: + """Get optional config value either from provided value or environment variable.""" + if value is None: + value = os.getenv(env_var_name) + return value + class DiodeClient: """Diode Client.""" @@ -90,12 +96,13 @@ def __init__( sentry_dsn: str = None, sentry_traces_sample_rate: float = 1.0, sentry_profiles_sample_rate: float = 1.0, + max_auth_retries: int = 3, ): """Initiate a new client.""" log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper() logging.basicConfig(level=log_level) - self._max_auth_retries = os.getenv(_MAX_RETRIES_ENVVAR_NAME, 3) + self._max_auth_retries = _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, max_auth_retries) self._target, self._path, self._tls_verify = parse_target(target) self._app_name = app_name self._app_version = app_version From d97975f9dfef76f1e42b2cd2a7a84313b31e268c Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Wed, 16 Apr 2025 17:07:06 +0100 Subject: [PATCH 11/16] fixes linting --- netboxlabs/diode/sdk/client.py | 1 + tests/test_client.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 79ec5e5..758ce0f 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -238,6 +238,7 @@ def ingest( self._authenticate() continue raise DiodeClientError(err) from err + return RuntimeError("Max retries exceeded") def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float): sentry_sdk.init( diff --git a/tests/test_client.py b/tests/test_client.py index 4329c62..b965bd8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,7 @@ import os from unittest import mock -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import grpc import pytest @@ -546,4 +546,4 @@ def test_ingest_retries_on_unauthenticated_error(mock_diode_authentication): client.ingest(entities=[]) # Verify that the Ingest method was called the expected number of times - assert mock_stub.Ingest.call_count == client._max_auth_retries \ No newline at end of file + assert mock_stub.Ingest.call_count == client._max_auth_retries From 4b7ff6e9f5551a1bdb81bf23948a039f99a438c9 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Thu, 17 Apr 2025 09:38:54 +0100 Subject: [PATCH 12/16] fixes linting --- netboxlabs/diode/sdk/ingester.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/netboxlabs/diode/sdk/ingester.py b/netboxlabs/diode/sdk/ingester.py index c5c84ac..3bf7f0b 100644 --- a/netboxlabs/diode/sdk/ingester.py +++ b/netboxlabs/diode/sdk/ingester.py @@ -11,7 +11,9 @@ import datetime import re from typing import Any + from google.protobuf import timestamp_pb2 as _timestamp_pb2 + import netboxlabs.diode.sdk.diode.v1.ingester_pb2 as pb PRIMARY_VALUE_MAP = { From 9bc9e0c74d004cb5a2e0221985755f4a03d9c433 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Thu, 17 Apr 2025 16:21:50 +0100 Subject: [PATCH 13/16] handle empty paths and adds tests --- netboxlabs/diode/sdk/client.py | 5 +-- tests/test_client.py | 58 ++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 758ce0f..ee7ae81 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -256,7 +256,7 @@ def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rat sentry_sdk.set_tag("python_version", self._python_version) def _authenticate(self): - authentication_client = _DiodeAuthentication(self._target, self._tls_verify, self._client_id, self._client_secret, self._path) + authentication_client = _DiodeAuthentication(self._target, self._path, self._tls_verify, self._client_id, self._client_secret) access_token = authentication_client.authenticate() self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \ [("authorization", f"Bearer {access_token}")] @@ -289,7 +289,8 @@ def authenticate(self) -> str: "client_secret": self._client_secret, } ) - conn.request("POST", f"{self._path}/auth/token", data, headers) + url = f"{self._path}/auth/token" if self._path and self._path != "/" else "/auth/token" + conn.request("POST", url, data, headers) response = conn.getresponse() if response.status != 200: raise DiodeConfigError(f"Failed to obtain access token: {response.reason}") diff --git a/tests/test_client.py b/tests/test_client.py index b965bd8..7c22a44 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,6 +5,7 @@ import os from unittest import mock from unittest.mock import MagicMock, patch +import json import grpc import pytest @@ -17,6 +18,7 @@ _get_sentry_dsn, _load_certs, parse_target, + _DiodeAuthentication, ) from netboxlabs.diode.sdk.exceptions import DiodeClientError, DiodeConfigError from netboxlabs.diode.sdk.version import version_semver @@ -522,6 +524,7 @@ def test_diode_client_with_mocked_authentication(mock_diode_authentication): assert client._metadata[0] == ("platform", client._platform) assert client._metadata[-1] == ("authorization", "Bearer mocked_token") + def test_ingest_retries_on_unauthenticated_error(mock_diode_authentication): """Test that the ingest method retries on UNAUTHENTICATED error.""" # Create a mock stub that raises UNAUTHENTICATED error @@ -547,3 +550,58 @@ def test_ingest_retries_on_unauthenticated_error(mock_diode_authentication): # Verify that the Ingest method was called the expected number of times assert mock_stub.Ingest.call_count == client._max_auth_retries + + +def test_diode_authentication_success(mock_diode_authentication): + """Test successful authentication in _DiodeAuthentication.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + ) + with mock.patch("http.client.HTTPConnection") as mock_http_conn: + mock_conn_instance = mock_http_conn.return_value + mock_conn_instance.getresponse.return_value.status = 200 + mock_conn_instance.getresponse.return_value.read.return_value = json.dumps({"access_token": "mocked_token"}).encode() + + token = auth.authenticate() + assert token == "mocked_token" + + +def test_diode_authentication_failure(mock_diode_authentication): + """Test authentication failure in _DiodeAuthentication.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + ) + with mock.patch("http.client.HTTPConnection") as mock_http_conn: + mock_conn_instance = mock_http_conn.return_value + mock_conn_instance.getresponse.return_value.status = 401 + mock_conn_instance.getresponse.return_value.reason = "Unauthorized" + + with pytest.raises(DiodeConfigError) as excinfo: + auth.authenticate() + assert "Failed to obtain access token" in str(excinfo.value) + +@pytest.mark.parametrize("path", ["/diode", "", None]) +def test_diode_authentication_url_with_path(mock_diode_authentication, path): + """Test that the authentication URL is correctly formatted with a path.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path=path, + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + ) + with mock.patch("http.client.HTTPConnection") as mock_http_conn: + mock_conn_instance = mock_http_conn.return_value + mock_conn_instance.getresponse.return_value.status = 200 + mock_conn_instance.getresponse.return_value.read.return_value = json.dumps({"access_token": "mocked_token"}).encode() + auth.authenticate() + mock_conn_instance.request.assert_called_once_with("POST", f"{path or ''}/auth/token", mock.ANY, mock.ANY) + From 90f33a9de3bf9ac7a3bd9b2b75c51a9d2734f526 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Thu, 17 Apr 2025 16:22:16 +0100 Subject: [PATCH 14/16] linting --- tests/test_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 7c22a44..9273c35 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,10 +2,10 @@ # Copyright 2024 NetBox Labs Inc """NetBox Labs - Tests.""" +import json import os from unittest import mock from unittest.mock import MagicMock, patch -import json import grpc import pytest @@ -15,10 +15,10 @@ DiodeClient, DiodeMethodClientInterceptor, _ClientCallDetails, + _DiodeAuthentication, _get_sentry_dsn, _load_certs, parse_target, - _DiodeAuthentication, ) from netboxlabs.diode.sdk.exceptions import DiodeClientError, DiodeConfigError from netboxlabs.diode.sdk.version import version_semver From 210719a9fac437013b3ecdc23e2fe4a924baa6cf Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Thu, 17 Apr 2025 16:40:30 +0100 Subject: [PATCH 15/16] adds better bath handling --- netboxlabs/diode/sdk/client.py | 9 +++++++-- tests/test_client.py | 11 +++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index ee7ae81..429cab5 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -113,7 +113,6 @@ def __init__( self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id) self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret) - self._metadata = ( ("platform", self._platform), ("python-version", self._python_version), @@ -289,7 +288,7 @@ def authenticate(self) -> str: "client_secret": self._client_secret, } ) - url = f"{self._path}/auth/token" if self._path and self._path != "/" else "/auth/token" + url = self._get_auth_url() conn.request("POST", url, data, headers) response = conn.getresponse() if response.status != 200: @@ -302,6 +301,12 @@ def authenticate(self) -> str: _LOGGER.debug(f"Access token obtained for client {self._client_id}") return access_token + def _get_auth_url(self) -> str: + """Construct the authentication URL, handling trailing slashes in the path.""" + # Ensure the path does not have trailing slashes + path = self._path.rstrip('/') if self._path else '' + return f"{path}/auth/token" + class _ClientCallDetails( collections.namedtuple( diff --git a/tests/test_client.py b/tests/test_client.py index 9273c35..4cc510e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -588,7 +588,14 @@ def test_diode_authentication_failure(mock_diode_authentication): auth.authenticate() assert "Failed to obtain access token" in str(excinfo.value) -@pytest.mark.parametrize("path", ["/diode", "", None]) +@pytest.mark.parametrize("path", [ + "/diode", + "", + None, + "/diode/", + "diode", + "diode/", + ]) def test_diode_authentication_url_with_path(mock_diode_authentication, path): """Test that the authentication URL is correctly formatted with a path.""" auth = _DiodeAuthentication( @@ -603,5 +610,5 @@ def test_diode_authentication_url_with_path(mock_diode_authentication, path): mock_conn_instance.getresponse.return_value.status = 200 mock_conn_instance.getresponse.return_value.read.return_value = json.dumps({"access_token": "mocked_token"}).encode() auth.authenticate() - mock_conn_instance.request.assert_called_once_with("POST", f"{path or ''}/auth/token", mock.ANY, mock.ANY) + mock_conn_instance.request.assert_called_once_with("POST", f"{(path or '').rstrip('/')}/auth/token", mock.ANY, mock.ANY) From a285af3d43d36e4ac0b06c912f38ab257cde561d Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Thu, 17 Apr 2025 16:45:53 +0100 Subject: [PATCH 16/16] omits ingester from coverage report --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ae49496..31b892b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ test = ["coverage", "pytest", "pytest-cov==6.0.0"] [tool.coverage.run] omit = [ + "*/netboxlabs/diode/sdk/ingester.py", "*/netboxlabs/diode/sdk/diode/*", "*/netboxlabs/diode/sdk/validate/*", "*/tests/*",