Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion netbox_branching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ class AppConfig(PluginConfig):

def ready(self):
super().ready()
from django.core.signals import request_started, request_finished
from . import constants, events, search, signal_receivers, webhook_callbacks # noqa: F401
from .models import Branch
from .utilities import DynamicSchemaDict
from .utilities import DynamicSchemaDict, close_old_branch_connections

# Validate required settings
if type(settings.DATABASES) is not DynamicSchemaDict:
Expand All @@ -59,6 +60,14 @@ def ready(self):
"netbox_branching: DATABASE_ROUTERS must contain 'netbox_branching.database.BranchAwareRouter'."
)

# Register cleanup handler for branch connections (#358)
# This ensures branch connections are closed when they exceed CONN_MAX_AGE,
# preventing connection leaks. Django's built-in close_old_connections()
# only handles connections in DATABASES.keys(), which doesn't include
# dynamically-created branch aliases.
request_started.connect(close_old_branch_connections)
request_finished.connect(close_old_branch_connections)

# Register the "branching" model feature
register_model_feature('branching', supports_branching)

Expand Down
91 changes: 91 additions & 0 deletions netbox_branching/tests/test_connection_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import time

from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import connections
from django.test import tag, TransactionTestCase

from netbox_branching.models import Branch
from netbox_branching.utilities import activate_branch, close_old_branch_connections


@tag('regression') # netbox-branching #358
class BranchConnectionLifecycleTestCase(TransactionTestCase):
def setUp(self):
"""Set up test environment with CONN_MAX_AGE=1."""
self.original_max_age = settings.DATABASES['default'].get('CONN_MAX_AGE', 0)
settings.DATABASES['default']['CONN_MAX_AGE'] = 1
self.user = get_user_model().objects.create_user(username='testuser', is_superuser=True)
self.branches = []

def tearDown(self):
"""Clean up branches and restore CONN_MAX_AGE."""
for branch in self.branches:
try:
connections[branch.connection_name].close()
except Exception:
pass
Branch.objects.filter(pk=branch.pk).delete()
settings.DATABASES['default']['CONN_MAX_AGE'] = self.original_max_age

def create_and_provision_branch(self, name):
"""Create and provision a test branch."""
branch = Branch(name=name, description=f'Test {name}')
branch.save(provision=False)
branch.provision(self.user)
self.branches.append(branch)
return branch

def open_branch_connection(self, branch):
"""Open a connection to the branch by executing a query."""
with activate_branch(branch):
from django.contrib.contenttypes.models import ContentType
list(ContentType.objects.using(branch.connection_name).all()[:1])

def test_branch_connections_close_after_max_age(self):
"""Branch connections should close after CONN_MAX_AGE expires."""
branch = self.create_and_provision_branch('test-conn-cleanup')
self.open_branch_connection(branch)

conn = connections[branch.connection_name]
self.assertIsNotNone(conn.connection, "Connection should be open after query")
self.assertIsNotNone(conn.close_at, "close_at should be set when CONN_MAX_AGE > 0")

time.sleep(2)
close_old_branch_connections()

self.assertIsNone(conn.connection, "Connection should be closed after CONN_MAX_AGE expires")

def test_multiple_branch_connections_cleanup(self):
"""Multiple branch connections should all close after CONN_MAX_AGE."""
branches = [self.create_and_provision_branch(f'test-multi-{i}') for i in range(3)]

for branch in branches:
self.open_branch_connection(branch)

conns = [connections[b.connection_name] for b in branches]
for conn in conns:
self.assertIsNotNone(conn.connection, "Connection should be open")

time.sleep(2)
close_old_branch_connections()

for i, conn in enumerate(conns):
self.assertIsNone(conn.connection, f"Branch {i} connection should be closed")

def test_cleanup_handles_deleted_branch(self):
"""Cleanup should gracefully handle connections to deleted branch schemas."""
branch = self.create_and_provision_branch('test-deleted-branch')
self.open_branch_connection(branch)

conn = connections[branch.connection_name]
self.assertIsNotNone(conn.connection, "Connection should be open")

branch.deprovision()
Branch.objects.filter(pk=branch.pk).delete()
self.branches.remove(branch)

try:
close_old_branch_connections()
except Exception as e:
self.fail(f"cleanup should not raise exception for deleted branch: {e}")
45 changes: 45 additions & 0 deletions netbox_branching/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from dataclasses import dataclass
from functools import cached_property

from asgiref.local import Local
from django.contrib import messages
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.db import connections
from django.db.models import ForeignKey, ManyToManyField
from django.http import HttpResponseBadRequest
from django.urls import reverse
Expand All @@ -16,13 +18,19 @@
from .constants import BRANCH_HEADER, COOKIE_NAME, EXEMPT_MODELS, INCLUDE_MODELS, QUERY_PARAM
from .contextvars import active_branch

# Thread-local storage for tracking branch connection aliases (matches Django's approach)
# Note: Aliases are tracked once and never removed, matching Django's pattern where
# DATABASES.keys() is static. Memory overhead is negligible (string references only).
_branch_connections_tracker = Local(thread_critical=False)

__all__ = (
'BranchActionIndicator',
'ChangeSummary',
'DynamicSchemaDict',
'ListHandler',
'ActiveBranchContextManager',
'activate_branch',
'close_old_branch_connections',
'deactivate_branch',
'get_active_branch',
'get_branchable_object_types',
Expand All @@ -31,10 +39,23 @@
'is_api_request',
'record_applied_change',
'supports_branching',
'track_branch_connection',
'update_object',
)


def _get_tracked_branch_aliases():
"""Get set of tracked branch aliases for current thread."""
if not hasattr(_branch_connections_tracker, 'aliases'):
_branch_connections_tracker.aliases = set()
return _branch_connections_tracker.aliases


def track_branch_connection(alias):
"""Register a branch connection alias for cleanup tracking."""
_get_tracked_branch_aliases().add(alias)


class DynamicSchemaDict(dict):
"""
Behaves like a normal dictionary, except for keys beginning with "schema_". Any lookup for
Expand All @@ -47,6 +68,8 @@ def main_schema(self):
def __getitem__(self, item):
if type(item) is str and item.startswith('schema_'):
if schema := item.removeprefix('schema_'):
track_branch_connection(item)

default_config = super().__getitem__('default')
return {
**default_config,
Expand All @@ -62,6 +85,28 @@ def __contains__(self, item):
return super().__contains__(item)


def close_old_branch_connections(**kwargs):
"""
Close branch database connections that have exceeded their maximum age.

This function complements Django's close_old_connections() by handling
dynamically-created branch connections. It tracks branch connection aliases
in thread-local storage and closes them when they exceed CONN_MAX_AGE.

Django's close_old_connections() only closes connections for database aliases
found in DATABASES.keys(). Since branch aliases are generated dynamically and
not present in that iteration (to avoid test isolation issues), they would never
be cleaned up, causing connection leaks.

This function is connected to request_started and request_finished signals,
matching Django's cleanup timing.
"""

for alias in _get_tracked_branch_aliases():
conn = connections[alias]
conn.close_if_unusable_or_obsolete()


@contextmanager
def activate_branch(branch):
"""
Expand Down
Loading