|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | import os |
6 | | -import asyncio |
7 | | -import subprocess |
8 | | -import time |
9 | | -from typing import AsyncGenerator, Optional |
| 6 | +from typing import AsyncGenerator |
10 | 7 | from urllib.parse import quote_plus as urlquote |
11 | | -from pathlib import Path |
12 | 8 |
|
13 | 9 | from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession |
14 | 10 | from sqlalchemy.orm import sessionmaker |
15 | 11 | from sqlalchemy.schema import CreateSchema |
16 | 12 | from fastapi import Depends |
17 | | -from alembic.config import Config |
18 | 13 |
|
19 | 14 | from .models import Base, SCHEMA_NAME |
20 | 15 |
|
|
40 | 35 | engine, class_=AsyncSession, expire_on_commit=False |
41 | 36 | ) |
42 | 37 |
|
43 | | -async def run_migrations_with_lock(redis_client=None, lock_timeout: int = 120, max_wait_time: int = 300) -> bool: |
44 | | - """ |
45 | | - Run database migrations using Alembic with a Redis distributed lock. |
46 | | - All workers will wait for the migration to complete before proceeding. |
47 | | - |
48 | | - Args: |
49 | | - redis_client: Redis client instance |
50 | | - lock_timeout: How long the lock should be held (seconds) |
51 | | - max_wait_time: Maximum time to wait for migrations to complete (seconds) |
52 | | - |
53 | | - Returns: |
54 | | - bool: True if migrations were run successfully or completed by another instance, False on timeout or error |
55 | | - """ |
56 | | - if redis_client is None: |
57 | | - # Import here to avoid circular imports |
58 | | - from config import get_redis_client |
59 | | - redis_client = get_redis_client() |
60 | | - |
61 | | - # Keys for Redis coordination |
62 | | - lock_name = "alembic_migration_lock" |
63 | | - status_key = "alembic_migration_status" |
64 | | - lock_value = f"instance_{time.time()}" |
65 | | - |
66 | | - # Check if migrations are already completed |
67 | | - migration_status = redis_client.get(status_key) |
68 | | - if migration_status == "completed": |
69 | | - print("Migrations already completed - continuing startup") |
70 | | - return True |
71 | | - |
72 | | - # Try to acquire the lock - non-blocking |
73 | | - lock_acquired = redis_client.set( |
74 | | - lock_name, |
75 | | - lock_value, |
76 | | - nx=True, # Only set if key doesn't exist |
77 | | - ex=lock_timeout # Expiry in seconds |
78 | | - ) |
79 | | - |
80 | | - if lock_acquired: |
81 | | - print("This instance will run migrations") |
82 | | - try: |
83 | | - # Set status to in-progress |
84 | | - redis_client.set(status_key, "in_progress", ex=lock_timeout) |
85 | | - |
86 | | - # Run migrations |
87 | | - success = await run_migrations_subprocess() |
88 | | - |
89 | | - if success: |
90 | | - # Set status to completed with a longer expiry (1 hour) |
91 | | - redis_client.set(status_key, "completed", ex=3600) |
92 | | - print("Migration completed successfully - signaling other instances") |
93 | | - return True |
94 | | - else: |
95 | | - # Set status to failed |
96 | | - redis_client.set(status_key, "failed", ex=3600) |
97 | | - print("Migration failed - signaling other instances") |
98 | | - return False |
99 | | - finally: |
100 | | - # Release the lock only if we're the owner |
101 | | - current_value = redis_client.get(lock_name) |
102 | | - if current_value == lock_value: |
103 | | - redis_client.delete(lock_name) |
104 | | - else: |
105 | | - print("Another instance is running migrations - waiting for completion") |
106 | | - |
107 | | - # Wait for the migration to complete |
108 | | - start_time = time.time() |
109 | | - while time.time() - start_time < max_wait_time: |
110 | | - # Check migration status |
111 | | - status = redis_client.get(status_key) |
112 | | - |
113 | | - if status == "completed": |
114 | | - print("Migrations completed by another instance - continuing startup") |
115 | | - return True |
116 | | - elif status == "failed": |
117 | | - print("Migrations failed in another instance - continuing startup with caution") |
118 | | - return False |
119 | | - elif status is None: |
120 | | - # No status yet, might be a stale lock or not started |
121 | | - # Check if lock exists |
122 | | - if not redis_client.exists(lock_name): |
123 | | - # Lock released but no status - try to acquire the lock ourselves |
124 | | - print("No active migration lock - attempting to acquire") |
125 | | - return await run_migrations_with_lock(redis_client, lock_timeout, max_wait_time) |
126 | | - |
127 | | - # Wait before checking again |
128 | | - await asyncio.sleep(1) |
129 | | - |
130 | | - # Timeout waiting for migration |
131 | | - print(f"Timeout waiting for migrations after {max_wait_time} seconds") |
132 | | - return False |
133 | | - |
134 | | -async def run_migrations_subprocess() -> bool: |
135 | | - """ |
136 | | - Run Alembic migrations using a subprocess |
137 | | - |
138 | | - Returns: |
139 | | - bool: True if migrations were successful, False otherwise |
140 | | - """ |
141 | | - try: |
142 | | - # Get the path to the database directory |
143 | | - db_dir = Path(__file__).parent |
144 | | - |
145 | | - # Create a subprocess to run alembic |
146 | | - process = await asyncio.create_subprocess_exec( |
147 | | - 'alembic', 'upgrade', 'head', |
148 | | - stdout=asyncio.subprocess.PIPE, |
149 | | - stderr=asyncio.subprocess.PIPE, |
150 | | - cwd=str(db_dir) # Run in the database directory |
151 | | - ) |
152 | | - |
153 | | - # Wait for the process to complete with a timeout |
154 | | - try: |
155 | | - stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=60) |
156 | | - |
157 | | - if process.returncode == 0: |
158 | | - print("Database migrations completed successfully") |
159 | | - if stdout: |
160 | | - print(stdout.decode()) |
161 | | - return True |
162 | | - else: |
163 | | - print(f"Migration failed with error code {process.returncode}") |
164 | | - if stderr: |
165 | | - print(stderr.decode()) |
166 | | - return False |
167 | | - |
168 | | - except asyncio.TimeoutError: |
169 | | - print("Migration timed out after 60 seconds") |
170 | | - # Try to terminate the process |
171 | | - process.terminate() |
172 | | - return False |
173 | | - |
174 | | - except Exception as e: |
175 | | - print(f"Migration failed: {str(e)}") |
176 | | - return False |
177 | | - |
178 | | -async def run_migrations() -> None: |
179 | | - """ |
180 | | - Legacy function to run migrations directly (without lock) |
181 | | - This is kept for backward compatibility |
182 | | - """ |
183 | | - await run_migrations_subprocess() |
184 | 38 |
|
185 | 39 | async def init_db() -> None: |
186 | 40 | """Initialize the database with required tables""" |
187 | | - # Only create tables, let Alembic handle schema creation |
188 | | - async with engine.begin() as conn: |
189 | | - await conn.run_sync(Base.metadata.create_all) |
| 41 | + try: |
| 42 | + # Create schema if it doesn't exist |
| 43 | + async with engine.begin() as conn: |
| 44 | + await conn.execute(CreateSchema(SCHEMA_NAME, if_not_exists=True)) |
| 45 | + |
| 46 | + # Create tables |
| 47 | + async with engine.begin() as conn: |
| 48 | + await conn.run_sync(Base.metadata.create_all) |
| 49 | + |
| 50 | + except Exception as e: |
| 51 | + print(f"Error initializing database: {str(e)}") |
| 52 | + raise |
| 53 | + |
190 | 54 |
|
191 | 55 | async def get_session() -> AsyncGenerator[AsyncSession, None]: |
192 | 56 | """Get a database session""" |
|
0 commit comments