diff --git a/src/backend/coder.py b/src/backend/coder.py index 0927ddd..83b31ca 100644 --- a/src/backend/coder.py +++ b/src/backend/coder.py @@ -1,26 +1,22 @@ import os import requests -from dotenv import load_dotenv +from config import CODER_API_KEY, CODER_URL, CODER_TEMPLATE_ID, CODER_DEFAULT_ORGANIZATION, CODER_WORKSPACE_NAME class CoderAPI: """ - A class for interacting with the Coder API using credentials from .env file + A class for interacting with the Coder API using credentials from config """ def __init__(self): - # Load environment variables from .env file - load_dotenv() + # Get configuration from config + self.api_key = CODER_API_KEY + self.coder_url = CODER_URL + self.template_id = CODER_TEMPLATE_ID + self.default_organization_id = CODER_DEFAULT_ORGANIZATION - # Get configuration from environment variables - self.api_key = os.getenv("CODER_API_KEY") - self.coder_url = os.getenv("CODER_URL") - self.user_id = os.getenv("USER_ID") - self.template_id = os.getenv("CODER_TEMPLATE_ID") - self.default_organization_id = os.getenv("CODER_DEFAULT_ORGANIZATION") - - # Check if required environment variables are set + # Check if required configuration variables are set if not self.api_key or not self.coder_url: - raise ValueError("CODER_API_KEY and CODER_URL must be set in .env file") + raise ValueError("CODER_API_KEY and CODER_URL must be set in environment variables") # Set up common headers for API requests self.headers = { @@ -56,9 +52,9 @@ def create_workspace(self, user_id, parameter_values=None): template_id = self.template_id if not template_id: - raise ValueError("template_id must be provided or TEMPLATE_ID must be set in .env") + raise ValueError("template_id must be provided or TEMPLATE_ID must be set in environment variables") - name = os.getenv("CODER_WORKSPACE_NAME", "ubuntu") + name = CODER_WORKSPACE_NAME # Prepare the request data data = { @@ -201,7 +197,7 @@ def get_workspace_status_for_user(self, username): Returns: dict: Workspace status data if found, None otherwise """ - workspace_name = os.getenv("CODER_WORKSPACE_NAME", "ubuntu") + workspace_name = CODER_WORKSPACE_NAME endpoint = f"{self.coder_url}/api/v2/users/{username}/workspace/{workspace_name}" response = requests.get(endpoint, headers=self.headers) @@ -282,4 +278,4 @@ def stop_workspace(self, workspace_id): headers['Content-Type'] = 'application/json' response = requests.post(endpoint, headers=headers, json=data) response.raise_for_status() - return response.json() \ No newline at end of file + return response.json() diff --git a/src/backend/config.py b/src/backend/config.py index b83ee9a..bce429e 100644 --- a/src/backend/config.py +++ b/src/backend/config.py @@ -3,44 +3,84 @@ import time import httpx import redis +from redis import ConnectionPool, Redis import jwt +from jwt.jwks_client import PyJWKClient from typing import Optional, Dict, Any, Tuple from dotenv import load_dotenv +# Load environment variables once load_dotenv() +# ===== Application Configuration ===== STATIC_DIR = os.getenv("STATIC_DIR") ASSETS_DIR = os.getenv("ASSETS_DIR") +FRONTEND_URL = os.getenv('FRONTEND_URL') -OIDC_CONFIG = { - 'client_id': os.getenv('OIDC_CLIENT_ID'), - 'client_secret': os.getenv('OIDC_CLIENT_SECRET'), - 'server_url': os.getenv('OIDC_SERVER_URL'), - 'realm': os.getenv('OIDC_REALM'), - 'redirect_uri': os.getenv('REDIRECT_URI'), - 'frontend_url': os.getenv('FRONTEND_URL') -} - -# Redis connection -redis_client = redis.Redis( - host=os.getenv('REDIS_HOST', 'localhost'), - password=os.getenv('REDIS_PASSWORD', None), - port=int(os.getenv('REDIS_PORT', 6379)), +MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user +MIN_INTERVAL_MINUTES = 5 # Minimum interval in minutes between backups +DEFAULT_PAD_NAME = "Untitled" # Default name for new pads +DEFAULT_TEMPLATE_NAME = "default" # Template name to use when a user doesn't have a pad + +# ===== PostHog Configuration ===== +POSTHOG_API_KEY = os.getenv("VITE_PUBLIC_POSTHOG_KEY") +POSTHOG_HOST = os.getenv("VITE_PUBLIC_POSTHOG_HOST") + +# ===== OIDC Configuration ===== +OIDC_CLIENT_ID = os.getenv('OIDC_CLIENT_ID') +OIDC_CLIENT_SECRET = os.getenv('OIDC_CLIENT_SECRET') +OIDC_SERVER_URL = os.getenv('OIDC_SERVER_URL') +OIDC_REALM = os.getenv('OIDC_REALM') +OIDC_REDIRECT_URI = os.getenv('REDIRECT_URI') + +# ===== Redis Configuration ===== +REDIS_HOST = os.getenv('REDIS_HOST', 'localhost') +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD', None) +REDIS_PORT = int(os.getenv('REDIS_PORT', 6379)) + +# Create a Redis connection pool +redis_pool = ConnectionPool( + host=REDIS_HOST, + password=REDIS_PASSWORD, + port=REDIS_PORT, db=0, - decode_responses=True + decode_responses=True, + max_connections=10, # Adjust based on your application's needs + socket_timeout=5.0, + socket_connect_timeout=1.0, + health_check_interval=30 ) +# Create a Redis client that uses the connection pool +redis_client = Redis(connection_pool=redis_pool) + +def get_redis_client(): + """Get a Redis client from the connection pool""" + return Redis(connection_pool=redis_pool) + +# ===== Coder API Configuration ===== +CODER_API_KEY = os.getenv("CODER_API_KEY") +CODER_URL = os.getenv("CODER_URL") +CODER_TEMPLATE_ID = os.getenv("CODER_TEMPLATE_ID") +CODER_DEFAULT_ORGANIZATION = os.getenv("CODER_DEFAULT_ORGANIZATION") +CODER_WORKSPACE_NAME = os.getenv("CODER_WORKSPACE_NAME", "ubuntu") + +# Cache for JWKS client +_jwks_client = None + # Session management functions def get_session(session_id: str) -> Optional[Dict[str, Any]]: """Get session data from Redis""" - session_data = redis_client.get(f"session:{session_id}") + client = get_redis_client() + session_data = client.get(f"session:{session_id}") if session_data: return json.loads(session_data) return None def set_session(session_id: str, data: Dict[str, Any], expiry: int) -> None: """Store session data in Redis with expiry in seconds""" - redis_client.setex( + client = get_redis_client() + client.setex( f"session:{session_id}", expiry, json.dumps(data) @@ -48,49 +88,47 @@ def set_session(session_id: str, data: Dict[str, Any], expiry: int) -> None: def delete_session(session_id: str) -> None: """Delete session data from Redis""" - redis_client.delete(f"session:{session_id}") - -provisioning_times = {} + client = get_redis_client() + client.delete(f"session:{session_id}") def get_auth_url() -> str: """Generate the authentication URL for Keycloak login""" - auth_url = f"{OIDC_CONFIG['server_url']}/realms/{OIDC_CONFIG['realm']}/protocol/openid-connect/auth" + auth_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/auth" params = { - 'client_id': OIDC_CONFIG['client_id'], + 'client_id': OIDC_CLIENT_ID, 'response_type': 'code', - 'redirect_uri': OIDC_CONFIG['redirect_uri'], + 'redirect_uri': OIDC_REDIRECT_URI, 'scope': 'openid profile email' } return f"{auth_url}?{'&'.join(f'{k}={v}' for k,v in params.items())}" def get_token_url() -> str: """Get the token endpoint URL""" - return f"{OIDC_CONFIG['server_url']}/realms/{OIDC_CONFIG['realm']}/protocol/openid-connect/token" + return f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/token" def is_token_expired(token_data: Dict[str, Any], buffer_seconds: int = 30) -> bool: - """ - Check if the access token is expired or about to expire - - Args: - token_data: The token data containing the access token - buffer_seconds: Buffer time in seconds to refresh token before it actually expires - - Returns: - bool: True if token is expired or about to expire, False otherwise - """ if not token_data or 'access_token' not in token_data: return True try: - # Decode the JWT token without verification to get expiration time - decoded = jwt.decode(token_data['access_token'], options={"verify_signature": False}) + # Get the signing key + jwks_client = get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(token_data['access_token']) - # Get expiration time from token - exp_time = decoded.get('exp', 0) + # Decode with verification + decoded = jwt.decode( + token_data['access_token'], + signing_key.key, + algorithms=["RS256"], # Common algorithm for OIDC + audience=OIDC_CLIENT_ID, + ) - # Check if token is expired or about to expire (with buffer) + # Check expiration + exp_time = decoded.get('exp', 0) current_time = time.time() return current_time + buffer_seconds >= exp_time + except jwt.ExpiredSignatureError: + return True except Exception as e: print(f"Error checking token expiration: {str(e)}") return True @@ -115,8 +153,8 @@ async def refresh_token(session_id: str, token_data: Dict[str, Any]) -> Tuple[bo get_token_url(), data={ 'grant_type': 'refresh_token', - 'client_id': OIDC_CONFIG['client_id'], - 'client_secret': OIDC_CONFIG['client_secret'], + 'client_id': OIDC_CLIENT_ID, + 'client_secret': OIDC_CLIENT_SECRET, 'refresh_token': token_data['refresh_token'] } ) @@ -136,3 +174,11 @@ async def refresh_token(session_id: str, token_data: Dict[str, Any]) -> Tuple[bo except Exception as e: print(f"Error refreshing token: {str(e)}") return False, token_data + +def get_jwks_client(): + """Get or create a PyJWKClient for token verification""" + global _jwks_client + if _jwks_client is None: + jwks_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/certs" + _jwks_client = PyJWKClient(jwks_url) + return _jwks_client diff --git a/src/backend/database/__init__.py b/src/backend/database/__init__.py new file mode 100644 index 0000000..6c5c0e7 --- /dev/null +++ b/src/backend/database/__init__.py @@ -0,0 +1,31 @@ +""" +Database module for the application. + +This module provides access to all database components used in the application. +""" + +from .database import ( + init_db, + get_session, + get_user_repository, + get_pad_repository, + get_backup_repository, + get_template_pad_repository, + get_user_service, + get_pad_service, + get_backup_service, + get_template_pad_service +) + +__all__ = [ + 'init_db', + 'get_session', + 'get_user_repository', + 'get_pad_repository', + 'get_backup_repository', + 'get_template_pad_repository', + 'get_user_service', + 'get_pad_service', + 'get_backup_service', + 'get_template_pad_service', +] diff --git a/src/backend/database/alembic.ini b/src/backend/database/alembic.ini new file mode 100644 index 0000000..f36d6dd --- /dev/null +++ b/src/backend/database/alembic.ini @@ -0,0 +1,123 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = migrations + +# Use a more descriptive file template that includes date and time +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# The SQLAlchemy connection URL is set in env.py from environment variables +sqlalchemy.url = postgresql://postgres:postgres@localhost/pad + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/backend/database/database.py b/src/backend/database/database.py new file mode 100644 index 0000000..40c6d13 --- /dev/null +++ b/src/backend/database/database.py @@ -0,0 +1,129 @@ +""" +Database connection and session management. +""" + +import os +import asyncio +from typing import AsyncGenerator +from urllib.parse import quote_plus as urlquote +from pathlib import Path + +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from sqlalchemy.schema import CreateSchema +from fastapi import Depends +from alembic.config import Config +from alembic import command + +from .models import Base, SCHEMA_NAME + +from dotenv import load_dotenv + +load_dotenv() + +# PostgreSQL connection configuration +DB_USER = os.getenv('POSTGRES_USER', 'postgres') +DB_PASSWORD = os.getenv('POSTGRES_PASSWORD', 'postgres') +DB_NAME = os.getenv('POSTGRES_DB', 'pad') +DB_HOST = os.getenv('POSTGRES_HOST', 'localhost') +DB_PORT = os.getenv('POSTGRES_PORT', '5432') + +# SQLAlchemy async database URL +DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{urlquote(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}" + +# Create async engine +engine = create_async_engine(DATABASE_URL, echo=False) + +# Create async session factory +async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False +) + +async def run_migrations() -> None: + """Run database migrations using Alembic""" + # Get the path to the alembic.ini file + alembic_ini_path = Path(__file__).parent / "alembic.ini" + + # Create Alembic configuration + alembic_cfg = Config(str(alembic_ini_path)) + + # Set the script_location to the correct path + # This ensures Alembic finds the migrations directory + alembic_cfg.set_main_option('script_location', str(Path(__file__).parent / "migrations")) + + # Define a function to run in a separate thread + def run_upgrade(): + # Import the command module here to avoid import issues + from alembic import command + + # Set attributes that env.py might need + import sys + from pathlib import Path + + # Add the database directory to sys.path + db_dir = Path(__file__).parent + if str(db_dir) not in sys.path: + sys.path.insert(0, str(db_dir)) + + # Run the upgrade command + command.upgrade(alembic_cfg, "head") + + # Run the migrations in a separate thread to avoid blocking the event loop + await asyncio.to_thread(run_upgrade) + +async def init_db() -> None: + """Initialize the database with required tables""" + # Create schema and tables + async with engine.begin() as conn: + await conn.execute(CreateSchema(SCHEMA_NAME, if_not_exists=True)) + await conn.run_sync(Base.metadata.create_all) + +async def get_session() -> AsyncGenerator[AsyncSession, None]: + """Get a database session""" + async with async_session() as session: + try: + yield session + finally: + await session.close() + +# Dependency for getting repositories +async def get_user_repository(session: AsyncSession = Depends(get_session)): + """Get a user repository""" + from .repository import UserRepository + return UserRepository(session) + +async def get_pad_repository(session: AsyncSession = Depends(get_session)): + """Get a pad repository""" + from .repository import PadRepository + return PadRepository(session) + +async def get_backup_repository(session: AsyncSession = Depends(get_session)): + """Get a backup repository""" + from .repository import BackupRepository + return BackupRepository(session) + +async def get_template_pad_repository(session: AsyncSession = Depends(get_session)): + """Get a template pad repository""" + from .repository import TemplatePadRepository + return TemplatePadRepository(session) + +# Dependency for getting services +async def get_user_service(session: AsyncSession = Depends(get_session)): + """Get a user service""" + from .service import UserService + return UserService(session) + +async def get_pad_service(session: AsyncSession = Depends(get_session)): + """Get a pad service""" + from .service import PadService + return PadService(session) + +async def get_backup_service(session: AsyncSession = Depends(get_session)): + """Get a backup service""" + from .service import BackupService + return BackupService(session) + +async def get_template_pad_service(session: AsyncSession = Depends(get_session)): + """Get a template pad service""" + from .service import TemplatePadService + return TemplatePadService(session) diff --git a/src/backend/database/migrations/env.py b/src/backend/database/migrations/env.py new file mode 100644 index 0000000..d753507 --- /dev/null +++ b/src/backend/database/migrations/env.py @@ -0,0 +1,106 @@ +from logging.config import fileConfig +import os +import sys +from pathlib import Path + +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from sqlalchemy.engine import URL + +from alembic import context +from dotenv import load_dotenv + +# Add the parent directory to sys.path +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +# Load environment variables from .env file +load_dotenv() + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Import the Base metadata from the models +# We need to handle imports differently to avoid module not found errors +import importlib.util +import os + +# Get the absolute path to the models module +models_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models", "__init__.py") + +# Load the module dynamically +spec = importlib.util.spec_from_file_location("models", models_path) +models = importlib.util.module_from_spec(spec) +spec.loader.exec_module(models) + +# Get Base and SCHEMA_NAME from the loaded module +Base = models.Base +SCHEMA_NAME = models.SCHEMA_NAME +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +# Get database connection details from environment variables +DB_USER = os.getenv('POSTGRES_USER', 'postgres') +DB_PASSWORD = os.getenv('POSTGRES_PASSWORD', 'postgres') +DB_NAME = os.getenv('POSTGRES_DB', 'pad') +DB_HOST = os.getenv('POSTGRES_HOST', 'localhost') +DB_PORT = os.getenv('POSTGRES_PORT', '5432') + +# Override sqlalchemy.url in alembic.ini +config.set_main_option('sqlalchemy.url', f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}") + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode.""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + include_schemas=True, + version_table_schema=SCHEMA_NAME, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + include_schemas=True, + version_table_schema=SCHEMA_NAME + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/backend/database/migrations/script.py.mako b/src/backend/database/migrations/script.py.mako new file mode 100644 index 0000000..480b130 --- /dev/null +++ b/src/backend/database/migrations/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/src/backend/database/migrations/versions/2025_05_02_2055-migrate_canvas_data.py b/src/backend/database/migrations/versions/2025_05_02_2055-migrate_canvas_data.py new file mode 100644 index 0000000..f6e9568 --- /dev/null +++ b/src/backend/database/migrations/versions/2025_05_02_2055-migrate_canvas_data.py @@ -0,0 +1,192 @@ +"""Migrate canvas_data and canvas_backups to new schema + +Revision ID: migrate_canvas_data +Revises: +Create Date: 2025-05-02 20:55:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import Session +import uuid +from datetime import datetime + +# revision identifiers, used by Alembic. +revision = 'migrate_canvas_data' +down_revision = None +branch_labels = None +depends_on = None + +# Import the schema name from the models using dynamic import +import importlib.util +import os + +# Get the absolute path to the base_model module +base_model_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models", "base_model.py") + +# Load the module dynamically +spec = importlib.util.spec_from_file_location("base_model", base_model_path) +base_model = importlib.util.module_from_spec(spec) +spec.loader.exec_module(base_model) + +# Get SCHEMA_NAME from the loaded module +SCHEMA_NAME = base_model.SCHEMA_NAME + +def upgrade() -> None: + """Migrate data from old tables to new schema""" + + # Create a connection to execute raw SQL + connection = op.get_bind() + + # Define tables for direct SQL operations + metadata = sa.MetaData() + + # Define the old tables in the public schema + canvas_data = sa.Table( + 'canvas_data', + metadata, + sa.Column('user_id', UUID(as_uuid=True), primary_key=True), + sa.Column('data', JSONB), + schema='public' + ) + + canvas_backups = sa.Table( + 'canvas_backups', + metadata, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('user_id', UUID(as_uuid=True)), + sa.Column('canvas_data', JSONB), + sa.Column('timestamp', sa.DateTime), + schema='public' + ) + + # Define the new tables in the pad_ws schema with all required columns + users = sa.Table( + 'users', + metadata, + sa.Column('id', UUID(as_uuid=True), primary_key=True), + sa.Column('username', sa.String(254)), + sa.Column('email', sa.String(254)), + sa.Column('email_verified', sa.Boolean), + sa.Column('name', sa.String(254)), + sa.Column('given_name', sa.String(254)), + sa.Column('family_name', sa.String(254)), + sa.Column('roles', JSONB), + schema=SCHEMA_NAME + ) + + pads = sa.Table( + 'pads', + metadata, + sa.Column('id', UUID(as_uuid=True), primary_key=True), + sa.Column('owner_id', UUID(as_uuid=True)), + sa.Column('display_name', sa.String(100)), + sa.Column('data', JSONB), + schema=SCHEMA_NAME + ) + + backups = sa.Table( + 'backups', + metadata, + sa.Column('id', UUID(as_uuid=True), primary_key=True), + sa.Column('source_id', UUID(as_uuid=True)), + sa.Column('data', JSONB), + sa.Column('created_at', sa.DateTime), + schema=SCHEMA_NAME + ) + + # Create a session for ORM operations + session = Session(connection) + + try: + # Step 1: Get all canvas_data records + canvas_data_records = session.execute(sa.select(canvas_data)).fetchall() + + # Dictionary to store user_id -> pad_id mapping for later use with backups + user_pad_mapping = {} + + # Step 2: For each canvas_data record, create a new pad + for record in canvas_data_records: + user_id = record.user_id + + # Check if the user exists in the new schema + user_exists = session.execute( + sa.select(users).where(users.c.id == user_id) + ).fetchone() + + if not user_exists: + print(f"User {user_id} not found in new schema, creating with placeholder data") + # Create a new user with placeholder data + # The real data will be updated when the user accesses the /me route + session.execute( + users.insert().values( + id=user_id, + username=f"migrated_user_{user_id}", + email=f"migrated_{user_id}@example.com", + email_verified=False, + name="Migrated User", + given_name="Migrated", + family_name="User", + roles=[], + ) + ) + + # Generate a new UUID for the pad + pad_id = uuid.uuid4() + + # Store the mapping for later use + user_pad_mapping[user_id] = pad_id + + # Insert the pad record + session.execute( + pads.insert().values( + id=pad_id, + owner_id=user_id, + display_name="Untitled", + data=record.data, + ) + ) + + # Step 3: Get all canvas_backups records + canvas_backup_records = session.execute(sa.select(canvas_backups)).fetchall() + + # Step 4: For each canvas_backup record, create a new backup + for record in canvas_backup_records: + user_id = record.user_id + + # Skip if we don't have a pad for this user + if user_id not in user_pad_mapping: + print(f"Warning: No pad found for user {user_id}, skipping backup") + continue + + pad_id = user_pad_mapping[user_id] + + # Insert the backup record + session.execute( + backups.insert().values( + id=uuid.uuid4(), + source_id=pad_id, + data=record.canvas_data, # Note: using canvas_data field from the record + created_at=record.timestamp, + ) + ) + + # Commit the transaction + session.commit() + + print(f"Migration complete: {len(canvas_data_records)} pads and {len(canvas_backup_records)} backups migrated") + + except Exception as e: + session.rollback() + print(f"Error during migration: {e}") + raise + finally: + session.close() + + +def downgrade() -> None: + """Downgrade is not supported for this migration""" + print("Downgrade is not supported for this data migration") diff --git a/src/backend/database/models/__init__.py b/src/backend/database/models/__init__.py new file mode 100644 index 0000000..41fe372 --- /dev/null +++ b/src/backend/database/models/__init__.py @@ -0,0 +1,20 @@ +""" +Database models for the application. + +This module provides access to all database models used in the application. +""" + +from .base_model import Base, BaseModel, SCHEMA_NAME +from .user_model import UserModel +from .pad_model import PadModel, TemplatePadModel +from .backup_model import BackupModel + +__all__ = [ + 'Base', + 'BaseModel', + 'UserModel', + 'PadModel', + 'BackupModel', + 'TemplatePadModel', + 'SCHEMA_NAME', +] diff --git a/src/backend/database/models/backup_model.py b/src/backend/database/models/backup_model.py new file mode 100644 index 0000000..5e2f7d3 --- /dev/null +++ b/src/backend/database/models/backup_model.py @@ -0,0 +1,42 @@ +from typing import Dict, Any, TYPE_CHECKING + +from sqlalchemy import Column, ForeignKey, Index, UUID +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import relationship, Mapped + +from .base_model import Base, BaseModel, SCHEMA_NAME + +if TYPE_CHECKING: + from .pad_model import PadModel + +class BackupModel(Base, BaseModel): + """Model for backups table in app schema""" + __tablename__ = "backups" + __table_args__ = ( + Index("ix_backups_source_id", "source_id"), + Index("ix_backups_created_at", "created_at"), + {"schema": SCHEMA_NAME} + ) + + # Backup-specific fields + source_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{SCHEMA_NAME}.pads.id", ondelete="CASCADE"), + nullable=False + ) + data = Column(JSONB, nullable=False) + + # Relationships + pad: Mapped["PadModel"] = relationship("PadModel", back_populates="backups") + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + """Convert model instance to dictionary with additional fields""" + result = super().to_dict() + # Convert data to dict if it's not already + if isinstance(result["data"], str): + import json + result["data"] = json.loads(result["data"]) + return result diff --git a/src/backend/database/models/base_model.py b/src/backend/database/models/base_model.py new file mode 100644 index 0000000..87f3d8c --- /dev/null +++ b/src/backend/database/models/base_model.py @@ -0,0 +1,28 @@ +from uuid import uuid4 +from typing import Any, Dict + +from sqlalchemy import Column, DateTime, UUID, func, MetaData +from sqlalchemy.orm import DeclarativeMeta, declarative_base + +# Define schema name in a central location +SCHEMA_NAME = "pad_ws" + +# Create metadata with schema +metadata = MetaData(schema=SCHEMA_NAME) + +# Create a single shared Base for all models with the schema-aware metadata +Base: DeclarativeMeta = declarative_base(metadata=metadata) + +class BaseModel: + """Base model with common fields for all models""" + + # Primary key using UUID + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, index=True) + + # Timestamps for creation and updates + created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + + def to_dict(self) -> Dict[str, Any]: + """Convert model instance to dictionary""" + return {c.name: getattr(self, c.name) for c in self.__table__.columns} diff --git a/src/backend/database/models/pad_model.py b/src/backend/database/models/pad_model.py new file mode 100644 index 0000000..fafddbc --- /dev/null +++ b/src/backend/database/models/pad_model.py @@ -0,0 +1,67 @@ +from typing import List, Dict, Any, TYPE_CHECKING + +from sqlalchemy import Column, String, ForeignKey, Index, UUID +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import relationship, Mapped + +from .base_model import Base, BaseModel, SCHEMA_NAME + +if TYPE_CHECKING: + from .backup_model import BackupModel + from .user_model import UserModel + +class PadModel(Base, BaseModel): + """Model for pads table in app schema""" + __tablename__ = "pads" + __table_args__ = ( + Index("ix_pads_owner_id", "owner_id"), + Index("ix_pads_display_name", "display_name"), + {"schema": SCHEMA_NAME} + ) + + # Pad-specific fields + owner_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{SCHEMA_NAME}.users.id", ondelete="CASCADE"), + nullable=False + ) + display_name = Column(String(100), nullable=False) + data = Column(JSONB, nullable=False) + + # Relationships + owner: Mapped["UserModel"] = relationship("UserModel", back_populates="pads") + backups: Mapped[List["BackupModel"]] = relationship( + "BackupModel", + back_populates="pad", + cascade="all, delete-orphan", + lazy="selectin" + ) + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + """Convert model instance to dictionary with additional fields""" + result = super().to_dict() + # Convert data to dict if it's not already + if isinstance(result["data"], str): + import json + result["data"] = json.loads(result["data"]) + return result + + +class TemplatePadModel(Base, BaseModel): + """Model for template pads table in app schema""" + __tablename__ = "template_pads" + __table_args__ = ( + Index("ix_template_pads_display_name", "display_name"), + Index("ix_template_pads_name", "name"), + {"schema": SCHEMA_NAME} + ) + + name = Column(String(100), nullable=False, unique=True) + display_name = Column(String(100), nullable=False) + data = Column(JSONB, nullable=False) + + def __repr__(self) -> str: + return f"" \ No newline at end of file diff --git a/src/backend/database/models/user_model.py b/src/backend/database/models/user_model.py new file mode 100644 index 0000000..28526b5 --- /dev/null +++ b/src/backend/database/models/user_model.py @@ -0,0 +1,41 @@ +from typing import List, TYPE_CHECKING +from sqlalchemy import Column, Index, String, UUID, Boolean +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import relationship, Mapped + +from .base_model import Base, BaseModel, SCHEMA_NAME + +if TYPE_CHECKING: + from .pad_model import PadModel + +class UserModel(Base, BaseModel): + """Model for users table in app schema""" + __tablename__ = "users" + __table_args__ = ( + Index("ix_users_username", "username"), + Index("ix_users_email", "email"), + {"schema": SCHEMA_NAME} + ) + + # Override the default id column to use Keycloak's UUID + id = Column(UUID(as_uuid=True), primary_key=True) + + # User-specific fields + username = Column(String(254), nullable=False, unique=True) + email = Column(String(254), nullable=False, unique=True) + email_verified = Column(Boolean, nullable=False, default=False) + name = Column(String(254), nullable=True) + given_name = Column(String(254), nullable=True) + family_name = Column(String(254), nullable=True) + roles = Column(JSONB, nullable=False, default=[]) + + # Relationships + pads: Mapped[List["PadModel"]] = relationship( + "PadModel", + back_populates="owner", + cascade="all, delete-orphan", + lazy="selectin" + ) + + def __repr__(self) -> str: + return f"" diff --git a/src/backend/database/repository/__init__.py b/src/backend/database/repository/__init__.py new file mode 100644 index 0000000..a2433d7 --- /dev/null +++ b/src/backend/database/repository/__init__.py @@ -0,0 +1,17 @@ +""" +Repository module for database operations. + +This module provides access to all repositories used for database operations. +""" + +from .user_repository import UserRepository +from .pad_repository import PadRepository +from .backup_repository import BackupRepository +from .template_pad_repository import TemplatePadRepository + +__all__ = [ + 'UserRepository', + 'PadRepository', + 'BackupRepository', + 'TemplatePadRepository', +] diff --git a/src/backend/database/repository/backup_repository.py b/src/backend/database/repository/backup_repository.py new file mode 100644 index 0000000..dcae35a --- /dev/null +++ b/src/backend/database/repository/backup_repository.py @@ -0,0 +1,116 @@ +""" +Backup repository for database operations related to backups. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID +from datetime import datetime + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import delete, func, join + +from ..models import BackupModel, PadModel + +class BackupRepository: + """Repository for backup-related database operations""" + + def __init__(self, session: AsyncSession): + """Initialize the repository with a database session""" + self.session = session + + async def create(self, source_id: UUID, data: Dict[str, Any]) -> BackupModel: + """Create a new backup""" + backup = BackupModel(source_id=source_id, data=data) + self.session.add(backup) + await self.session.commit() + await self.session.refresh(backup) + return backup + + async def get_by_id(self, backup_id: UUID) -> Optional[BackupModel]: + """Get a backup by ID""" + stmt = select(BackupModel).where(BackupModel.id == backup_id) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_source(self, source_id: UUID) -> List[BackupModel]: + """Get all backups for a specific source pad""" + stmt = select(BackupModel).where(BackupModel.source_id == source_id).order_by(BackupModel.created_at.desc()) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def get_latest_by_source(self, source_id: UUID) -> Optional[BackupModel]: + """Get the most recent backup for a specific source pad""" + stmt = select(BackupModel).where(BackupModel.source_id == source_id).order_by(BackupModel.created_at.desc()).limit(1) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_date_range(self, source_id: UUID, start_date: datetime, end_date: datetime) -> List[BackupModel]: + """Get backups for a specific source pad within a date range""" + stmt = select(BackupModel).where( + BackupModel.source_id == source_id, + BackupModel.created_at >= start_date, + BackupModel.created_at <= end_date + ).order_by(BackupModel.created_at.desc()) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def delete(self, backup_id: UUID) -> bool: + """Delete a backup""" + stmt = delete(BackupModel).where(BackupModel.id == backup_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount > 0 + + async def delete_older_than(self, source_id: UUID, keep_count: int) -> int: + """Delete older backups, keeping only the most recent ones""" + # Get the created_at timestamp of the backup at position keep_count + subquery = select(BackupModel.created_at).where( + BackupModel.source_id == source_id + ).order_by(BackupModel.created_at.desc()).offset(keep_count).limit(1) + + result = await self.session.execute(subquery) + cutoff_date = result.scalar() + + if not cutoff_date: + return 0 # Not enough backups to delete any + + # Delete backups older than the cutoff date + stmt = delete(BackupModel).where( + BackupModel.source_id == source_id, + BackupModel.created_at < cutoff_date + ) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount + + async def count_by_source(self, source_id: UUID) -> int: + """Count the number of backups for a specific source pad""" + stmt = select(func.count()).select_from(BackupModel).where(BackupModel.source_id == source_id) + result = await self.session.execute(stmt) + return result.scalar() + + async def get_backups_by_user(self, user_id: UUID, limit: int = 10) -> List[BackupModel]: + """ + Get backups for a user's first pad directly using a join operation. + This eliminates the N+1 query problem by fetching the pad and its backups in a single query. + + Args: + user_id: The user ID to get backups for + limit: Maximum number of backups to return + + Returns: + List of backup models + """ + # Create a join between PadModel and BackupModel + stmt = select(BackupModel).join( + PadModel, + BackupModel.source_id == PadModel.id + ).where( + PadModel.owner_id == user_id + ).order_by( + BackupModel.created_at.desc() + ).limit(limit) + + result = await self.session.execute(stmt) + return result.scalars().all() diff --git a/src/backend/database/repository/pad_repository.py b/src/backend/database/repository/pad_repository.py new file mode 100644 index 0000000..7b8123a --- /dev/null +++ b/src/backend/database/repository/pad_repository.py @@ -0,0 +1,66 @@ +""" +Pad repository for database operations related to pads. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update, delete + +from ..models import PadModel + +class PadRepository: + """Repository for pad-related database operations""" + + def __init__(self, session: AsyncSession): + """Initialize the repository with a database session""" + self.session = session + + async def create(self, owner_id: UUID, display_name: str, data: Dict[str, Any]) -> PadModel: + """Create a new pad""" + pad = PadModel(owner_id=owner_id, display_name=display_name, data=data) + self.session.add(pad) + await self.session.commit() + await self.session.refresh(pad) + return pad + + async def get_by_id(self, pad_id: UUID) -> Optional[PadModel]: + """Get a pad by ID""" + stmt = select(PadModel).where(PadModel.id == pad_id) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_owner(self, owner_id: UUID) -> List[PadModel]: + """Get all pads for a specific owner""" + stmt = select(PadModel).where(PadModel.owner_id == owner_id) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def get_by_name(self, owner_id: UUID, display_name: str) -> Optional[PadModel]: + """Get a pad by owner and display name""" + stmt = select(PadModel).where( + PadModel.owner_id == owner_id, + PadModel.display_name == display_name + ) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def update(self, pad_id: UUID, data: Dict[str, Any]) -> Optional[PadModel]: + """Update a pad""" + stmt = update(PadModel).where(PadModel.id == pad_id).values(**data).returning(PadModel) + result = await self.session.execute(stmt) + await self.session.commit() + return result.scalars().first() + + async def update_data(self, pad_id: UUID, pad_data: Dict[str, Any]) -> Optional[PadModel]: + """Update just the data field of a pad""" + return await self.update(pad_id, {"data": pad_data}) + + async def delete(self, pad_id: UUID) -> bool: + """Delete a pad""" + stmt = delete(PadModel).where(PadModel.id == pad_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount > 0 diff --git a/src/backend/database/repository/template_pad_repository.py b/src/backend/database/repository/template_pad_repository.py new file mode 100644 index 0000000..ff85e39 --- /dev/null +++ b/src/backend/database/repository/template_pad_repository.py @@ -0,0 +1,63 @@ +""" +Template pad repository for database operations related to template pads. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update, delete + +from ..models import TemplatePadModel + +class TemplatePadRepository: + """Repository for template pad-related database operations""" + + def __init__(self, session: AsyncSession): + """Initialize the repository with a database session""" + self.session = session + + async def create(self, name: str, display_name: str, data: Dict[str, Any]) -> TemplatePadModel: + """Create a new template pad""" + template_pad = TemplatePadModel(name=name, display_name=display_name, data=data) + self.session.add(template_pad) + await self.session.commit() + await self.session.refresh(template_pad) + return template_pad + + async def get_by_id(self, template_id: UUID) -> Optional[TemplatePadModel]: + """Get a template pad by ID""" + stmt = select(TemplatePadModel).where(TemplatePadModel.id == template_id) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_name(self, name: str) -> Optional[TemplatePadModel]: + """Get a template pad by name""" + stmt = select(TemplatePadModel).where(TemplatePadModel.name == name) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_all(self) -> List[TemplatePadModel]: + """Get all template pads""" + stmt = select(TemplatePadModel).order_by(TemplatePadModel.display_name) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def update(self, name: str, data: Dict[str, Any]) -> Optional[TemplatePadModel]: + """Update a template pad""" + stmt = update(TemplatePadModel).where(TemplatePadModel.name == name).values(**data).returning(TemplatePadModel) + result = await self.session.execute(stmt) + await self.session.commit() + return result.scalars().first() + + async def update_data(self, name: str, template_data: Dict[str, Any]) -> Optional[TemplatePadModel]: + """Update just the data field of a template pad""" + return await self.update(name, {"data": template_data}) + + async def delete(self, name: str) -> bool: + """Delete a template pad""" + stmt = delete(TemplatePadModel).where(TemplatePadModel.name == name) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount > 0 diff --git a/src/backend/database/repository/user_repository.py b/src/backend/database/repository/user_repository.py new file mode 100644 index 0000000..2e25e26 --- /dev/null +++ b/src/backend/database/repository/user_repository.py @@ -0,0 +1,76 @@ +""" +User repository for database operations related to users. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update, delete + +from ..models import UserModel + +class UserRepository: + """Repository for user-related database operations""" + + def __init__(self, session: AsyncSession): + """Initialize the repository with a database session""" + self.session = session + + async def create(self, user_id: UUID, username: str, email: str, email_verified: bool = False, + name: str = None, given_name: str = None, family_name: str = None, + roles: list = None) -> UserModel: + """Create a new user with specified ID and optional fields""" + user = UserModel( + id=user_id, + username=username, + email=email, + email_verified=email_verified, + name=name, + given_name=given_name, + family_name=family_name, + roles=roles or [] + ) + self.session.add(user) + await self.session.commit() + await self.session.refresh(user) + return user + + async def get_by_id(self, user_id: UUID) -> Optional[UserModel]: + """Get a user by ID""" + stmt = select(UserModel).where(UserModel.id == user_id) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_username(self, username: str) -> Optional[UserModel]: + """Get a user by username""" + stmt = select(UserModel).where(UserModel.username == username) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_email(self, email: str) -> Optional[UserModel]: + """Get a user by email""" + stmt = select(UserModel).where(UserModel.email == email) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_all(self) -> List[UserModel]: + """Get all users""" + stmt = select(UserModel) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def update(self, user_id: UUID, data: Dict[str, Any]) -> Optional[UserModel]: + """Update a user""" + stmt = update(UserModel).where(UserModel.id == user_id).values(**data).returning(UserModel) + result = await self.session.execute(stmt) + await self.session.commit() + return result.scalars().first() + + async def delete(self, user_id: UUID) -> bool: + """Delete a user""" + stmt = delete(UserModel).where(UserModel.id == user_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount > 0 diff --git a/src/backend/database/service/__init__.py b/src/backend/database/service/__init__.py new file mode 100644 index 0000000..d362c0b --- /dev/null +++ b/src/backend/database/service/__init__.py @@ -0,0 +1,17 @@ +""" +Service module for business logic. + +This module provides access to all services used for business logic operations. +""" + +from .user_service import UserService +from .pad_service import PadService +from .backup_service import BackupService +from .template_pad_service import TemplatePadService + +__all__ = [ + 'UserService', + 'PadService', + 'BackupService', + 'TemplatePadService', +] diff --git a/src/backend/database/service/backup_service.py b/src/backend/database/service/backup_service.py new file mode 100644 index 0000000..a9e44db --- /dev/null +++ b/src/backend/database/service/backup_service.py @@ -0,0 +1,176 @@ +""" +Backup service for business logic related to backups. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID +from datetime import datetime, timezone + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..repository import BackupRepository, PadRepository, UserRepository + +class BackupService: + """Service for backup-related business logic""" + + def __init__(self, session: AsyncSession): + """Initialize the service with a database session""" + self.session = session + self.repository = BackupRepository(session) + self.pad_repository = PadRepository(session) + + async def create_backup(self, source_id: UUID, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new backup""" + # Validate input + if not data: + raise ValueError("Backup data is required") + + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + # Create backup + backup = await self.repository.create(source_id, data) + return backup.to_dict() + + async def get_backup(self, backup_id: UUID) -> Optional[Dict[str, Any]]: + """Get a backup by ID""" + backup = await self.repository.get_by_id(backup_id) + return backup.to_dict() if backup else None + + async def get_backups_by_source(self, source_id: UUID) -> List[Dict[str, Any]]: + """Get all backups for a specific source pad""" + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + backups = await self.repository.get_by_source(source_id) + return [backup.to_dict() for backup in backups] + + async def get_latest_backup(self, source_id: UUID) -> Optional[Dict[str, Any]]: + """Get the most recent backup for a specific source pad""" + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + backup = await self.repository.get_latest_by_source(source_id) + return backup.to_dict() if backup else None + + async def get_backups_by_date_range(self, source_id: UUID, start_date: datetime, end_date: datetime) -> List[Dict[str, Any]]: + """Get backups for a specific source pad within a date range""" + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + # Validate date range + if start_date > end_date: + raise ValueError("Start date must be before end date") + + backups = await self.repository.get_by_date_range(source_id, start_date, end_date) + return [backup.to_dict() for backup in backups] + + async def delete_backup(self, backup_id: UUID) -> bool: + """Delete a backup""" + # Get the backup to check if it exists + backup = await self.repository.get_by_id(backup_id) + if not backup: + raise ValueError(f"Backup with ID '{backup_id}' does not exist") + + return await self.repository.delete(backup_id) + + async def manage_backups(self, source_id: UUID, max_backups: int = 10) -> int: + """Manage backups for a source pad, keeping only the most recent ones""" + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + # Validate max_backups + if max_backups < 1: + raise ValueError("Maximum number of backups must be at least 1") + + # Count current backups + backup_count = await self.repository.count_by_source(source_id) + + # If we have more backups than the maximum, delete the oldest ones + if backup_count > max_backups: + return await self.repository.delete_older_than(source_id, max_backups) + + return 0 # No backups deleted + + async def get_backups_by_user(self, user_id: UUID, limit: int = 10) -> List[Dict[str, Any]]: + """ + Get backups for a user's first pad directly using a join operation. + This eliminates the N+1 query problem by fetching the pad and its backups in a single query. + + Args: + user_id: The user ID to get backups for + limit: Maximum number of backups to return + + Returns: + List of backup dictionaries + """ + # Check if user exists + user_repository = UserRepository(self.session) + user = await user_repository.get_by_id(user_id) + if not user: + raise ValueError(f"User with ID '{user_id}' does not exist") + + # Get backups directly with a single query + backups = await self.repository.get_backups_by_user(user_id, limit) + return [backup.to_dict() for backup in backups] + + async def create_backup_if_needed(self, source_id: UUID, data: Dict[str, Any], + min_interval_minutes: int = 5, + max_backups: int = 10) -> Optional[Dict[str, Any]]: + """ + Create a backup only if needed: + - If there are no existing backups + - If the latest backup is older than the specified interval + + Args: + source_id: The ID of the source pad + data: The data to backup + min_interval_minutes: Minimum time between backups in minutes + max_backups: Maximum number of backups to keep + + Returns: + The created backup dict if a backup was created, None otherwise + """ + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + # Get the latest backup + latest_backup = await self.repository.get_latest_by_source(source_id) + + # Calculate the current time with timezone information + current_time = datetime.now(timezone.utc) + + # Determine if we need to create a backup + create_backup = False + + if not latest_backup: + # No backups exist yet, so create one + create_backup = True + else: + # Check if the latest backup is older than the minimum interval + backup_age = current_time - latest_backup.created_at + if backup_age.total_seconds() > (min_interval_minutes * 60): + create_backup = True + + # Create a backup if needed + if create_backup: + backup = await self.repository.create(source_id, data) + + # Manage backups (clean up old ones) + await self.manage_backups(source_id, max_backups) + + return backup.to_dict() + + return None diff --git a/src/backend/database/service/pad_service.py b/src/backend/database/service/pad_service.py new file mode 100644 index 0000000..1cf6566 --- /dev/null +++ b/src/backend/database/service/pad_service.py @@ -0,0 +1,103 @@ +""" +Pad service for business logic related to pads. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..repository import PadRepository, UserRepository + +class PadService: + """Service for pad-related business logic""" + + def __init__(self, session: AsyncSession): + """Initialize the service with a database session""" + self.session = session + self.repository = PadRepository(session) + self.user_repository = UserRepository(session) + + async def create_pad(self, owner_id: UUID, display_name: str, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new pad""" + # Validate input + if not display_name: + raise ValueError("Display name is required") + + if not data: + raise ValueError("Pad data is required") + + # Check if owner exists + owner = await self.user_repository.get_by_id(owner_id) + if not owner: + raise ValueError(f"User with ID '{owner_id}' does not exist") + + # Check if pad with same name already exists for this owner + existing_pad = await self.repository.get_by_name(owner_id, display_name) + if existing_pad: + raise ValueError(f"Pad with name '{display_name}' already exists for this user") + + # Create pad + pad = await self.repository.create(owner_id, display_name, data) + return pad.to_dict() + + async def get_pad(self, pad_id: UUID) -> Optional[Dict[str, Any]]: + """Get a pad by ID""" + pad = await self.repository.get_by_id(pad_id) + return pad.to_dict() if pad else None + + async def get_pads_by_owner(self, owner_id: UUID) -> List[Dict[str, Any]]: + """Get all pads for a specific owner""" + # Check if owner exists + owner = await self.user_repository.get_by_id(owner_id) + if not owner: + raise ValueError(f"User with ID '{owner_id}' does not exist") + + pads = await self.repository.get_by_owner(owner_id) + return [pad.to_dict() for pad in pads] + + async def get_pad_by_name(self, owner_id: UUID, display_name: str) -> Optional[Dict[str, Any]]: + """Get a pad by owner and display name""" + pad = await self.repository.get_by_name(owner_id, display_name) + return pad.to_dict() if pad else None + + async def update_pad(self, pad_id: UUID, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update a pad""" + # Get the pad to check if it exists + pad = await self.repository.get_by_id(pad_id) + if not pad: + raise ValueError(f"Pad with ID '{pad_id}' does not exist") + + # Validate display_name if it's being updated + if 'display_name' in data and not data['display_name']: + raise ValueError("Display name cannot be empty") + + # Check if new display_name already exists for this owner (if being updated) + if 'display_name' in data and data['display_name'] != pad.display_name: + existing_pad = await self.repository.get_by_name(pad.owner_id, data['display_name']) + if existing_pad: + raise ValueError(f"Pad with name '{data['display_name']}' already exists for this user") + + # Update pad + updated_pad = await self.repository.update(pad_id, data) + return updated_pad.to_dict() if updated_pad else None + + async def update_pad_data(self, pad_id: UUID, pad_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update just the data field of a pad""" + # Get the pad to check if it exists + pad = await self.repository.get_by_id(pad_id) + if not pad: + raise ValueError(f"Pad with ID '{pad_id}' does not exist") + + # Update pad data + updated_pad = await self.repository.update_data(pad_id, pad_data) + return updated_pad.to_dict() if updated_pad else None + + async def delete_pad(self, pad_id: UUID) -> bool: + """Delete a pad""" + # Get the pad to check if it exists + pad = await self.repository.get_by_id(pad_id) + if not pad: + raise ValueError(f"Pad with ID '{pad_id}' does not exist") + + return await self.repository.delete(pad_id) diff --git a/src/backend/database/service/template_pad_service.py b/src/backend/database/service/template_pad_service.py new file mode 100644 index 0000000..ead220b --- /dev/null +++ b/src/backend/database/service/template_pad_service.py @@ -0,0 +1,98 @@ +""" +Template pad service for business logic related to template pads. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..repository import TemplatePadRepository + +class TemplatePadService: + """Service for template pad-related business logic""" + + def __init__(self, session: AsyncSession): + """Initialize the service with a database session""" + self.session = session + self.repository = TemplatePadRepository(session) + + async def create_template(self, name: str, display_name: str, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new template pad""" + # Validate input + if not name: + raise ValueError("Name is required") + + if not display_name: + raise ValueError("Display name is required") + + if not data: + raise ValueError("Template data is required") + + # Check if template with same name already exists + existing_template = await self.repository.get_by_name(name) + if existing_template: + raise ValueError(f"Template with name '{name}' already exists") + + # Create template pad + template_pad = await self.repository.create(name, display_name, data) + return template_pad.to_dict() + + async def get_template(self, template_id: UUID) -> Optional[Dict[str, Any]]: + """Get a template pad by ID""" + template_pad = await self.repository.get_by_id(template_id) + return template_pad.to_dict() if template_pad else None + + async def get_template_by_name(self, name: str) -> Optional[Dict[str, Any]]: + """Get a template pad by name""" + template_pad = await self.repository.get_by_name(name) + return template_pad.to_dict() if template_pad else None + + async def get_all_templates(self) -> List[Dict[str, Any]]: + """Get all template pads""" + template_pads = await self.repository.get_all() + return [template_pad.to_dict() for template_pad in template_pads] + + async def update_template(self, name: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update a template pad""" + # Get the template pad to check if it exists + template_pad = await self.repository.get_by_name(name) + if not template_pad: + raise ValueError(f"Template pad with name '{name}' does not exist") + + # Validate name and display_name if they're being updated + if 'name' in data and not data['name']: + raise ValueError("Name cannot be empty") + + if 'display_name' in data and not data['display_name']: + raise ValueError("Display name cannot be empty") + + # Check if new name already exists (if being updated) + if 'name' in data and data['name'] != template_pad.name: + existing_template = await self.repository.get_by_name(data['name']) + if existing_template: + raise ValueError(f"Template with name '{data['name']}' already exists") + + # Update template pad + updated_template = await self.repository.update(name, data) + return updated_template.to_dict() if updated_template else None + + async def update_template_data(self, name: str, template_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update just the data field of a template pad""" + # Get the template pad to check if it exists + template_pad = await self.repository.get_by_name(name) + if not template_pad: + raise ValueError(f"Template pad with name '{name}' does not exist") + + # Update template pad data + updated_template = await self.repository.update_data(name, template_data) + return updated_template.to_dict() if updated_template else None + + async def delete_template(self, name: str) -> bool: + """Delete a template pad""" + # Get the template pad to check if it exists + template_pad = await self.repository.get_by_name(name) + if not template_pad: + raise ValueError(f"Template pad with name '{name}' does not exist") + + return await self.repository.delete(name) diff --git a/src/backend/database/service/user_service.py b/src/backend/database/service/user_service.py new file mode 100644 index 0000000..6dc9d8d --- /dev/null +++ b/src/backend/database/service/user_service.py @@ -0,0 +1,157 @@ +""" +User service for business logic related to users. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..repository import UserRepository + +class UserService: + """Service for user-related business logic""" + + def __init__(self, session: AsyncSession): + """Initialize the service with a database session""" + self.session = session + self.repository = UserRepository(session) + + async def create_user(self, user_id: UUID, username: str, email: str, + email_verified: bool = False, name: str = None, + given_name: str = None, family_name: str = None, + roles: list = None) -> Dict[str, Any]: + """Create a new user with specified ID and optional fields""" + # Validate input + if not user_id or not username or not email: + raise ValueError("User ID, username, and email are required") + + # Check if user_id already exists + existing_id = await self.repository.get_by_id(user_id) + if existing_id: + raise ValueError(f"User with ID '{user_id}' already exists") + + # Check if username already exists + existing_user = await self.repository.get_by_username(username) + if existing_user: + raise ValueError(f"Username '{username}' is already taken") + + # Create user + user = await self.repository.create( + user_id=user_id, + username=username, + email=email, + email_verified=email_verified, + name=name, + given_name=given_name, + family_name=family_name, + roles=roles + ) + return user.to_dict() + + async def get_user(self, user_id: UUID) -> Optional[Dict[str, Any]]: + """Get a user by ID""" + user = await self.repository.get_by_id(user_id) + return user.to_dict() if user else None + + async def get_user_by_username(self, username: str) -> Optional[Dict[str, Any]]: + """Get a user by username""" + user = await self.repository.get_by_username(username) + return user.to_dict() if user else None + + async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: + """Get a user by email""" + user = await self.repository.get_by_email(email) + return user.to_dict() if user else None + + async def get_all_users(self) -> List[Dict[str, Any]]: + """Get all users""" + users = await self.repository.get_all() + return [user.to_dict() for user in users] + + async def update_user(self, user_id: UUID, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update a user""" + # Validate input + if 'username' in data and not data['username']: + raise ValueError("Username cannot be empty") + + if 'email' in data and not data['email']: + raise ValueError("Email cannot be empty") + + # Check if username already exists (if being updated) + if 'username' in data: + existing_user = await self.repository.get_by_username(data['username']) + if existing_user and existing_user.id != user_id: + raise ValueError(f"Username '{data['username']}' is already taken") + + # Check if email already exists (if being updated) + if 'email' in data: + existing_email = await self.repository.get_by_email(data['email']) + if existing_email and existing_email.id != user_id: + raise ValueError(f"Email '{data['email']}' is already registered") + + # Update user + user = await self.repository.update(user_id, data) + return user.to_dict() if user else None + + async def delete_user(self, user_id: UUID) -> bool: + """Delete a user""" + return await self.repository.delete(user_id) + + async def sync_user_with_token_data(self, user_id: UUID, token_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Synchronize user data in the database with data from the authentication token. + If the user doesn't exist, it will be created. If it exists but has different data, + it will be updated to match the token data. + + Args: + user_id: The user's UUID + token_data: Dictionary containing user data from the authentication token + + Returns: + The user data dictionary or None if operation failed + """ + # Check if user exists + user_data = await self.get_user(user_id) + + # If user doesn't exist, create a new one + if not user_data: + try: + return await self.create_user( + user_id=user_id, + username=token_data.get("username", ""), + email=token_data.get("email", ""), + email_verified=token_data.get("email_verified", False), + name=token_data.get("name"), + given_name=token_data.get("given_name"), + family_name=token_data.get("family_name"), + roles=token_data.get("roles", []) + ) + except ValueError as e: + # Handle case where user might have been created in a race condition + if "already exists" in str(e): + user_data = await self.get_user(user_id) + else: + raise e + + # Check if user data needs to be updated + update_data = {} + fields_to_check = [ + "username", "email", "email_verified", + "name", "given_name", "family_name" + ] + + for field in fields_to_check: + token_value = token_data.get(field) + if token_value is not None and user_data.get(field) != token_value: + update_data[field] = token_value + + # Handle roles separately as they might have a different structure + if "roles" in token_data and user_data.get("roles") != token_data["roles"]: + update_data["roles"] = token_data["roles"] + + # Update user if any field has changed + if update_data: + return await self.update_user(user_id, update_data) + + return user_data diff --git a/src/backend/db.py b/src/backend/db.py deleted file mode 100644 index 019f350..0000000 --- a/src/backend/db.py +++ /dev/null @@ -1,208 +0,0 @@ -import os -from typing import Optional, Dict, Any, List -from datetime import datetime -from dotenv import load_dotenv -from sqlalchemy import Column, String, JSON, DateTime, Integer, func, create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker -from sqlalchemy.future import select -from urllib.parse import quote_plus as urlquote - -load_dotenv() - -# PostgreSQL connection configuration -DB_USER = os.getenv('POSTGRES_USER', 'postgres') -DB_PASSWORD = os.getenv('POSTGRES_PASSWORD', 'postgres') -DB_NAME = os.getenv('POSTGRES_DB', 'pad') -DB_HOST = os.getenv('POSTGRES_HOST', 'localhost') -DB_PORT = os.getenv('POSTGRES_PORT', '5432') - -# SQLAlchemy async database URL -DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{urlquote(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}" - -# Create async engine -engine = create_async_engine(DATABASE_URL, echo=False) - -# Create async session factory -async_session = sessionmaker( - engine, class_=AsyncSession, expire_on_commit=False -) - -# Create base model -Base = declarative_base() - -# Canvas backup configuration -BACKUP_INTERVAL_SECONDS = 300 # 5 minutes between backups -MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user - -# In-memory dictionaries to track user activity -user_last_backup_time = {} # Tracks when each user last had a backup -user_last_activity_time = {} # Tracks when each user was last active - -# Memory management configuration -INACTIVITY_THRESHOLD_MINUTES = 30 # Remove users from memory after this many minutes of inactivity -MAX_USERS_BEFORE_CLEANUP = 1000 # Trigger cleanup when we have this many users in memory -CLEANUP_INTERVAL_SECONDS = 3600 # Run cleanup at least once per hour (1 hour) -last_cleanup_time = datetime.now() # Track when we last ran cleanup - -class CanvasData(Base): - """Model for canvas data table""" - __tablename__ = "canvas_data" - - user_id = Column(String, primary_key=True) - data = Column(JSON, nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) - - def __repr__(self): - return f"" - -class CanvasBackup(Base): - """Model for canvas backups table""" - __tablename__ = "canvas_backups" - - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(String, nullable=False) - timestamp = Column(DateTime(timezone=True), server_default=func.now()) - canvas_data = Column(JSON, nullable=False) - - def __repr__(self): - return f"" - -async def get_db_session(): - """Get a database session""" - async with async_session() as session: - yield session - -async def init_db(): - """Initialize the database with required tables""" - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - -async def cleanup_inactive_users(inactivity_threshold_minutes: int = INACTIVITY_THRESHOLD_MINUTES): - """Remove users from memory tracking if they've been inactive for the specified time""" - current_time = datetime.now() - inactive_users = [] - - for user_id, last_activity in user_last_activity_time.items(): - # Check if user has been inactive for longer than the threshold - if (current_time - last_activity).total_seconds() > (inactivity_threshold_minutes * 60): - inactive_users.append(user_id) - - # Remove inactive users from both dictionaries - for user_id in inactive_users: - if user_id in user_last_backup_time: - del user_last_backup_time[user_id] - if user_id in user_last_activity_time: - del user_last_activity_time[user_id] - - return len(inactive_users) # Return count of removed users for logging - -async def check_if_cleanup_needed(): - """Check if we should run the cleanup function""" - global last_cleanup_time - current_time = datetime.now() - time_since_last_cleanup = (current_time - last_cleanup_time).total_seconds() - - # Run cleanup if we have too many users or it's been too long - if (len(user_last_activity_time) > MAX_USERS_BEFORE_CLEANUP or - time_since_last_cleanup > CLEANUP_INTERVAL_SECONDS): - removed_count = await cleanup_inactive_users() - last_cleanup_time = current_time - print(f"[db.py] Cleanup completed: removed {removed_count} inactive users from memory") - -async def store_canvas_data(user_id: str, data: Dict[str, Any]) -> bool: - try: - # Update user's last activity time - current_time = datetime.now() - user_last_activity_time[user_id] = current_time - - # Check if cleanup is needed - await check_if_cleanup_needed() - - async with async_session() as session: - # Check if record exists - stmt = select(CanvasData).where(CanvasData.user_id == user_id) - result = await session.execute(stmt) - canvas_data = result.scalars().first() - - if canvas_data: - # Update existing record - canvas_data.data = data - else: - # Create new record - canvas_data = CanvasData(user_id=user_id, data=data) - session.add(canvas_data) - - # Check if we should create a backup - should_backup = False - if user_id not in user_last_backup_time: - # First time this user is saving, create a backup - should_backup = True - else: - # Check if backup interval has passed since last backup - time_since_last_backup = (current_time - user_last_backup_time[user_id]).total_seconds() - if time_since_last_backup >= BACKUP_INTERVAL_SECONDS: - should_backup = True - - if should_backup: - # Update the backup timestamp - user_last_backup_time[user_id] = current_time - - # Count existing backups for this user - backup_count_stmt = select(func.count()).select_from(CanvasBackup).where(CanvasBackup.user_id == user_id) - backup_count_result = await session.execute(backup_count_stmt) - backup_count = backup_count_result.scalar() - - # If user has reached the maximum number of backups, delete the oldest one - if backup_count >= MAX_BACKUPS_PER_USER: - oldest_backup_stmt = select(CanvasBackup).where(CanvasBackup.user_id == user_id).order_by(CanvasBackup.timestamp).limit(1) - oldest_backup_result = await session.execute(oldest_backup_stmt) - oldest_backup = oldest_backup_result.scalars().first() - - if oldest_backup: - await session.delete(oldest_backup) - - # Create new backup - new_backup = CanvasBackup(user_id=user_id, canvas_data=data) - session.add(new_backup) - - await session.commit() - return True - except Exception as e: - print(f"Error storing canvas data: {e}") - return False - -async def get_canvas_data(user_id: str) -> Optional[Dict[str, Any]]: - try: - # Update user's last activity time - user_last_activity_time[user_id] = datetime.now() - - async with async_session() as session: - stmt = select(CanvasData).where(CanvasData.user_id == user_id) - result = await session.execute(stmt) - canvas_data = result.scalars().first() - - if canvas_data: - return canvas_data.data - return None - except Exception as e: - print(f"Error retrieving canvas data: {e}") - return None - -async def get_recent_canvases(user_id: str, limit: int = MAX_BACKUPS_PER_USER) -> List[Dict[str, Any]]: - """Get the most recent canvas backups for a user""" - try: - # Update user's last activity time - user_last_activity_time[user_id] = datetime.now() - - async with async_session() as session: - # Get the most recent backups, limited to MAX_BACKUPS_PER_USER - stmt = select(CanvasBackup).where(CanvasBackup.user_id == user_id).order_by(CanvasBackup.timestamp.desc()).limit(limit) - result = await session.execute(stmt) - backups = result.scalars().all() - - return [{"id": backup.id, "timestamp": backup.timestamp, "data": backup.canvas_data} for backup in backups] - except Exception as e: - print(f"Error retrieving canvas backups: {e}") - return [] diff --git a/src/backend/dependencies.py b/src/backend/dependencies.py index cf86d0e..749c26b 100644 --- a/src/backend/dependencies.py +++ b/src/backend/dependencies.py @@ -1,60 +1,156 @@ -from typing import Optional +import jwt +from typing import Optional, Dict, Any +from uuid import UUID + from fastapi import Request, HTTPException, Depends from config import get_session, is_token_expired, refresh_token +from database.service import UserService +from coder import CoderAPI -class SessionData: - def __init__(self, access_token: str, token_data: dict): +class UserSession: + """ + Unified user session model that integrates authentication data with user information. + This provides a single interface for accessing both token data and user details. + """ + def __init__(self, access_token: str, token_data: dict, user_id: UUID = None): self.access_token = access_token - self.token_data = token_data + self._user_data = None + + # Get the signing key and decode with verification + from config import get_jwks_client, OIDC_CLIENT_ID + try: + jwks_client = get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(access_token) + + self.token_data = jwt.decode( + access_token, + signing_key.key, + algorithms=["RS256"], + audience=OIDC_CLIENT_ID + ) + + except jwt.InvalidTokenError as e: + # Log the error and raise an appropriate exception + print(f"Invalid token: {str(e)}") + raise ValueError(f"Invalid authentication token: {str(e)}") + + @property + def is_authenticated(self) -> bool: + """Check if the session is authenticated""" + return bool(self.access_token and self.id) + + @property + def id(self) -> UUID: + """Get user ID from token data""" + return UUID(self.token_data.get("sub")) + + @property + def email(self) -> str: + """Get user email from token data""" + return self.token_data.get("email", "") + + @property + def email_verified(self) -> bool: + """Get email verification status from token data""" + return self.token_data.get("email_verified", False) + + @property + def username(self) -> str: + """Get username from token data""" + return self.token_data.get("preferred_username", "") + + @property + def name(self) -> str: + """Get full name from token data""" + return self.token_data.get("name", "") + + @property + def given_name(self) -> str: + """Get given name from token data""" + return self.token_data.get("given_name", "") + + @property + def family_name(self) -> str: + """Get family name from token data""" + return self.token_data.get("family_name", "") + + @property + def roles(self) -> list: + """Get user roles from token data""" + return self.token_data.get("realm_access", {}).get("roles", []) + + @property + def is_admin(self) -> bool: + """Check if user has admin role""" + return "admin" in self.roles + + async def get_user_data(self, user_service: UserService) -> Dict[str, Any]: + """Get user data from database, caching the result""" + if self._user_data is None and self.id: + self._user_data = await user_service.get_user(self.id) + return self._user_data class AuthDependency: - def __init__(self, auto_error: bool = True): + """ + Authentication dependency that validates session tokens and provides + a unified UserSession object for route handlers. + """ + def __init__(self, auto_error: bool = True, require_admin: bool = False): self.auto_error = auto_error + self.require_admin = require_admin - async def __call__(self, request: Request) -> Optional[SessionData]: + async def __call__(self, request: Request) -> Optional[UserSession]: + # Get session ID from cookies session_id = request.cookies.get('session_id') + # Handle missing session ID if not session_id: - if self.auto_error: - raise HTTPException( - status_code=401, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - return None + return self._handle_auth_error("Not authenticated") + # Get session data from Redis session = get_session(session_id) if not session: - if self.auto_error: - raise HTTPException( - status_code=401, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - return None + return self._handle_auth_error("Not authenticated") - # Check if token is expired and refresh if needed + # Handle token expiration if is_token_expired(session): # Try to refresh the token success, new_session = await refresh_token(session_id, session) if not success: - # Token refresh failed, user needs to re-authenticate - if self.auto_error: - raise HTTPException( - status_code=401, - detail="Session expired", - headers={"WWW-Authenticate": "Bearer"}, - ) - return None - # Use the refreshed token data + return self._handle_auth_error("Session expired") session = new_session - - return SessionData( + + # Create user session object + user_session = UserSession( access_token=session.get('access_token'), token_data=session ) + + # Check admin requirement if specified + if self.require_admin and not user_session.is_admin: + return self._handle_auth_error("Admin privileges required", status_code=403) + + return user_session + + def _handle_auth_error(self, detail: str, status_code: int = 401) -> Optional[None]: + """Handle authentication errors based on auto_error setting""" + if self.auto_error: + headers = {"WWW-Authenticate": "Bearer"} if status_code == 401 else None + raise HTTPException( + status_code=status_code, + detail=detail, + headers=headers, + ) + return None -# Create instances for use in route handlers +# Create dependency instances for use in route handlers require_auth = AuthDependency(auto_error=True) optional_auth = AuthDependency(auto_error=False) +require_admin = AuthDependency(auto_error=True, require_admin=True) + +def get_coder_api(): + """ + Dependency that provides a CoderAPI instance. + """ + return CoderAPI() diff --git a/src/backend/main.py b/src/backend/main.py index a9dc15f..563ef39 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -1,43 +1,118 @@ import os +import json from contextlib import asynccontextmanager from typing import Optional +import posthog from fastapi import FastAPI, Request, Depends from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from dotenv import load_dotenv -import posthog - -load_dotenv() -POSTHOG_API_KEY = os.environ.get("VITE_PUBLIC_POSTHOG_KEY") -POSTHOG_HOST = os.environ.get("VITE_PUBLIC_POSTHOG_HOST") +from database import init_db +from config import STATIC_DIR, ASSETS_DIR, POSTHOG_API_KEY, POSTHOG_HOST, redis_client, redis_pool +from dependencies import UserSession, optional_auth +from routers.auth_router import auth_router +from routers.user_router import user_router +from routers.workspace_router import workspace_router +from routers.pad_router import pad_router +from routers.template_pad_router import template_pad_router +from database.service import TemplatePadService +from database.database import async_session, run_migrations +# Initialize PostHog if API key is available if POSTHOG_API_KEY: posthog.project_api_key = POSTHOG_API_KEY posthog.host = POSTHOG_HOST -from db import init_db -from config import STATIC_DIR, ASSETS_DIR -from dependencies import SessionData, optional_auth -from routers.auth import auth_router -from routers.canvas import canvas_router -from routers.user import user_router -from routers.workspace import workspace_router +async def load_templates(): + """ + Load all templates from the templates directory into the database if they don't exist. + + This function reads all JSON files in the templates directory, extracts the display name + from the "appState.pad.displayName" field, uses the filename as the name, and stores + the entire JSON as the data. + """ + try: + # Get a session and template service + async with async_session() as session: + template_service = TemplatePadService(session) + + # Get the templates directory path + templates_dir = os.path.join(os.path.dirname(__file__), "templates") + + # Iterate through all JSON files in the templates directory + for filename in os.listdir(templates_dir): + if filename.endswith(".json"): + # Use the filename without extension as the name + name = os.path.splitext(filename)[0] + + # Check if template already exists + existing_template = await template_service.get_template_by_name(name) + + if not existing_template: + + file_path = os.path.join(templates_dir, filename) + + # Read the JSON file + with open(file_path, 'r') as f: + template_data = json.load(f) + + # Extract the display name from the JSON + display_name = template_data.get("appState", {}).get("pad", {}).get("displayName", "Untitled") + + # Create the template if it doesn't exist + await template_service.create_template( + name=name, + display_name=display_name, + data=template_data + ) + print(f"Added template: {name} ({display_name})") + else: + print(f"Template already in database: '{name}'") + + except Exception as e: + print(f"Error loading templates: {str(e)}") @asynccontextmanager async def lifespan(_: FastAPI): + # Initialize database await init_db() print("Database connection established successfully") + + # Run database migrations + try: + await run_migrations() + print("Database migrations completed successfully") + except Exception as e: + print(f"Warning: Failed to run migrations: {str(e)}") + + # Check Redis connection + try: + redis_client.ping() + print("Redis connection established successfully") + except Exception as e: + print(f"Warning: Redis connection failed: {str(e)}") + + # Load all templates from the templates directory + await load_templates() + print("Templates loaded successfully") + yield + + # Clean up connections when shutting down + try: + redis_pool.disconnect() + print("Redis connections closed") + except Exception as e: + print(f"Error closing Redis connections: {str(e)}") app = FastAPI(lifespan=lifespan) # CORS middleware setup app.add_middleware( CORSMiddleware, - allow_origins=["https://kc.pad.ws", "https://alex.pad.ws"], + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -47,14 +122,15 @@ async def lifespan(_: FastAPI): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") @app.get("/") -async def read_root(request: Request, auth: Optional[SessionData] = Depends(optional_auth)): +async def read_root(request: Request, auth: Optional[UserSession] = Depends(optional_auth)): return FileResponse(os.path.join(STATIC_DIR, "index.html")) # Include routers in the main app with the /api prefix app.include_router(auth_router, prefix="/auth") -app.include_router(canvas_router, prefix="/api/canvas") -app.include_router(user_router, prefix="/api/user") +app.include_router(user_router, prefix="/api/users") app.include_router(workspace_router, prefix="/api/workspace") +app.include_router(pad_router, prefix="/api/pad") +app.include_router(template_pad_router, prefix="/api/templates") if __name__ == "__main__": import uvicorn diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index b24ffde..302d8ff 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -9,3 +9,7 @@ requests sqlalchemy posthog redis +psycopg2-binary +python-multipart +cryptography # Required for JWT key handling +alembic \ No newline at end of file diff --git a/src/backend/routers/auth.py b/src/backend/routers/auth_router.py similarity index 80% rename from src/backend/routers/auth.py rename to src/backend/routers/auth_router.py index f97457f..0a2580b 100644 --- a/src/backend/routers/auth.py +++ b/src/backend/routers/auth_router.py @@ -5,12 +5,12 @@ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse import os -from config import get_auth_url, get_token_url, OIDC_CONFIG, set_session, delete_session, STATIC_DIR, get_session -from dependencies import SessionData, require_auth +from config import (get_auth_url, get_token_url, set_session, delete_session, get_session, + FRONTEND_URL, OIDC_CLIENT_ID, OIDC_CLIENT_SECRET, OIDC_SERVER_URL, OIDC_REALM, OIDC_REDIRECT_URI, STATIC_DIR) +from dependencies import get_coder_api from coder import CoderAPI auth_router = APIRouter() -coder_api = CoderAPI() @auth_router.get("/login") async def login(request: Request, kc_idp_hint: str = None, popup: str = None): @@ -32,7 +32,12 @@ async def login(request: Request, kc_idp_hint: str = None, popup: str = None): return response @auth_router.get("/callback") -async def callback(request: Request, code: str, state: str = "default"): +async def callback( + request: Request, + code: str, + state: str = "default", + coder_api: CoderAPI = Depends(get_coder_api) +): session_id = request.cookies.get('session_id') if not session_id: raise HTTPException(status_code=400, detail="No session") @@ -43,10 +48,10 @@ async def callback(request: Request, code: str, state: str = "default"): get_token_url(), data={ 'grant_type': 'authorization_code', - 'client_id': OIDC_CONFIG['client_id'], - 'client_secret': OIDC_CONFIG['client_secret'], + 'client_id': OIDC_CLIENT_ID, + 'client_secret': OIDC_CLIENT_SECRET, 'code': code, - 'redirect_uri': OIDC_CONFIG['redirect_uri'] + 'redirect_uri': OIDC_REDIRECT_URI } ) @@ -87,9 +92,8 @@ async def logout(request: Request): delete_session(session_id) # Create the Keycloak logout URL with redirect back to our app - logout_url = f"{OIDC_CONFIG['server_url']}/realms/{OIDC_CONFIG['realm']}/protocol/openid-connect/logout" - redirect_uri = OIDC_CONFIG['frontend_url'] # Match the frontend redirect URI - full_logout_url = f"{logout_url}?id_token_hint={id_token}&post_logout_redirect_uri={redirect_uri}" + logout_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/logout" + full_logout_url = f"{logout_url}?id_token_hint={id_token}&post_logout_redirect_uri={FRONTEND_URL}" # Create a redirect response to Keycloak's logout endpoint response = JSONResponse({"status": "success", "logout_url": full_logout_url}) diff --git a/src/backend/routers/canvas.py b/src/backend/routers/canvas.py deleted file mode 100644 index 5f20b6e..0000000 --- a/src/backend/routers/canvas.py +++ /dev/null @@ -1,67 +0,0 @@ -import json -import jwt -from typing import Dict, Any, List -from fastapi import APIRouter, HTTPException, Depends, Request -from fastapi.responses import JSONResponse - -from dependencies import SessionData, require_auth -from db import store_canvas_data, get_canvas_data, get_recent_canvases, MAX_BACKUPS_PER_USER -import posthog - -canvas_router = APIRouter() - -def get_default_canvas_data(): - try: - with open("default_canvas.json", "r") as f: - return json.load(f) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Failed to load default canvas: {str(e)}" - ) - -@canvas_router.get("/default") -async def get_default_canvas(auth: SessionData = Depends(require_auth)): - try: - with open("default_canvas.json", "r") as f: - canvas_data = json.load(f) - return canvas_data - except Exception as e: - return JSONResponse( - status_code=500, - content={"error": f"Failed to load default canvas: {str(e)}"} - ) - -@canvas_router.post("") -async def save_canvas(data: Dict[str, Any], auth: SessionData = Depends(require_auth), request: Request = None): - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - user_id = decoded["sub"] - success = await store_canvas_data(user_id, data) - if not success: - raise HTTPException(status_code=500, detail="Failed to save canvas data") - return {"status": "success"} - -@canvas_router.get("") -async def get_canvas(auth: SessionData = Depends(require_auth)): - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - user_id = decoded["sub"] - data = await get_canvas_data(user_id) - if data is None: - return get_default_canvas_data() - return data - -@canvas_router.get("/recent") -async def get_recent_canvas_backups(limit: int = MAX_BACKUPS_PER_USER, auth: SessionData = Depends(require_auth)): - """Get the most recent canvas backups for the authenticated user""" - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - user_id = decoded["sub"] - - # Limit the number of backups to the maximum configured value - if limit > MAX_BACKUPS_PER_USER: - limit = MAX_BACKUPS_PER_USER - - backups = await get_recent_canvases(user_id, limit) - return {"backups": backups} diff --git a/src/backend/routers/pad_router.py b/src/backend/routers/pad_router.py new file mode 100644 index 0000000..96ff9ec --- /dev/null +++ b/src/backend/routers/pad_router.py @@ -0,0 +1,144 @@ +from uuid import UUID +from typing import Dict, Any + +from fastapi import APIRouter, HTTPException, Depends, Request + +from dependencies import UserSession, require_auth +from database import get_pad_service, get_backup_service, get_template_pad_service +from database.service import PadService, BackupService, TemplatePadService +from config import MAX_BACKUPS_PER_USER, MIN_INTERVAL_MINUTES, DEFAULT_PAD_NAME, DEFAULT_TEMPLATE_NAME +pad_router = APIRouter() + + +@pad_router.post("") +async def save_canvas( + data: Dict[str, Any], + user: UserSession = Depends(require_auth), + pad_service: PadService = Depends(get_pad_service), + backup_service: BackupService = Depends(get_backup_service), +): + """Save canvas data for the authenticated user""" + try: + # Check if user already has a pad + user_pads = await pad_service.get_pads_by_owner(user.id) + + if not user_pads: + # Create a new pad if user doesn't have one + pad = await pad_service.create_pad( + owner_id=user.id, + display_name=DEFAULT_PAD_NAME, + data=data + ) + else: + # Update existing pad + pad = user_pads[0] # Use the first pad (assuming one pad per user for now) + await pad_service.update_pad_data(pad["id"], data) + + # Create a backup only if needed (if none exist or latest is > 5 min old) + await backup_service.create_backup_if_needed( + source_id=pad["id"], + data=data, + min_interval_minutes=MIN_INTERVAL_MINUTES, + max_backups=MAX_BACKUPS_PER_USER + ) + + return {"status": "success"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to save canvas data: {str(e)}") + + +@pad_router.get("") +async def get_canvas( + user: UserSession = Depends(require_auth), + pad_service: PadService = Depends(get_pad_service), + template_pad_service: TemplatePadService = Depends(get_template_pad_service), + backup_service: BackupService = Depends(get_backup_service) +): + """Get canvas data for the authenticated user""" + try: + # Get user's pads + user_pads = await pad_service.get_pads_by_owner(user.id) + + if not user_pads: + # Return default canvas if user doesn't have a pad + return await create_pad_from_template( + name=DEFAULT_TEMPLATE_NAME, + display_name=DEFAULT_PAD_NAME, + user=user, + pad_service=pad_service, + template_pad_service=template_pad_service, + backup_service=backup_service + ) + + # Return the first pad's data (assuming one pad per user for now) + return user_pads[0]["data"] + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get canvas data: {str(e)}") + + +@pad_router.post("/from-template/{name}") +async def create_pad_from_template( + name: str, + display_name: str = DEFAULT_PAD_NAME, + user: UserSession = Depends(require_auth), + pad_service: PadService = Depends(get_pad_service), + template_pad_service: TemplatePadService = Depends(get_template_pad_service), + backup_service: BackupService = Depends(get_backup_service) +): + """Create a new pad from a template""" + + try: + # Get the template + template = await template_pad_service.get_template_by_name(name) + if not template: + raise HTTPException(status_code=404, detail="Template not found") + + # Create a new pad using the template data + pad = await pad_service.create_pad( + owner_id=user.id, + display_name=display_name, + data=template["data"] + ) + + # Create an initial backup for the new pad + await backup_service.create_backup_if_needed( + source_id=pad["id"], + data=template["data"], + min_interval_minutes=0, # Always create initial backup + max_backups=MAX_BACKUPS_PER_USER + ) + + return pad + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create pad from template: {str(e)}") + + +@pad_router.get("/recent") +async def get_recent_canvas_backups( + limit: int = MAX_BACKUPS_PER_USER, + user: UserSession = Depends(require_auth), + backup_service: BackupService = Depends(get_backup_service) +): + """Get the most recent canvas backups for the authenticated user""" + # Limit the number of backups to the maximum configured value + if limit > MAX_BACKUPS_PER_USER: + limit = MAX_BACKUPS_PER_USER + + try: + # Get backups directly with a single query + backups_data = await backup_service.get_backups_by_user(user.id, limit) + + # Format backups to match the expected response format + backups = [] + for backup in backups_data: + backups.append({ + "id": backup["id"], + "timestamp": backup["created_at"], + "data": backup["data"] + }) + + return {"backups": backups} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get canvas backups: {str(e)}") diff --git a/src/backend/routers/template_pad_router.py b/src/backend/routers/template_pad_router.py new file mode 100644 index 0000000..a787274 --- /dev/null +++ b/src/backend/routers/template_pad_router.py @@ -0,0 +1,109 @@ +from typing import Dict, Any + +from fastapi import APIRouter, HTTPException, Depends + +from dependencies import UserSession, require_auth, require_admin +from database import get_template_pad_service +from database.service import TemplatePadService + +template_pad_router = APIRouter() + + +@template_pad_router.post("") +async def create_template_pad( + data: Dict[str, Any], + name: str, + display_name: str, + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Create a new template pad (admin only)""" + try: + template_pad = await template_pad_service.create_template( + name=name, + display_name=display_name, + data=data + ) + return template_pad + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@template_pad_router.get("") +async def get_all_template_pads( + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Get all template pads""" + try: + template_pads = await template_pad_service.get_all_templates() + return template_pads + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get template pads: {str(e)}") + + +@template_pad_router.get("/{name}") +async def get_template_pad( + name: str, + _: UserSession = Depends(require_auth), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Get a specific template pad by name""" + template_pad = await template_pad_service.get_template_by_name(name) + if not template_pad: + raise HTTPException(status_code=404, detail="Template pad not found") + + return template_pad + + +@template_pad_router.put("/{name}") +async def update_template_pad( + name: str, + data: Dict[str, Any], + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Update a template pad (admin only)""" + try: + updated_template = await template_pad_service.update_template(name, data) + if not updated_template: + raise HTTPException(status_code=404, detail="Template pad not found") + + return updated_template + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@template_pad_router.put("/{name}/data") +async def update_template_pad_data( + name: str, + data: Dict[str, Any], + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Update just the data field of a template pad (admin only)""" + try: + updated_template = await template_pad_service.update_template_data(name, data) + if not updated_template: + raise HTTPException(status_code=404, detail="Template pad not found") + + return updated_template + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@template_pad_router.delete("/{name}") +async def delete_template_pad( + name: str, + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Delete a template pad (admin only)""" + try: + success = await template_pad_service.delete_template(name) + if not success: + raise HTTPException(status_code=404, detail="Template pad not found") + + return {"status": "success", "message": "Template pad deleted successfully"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/backend/routers/user.py b/src/backend/routers/user.py deleted file mode 100644 index 969da69..0000000 --- a/src/backend/routers/user.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import jwt - -import posthog -from fastapi import APIRouter, Depends, Request -from dotenv import load_dotenv - -from dependencies import SessionData, require_auth -from config import redis_client - -load_dotenv() - -user_router = APIRouter() - -@user_router.get("/me") -async def get_user_info(auth: SessionData = Depends(require_auth), request: Request = None): - token_data = auth.token_data - access_token = token_data.get("access_token") - - decoded = jwt.decode(access_token, options={"verify_signature": False}) - - # Build full URL (mirroring canvas.py logic) - full_url = None - if request: - full_url = str(request.base_url).rstrip("/") + str(request.url.path) - full_url = full_url.replace("http://", "https://") - - user_data: dict = { - "id": decoded["sub"], # Unique user ID - "email": decoded.get("email", ""), - "username": decoded.get("preferred_username", ""), - "name": decoded.get("name", ""), - "given_name": decoded.get("given_name", ""), - "family_name": decoded.get("family_name", ""), - "email_verified": decoded.get("email_verified", False) - } - - if os.getenv("VITE_PUBLIC_POSTHOG_KEY"): - telemetry = user_data | {"$current_url": full_url} - posthog.identify(distinct_id=decoded["sub"], properties=telemetry) - - return user_data - -@user_router.get("/count") -async def get_user_count(auth: SessionData = Depends(require_auth)): - """ - Get the count of active sessions in Redis - - Returns: - dict: A dictionary containing the count of active sessions - """ - # Count keys that match the session pattern - session_count = len(redis_client.keys("session:*")) - - return { - "active_sessions": session_count, - "message": f"There are currently {session_count} active sessions in Redis" - } \ No newline at end of file diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py new file mode 100644 index 0000000..f4d84fe --- /dev/null +++ b/src/backend/routers/user_router.py @@ -0,0 +1,111 @@ +import os +from uuid import UUID + +import posthog +from fastapi import APIRouter, Depends, HTTPException + +from config import get_redis_client, FRONTEND_URL +from database import get_user_service +from database.service import UserService +from dependencies import UserSession, require_admin, require_auth + +user_router = APIRouter() + + +@user_router.post("") +async def create_user( + user_id: UUID, + username: str, + email: str, + email_verified: bool = False, + name: str = None, + given_name: str = None, + family_name: str = None, + roles: list = None, + _: bool = Depends(require_admin), + user_service: UserService = Depends(get_user_service) +): + """Create a new user (admin only)""" + try: + user = await user_service.create_user( + user_id=user_id, + username=username, + email=email, + email_verified=email_verified, + name=name, + given_name=given_name, + family_name=family_name, + roles=roles + ) + return user + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@user_router.get("") +async def get_all_users( + _: bool = Depends(require_admin), + user_service: UserService = Depends(get_user_service) +): + """Get all users (admin only)""" + users = await user_service.get_all_users() + return users + + +@user_router.get("/me") +async def get_user_info( + user: UserSession = Depends(require_auth), + user_service: UserService = Depends(get_user_service), +): + """Get the current user's information and sync with token data""" + + # Create token data dictionary from UserSession properties + token_data = { + "username": user.username, + "email": user.email, + "email_verified": user.email_verified, + "name": user.name, + "given_name": user.given_name, + "family_name": user.family_name, + "roles": user.roles + } + + try: + # Sync user with token data + user_data = await user_service.sync_user_with_token_data(user.id, token_data) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error syncing user data: {e}" + ) + + if os.getenv("VITE_PUBLIC_POSTHOG_KEY"): + telemetry = user_data.copy() + telemetry["$current_url"] = FRONTEND_URL + posthog.identify(distinct_id=user_data["id"], properties=telemetry) + + return user_data + + +@user_router.get("/count") +async def get_user_count( + _: bool = Depends(require_admin), +): + """Get the number of active sessions (admin only)""" + client = get_redis_client() + session_count = len(client.keys("session:*")) + return {"active_sessions": session_count } + + +@user_router.get("/{user_id}") +async def get_user( + user_id: UUID, + _: bool = Depends(require_admin), + user_service: UserService = Depends(get_user_service) +): + """Get a user by ID (admin only)""" + user = await user_service.get_user(user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + return user diff --git a/src/backend/routers/workspace.py b/src/backend/routers/workspace.py deleted file mode 100644 index 56af56f..0000000 --- a/src/backend/routers/workspace.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Dict, Any -from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import JSONResponse -from pydantic import BaseModel -import jwt -import os - -from dependencies import SessionData, require_auth -from coder import CoderAPI - -workspace_router = APIRouter() -coder_api = CoderAPI() - -class WorkspaceState(BaseModel): - state: str - workspace_id: str - username: str - base_url: str - agent: str - -@workspace_router.get("/state", response_model=WorkspaceState) -async def get_workspace_state(auth: SessionData = Depends(require_auth)): - """ - Get the current state of the user's workspace - """ - # Get user info from token - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - username = decoded.get("preferred_username") - email = decoded.get("email") - - # Get user's workspaces - user = coder_api.get_user_by_email(email) - username = user.get('username', None) - if not username: - raise HTTPException(status_code=404, detail="User not found") - - workspace = coder_api.get_workspace_status_for_user(username) - - if not workspace: - raise HTTPException(status_code=404, detail="No workspace found for user") - - #states can be: - #starting - #running - #stopping - #stopped - #error - - return WorkspaceState( - state=workspace.get('latest_build', {}).get('status', 'error'), - workspace_id=workspace.get('latest_build', {}).get('workspace_name', ''), - username=username, - base_url=os.getenv("CODER_URL", ""), - agent="main" - ) - -@workspace_router.post("/start") -async def start_workspace(auth: SessionData = Depends(require_auth)): - """ - Start a workspace for the authenticated user - """ - # Get user info from token - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - email = decoded.get("email") - - user = coder_api.get_user_by_email(email) - username = user.get('username', None) - if not username: - raise HTTPException(status_code=404, detail="User not found") - # Get user's workspace - workspace = coder_api.get_workspace_status_for_user(username) - if not workspace: - raise HTTPException(status_code=404, detail="No workspace found for user") - - # Start the workspace - try: - response = coder_api.start_workspace(workspace["id"]) - return JSONResponse(content=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@workspace_router.post("/stop") -async def stop_workspace(auth: SessionData = Depends(require_auth)): - """ - Stop a workspace for the authenticated user - """ - # Get user info from token - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - email = decoded.get("email") - - user = coder_api.get_user_by_email(email) - username = user.get('username', None) - if not username: - raise HTTPException(status_code=404, detail="User not found") - # Get user's workspace - workspace = coder_api.get_workspace_status_for_user(username) - if not workspace: - raise HTTPException(status_code=404, detail="No workspace found for user") - - # Stop the workspace - try: - response = coder_api.stop_workspace(workspace["id"]) - return JSONResponse(content=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - diff --git a/src/backend/routers/workspace_router.py b/src/backend/routers/workspace_router.py new file mode 100644 index 0000000..0511702 --- /dev/null +++ b/src/backend/routers/workspace_router.py @@ -0,0 +1,89 @@ +import os + +from pydantic import BaseModel +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse + +from dependencies import UserSession, require_auth, get_coder_api +from coder import CoderAPI + +workspace_router = APIRouter() + +class WorkspaceState(BaseModel): + id: str + state: str + name: str + username: str + base_url: str + agent: str + +@workspace_router.get("/state", response_model=WorkspaceState) +async def get_workspace_state( + user: UserSession = Depends(require_auth), + coder_api: CoderAPI = Depends(get_coder_api) +): + """ + Get the current state of the user's workspace + """ + + coder_user: dict = coder_api.get_user_by_email(user.email) + if not coder_user: + raise HTTPException(status_code=404, detail=f"Coder user not found for user {user.email} ({user.id})") + + coder_username: str = coder_user.get('username', None) + if not coder_username: + raise HTTPException(status_code=404, detail=f"Coder username not found for user {user.email} ({user.id})") + + workspace: dict = coder_api.get_workspace_status_for_user(coder_username) + if not workspace: + raise HTTPException(status_code=404, detail=f"Coder Workspace not found for user {user.email} ({user.id})") + + latest_build: dict = workspace.get('latest_build', {}) + latest_build_status: str = latest_build.get('status', 'error') + workspace_name: str = latest_build.get('workspace_name', None) + workspace_id: str = workspace.get('id', {}) + + return WorkspaceState( + state=latest_build_status, + id=workspace_id, + name=workspace_name, + username=coder_username, + base_url=os.getenv("CODER_URL", ""), + agent="main" + ) + + +@workspace_router.post("/start") +async def start_workspace( + user: UserSession = Depends(require_auth), + coder_api: CoderAPI = Depends(get_coder_api) +): + """ + Start a workspace for the authenticated user + """ + + workspace: WorkspaceState = await get_workspace_state(user, coder_api) + + try: + response = coder_api.start_workspace(workspace.id) + return JSONResponse(content=response) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@workspace_router.post("/stop") +async def stop_workspace( + user: UserSession = Depends(require_auth), + coder_api: CoderAPI = Depends(get_coder_api) +): + """ + Stop a workspace for the authenticated user + """ + + workspace: WorkspaceState = await get_workspace_state(user, coder_api) + + try: + response = coder_api.stop_workspace(workspace.id) + return JSONResponse(content=response) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/backend/default_canvas.json b/src/backend/templates/default.json similarity index 99% rename from src/backend/default_canvas.json rename to src/backend/templates/default.json index a73e469..0629611 100644 --- a/src/backend/default_canvas.json +++ b/src/backend/templates/default.json @@ -1,6 +1,9 @@ { "files": {}, "appState": { + "pad": { + "displayName": "Welcome!" + }, "name": "Pad.ws", "zoom": { "value": 1 diff --git a/src/frontend/src/api/hooks.ts b/src/frontend/src/api/hooks.ts index 58d198f..7c8b0d2 100644 --- a/src/frontend/src/api/hooks.ts +++ b/src/frontend/src/api/hooks.ts @@ -5,10 +5,11 @@ import { queryClient } from './queryClient'; // Types export interface WorkspaceState { status: 'running' | 'starting' | 'stopping' | 'stopped' | 'error'; - workspace_id: string | null; username: string | null; + name: string | null; base_url: string | null; agent: string | null; + id: string | null; error?: string; } @@ -56,7 +57,7 @@ export const api = { // User profile getUserProfile: async (): Promise => { try { - const result = await fetchApi('/api/user/me'); + const result = await fetchApi('/api/users/me'); return result; } catch (error) { throw error; @@ -96,7 +97,7 @@ export const api = { // Canvas getCanvas: async (): Promise => { try { - const result = await fetchApi('/api/canvas'); + const result = await fetchApi('/api/pad'); return result; } catch (error) { throw error; @@ -105,7 +106,7 @@ export const api = { saveCanvas: async (data: CanvasData): Promise => { try { - const result = await fetchApi('/api/canvas', { + const result = await fetchApi('/api/pad', { method: 'POST', body: JSON.stringify(data), }); @@ -117,7 +118,7 @@ export const api = { getDefaultCanvas: async (): Promise => { try { - const result = await fetchApi('/api/canvas/default'); + const result = await fetchApi('/api/pad/from-template/default'); return result; } catch (error) { throw error; @@ -127,7 +128,7 @@ export const api = { // Canvas Backups getCanvasBackups: async (limit: number = 10): Promise => { try { - const result = await fetchApi(`/api/canvas/recent?limit=${limit}`); + const result = await fetchApi(`/api/pad/recent?limit=${limit}`); return result; } catch (error) { throw error; diff --git a/src/frontend/src/pad/buttons/ActionButton.tsx b/src/frontend/src/pad/buttons/ActionButton.tsx index 91118f5..c752e05 100644 --- a/src/frontend/src/pad/buttons/ActionButton.tsx +++ b/src/frontend/src/pad/buttons/ActionButton.tsx @@ -285,7 +285,7 @@ const ActionButton: React.FC = ({ return ''; } - return `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.workspace_id}.${workspaceState.agent}/apps/code-server`; + return `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.name}.${workspaceState.agent}/apps/code-server`; }; // Placement logic has been moved to ExcalidrawElementFactory.placeInScene @@ -345,7 +345,7 @@ const ActionButton: React.FC = ({ return; } - const terminalUrl = `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.workspace_id}.${workspaceState.agent}/terminal`; + const terminalUrl = `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.name}.${workspaceState.agent}/terminal`; console.debug(`[pad.ws] Opening terminal in new tab: ${terminalUrl}`); window.open(terminalUrl, '_blank'); } else { @@ -366,7 +366,7 @@ const ActionButton: React.FC = ({ } const owner = workspaceState.username; - const workspace = workspaceState.workspace_id; + const workspace = workspaceState.name; const url = workspaceState.base_url; const agent = workspaceState.agent; diff --git a/src/frontend/src/pad/containers/Terminal.tsx b/src/frontend/src/pad/containers/Terminal.tsx index 9460bba..64195ff 100644 --- a/src/frontend/src/pad/containers/Terminal.tsx +++ b/src/frontend/src/pad/containers/Terminal.tsx @@ -56,7 +56,7 @@ export const Terminal: React.FC = ({ terminalId, baseUrl: workspaceState.base_url, username: workspaceState.username, - workspaceId: workspaceState.workspace_id, + workspaceId: workspaceState.name, agent: workspaceState.agent }; @@ -162,7 +162,7 @@ export const Terminal: React.FC = ({ return ''; } - const baseUrl = `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.workspace_id}.${workspaceState.agent}/terminal`; + const baseUrl = `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.name}.${workspaceState.agent}/terminal`; // Add reconnect parameter if terminal ID exists if (terminalId) {