|
4 | 4 |
|
5 | 5 | import os |
6 | 6 | import asyncio |
7 | | -from typing import AsyncGenerator |
| 7 | +import subprocess |
| 8 | +import time |
| 9 | +from typing import AsyncGenerator, Optional |
8 | 10 | from urllib.parse import quote_plus as urlquote |
9 | 11 | from pathlib import Path |
10 | 12 |
|
|
13 | 15 | from sqlalchemy.schema import CreateSchema |
14 | 16 | from fastapi import Depends |
15 | 17 | from alembic.config import Config |
16 | | -from alembic import command |
17 | 18 |
|
18 | 19 | from .models import Base, SCHEMA_NAME |
19 | 20 |
|
|
39 | 40 | engine, class_=AsyncSession, expire_on_commit=False |
40 | 41 | ) |
41 | 42 |
|
42 | | -async def run_migrations() -> None: |
43 | | - """Run database migrations using Alembic""" |
44 | | - # Get the path to the alembic.ini file |
45 | | - alembic_ini_path = Path(__file__).parent / "alembic.ini" |
| 43 | +async def run_migrations_with_lock(redis_client=None, lock_timeout: int = 120, max_wait_time: int = 300) -> bool: |
| 44 | + """ |
| 45 | + Run database migrations using Alembic with a Redis distributed lock. |
| 46 | + All workers will wait for the migration to complete before proceeding. |
| 47 | + |
| 48 | + Args: |
| 49 | + redis_client: Redis client instance |
| 50 | + lock_timeout: How long the lock should be held (seconds) |
| 51 | + max_wait_time: Maximum time to wait for migrations to complete (seconds) |
| 52 | + |
| 53 | + Returns: |
| 54 | + bool: True if migrations were run successfully or completed by another instance, False on timeout or error |
| 55 | + """ |
| 56 | + if redis_client is None: |
| 57 | + # Import here to avoid circular imports |
| 58 | + from config import get_redis_client |
| 59 | + redis_client = get_redis_client() |
| 60 | + |
| 61 | + # Keys for Redis coordination |
| 62 | + lock_name = "alembic_migration_lock" |
| 63 | + status_key = "alembic_migration_status" |
| 64 | + lock_value = f"instance_{time.time()}" |
46 | 65 |
|
47 | | - # Create Alembic configuration |
48 | | - alembic_cfg = Config(str(alembic_ini_path)) |
| 66 | + # Check if migrations are already completed |
| 67 | + migration_status = redis_client.get(status_key) |
| 68 | + if migration_status == "completed": |
| 69 | + print("Migrations already completed - continuing startup") |
| 70 | + return True |
49 | 71 |
|
50 | | - # Set the script_location to the correct path |
51 | | - # This ensures Alembic finds the migrations directory |
52 | | - alembic_cfg.set_main_option('script_location', str(Path(__file__).parent / "migrations")) |
| 72 | + # Try to acquire the lock - non-blocking |
| 73 | + lock_acquired = redis_client.set( |
| 74 | + lock_name, |
| 75 | + lock_value, |
| 76 | + nx=True, # Only set if key doesn't exist |
| 77 | + ex=lock_timeout # Expiry in seconds |
| 78 | + ) |
53 | 79 |
|
54 | | - # Define a function to run in a separate thread |
55 | | - def run_upgrade(): |
56 | | - # Import the command module here to avoid import issues |
57 | | - from alembic import command |
| 80 | + if lock_acquired: |
| 81 | + print("This instance will run migrations") |
| 82 | + try: |
| 83 | + # Set status to in-progress |
| 84 | + redis_client.set(status_key, "in_progress", ex=lock_timeout) |
| 85 | + |
| 86 | + # Run migrations |
| 87 | + success = await run_migrations_subprocess() |
| 88 | + |
| 89 | + if success: |
| 90 | + # Set status to completed with a longer expiry (1 hour) |
| 91 | + redis_client.set(status_key, "completed", ex=3600) |
| 92 | + print("Migration completed successfully - signaling other instances") |
| 93 | + return True |
| 94 | + else: |
| 95 | + # Set status to failed |
| 96 | + redis_client.set(status_key, "failed", ex=3600) |
| 97 | + print("Migration failed - signaling other instances") |
| 98 | + return False |
| 99 | + finally: |
| 100 | + # Release the lock only if we're the owner |
| 101 | + current_value = redis_client.get(lock_name) |
| 102 | + if current_value == lock_value: |
| 103 | + redis_client.delete(lock_name) |
| 104 | + else: |
| 105 | + print("Another instance is running migrations - waiting for completion") |
58 | 106 |
|
59 | | - # Set attributes that env.py might need |
60 | | - import sys |
61 | | - from pathlib import Path |
| 107 | + # Wait for the migration to complete |
| 108 | + start_time = time.time() |
| 109 | + while time.time() - start_time < max_wait_time: |
| 110 | + # Check migration status |
| 111 | + status = redis_client.get(status_key) |
| 112 | + |
| 113 | + if status == "completed": |
| 114 | + print("Migrations completed by another instance - continuing startup") |
| 115 | + return True |
| 116 | + elif status == "failed": |
| 117 | + print("Migrations failed in another instance - continuing startup with caution") |
| 118 | + return False |
| 119 | + elif status is None: |
| 120 | + # No status yet, might be a stale lock or not started |
| 121 | + # Check if lock exists |
| 122 | + if not redis_client.exists(lock_name): |
| 123 | + # Lock released but no status - try to acquire the lock ourselves |
| 124 | + print("No active migration lock - attempting to acquire") |
| 125 | + return await run_migrations_with_lock(redis_client, lock_timeout, max_wait_time) |
| 126 | + |
| 127 | + # Wait before checking again |
| 128 | + await asyncio.sleep(1) |
62 | 129 |
|
63 | | - # Add the database directory to sys.path |
| 130 | + # Timeout waiting for migration |
| 131 | + print(f"Timeout waiting for migrations after {max_wait_time} seconds") |
| 132 | + return False |
| 133 | + |
| 134 | +async def run_migrations_subprocess() -> bool: |
| 135 | + """ |
| 136 | + Run Alembic migrations using a subprocess |
| 137 | + |
| 138 | + Returns: |
| 139 | + bool: True if migrations were successful, False otherwise |
| 140 | + """ |
| 141 | + try: |
| 142 | + # Get the path to the database directory |
64 | 143 | db_dir = Path(__file__).parent |
65 | | - if str(db_dir) not in sys.path: |
66 | | - sys.path.insert(0, str(db_dir)) |
67 | 144 |
|
68 | | - # Run the upgrade command |
69 | | - command.upgrade(alembic_cfg, "head") |
70 | | - |
71 | | - # Run the migrations in a separate thread to avoid blocking the event loop |
72 | | - await asyncio.to_thread(run_upgrade) |
| 145 | + # Create a subprocess to run alembic |
| 146 | + process = await asyncio.create_subprocess_exec( |
| 147 | + 'alembic', 'upgrade', 'head', |
| 148 | + stdout=asyncio.subprocess.PIPE, |
| 149 | + stderr=asyncio.subprocess.PIPE, |
| 150 | + cwd=str(db_dir) # Run in the database directory |
| 151 | + ) |
| 152 | + |
| 153 | + # Wait for the process to complete with a timeout |
| 154 | + try: |
| 155 | + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=60) |
| 156 | + |
| 157 | + if process.returncode == 0: |
| 158 | + print("Database migrations completed successfully") |
| 159 | + if stdout: |
| 160 | + print(stdout.decode()) |
| 161 | + return True |
| 162 | + else: |
| 163 | + print(f"Migration failed with error code {process.returncode}") |
| 164 | + if stderr: |
| 165 | + print(stderr.decode()) |
| 166 | + return False |
| 167 | + |
| 168 | + except asyncio.TimeoutError: |
| 169 | + print("Migration timed out after 60 seconds") |
| 170 | + # Try to terminate the process |
| 171 | + process.terminate() |
| 172 | + return False |
| 173 | + |
| 174 | + except Exception as e: |
| 175 | + print(f"Migration failed: {str(e)}") |
| 176 | + return False |
| 177 | + |
| 178 | +async def run_migrations() -> None: |
| 179 | + """ |
| 180 | + Legacy function to run migrations directly (without lock) |
| 181 | + This is kept for backward compatibility |
| 182 | + """ |
| 183 | + await run_migrations_subprocess() |
73 | 184 |
|
74 | 185 | async def init_db() -> None: |
75 | 186 | """Initialize the database with required tables""" |
|
0 commit comments