diff --git a/README.md b/README.md index 41e69eb..5226e1c 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,8 @@ pip install netboxlabs-diode-sdk * `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting * `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication * `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication +* `DIODE_CERT_FILE` - Path to custom certificate file for TLS connections +* `DIODE_SKIP_TLS_VERIFY` - Skip TLS verification (default: `false`) * `DIODE_DRY_RUN_OUTPUT_DIR` - Directory where `DiodeDryRunClient` will write JSON files ### Example @@ -77,6 +79,36 @@ if __name__ == "__main__": ``` +### TLS verification and certificates + +TLS verification is controlled by the target URL scheme: +- **Secure schemes** (`grpcs://`, `https://`): TLS verification enabled +- **Insecure schemes** (`grpc://`, `http://`): TLS verification disabled + +```python +# TLS verification enabled (uses system certificates) +client = DiodeClient(target="grpcs://example.com", ...) + +# TLS verification disabled +client = DiodeClient(target="grpc://example.com", ...) +``` + +#### Using custom certificates + +```python +# Via constructor parameter +client = DiodeClient(target="grpcs://example.com", cert_file="/path/to/cert.pem", ...) + +# Or via environment variable +export DIODE_CERT_FILE=/path/to/cert.pem +``` + +#### Disabling TLS verification + +```bash +export DIODE_SKIP_TLS_VERIFY=true +``` + ### Dry run mode `DiodeDryRunClient` generates ingestion requests without contacting a Diode server. Requests are printed to stdout by default, or written to JSON files when `output_dir` (or the `DIODE_DRY_RUN_OUTPUT_DIR` environment variable) is specified. The `app_name` parameter serves as the filename prefix; if not provided, `dryrun` is used as the default prefix. The file name is suffixed with a nanosecond-precision timestamp, resulting in the format `_.json`. diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 045ab0d..295514b 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -26,16 +26,17 @@ from netboxlabs.diode.sdk.ingester import Entity from netboxlabs.diode.sdk.version import version_semver -_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES" -_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL" -_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN" _CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID" _CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET" +_DEFAULT_STREAM = "latest" +_DIODE_CERT_FILE_ENVVAR_NAME = "DIODE_CERT_FILE" +_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL" +_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN" +_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME = "DIODE_SKIP_TLS_VERIFY" _DRY_RUN_OUTPUT_DIR_ENVVAR_NAME = "DIODE_DRY_RUN_OUTPUT_DIR" _INGEST_SCOPE = "diode:ingest" -_DEFAULT_STREAM = "latest" _LOGGER = logging.getLogger(__name__) - +_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES" def load_dryrun_entities(file_path: str | Path) -> Iterable[Entity]: """Yield entities from a file with concatenated JSON messages.""" @@ -53,12 +54,26 @@ class DiodeClientInterface: pass -def _load_certs() -> bytes: - """Loads cacert.pem.""" - with open(certifi.where(), "rb") as f: +def _load_certs(cert_file: str | None = None) -> bytes: + """Loads cacert.pem or custom certificate file.""" + cert_path = cert_file or certifi.where() + with open(cert_path, "rb") as f: return f.read() +def _should_verify_tls(scheme: str) -> bool: + """Determine if TLS verification should be enabled based on scheme and environment variable.""" + # Check if scheme is insecure + insecure_scheme = scheme in ["grpc", "http"] + + # Check environment variable + skip_tls_env = os.getenv(_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME, "").lower() + skip_tls_from_env = skip_tls_env in ["true", "1", "yes", "on"] + + # TLS verification is enabled by default, disabled only for insecure schemes or env var + return not (insecure_scheme or skip_tls_from_env) + + def parse_target(target: str) -> tuple[str, str, bool]: """Parse the target into authority, path and tls_verify.""" parsed_target = urlparse(target) @@ -66,7 +81,8 @@ def parse_target(target: str) -> tuple[str, str, bool]: if parsed_target.scheme not in ["grpc", "grpcs", "http", "https"]: raise ValueError("target should start with grpc://, grpcs://, http:// or https://") - tls_verify = parsed_target.scheme in ["grpcs", "https"] + # Determine if TLS verification should be enabled + tls_verify = _should_verify_tls(parsed_target.scheme) authority = parsed_target.netloc @@ -127,15 +143,22 @@ def __init__( sentry_traces_sample_rate: float = 1.0, sentry_profiles_sample_rate: float = 1.0, max_auth_retries: int = 3, + cert_file: str | None = None, ): """Initiate a new client.""" log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper() logging.basicConfig(level=log_level) - self._max_auth_retries = _get_optional_config_value( - _MAX_RETRIES_ENVVAR_NAME, max_auth_retries + self._max_auth_retries = int(_get_optional_config_value( + _MAX_RETRIES_ENVVAR_NAME, str(max_auth_retries) + ) or max_auth_retries) + self._cert_file = _get_optional_config_value( + _DIODE_CERT_FILE_ENVVAR_NAME, cert_file ) self._target, self._path, self._tls_verify = parse_target(target) + + # Load certificates once if needed + self._certificates = _load_certs(self._cert_file) if (self._tls_verify or self._cert_file) else None self._app_name = app_name self._app_version = app_version self._platform = platform.platform() @@ -161,12 +184,12 @@ def __init__( ), ) - if self._tls_verify: + if self._tls_verify and self._certificates: _LOGGER.debug("Setting up gRPC secure channel") self._channel = grpc.secure_channel( self._target, grpc.ssl_channel_credentials( - root_certificates=_load_certs(), + root_certificates=self._certificates, ), options=channel_opts, ) @@ -304,6 +327,7 @@ def _authenticate(self, scope: str): self._client_id, self._client_secret, scope, + self._certificates, ) access_token = authentication_client.authenticate() self._metadata = list( @@ -391,6 +415,7 @@ def __init__( client_id: str, client_secret: str, scope: str, + certificates: bytes | None = None, ): self._target = target self._tls_verify = tls_verify @@ -398,13 +423,16 @@ def __init__( self._client_secret = client_secret self._path = path self._scope = scope + self._certificates = certificates def authenticate(self) -> str: """Request an OAuth2 token using client credentials and return it.""" - if self._tls_verify: + if self._tls_verify and self._certificates: + context = ssl.create_default_context() + context.load_verify_locations(cadata=self._certificates.decode('utf-8')) conn = http.client.HTTPSConnection( self._target, - context=None if self._tls_verify else ssl._create_unverified_context(), + context=context, ) else: conn = http.client.HTTPConnection( diff --git a/tests/test_client.py b/tests/test_client.py index a265ddc..4652040 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -109,6 +109,7 @@ def test_parse_target_handles_ftp_prefix(): with pytest.raises(ValueError): parse_target("ftp://localhost:8081") + def test_parse_target_parses_authority_correctly(): """Check that parse_target parses the authority correctly.""" authority, path, tls_verify = parse_target("grpc://localhost:8081") @@ -739,24 +740,445 @@ def test_load_dryrun_entities_from_fixture(message_path, tmp_path): ) assert entities[-1].wireless_link.ssid == "P2P-Link-1" - client = DiodeDryRunClient(output_dir=str(tmp_path)) - client._stub = MagicMock() - client.ingest(entities=entities) +def test_diode_authentication_with_custom_certificates(): + """Test _DiodeAuthentication with custom certificates - covers SSL context creation.""" + # Create test certificate content + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) - assert client._stub.Ingest.call_count == 0 - files = list(tmp_path.glob("dryrun*.json")) - assert len(files) == 1 - entities = list(load_dryrun_entities(files[0])) - assert files[0].read_text().startswith("{") + auth = _DiodeAuthentication( + target="example.com:443", + path="/api/v1", + tls_verify=True, + client_id="test_client", + client_secret="test_secret", + scope="test_scope", + certificates=cert_content, + ) - entities = list(load_dryrun_entities(files[0])) + with ( + mock.patch("http.client.HTTPSConnection") as mock_https_conn, + mock.patch("ssl.create_default_context") as mock_ssl_context, + ): + # Setup mocks + mock_context_instance = mock.Mock() + mock_ssl_context.return_value = mock_context_instance - assert len(entities) == 94 - assert isinstance(entities[0], ingester_pb2.Entity) - assert entities[0].asn.asn == 555 - assert entities[33].ip_address.address == "192.168.100.1/24" - assert ( - entities[33].ip_address.assigned_object_interface.name == "GigabitEthernet1/0/1" + mock_conn_instance = mock.Mock() + mock_https_conn.return_value = mock_conn_instance + + mock_response = mock.Mock() + mock_response.status = 200 + mock_response.read.return_value = b'{"access_token": "test_token"}' + mock_conn_instance.getresponse.return_value = mock_response + + # Call authenticate to trigger SSL context creation + token = auth.authenticate() + + # Verify SSL context was created and configured with custom certs + mock_ssl_context.assert_called_once() + mock_context_instance.load_verify_locations.assert_called_once_with( + cadata=cert_content.decode("utf-8") + ) + + # Verify HTTPS connection was created with custom context + mock_https_conn.assert_called_once_with( + "example.com:443", + context=mock_context_instance, + ) + + # Verify token was returned + assert token == "test_token" + + +def test_load_certs_with_custom_cert_file(tmp_path): + """Test _load_certs loads custom certificate file.""" + # Create a dummy certificate file + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" ) - assert entities[-1].wireless_link.ssid == "P2P-Link-1" + cert_file = tmp_path / "custom.pem" + cert_file.write_bytes(cert_content) + + result = _load_certs(str(cert_file)) + assert result == cert_content + + +def test_load_certs_with_none_uses_default(): + """Test _load_certs uses default certifi when cert_file is None.""" + result = _load_certs(None) + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_client_with_cert_file_parameter(mock_diode_authentication, tmp_path): + """Test DiodeClient with cert_file parameter loads custom cert but respects TLS scheme.""" + # Create a dummy certificate file + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) + cert_file = tmp_path / "custom.pem" + cert_file.write_bytes(cert_content) + + with mock.patch("grpc.insecure_channel") as mock_insecure_channel: + client = DiodeClient( + target="grpc://localhost:8081", # Note: grpc:// insecure scheme + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + cert_file=str(cert_file), + ) + + # Should respect scheme (insecure) even with cert file + assert client.tls_verify is False + + # Should use insecure channel + mock_insecure_channel.assert_called_once() + + # Verify certificate was still loaded for potential use + assert client._certificates == cert_content + + +def test_client_with_cert_file_env_var(mock_diode_authentication, tmp_path): + """Test DiodeClient with DIODE_CERT_FILE environment variable respects scheme.""" + from netboxlabs.diode.sdk.client import _DIODE_CERT_FILE_ENVVAR_NAME + + # Create a dummy certificate file + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) + cert_file = tmp_path / "custom.pem" + cert_file.write_bytes(cert_content) + + # Set environment variable + original_env = os.environ.get(_DIODE_CERT_FILE_ENVVAR_NAME) + os.environ[_DIODE_CERT_FILE_ENVVAR_NAME] = str(cert_file) + + try: + with mock.patch("grpc.insecure_channel") as mock_insecure_channel: + client = DiodeClient( + target="grpc://localhost:8081", # Note: grpc:// insecure scheme + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + # Should respect scheme (insecure) even with cert file + assert client.tls_verify is False + + # Should use insecure channel + mock_insecure_channel.assert_called_once() + + # Verify certificate was still loaded + assert client._certificates == cert_content + + finally: + # Clean up environment variable + if original_env is not None: + os.environ[_DIODE_CERT_FILE_ENVVAR_NAME] = original_env + else: + if _DIODE_CERT_FILE_ENVVAR_NAME in os.environ: + del os.environ[_DIODE_CERT_FILE_ENVVAR_NAME] + + +def test_client_cert_file_parameter_overrides_env_var( + mock_diode_authentication, tmp_path +): + """Test cert_file parameter takes precedence over environment variable.""" + from netboxlabs.diode.sdk.client import _DIODE_CERT_FILE_ENVVAR_NAME + + # Create two dummy certificate files + env_cert_content = ( + b"-----BEGIN CERTIFICATE-----\nENV CERT\n-----END CERTIFICATE-----\n" + ) + param_cert_content = ( + b"-----BEGIN CERTIFICATE-----\nPARAM CERT\n-----END CERTIFICATE-----\n" + ) + + env_cert_file = tmp_path / "env.pem" + param_cert_file = tmp_path / "param.pem" + + env_cert_file.write_bytes(env_cert_content) + param_cert_file.write_bytes(param_cert_content) + + # Set environment variable + original_env = os.environ.get(_DIODE_CERT_FILE_ENVVAR_NAME) + os.environ[_DIODE_CERT_FILE_ENVVAR_NAME] = str(env_cert_file) + + try: + with mock.patch("netboxlabs.diode.sdk.client._load_certs") as mock_load_certs: + mock_load_certs.return_value = param_cert_content + + client = DiodeClient( + target="grpc://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + cert_file=str(param_cert_file), + ) + + # Should use the parameter file, not the environment variable + mock_load_certs.assert_called_with(str(param_cert_file)) + # grpc:// scheme should keep tls_verify=False even with cert file + assert client.tls_verify is False + + finally: + # Clean up environment variable + if original_env is not None: + os.environ[_DIODE_CERT_FILE_ENVVAR_NAME] = original_env + else: + if _DIODE_CERT_FILE_ENVVAR_NAME in os.environ: + del os.environ[_DIODE_CERT_FILE_ENVVAR_NAME] + + +def test_client_secure_channel_uses_custom_cert(mock_diode_authentication, tmp_path): + """Test secure channel creation uses custom certificate when provided.""" + # Create a dummy certificate file + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) + cert_file = tmp_path / "custom.pem" + cert_file.write_bytes(cert_content) + + with ( + mock.patch("grpc.secure_channel") as mock_secure_channel, + mock.patch("grpc.ssl_channel_credentials") as mock_ssl_creds, + mock.patch("netboxlabs.diode.sdk.client._load_certs") as mock_load_certs, + ): + mock_load_certs.return_value = cert_content + + _ = DiodeClient( + target="grpcs://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + cert_file=str(cert_file), + ) + + # Verify _load_certs was called with the custom cert file + mock_load_certs.assert_called_with(str(cert_file)) + + # Verify ssl_channel_credentials was called with the custom cert content + mock_ssl_creds.assert_called_once_with(root_certificates=cert_content) + + # Verify secure_channel was called + mock_secure_channel.assert_called_once() + + +def test_client_without_cert_file_uses_default_certs(mock_diode_authentication): + """Test secure channel uses default certificates when no cert_file provided.""" + with ( + mock.patch("grpc.secure_channel") as mock_secure_channel, + mock.patch("grpc.ssl_channel_credentials") as mock_ssl_creds, + mock.patch("netboxlabs.diode.sdk.client._load_certs") as mock_load_certs, + ): + mock_load_certs.return_value = b"default cert content" + + _ = DiodeClient( + target="grpcs://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + # Verify _load_certs was called with None (default) + mock_load_certs.assert_called_with(None) + + # Verify ssl_channel_credentials was called with default cert content + mock_ssl_creds.assert_called_once_with( + root_certificates=b"default cert content" + ) + + # Verify secure_channel was called + mock_secure_channel.assert_called_once() + + +def test_should_verify_tls_with_different_schemes(): + """Test _should_verify_tls with different URL schemes.""" + from netboxlabs.diode.sdk.client import ( + _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME, + _should_verify_tls, + ) + + # Clear environment variable to avoid interference + if _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME in os.environ: + del os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] + + assert _should_verify_tls("grpc") is False # insecure scheme + assert _should_verify_tls("http") is False # insecure scheme + assert _should_verify_tls("grpcs") is True # secure scheme + assert _should_verify_tls("https") is True # secure scheme + + +def test_should_verify_tls_with_skip_env_var(): + """Test _should_verify_tls with DIODE_SKIP_TLS_VERIFY environment variable.""" + from netboxlabs.diode.sdk.client import ( + _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME, + _should_verify_tls, + ) + + original_env = os.environ.get(_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME) + + try: + # Test truthy values that should skip TLS verification + for skip_value in ["true", "True", "TRUE", "1", "yes", "on"]: + os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = skip_value + assert ( + _should_verify_tls("grpcs") is False + ) # Should skip even for secure schemes + + # Test falsy values that should NOT skip TLS verification + for verify_value in ["false", "0", "no", "off", "", "random"]: + os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = verify_value + assert ( + _should_verify_tls("grpcs") is True + ) # Should verify for secure schemes + + finally: + # Clean up environment variable + if original_env is not None: + os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = original_env + else: + if _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME in os.environ: + del os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] + + +def test_client_with_skip_tls_verify_env_var(mock_diode_authentication): + """Test DiodeClient with DIODE_SKIP_TLS_VERIFY environment variable.""" + from netboxlabs.diode.sdk.client import _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME + + original_env = os.environ.get(_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME) + + try: + # Set environment variable to skip TLS verification + os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = "true" + + with mock.patch("grpc.insecure_channel") as mock_insecure_channel: + client = DiodeClient( + target="grpcs://localhost:8081", # Note: grpcs:// but TLS should be skipped + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + # Should skip TLS verification due to environment variable + assert client.tls_verify is False + + # Should use insecure channel even with grpcs:// + mock_insecure_channel.assert_called_once() + + finally: + # Clean up environment variable + if original_env is not None: + os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = original_env + else: + if _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME in os.environ: + del os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] + + +def test_client_cert_file_with_skip_tls_verify_env_var( + mock_diode_authentication, tmp_path +): + """Test cert_file parameter with DIODE_SKIP_TLS_VERIFY environment variable.""" + from netboxlabs.diode.sdk.client import _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME + + # Create a dummy certificate file + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) + cert_file = tmp_path / "custom.pem" + cert_file.write_bytes(cert_content) + + original_skip_env = os.environ.get(_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME) + + try: + # Set environment variable to skip TLS verification + os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = "true" + + with mock.patch("grpc.insecure_channel") as mock_insecure_channel: + client = DiodeClient( + target="grpcs://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + cert_file=str(cert_file), + ) + + # Should respect DIODE_SKIP_TLS_VERIFY=true even with cert_file + assert client.tls_verify is False + + # Should use insecure channel due to environment variable + mock_insecure_channel.assert_called_once() + + # Certificate should still be loaded for potential use + assert client._certificates == cert_content + + finally: + # Clean up environment variable + if original_skip_env is not None: + os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = original_skip_env + else: + if _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME in os.environ: + del os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] + + +def test_certificate_loading_efficiency(tmp_path): + """Test that certificates are loaded only once during client initialization.""" + # Create a dummy certificate file + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) + cert_file = tmp_path / "custom.pem" + cert_file.write_bytes(cert_content) + + with ( + mock.patch("netboxlabs.diode.sdk.client._load_certs") as mock_load_certs, + mock.patch( + "netboxlabs.diode.sdk.client._DiodeAuthentication" + ) as mock_auth_class, + ): + mock_load_certs.return_value = cert_content + mock_auth_instance = mock_auth_class.return_value + mock_auth_instance.authenticate.return_value = "test_token" + + # Create client with custom certificate + client = DiodeClient( + target="grpcs://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + cert_file=str(cert_file), + ) + + # Verify _load_certs was called exactly once during initialization + mock_load_certs.assert_called_once_with(str(cert_file)) + + # Verify certificates are stored and reused + assert client._certificates == cert_content + + # Verify that the authentication class was created with the certificate bytes + mock_auth_class.assert_called_once() + auth_call_args = mock_auth_class.call_args + + # The last argument should be the certificate bytes + assert auth_call_args[0][-1] == cert_content # certificates parameter + + # Reset the mock to verify no additional calls during authentication + mock_load_certs.reset_mock() + + # Authentication should have already been called during initialization + # and should have used the preloaded certificates + mock_auth_instance.authenticate.assert_called_once() + + # Verify _load_certs was NOT called again (certificates reused) + mock_load_certs.assert_not_called()