diff --git a/netbox_branching/__init__.py b/netbox_branching/__init__.py index 3b50985..b991e0a 100644 --- a/netbox_branching/__init__.py +++ b/netbox_branching/__init__.py @@ -45,9 +45,10 @@ class AppConfig(PluginConfig): def ready(self): super().ready() + from django.core.signals import request_started, request_finished from . import constants, events, search, signal_receivers, webhook_callbacks # noqa: F401 from .models import Branch - from .utilities import DynamicSchemaDict + from .utilities import DynamicSchemaDict, close_old_branch_connections # Validate required settings if type(settings.DATABASES) is not DynamicSchemaDict: @@ -59,6 +60,14 @@ def ready(self): "netbox_branching: DATABASE_ROUTERS must contain 'netbox_branching.database.BranchAwareRouter'." ) + # Register cleanup handler for branch connections (#358) + # This ensures branch connections are closed when they exceed CONN_MAX_AGE, + # preventing connection leaks. Django's built-in close_old_connections() + # only handles connections in DATABASES.keys(), which doesn't include + # dynamically-created branch aliases. + request_started.connect(close_old_branch_connections) + request_finished.connect(close_old_branch_connections) + # Register the "branching" model feature register_model_feature('branching', supports_branching) diff --git a/netbox_branching/tests/test_connection_lifecycle.py b/netbox_branching/tests/test_connection_lifecycle.py new file mode 100644 index 0000000..11d9b99 --- /dev/null +++ b/netbox_branching/tests/test_connection_lifecycle.py @@ -0,0 +1,93 @@ +import time + +from django.conf import settings +from django.contrib.auth import get_user_model +from django.db import connections +from django.test import tag, TransactionTestCase + +from netbox_branching.models import Branch +from netbox_branching.utilities import activate_branch, close_old_branch_connections + + +@tag('regression') # netbox-branching #358 +class BranchConnectionLifecycleTestCase(TransactionTestCase): + serialized_rollback = True + + def setUp(self): + """Set up test environment with CONN_MAX_AGE=1.""" + self.original_max_age = settings.DATABASES['default'].get('CONN_MAX_AGE', 0) + settings.DATABASES['default']['CONN_MAX_AGE'] = 1 + self.user = get_user_model().objects.create_user(username='testuser', is_superuser=True) + self.branches = [] + + def tearDown(self): + """Clean up branches and restore CONN_MAX_AGE.""" + for branch in self.branches: + try: + connections[branch.connection_name].close() + except Exception: + pass + Branch.objects.filter(pk=branch.pk).delete() + settings.DATABASES['default']['CONN_MAX_AGE'] = self.original_max_age + + def create_and_provision_branch(self, name): + """Create and provision a test branch.""" + branch = Branch(name=name, description=f'Test {name}') + branch.save(provision=False) + branch.provision(self.user) + self.branches.append(branch) + return branch + + def open_branch_connection(self, branch): + """Open a connection to the branch by executing a query.""" + with activate_branch(branch): + from django.contrib.contenttypes.models import ContentType + list(ContentType.objects.using(branch.connection_name).all()[:1]) + + def test_branch_connections_close_after_max_age(self): + """Branch connections should close after CONN_MAX_AGE expires.""" + branch = self.create_and_provision_branch('test-conn-cleanup') + self.open_branch_connection(branch) + + conn = connections[branch.connection_name] + self.assertIsNotNone(conn.connection, "Connection should be open after query") + self.assertIsNotNone(conn.close_at, "close_at should be set when CONN_MAX_AGE > 0") + + time.sleep(2) + close_old_branch_connections() + + self.assertIsNone(conn.connection, "Connection should be closed after CONN_MAX_AGE expires") + + def test_multiple_branch_connections_cleanup(self): + """Multiple branch connections should all close after CONN_MAX_AGE.""" + branches = [self.create_and_provision_branch(f'test-multi-{i}') for i in range(3)] + + for branch in branches: + self.open_branch_connection(branch) + + conns = [connections[b.connection_name] for b in branches] + for conn in conns: + self.assertIsNotNone(conn.connection, "Connection should be open") + + time.sleep(2) + close_old_branch_connections() + + for i, conn in enumerate(conns): + self.assertIsNone(conn.connection, f"Branch {i} connection should be closed") + + def test_cleanup_handles_deleted_branch(self): + """Cleanup should gracefully handle connections to deleted branch schemas.""" + branch = self.create_and_provision_branch('test-deleted-branch') + self.open_branch_connection(branch) + + conn = connections[branch.connection_name] + self.assertIsNotNone(conn.connection, "Connection should be open") + + branch.deprovision() + Branch.objects.filter(pk=branch.pk).delete() + self.branches.remove(branch) + + try: + close_old_branch_connections() + except Exception as e: + self.fail(f"cleanup should not raise exception for deleted branch: {e}") diff --git a/netbox_branching/utilities.py b/netbox_branching/utilities.py index f35f7fd..447c850 100644 --- a/netbox_branching/utilities.py +++ b/netbox_branching/utilities.py @@ -5,8 +5,10 @@ from dataclasses import dataclass from functools import cached_property +from asgiref.local import Local from django.contrib import messages from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist +from django.db import connections from django.db.models import ForeignKey, ManyToManyField from django.http import HttpResponseBadRequest from django.urls import reverse @@ -16,6 +18,11 @@ from .constants import BRANCH_HEADER, COOKIE_NAME, EXEMPT_MODELS, INCLUDE_MODELS, QUERY_PARAM from .contextvars import active_branch +# Thread-local storage for tracking branch connection aliases (matches Django's approach) +# Note: Aliases are tracked once and never removed, matching Django's pattern where +# DATABASES.keys() is static. Memory overhead is negligible (string references only). +_branch_connections_tracker = Local(thread_critical=False) + __all__ = ( 'BranchActionIndicator', 'ChangeSummary', @@ -23,6 +30,7 @@ 'ListHandler', 'ActiveBranchContextManager', 'activate_branch', + 'close_old_branch_connections', 'deactivate_branch', 'get_active_branch', 'get_branchable_object_types', @@ -31,10 +39,23 @@ 'is_api_request', 'record_applied_change', 'supports_branching', + 'track_branch_connection', 'update_object', ) +def _get_tracked_branch_aliases(): + """Get set of tracked branch aliases for current thread.""" + if not hasattr(_branch_connections_tracker, 'aliases'): + _branch_connections_tracker.aliases = set() + return _branch_connections_tracker.aliases + + +def track_branch_connection(alias): + """Register a branch connection alias for cleanup tracking.""" + _get_tracked_branch_aliases().add(alias) + + class DynamicSchemaDict(dict): """ Behaves like a normal dictionary, except for keys beginning with "schema_". Any lookup for @@ -47,6 +68,8 @@ def main_schema(self): def __getitem__(self, item): if type(item) is str and item.startswith('schema_'): if schema := item.removeprefix('schema_'): + track_branch_connection(item) + default_config = super().__getitem__('default') return { **default_config, @@ -62,6 +85,28 @@ def __contains__(self, item): return super().__contains__(item) +def close_old_branch_connections(**kwargs): + """ + Close branch database connections that have exceeded their maximum age. + + This function complements Django's close_old_connections() by handling + dynamically-created branch connections. It tracks branch connection aliases + in thread-local storage and closes them when they exceed CONN_MAX_AGE. + + Django's close_old_connections() only closes connections for database aliases + found in DATABASES.keys(). Since branch aliases are generated dynamically and + not present in that iteration (to avoid test isolation issues), they would never + be cleaned up, causing connection leaks. + + This function is connected to request_started and request_finished signals, + matching Django's cleanup timing. + """ + + for alias in _get_tracked_branch_aliases(): + conn = connections[alias] + conn.close_if_unusable_or_obsolete() + + @contextmanager def activate_branch(branch): """