Skip to content

Commit e69cb4e

Browse files
committed
refactor: enhance Redis connection management and backup functionality
- Introduced a Redis connection pool to optimize Redis client management, improving performance and resource utilization. - Updated session management functions to utilize the new Redis connection pool, ensuring efficient access to session data. - Implemented a new method in BackupService to retrieve backups for a user's first pad using a join operation, addressing the N+1 query problem. - Enhanced pad router to create backups conditionally based on time intervals, improving backup efficiency and management. - Streamlined user router to utilize the new Redis client retrieval method, promoting better code organization and maintainability.
1 parent 4f91532 commit e69cb4e

File tree

6 files changed

+178
-37
lines changed

6 files changed

+178
-37
lines changed

src/backend/config.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import httpx
55
import redis
6+
from redis import ConnectionPool, Redis
67
import jwt
78
from jwt.jwks_client import PyJWKClient
89
from typing import Optional, Dict, Any, Tuple
@@ -16,6 +17,11 @@
1617
ASSETS_DIR = os.getenv("ASSETS_DIR")
1718
FRONTEND_URL = os.getenv('FRONTEND_URL')
1819

20+
MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user
21+
MIN_INTERVAL_MINUTES = 5 # Minimum interval in minutes between backups
22+
DEFAULT_PAD_NAME = "Untitled" # Default name for new pads
23+
DEFAULT_TEMPLATE_NAME = "default" # Template name to use when a user doesn't have a pad
24+
1925
# ===== PostHog Configuration =====
2026
POSTHOG_API_KEY = os.getenv("VITE_PUBLIC_POSTHOG_KEY")
2127
POSTHOG_HOST = os.getenv("VITE_PUBLIC_POSTHOG_HOST")
@@ -32,15 +38,26 @@
3238
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD', None)
3339
REDIS_PORT = int(os.getenv('REDIS_PORT', 6379))
3440

35-
# Redis connection
36-
redis_client = redis.Redis(
41+
# Create a Redis connection pool
42+
redis_pool = ConnectionPool(
3743
host=REDIS_HOST,
3844
password=REDIS_PASSWORD,
3945
port=REDIS_PORT,
4046
db=0,
41-
decode_responses=True
47+
decode_responses=True,
48+
max_connections=10, # Adjust based on your application's needs
49+
socket_timeout=5.0,
50+
socket_connect_timeout=1.0,
51+
health_check_interval=30
4252
)
4353

54+
# Create a Redis client that uses the connection pool
55+
redis_client = Redis(connection_pool=redis_pool)
56+
57+
def get_redis_client():
58+
"""Get a Redis client from the connection pool"""
59+
return Redis(connection_pool=redis_pool)
60+
4461
# ===== Coder API Configuration =====
4562
CODER_API_KEY = os.getenv("CODER_API_KEY")
4663
CODER_URL = os.getenv("CODER_URL")
@@ -54,22 +71,25 @@
5471
# Session management functions
5572
def get_session(session_id: str) -> Optional[Dict[str, Any]]:
5673
"""Get session data from Redis"""
57-
session_data = redis_client.get(f"session:{session_id}")
74+
client = get_redis_client()
75+
session_data = client.get(f"session:{session_id}")
5876
if session_data:
5977
return json.loads(session_data)
6078
return None
6179

6280
def set_session(session_id: str, data: Dict[str, Any], expiry: int) -> None:
6381
"""Store session data in Redis with expiry in seconds"""
64-
redis_client.setex(
82+
client = get_redis_client()
83+
client.setex(
6584
f"session:{session_id}",
6685
expiry,
6786
json.dumps(data)
6887
)
6988

7089
def delete_session(session_id: str) -> None:
7190
"""Delete session data from Redis"""
72-
redis_client.delete(f"session:{session_id}")
91+
client = get_redis_client()
92+
client.delete(f"session:{session_id}")
7393

7494
def get_auth_url() -> str:
7595
"""Generate the authentication URL for Keycloak login"""
@@ -161,4 +181,4 @@ def get_jwks_client():
161181
if _jwks_client is None:
162182
jwks_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/certs"
163183
_jwks_client = PyJWKClient(jwks_url)
164-
return _jwks_client
184+
return _jwks_client

src/backend/database/repository/backup_repository.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from sqlalchemy.ext.asyncio import AsyncSession
1010
from sqlalchemy.future import select
11-
from sqlalchemy import delete, func
11+
from sqlalchemy import delete, func, join
1212

13-
from ..models import BackupModel
13+
from ..models import BackupModel, PadModel
1414

1515
class BackupRepository:
1616
"""Repository for backup-related database operations"""
@@ -89,3 +89,28 @@ async def count_by_source(self, source_id: UUID) -> int:
8989
stmt = select(func.count()).select_from(BackupModel).where(BackupModel.source_id == source_id)
9090
result = await self.session.execute(stmt)
9191
return result.scalar()
92+
93+
async def get_backups_by_user(self, user_id: UUID, limit: int = 10) -> List[BackupModel]:
94+
"""
95+
Get backups for a user's first pad directly using a join operation.
96+
This eliminates the N+1 query problem by fetching the pad and its backups in a single query.
97+
98+
Args:
99+
user_id: The user ID to get backups for
100+
limit: Maximum number of backups to return
101+
102+
Returns:
103+
List of backup models
104+
"""
105+
# Create a join between PadModel and BackupModel
106+
stmt = select(BackupModel).join(
107+
PadModel,
108+
BackupModel.source_id == PadModel.id
109+
).where(
110+
PadModel.owner_id == user_id
111+
).order_by(
112+
BackupModel.created_at.desc()
113+
).limit(limit)
114+
115+
result = await self.session.execute(stmt)
116+
return result.scalars().all()

src/backend/database/service/backup_service.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sqlalchemy.ext.asyncio import AsyncSession
1010

11-
from ..repository import BackupRepository, PadRepository
11+
from ..repository import BackupRepository, PadRepository, UserRepository
1212

1313
class BackupService:
1414
"""Service for backup-related business logic"""
@@ -101,3 +101,76 @@ async def manage_backups(self, source_id: UUID, max_backups: int = 10) -> int:
101101
return await self.repository.delete_older_than(source_id, max_backups)
102102

103103
return 0 # No backups deleted
104+
105+
async def get_backups_by_user(self, user_id: UUID, limit: int = 10) -> List[Dict[str, Any]]:
106+
"""
107+
Get backups for a user's first pad directly using a join operation.
108+
This eliminates the N+1 query problem by fetching the pad and its backups in a single query.
109+
110+
Args:
111+
user_id: The user ID to get backups for
112+
limit: Maximum number of backups to return
113+
114+
Returns:
115+
List of backup dictionaries
116+
"""
117+
# Check if user exists
118+
user_repository = UserRepository(self.session)
119+
user = await user_repository.get_by_id(user_id)
120+
if not user:
121+
raise ValueError(f"User with ID '{user_id}' does not exist")
122+
123+
# Get backups directly with a single query
124+
backups = await self.repository.get_backups_by_user(user_id, limit)
125+
return [backup.to_dict() for backup in backups]
126+
127+
async def create_backup_if_needed(self, source_id: UUID, data: Dict[str, Any],
128+
min_interval_minutes: int = 5,
129+
max_backups: int = 10) -> Optional[Dict[str, Any]]:
130+
"""
131+
Create a backup only if needed:
132+
- If there are no existing backups
133+
- If the latest backup is older than the specified interval
134+
135+
Args:
136+
source_id: The ID of the source pad
137+
data: The data to backup
138+
min_interval_minutes: Minimum time between backups in minutes
139+
max_backups: Maximum number of backups to keep
140+
141+
Returns:
142+
The created backup dict if a backup was created, None otherwise
143+
"""
144+
# Check if source pad exists
145+
source_pad = await self.pad_repository.get_by_id(source_id)
146+
if not source_pad:
147+
raise ValueError(f"Pad with ID '{source_id}' does not exist")
148+
149+
# Get the latest backup
150+
latest_backup = await self.repository.get_latest_by_source(source_id)
151+
152+
# Calculate the current time
153+
current_time = datetime.now()
154+
155+
# Determine if we need to create a backup
156+
create_backup = False
157+
158+
if not latest_backup:
159+
# No backups exist yet, so create one
160+
create_backup = True
161+
else:
162+
# Check if the latest backup is older than the minimum interval
163+
backup_age = current_time - latest_backup.created_at
164+
if backup_age.total_seconds() > (min_interval_minutes * 60):
165+
create_backup = True
166+
167+
# Create a backup if needed
168+
if create_backup:
169+
backup = await self.repository.create(source_id, data)
170+
171+
# Manage backups (clean up old ones)
172+
await self.manage_backups(source_id, max_backups)
173+
174+
return backup.to_dict()
175+
176+
return None

src/backend/main.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from fastapi.staticfiles import StaticFiles
1111

1212
from database import init_db
13-
from config import STATIC_DIR, ASSETS_DIR, POSTHOG_API_KEY, POSTHOG_HOST
13+
from config import STATIC_DIR, ASSETS_DIR, POSTHOG_API_KEY, POSTHOG_HOST, redis_client, redis_pool
1414
from dependencies import UserSession, optional_auth
1515
from routers.auth_router import auth_router
1616
from routers.user_router import user_router
@@ -76,14 +76,29 @@ async def load_templates():
7676

7777
@asynccontextmanager
7878
async def lifespan(_: FastAPI):
79+
# Initialize database
7980
await init_db()
8081
print("Database connection established successfully")
8182

83+
# Check Redis connection
84+
try:
85+
redis_client.ping()
86+
print("Redis connection established successfully")
87+
except Exception as e:
88+
print(f"Warning: Redis connection failed: {str(e)}")
89+
8290
# Load all templates from the templates directory
8391
await load_templates()
8492
print("Templates loaded successfully")
8593

8694
yield
95+
96+
# Clean up connections when shutting down
97+
try:
98+
redis_pool.disconnect()
99+
print("Redis connections closed")
100+
except Exception as e:
101+
print(f"Error closing Redis connections: {str(e)}")
87102

88103
app = FastAPI(lifespan=lifespan)
89104

src/backend/routers/pad_router.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@
66
from dependencies import UserSession, require_auth
77
from database import get_pad_service, get_backup_service, get_template_pad_service
88
from database.service import PadService, BackupService, TemplatePadService
9-
10-
# Constants
11-
MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user
12-
DEFAULT_PAD_NAME = "Untitled" # Default name for new pads
13-
DEFAULT_TEMPLATE_NAME = "default" # Template name to use when a user doesn't have a pad
14-
9+
from config import MAX_BACKUPS_PER_USER, MIN_INTERVAL_MINUTES, DEFAULT_PAD_NAME, DEFAULT_TEMPLATE_NAME
1510
pad_router = APIRouter()
1611

1712

@@ -39,11 +34,13 @@ async def save_canvas(
3934
pad = user_pads[0] # Use the first pad (assuming one pad per user for now)
4035
await pad_service.update_pad_data(pad["id"], data)
4136

42-
# Create a backup
43-
await backup_service.create_backup(pad["id"], data)
44-
45-
# Manage backups (keep only the most recent ones)
46-
await backup_service.manage_backups(pad["id"], MAX_BACKUPS_PER_USER)
37+
# Create a backup only if needed (if none exist or latest is > 5 min old)
38+
await backup_service.create_backup_if_needed(
39+
source_id=pad["id"],
40+
data=data,
41+
min_interval_minutes=MIN_INTERVAL_MINUTES,
42+
max_backups=MAX_BACKUPS_PER_USER
43+
)
4744

4845
return {"status": "success"}
4946
except Exception as e:
@@ -54,7 +51,8 @@ async def save_canvas(
5451
async def get_canvas(
5552
user: UserSession = Depends(require_auth),
5653
pad_service: PadService = Depends(get_pad_service),
57-
template_pad_service: TemplatePadService = Depends(get_template_pad_service)
54+
template_pad_service: TemplatePadService = Depends(get_template_pad_service),
55+
backup_service: BackupService = Depends(get_backup_service)
5856
):
5957
"""Get canvas data for the authenticated user"""
6058
try:
@@ -63,7 +61,14 @@ async def get_canvas(
6361

6462
if not user_pads:
6563
# Return default canvas if user doesn't have a pad
66-
return await create_pad_from_template(DEFAULT_TEMPLATE_NAME, DEFAULT_PAD_NAME, user, pad_service, template_pad_service)
64+
return await create_pad_from_template(
65+
name=DEFAULT_TEMPLATE_NAME,
66+
display_name=DEFAULT_PAD_NAME,
67+
user=user,
68+
pad_service=pad_service,
69+
template_pad_service=template_pad_service,
70+
backup_service=backup_service
71+
)
6772

6873
# Return the first pad's data (assuming one pad per user for now)
6974
return user_pads[0]["data"]
@@ -77,7 +82,8 @@ async def create_pad_from_template(
7782
display_name: str = DEFAULT_PAD_NAME,
7883
user: UserSession = Depends(require_auth),
7984
pad_service: PadService = Depends(get_pad_service),
80-
template_pad_service: TemplatePadService = Depends(get_template_pad_service)
85+
template_pad_service: TemplatePadService = Depends(get_template_pad_service),
86+
backup_service: BackupService = Depends(get_backup_service)
8187
):
8288
"""Create a new pad from a template"""
8389

@@ -94,6 +100,14 @@ async def create_pad_from_template(
94100
data=template["data"]
95101
)
96102

103+
# Create an initial backup for the new pad
104+
await backup_service.create_backup_if_needed(
105+
source_id=pad["id"],
106+
data=template["data"],
107+
min_interval_minutes=0, # Always create initial backup
108+
max_backups=MAX_BACKUPS_PER_USER
109+
)
110+
97111
return pad
98112
except ValueError as e:
99113
raise HTTPException(status_code=400, detail=str(e))
@@ -105,7 +119,6 @@ async def create_pad_from_template(
105119
async def get_recent_canvas_backups(
106120
limit: int = MAX_BACKUPS_PER_USER,
107121
user: UserSession = Depends(require_auth),
108-
pad_service: PadService = Depends(get_pad_service),
109122
backup_service: BackupService = Depends(get_backup_service)
110123
):
111124
"""Get the most recent canvas backups for the authenticated user"""
@@ -114,18 +127,12 @@ async def get_recent_canvas_backups(
114127
limit = MAX_BACKUPS_PER_USER
115128

116129
try:
117-
# Get user's pads
118-
user_pads = await pad_service.get_pads_by_owner(user.id)
119-
120-
# Get the first pad's ID (assuming one pad per user for now)
121-
pad_id = UUID(user_pads[0]["id"])
122-
123-
# Get backups for the pad
124-
backups_data = await backup_service.get_backups_by_source(pad_id)
130+
# Get backups directly with a single query
131+
backups_data = await backup_service.get_backups_by_user(user.id, limit)
125132

126133
# Format backups to match the expected response format
127134
backups = []
128-
for backup in backups_data[:limit]:
135+
for backup in backups_data:
129136
backups.append({
130137
"id": backup["id"],
131138
"timestamp": backup["created_at"],

src/backend/routers/user_router.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import posthog
55
from fastapi import APIRouter, Depends, HTTPException
66

7-
from config import redis_client, FRONTEND_URL
7+
from config import get_redis_client, FRONTEND_URL
88
from database import get_user_service
99
from database.service import UserService
1010
from dependencies import UserSession, require_admin, require_auth
@@ -92,7 +92,8 @@ async def get_user_count(
9292
_: bool = Depends(require_admin),
9393
):
9494
"""Get the number of active sessions (admin only)"""
95-
session_count = len(redis_client.keys("session:*"))
95+
client = get_redis_client()
96+
session_count = len(client.keys("session:*"))
9697
return {"active_sessions": session_count }
9798

9899

0 commit comments

Comments
 (0)