From 204b794c506b19e0687966bde6627f444d2494e6 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 09:34:15 +0000 Subject: [PATCH 01/32] feat: add BackupModel, PadModel, and UserModel for database schema - Introduced BackupModel to manage backups with relationships to PadModel. - Created PadModel to represent pads with relationships to UserModel and BackupModel. - Added UserModel to handle user data and relationships with PadModel. - Each model includes fields for UUID, timestamps, and relevant data types, utilizing SQLAlchemy for ORM functionality. --- src/backend/database/models/backup_model.py | 24 +++++++++++++++++++ src/backend/database/models/pad_model.py | 26 +++++++++++++++++++++ src/backend/database/models/user_model.py | 23 ++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 src/backend/database/models/backup_model.py create mode 100644 src/backend/database/models/pad_model.py create mode 100644 src/backend/database/models/user_model.py diff --git a/src/backend/database/models/backup_model.py b/src/backend/database/models/backup_model.py new file mode 100644 index 0000000..6a1a6ba --- /dev/null +++ b/src/backend/database/models/backup_model.py @@ -0,0 +1,24 @@ +from uuid import uuid4 + +from sqlalchemy import Column, String, DateTime, func, UUID, ForeignKey +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import DeclarativeMeta, relationship +from sqlalchemy.ext.declarative import declarative_base + +Base: DeclarativeMeta = declarative_base() + +class BackupModel(Base): + """Model for backups table in app schema""" + __tablename__ = "backups" + __table_args__ = {"schema": "padws"} + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + source_id = Column(UUID(as_uuid=True), ForeignKey("padws.pads.id"), nullable=False) + data = Column(JSONB, nullable=False) + + pad = relationship("PadModel", back_populates="backups") + + def __repr__(self): + return f"" diff --git a/src/backend/database/models/pad_model.py b/src/backend/database/models/pad_model.py new file mode 100644 index 0000000..dc2d7d4 --- /dev/null +++ b/src/backend/database/models/pad_model.py @@ -0,0 +1,26 @@ +from uuid import uuid4 + +from sqlalchemy import Column, String, DateTime, func, UUID, ForeignKey +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import DeclarativeMeta, relationship +from sqlalchemy.ext.declarative import declarative_base + +Base: DeclarativeMeta = declarative_base() + +class PadModel(Base): + """Model for pads table in app schema""" + __tablename__ = "pads" + __table_args__ = {"schema": "padws"} + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + owner_id = Column(UUID(as_uuid=True), ForeignKey("padws.users.id"), nullable=False) + display_name = Column(String, nullable=False) + data = Column(JSONB, nullable=False) + + owner = relationship("UserModel", back_populates="pads") + backups = relationship("BackupModel", back_populates="pad", cascade="all, delete-orphan") + + def __repr__(self): + return f"" diff --git a/src/backend/database/models/user_model.py b/src/backend/database/models/user_model.py new file mode 100644 index 0000000..798748a --- /dev/null +++ b/src/backend/database/models/user_model.py @@ -0,0 +1,23 @@ +from uuid import uuid4 + +from sqlalchemy import Column, String, DateTime, func, UUID +from sqlalchemy.orm import DeclarativeMeta, relationship +from sqlalchemy.ext.declarative import declarative_base + +Base: DeclarativeMeta = declarative_base() + +class UserModel(Base): + """Model for users table in app schema""" + __tablename__ = "users" + __table_args__ = {"schema": "padws"} + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + username = Column(String, nullable=False, unique=True) + email = Column(String, nullable=False) + + pads = relationship("PadModel", back_populates="owner", cascade="all, delete-orphan") + + def __repr__(self): + return f"" From dd234bb1404ee30e23ee9b82f09b72460b611ddc Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 09:53:40 +0000 Subject: [PATCH 02/32] feat: introduce base model and refactor existing models for database schema - Added BaseModel to centralize common fields and schema configuration for all models. - Refactored BackupModel, PadModel, and UserModel to inherit from BaseModel, enhancing code reusability. - Updated relationships and added indexing for improved query performance. - Implemented to_dict methods in models for easier data serialization. --- src/backend/database/models/__init__.py | 18 +++++++ src/backend/database/models/backup_model.py | 46 +++++++++++------ src/backend/database/models/base_model.py | 31 ++++++++++++ src/backend/database/models/pad_model.py | 55 +++++++++++++++------ src/backend/database/models/user_model.py | 37 ++++++++------ 5 files changed, 142 insertions(+), 45 deletions(-) create mode 100644 src/backend/database/models/__init__.py create mode 100644 src/backend/database/models/base_model.py diff --git a/src/backend/database/models/__init__.py b/src/backend/database/models/__init__.py new file mode 100644 index 0000000..040f570 --- /dev/null +++ b/src/backend/database/models/__init__.py @@ -0,0 +1,18 @@ +""" +Database models for the application. + +This module provides access to all database models used in the application. +""" + +from .base_model import Base, BaseModel +from .user_model import UserModel +from .pad_model import PadModel +from .backup_model import BackupModel + +__all__ = [ + 'Base', + 'BaseModel', + 'UserModel', + 'PadModel', + 'BackupModel', +] diff --git a/src/backend/database/models/backup_model.py b/src/backend/database/models/backup_model.py index 6a1a6ba..3563628 100644 --- a/src/backend/database/models/backup_model.py +++ b/src/backend/database/models/backup_model.py @@ -1,24 +1,42 @@ -from uuid import uuid4 +from typing import Dict, Any, Optional +from uuid import UUID as UUIDType -from sqlalchemy import Column, String, DateTime, func, UUID, ForeignKey +from sqlalchemy import Column, ForeignKey, Index from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import DeclarativeMeta, relationship -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship, Mapped -Base: DeclarativeMeta = declarative_base() +from backend.database.models.pad_model import PadModel -class BackupModel(Base): +from .base_model import Base, BaseModel + +class BackupModel(Base, BaseModel): """Model for backups table in app schema""" __tablename__ = "backups" - __table_args__ = {"schema": "padws"} + __table_args__ = ( + BaseModel.get_schema(), + Index("ix_backups_source_id", "source_id"), + Index("ix_backups_created_at", "created_at") + ) - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) - source_id = Column(UUID(as_uuid=True), ForeignKey("padws.pads.id"), nullable=False) + # Backup-specific fields + source_id = Column( + UUIDType(as_uuid=True), + ForeignKey(f"{BaseModel.get_schema()['schema']}.pads.id", ondelete="CASCADE"), + nullable=False + ) data = Column(JSONB, nullable=False) - pad = relationship("PadModel", back_populates="backups") + # Relationships + pad: Mapped["PadModel"] = relationship("PadModel", back_populates="backups") - def __repr__(self): - return f"" + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + """Convert model instance to dictionary with additional fields""" + result = super().to_dict() + # Convert data to dict if it's not already + if isinstance(result["data"], str): + import json + result["data"] = json.loads(result["data"]) + return result diff --git a/src/backend/database/models/base_model.py b/src/backend/database/models/base_model.py new file mode 100644 index 0000000..35ef728 --- /dev/null +++ b/src/backend/database/models/base_model.py @@ -0,0 +1,31 @@ +from uuid import uuid4 +from datetime import datetime +from typing import Any, Dict + +from sqlalchemy import Column, DateTime, UUID, func +from sqlalchemy.orm import DeclarativeMeta, declarative_base + +# Create a single shared Base for all models +Base: DeclarativeMeta = declarative_base() + +# Define schema name in a central location +SCHEMA_NAME = "padws" + +class BaseModel: + """Base model with common fields for all models""" + + # Primary key using UUID + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, index=True) + + # Timestamps for creation and updates + created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + + def to_dict(self) -> Dict[str, Any]: + """Convert model instance to dictionary""" + return {c.name: getattr(self, c.name) for c in self.__table__.columns} + + @classmethod + def get_schema(cls) -> Dict[str, Any]: + """Return schema configuration for the model""" + return {"schema": SCHEMA_NAME} diff --git a/src/backend/database/models/pad_model.py b/src/backend/database/models/pad_model.py index dc2d7d4..2fc4cf7 100644 --- a/src/backend/database/models/pad_model.py +++ b/src/backend/database/models/pad_model.py @@ -1,26 +1,49 @@ -from uuid import uuid4 +from typing import List, Dict, Any, Optional +from uuid import UUID as UUIDType -from sqlalchemy import Column, String, DateTime, func, UUID, ForeignKey +from sqlalchemy import Column, String, ForeignKey, Index from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import DeclarativeMeta, relationship -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship, Mapped -Base: DeclarativeMeta = declarative_base() +from .base_model import Base, BaseModel +from .backup_model import BackupModel +from .user_model import UserModel -class PadModel(Base): +class PadModel(Base, BaseModel): """Model for pads table in app schema""" __tablename__ = "pads" - __table_args__ = {"schema": "padws"} + __table_args__ = ( + BaseModel.get_schema(), + Index("ix_pads_owner_id", "owner_id"), + Index("ix_pads_display_name", "display_name") + ) - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) - owner_id = Column(UUID(as_uuid=True), ForeignKey("padws.users.id"), nullable=False) - display_name = Column(String, nullable=False) + # Pad-specific fields + owner_id = Column( + UUIDType(as_uuid=True), + ForeignKey(f"{BaseModel.get_schema()['schema']}.users.id", ondelete="CASCADE"), + nullable=False + ) + display_name = Column(String(100), nullable=False) data = Column(JSONB, nullable=False) - owner = relationship("UserModel", back_populates="pads") - backups = relationship("BackupModel", back_populates="pad", cascade="all, delete-orphan") + # Relationships + owner: Mapped["UserModel"] = relationship("UserModel", back_populates="pads") + backups: Mapped[List["BackupModel"]] = relationship( + "BackupModel", + back_populates="pad", + cascade="all, delete-orphan", + lazy="selectin" + ) - def __repr__(self): - return f"" + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> Dict[str, Any]: + """Convert model instance to dictionary with additional fields""" + result = super().to_dict() + # Convert data to dict if it's not already + if isinstance(result["data"], str): + import json + result["data"] = json.loads(result["data"]) + return result diff --git a/src/backend/database/models/user_model.py b/src/backend/database/models/user_model.py index 798748a..d4dd60a 100644 --- a/src/backend/database/models/user_model.py +++ b/src/backend/database/models/user_model.py @@ -1,23 +1,30 @@ -from uuid import uuid4 +from typing import List +from sqlalchemy import Column, Index, VARCHAR +from sqlalchemy.orm import relationship, Mapped -from sqlalchemy import Column, String, DateTime, func, UUID -from sqlalchemy.orm import DeclarativeMeta, relationship -from sqlalchemy.ext.declarative import declarative_base +from .base_model import Base, BaseModel +from .pad_model import PadModel -Base: DeclarativeMeta = declarative_base() - -class UserModel(Base): +class UserModel(Base, BaseModel): """Model for users table in app schema""" __tablename__ = "users" - __table_args__ = {"schema": "padws"} + __table_args__ = ( + BaseModel.get_schema(), + Index("ix_users_username", "username"), + Index("ix_users_email", "email") + ) - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) - username = Column(String, nullable=False, unique=True) - email = Column(String, nullable=False) + # User-specific fields + username = Column(VARCHAR(254), nullable=False, unique=True) + email = Column(VARCHAR(254), nullable=False) - pads = relationship("PadModel", back_populates="owner", cascade="all, delete-orphan") + # Relationships + pads: Mapped[List["PadModel"]] = relationship( + "PadModel", + back_populates="owner", + cascade="all, delete-orphan", + lazy="selectin" + ) - def __repr__(self): + def __repr__(self) -> str: return f"" From cb0658de024463d3c75b1638592185f79d11bb8b Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 09:53:50 +0000 Subject: [PATCH 03/32] feat: implement repository modules for database operations - Added a new repository module for database operations, including UserRepository, PadRepository, and BackupRepository. - Each repository provides methods for creating, retrieving, updating, and deleting records related to users, pads, and backups. - Enhanced modularity and organization of database interactions within the application. --- src/backend/database/repository/__init__.py | 15 +++ .../database/repository/backup_repository.py | 91 +++++++++++++++++++ .../database/repository/pad_repository.py | 66 ++++++++++++++ .../database/repository/user_repository.py | 65 +++++++++++++ 4 files changed, 237 insertions(+) create mode 100644 src/backend/database/repository/__init__.py create mode 100644 src/backend/database/repository/backup_repository.py create mode 100644 src/backend/database/repository/pad_repository.py create mode 100644 src/backend/database/repository/user_repository.py diff --git a/src/backend/database/repository/__init__.py b/src/backend/database/repository/__init__.py new file mode 100644 index 0000000..1e99327 --- /dev/null +++ b/src/backend/database/repository/__init__.py @@ -0,0 +1,15 @@ +""" +Repository module for database operations. + +This module provides access to all repositories used for database operations. +""" + +from .user_repository import UserRepository +from .pad_repository import PadRepository +from .backup_repository import BackupRepository + +__all__ = [ + 'UserRepository', + 'PadRepository', + 'BackupRepository', +] diff --git a/src/backend/database/repository/backup_repository.py b/src/backend/database/repository/backup_repository.py new file mode 100644 index 0000000..bfba95b --- /dev/null +++ b/src/backend/database/repository/backup_repository.py @@ -0,0 +1,91 @@ +""" +Backup repository for database operations related to backups. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID +from datetime import datetime + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import delete, func + +from ..models import BackupModel + +class BackupRepository: + """Repository for backup-related database operations""" + + def __init__(self, session: AsyncSession): + """Initialize the repository with a database session""" + self.session = session + + async def create(self, source_id: UUID, data: Dict[str, Any]) -> BackupModel: + """Create a new backup""" + backup = BackupModel(source_id=source_id, data=data) + self.session.add(backup) + await self.session.commit() + await self.session.refresh(backup) + return backup + + async def get_by_id(self, backup_id: UUID) -> Optional[BackupModel]: + """Get a backup by ID""" + stmt = select(BackupModel).where(BackupModel.id == backup_id) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_source(self, source_id: UUID) -> List[BackupModel]: + """Get all backups for a specific source pad""" + stmt = select(BackupModel).where(BackupModel.source_id == source_id).order_by(BackupModel.created_at.desc()) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def get_latest_by_source(self, source_id: UUID) -> Optional[BackupModel]: + """Get the most recent backup for a specific source pad""" + stmt = select(BackupModel).where(BackupModel.source_id == source_id).order_by(BackupModel.created_at.desc()).limit(1) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_date_range(self, source_id: UUID, start_date: datetime, end_date: datetime) -> List[BackupModel]: + """Get backups for a specific source pad within a date range""" + stmt = select(BackupModel).where( + BackupModel.source_id == source_id, + BackupModel.created_at >= start_date, + BackupModel.created_at <= end_date + ).order_by(BackupModel.created_at.desc()) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def delete(self, backup_id: UUID) -> bool: + """Delete a backup""" + stmt = delete(BackupModel).where(BackupModel.id == backup_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount > 0 + + async def delete_older_than(self, source_id: UUID, keep_count: int) -> int: + """Delete older backups, keeping only the most recent ones""" + # Get the created_at timestamp of the backup at position keep_count + subquery = select(BackupModel.created_at).where( + BackupModel.source_id == source_id + ).order_by(BackupModel.created_at.desc()).offset(keep_count).limit(1) + + result = await self.session.execute(subquery) + cutoff_date = result.scalar() + + if not cutoff_date: + return 0 # Not enough backups to delete any + + # Delete backups older than the cutoff date + stmt = delete(BackupModel).where( + BackupModel.source_id == source_id, + BackupModel.created_at < cutoff_date + ) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount + + async def count_by_source(self, source_id: UUID) -> int: + """Count the number of backups for a specific source pad""" + stmt = select(func.count()).select_from(BackupModel).where(BackupModel.source_id == source_id) + result = await self.session.execute(stmt) + return result.scalar() diff --git a/src/backend/database/repository/pad_repository.py b/src/backend/database/repository/pad_repository.py new file mode 100644 index 0000000..7b8123a --- /dev/null +++ b/src/backend/database/repository/pad_repository.py @@ -0,0 +1,66 @@ +""" +Pad repository for database operations related to pads. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update, delete + +from ..models import PadModel + +class PadRepository: + """Repository for pad-related database operations""" + + def __init__(self, session: AsyncSession): + """Initialize the repository with a database session""" + self.session = session + + async def create(self, owner_id: UUID, display_name: str, data: Dict[str, Any]) -> PadModel: + """Create a new pad""" + pad = PadModel(owner_id=owner_id, display_name=display_name, data=data) + self.session.add(pad) + await self.session.commit() + await self.session.refresh(pad) + return pad + + async def get_by_id(self, pad_id: UUID) -> Optional[PadModel]: + """Get a pad by ID""" + stmt = select(PadModel).where(PadModel.id == pad_id) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_owner(self, owner_id: UUID) -> List[PadModel]: + """Get all pads for a specific owner""" + stmt = select(PadModel).where(PadModel.owner_id == owner_id) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def get_by_name(self, owner_id: UUID, display_name: str) -> Optional[PadModel]: + """Get a pad by owner and display name""" + stmt = select(PadModel).where( + PadModel.owner_id == owner_id, + PadModel.display_name == display_name + ) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def update(self, pad_id: UUID, data: Dict[str, Any]) -> Optional[PadModel]: + """Update a pad""" + stmt = update(PadModel).where(PadModel.id == pad_id).values(**data).returning(PadModel) + result = await self.session.execute(stmt) + await self.session.commit() + return result.scalars().first() + + async def update_data(self, pad_id: UUID, pad_data: Dict[str, Any]) -> Optional[PadModel]: + """Update just the data field of a pad""" + return await self.update(pad_id, {"data": pad_data}) + + async def delete(self, pad_id: UUID) -> bool: + """Delete a pad""" + stmt = delete(PadModel).where(PadModel.id == pad_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount > 0 diff --git a/src/backend/database/repository/user_repository.py b/src/backend/database/repository/user_repository.py new file mode 100644 index 0000000..8253d69 --- /dev/null +++ b/src/backend/database/repository/user_repository.py @@ -0,0 +1,65 @@ +""" +User repository for database operations related to users. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update, delete + +from ..models import UserModel + +class UserRepository: + """Repository for user-related database operations""" + + def __init__(self, session: AsyncSession): + """Initialize the repository with a database session""" + self.session = session + + async def create(self, username: str, email: str) -> UserModel: + """Create a new user""" + user = UserModel(username=username, email=email) + self.session.add(user) + await self.session.commit() + await self.session.refresh(user) + return user + + async def get_by_id(self, user_id: UUID) -> Optional[UserModel]: + """Get a user by ID""" + stmt = select(UserModel).where(UserModel.id == user_id) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_username(self, username: str) -> Optional[UserModel]: + """Get a user by username""" + stmt = select(UserModel).where(UserModel.username == username) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_email(self, email: str) -> Optional[UserModel]: + """Get a user by email""" + stmt = select(UserModel).where(UserModel.email == email) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_all(self) -> List[UserModel]: + """Get all users""" + stmt = select(UserModel) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def update(self, user_id: UUID, data: Dict[str, Any]) -> Optional[UserModel]: + """Update a user""" + stmt = update(UserModel).where(UserModel.id == user_id).values(**data).returning(UserModel) + result = await self.session.execute(stmt) + await self.session.commit() + return result.scalars().first() + + async def delete(self, user_id: UUID) -> bool: + """Delete a user""" + stmt = delete(UserModel).where(UserModel.id == user_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount > 0 From a12827f7bc5009c91cb9f39d30d13afd41394141 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 09:54:04 +0000 Subject: [PATCH 04/32] feat: add service modules for user, pad, and backup management - Introduced UserService, PadService, and BackupService to handle business logic related to users, pads, and backups. - Each service includes methods for creating, retrieving, updating, and deleting records, enhancing modularity and organization of the application. - Added an __init__.py file to facilitate service module imports. --- src/backend/database/service/__init__.py | 15 +++ .../database/service/backup_service.py | 103 ++++++++++++++++++ src/backend/database/service/pad_service.py | 103 ++++++++++++++++++ src/backend/database/service/user_service.py | 87 +++++++++++++++ 4 files changed, 308 insertions(+) create mode 100644 src/backend/database/service/__init__.py create mode 100644 src/backend/database/service/backup_service.py create mode 100644 src/backend/database/service/pad_service.py create mode 100644 src/backend/database/service/user_service.py diff --git a/src/backend/database/service/__init__.py b/src/backend/database/service/__init__.py new file mode 100644 index 0000000..9fb7bd2 --- /dev/null +++ b/src/backend/database/service/__init__.py @@ -0,0 +1,15 @@ +""" +Service module for business logic. + +This module provides access to all services used for business logic operations. +""" + +from .user_service import UserService +from .pad_service import PadService +from .backup_service import BackupService + +__all__ = [ + 'UserService', + 'PadService', + 'BackupService', +] diff --git a/src/backend/database/service/backup_service.py b/src/backend/database/service/backup_service.py new file mode 100644 index 0000000..46f404f --- /dev/null +++ b/src/backend/database/service/backup_service.py @@ -0,0 +1,103 @@ +""" +Backup service for business logic related to backups. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID +from datetime import datetime + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..repository import BackupRepository, PadRepository + +class BackupService: + """Service for backup-related business logic""" + + def __init__(self, session: AsyncSession): + """Initialize the service with a database session""" + self.session = session + self.repository = BackupRepository(session) + self.pad_repository = PadRepository(session) + + async def create_backup(self, source_id: UUID, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new backup""" + # Validate input + if not data: + raise ValueError("Backup data is required") + + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + # Create backup + backup = await self.repository.create(source_id, data) + return backup.to_dict() + + async def get_backup(self, backup_id: UUID) -> Optional[Dict[str, Any]]: + """Get a backup by ID""" + backup = await self.repository.get_by_id(backup_id) + return backup.to_dict() if backup else None + + async def get_backups_by_source(self, source_id: UUID) -> List[Dict[str, Any]]: + """Get all backups for a specific source pad""" + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + backups = await self.repository.get_by_source(source_id) + return [backup.to_dict() for backup in backups] + + async def get_latest_backup(self, source_id: UUID) -> Optional[Dict[str, Any]]: + """Get the most recent backup for a specific source pad""" + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + backup = await self.repository.get_latest_by_source(source_id) + return backup.to_dict() if backup else None + + async def get_backups_by_date_range(self, source_id: UUID, start_date: datetime, end_date: datetime) -> List[Dict[str, Any]]: + """Get backups for a specific source pad within a date range""" + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + # Validate date range + if start_date > end_date: + raise ValueError("Start date must be before end date") + + backups = await self.repository.get_by_date_range(source_id, start_date, end_date) + return [backup.to_dict() for backup in backups] + + async def delete_backup(self, backup_id: UUID) -> bool: + """Delete a backup""" + # Get the backup to check if it exists + backup = await self.repository.get_by_id(backup_id) + if not backup: + raise ValueError(f"Backup with ID '{backup_id}' does not exist") + + return await self.repository.delete(backup_id) + + async def manage_backups(self, source_id: UUID, max_backups: int = 10) -> int: + """Manage backups for a source pad, keeping only the most recent ones""" + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + # Validate max_backups + if max_backups < 1: + raise ValueError("Maximum number of backups must be at least 1") + + # Count current backups + backup_count = await self.repository.count_by_source(source_id) + + # If we have more backups than the maximum, delete the oldest ones + if backup_count > max_backups: + return await self.repository.delete_older_than(source_id, max_backups) + + return 0 # No backups deleted diff --git a/src/backend/database/service/pad_service.py b/src/backend/database/service/pad_service.py new file mode 100644 index 0000000..1cf6566 --- /dev/null +++ b/src/backend/database/service/pad_service.py @@ -0,0 +1,103 @@ +""" +Pad service for business logic related to pads. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..repository import PadRepository, UserRepository + +class PadService: + """Service for pad-related business logic""" + + def __init__(self, session: AsyncSession): + """Initialize the service with a database session""" + self.session = session + self.repository = PadRepository(session) + self.user_repository = UserRepository(session) + + async def create_pad(self, owner_id: UUID, display_name: str, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new pad""" + # Validate input + if not display_name: + raise ValueError("Display name is required") + + if not data: + raise ValueError("Pad data is required") + + # Check if owner exists + owner = await self.user_repository.get_by_id(owner_id) + if not owner: + raise ValueError(f"User with ID '{owner_id}' does not exist") + + # Check if pad with same name already exists for this owner + existing_pad = await self.repository.get_by_name(owner_id, display_name) + if existing_pad: + raise ValueError(f"Pad with name '{display_name}' already exists for this user") + + # Create pad + pad = await self.repository.create(owner_id, display_name, data) + return pad.to_dict() + + async def get_pad(self, pad_id: UUID) -> Optional[Dict[str, Any]]: + """Get a pad by ID""" + pad = await self.repository.get_by_id(pad_id) + return pad.to_dict() if pad else None + + async def get_pads_by_owner(self, owner_id: UUID) -> List[Dict[str, Any]]: + """Get all pads for a specific owner""" + # Check if owner exists + owner = await self.user_repository.get_by_id(owner_id) + if not owner: + raise ValueError(f"User with ID '{owner_id}' does not exist") + + pads = await self.repository.get_by_owner(owner_id) + return [pad.to_dict() for pad in pads] + + async def get_pad_by_name(self, owner_id: UUID, display_name: str) -> Optional[Dict[str, Any]]: + """Get a pad by owner and display name""" + pad = await self.repository.get_by_name(owner_id, display_name) + return pad.to_dict() if pad else None + + async def update_pad(self, pad_id: UUID, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update a pad""" + # Get the pad to check if it exists + pad = await self.repository.get_by_id(pad_id) + if not pad: + raise ValueError(f"Pad with ID '{pad_id}' does not exist") + + # Validate display_name if it's being updated + if 'display_name' in data and not data['display_name']: + raise ValueError("Display name cannot be empty") + + # Check if new display_name already exists for this owner (if being updated) + if 'display_name' in data and data['display_name'] != pad.display_name: + existing_pad = await self.repository.get_by_name(pad.owner_id, data['display_name']) + if existing_pad: + raise ValueError(f"Pad with name '{data['display_name']}' already exists for this user") + + # Update pad + updated_pad = await self.repository.update(pad_id, data) + return updated_pad.to_dict() if updated_pad else None + + async def update_pad_data(self, pad_id: UUID, pad_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update just the data field of a pad""" + # Get the pad to check if it exists + pad = await self.repository.get_by_id(pad_id) + if not pad: + raise ValueError(f"Pad with ID '{pad_id}' does not exist") + + # Update pad data + updated_pad = await self.repository.update_data(pad_id, pad_data) + return updated_pad.to_dict() if updated_pad else None + + async def delete_pad(self, pad_id: UUID) -> bool: + """Delete a pad""" + # Get the pad to check if it exists + pad = await self.repository.get_by_id(pad_id) + if not pad: + raise ValueError(f"Pad with ID '{pad_id}' does not exist") + + return await self.repository.delete(pad_id) diff --git a/src/backend/database/service/user_service.py b/src/backend/database/service/user_service.py new file mode 100644 index 0000000..5b66f93 --- /dev/null +++ b/src/backend/database/service/user_service.py @@ -0,0 +1,87 @@ +""" +User service for business logic related to users. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..repository import UserRepository + +class UserService: + """Service for user-related business logic""" + + def __init__(self, session: AsyncSession): + """Initialize the service with a database session""" + self.session = session + self.repository = UserRepository(session) + + async def create_user(self, username: str, email: str) -> Dict[str, Any]: + """Create a new user""" + # Validate input + if not username or not email: + raise ValueError("Username and email are required") + + # Check if username already exists + existing_user = await self.repository.get_by_username(username) + if existing_user: + raise ValueError(f"Username '{username}' is already taken") + + # Check if email already exists + existing_email = await self.repository.get_by_email(email) + if existing_email: + raise ValueError(f"Email '{email}' is already registered") + + # Create user + user = await self.repository.create(username, email) + return user.to_dict() + + async def get_user(self, user_id: UUID) -> Optional[Dict[str, Any]]: + """Get a user by ID""" + user = await self.repository.get_by_id(user_id) + return user.to_dict() if user else None + + async def get_user_by_username(self, username: str) -> Optional[Dict[str, Any]]: + """Get a user by username""" + user = await self.repository.get_by_username(username) + return user.to_dict() if user else None + + async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: + """Get a user by email""" + user = await self.repository.get_by_email(email) + return user.to_dict() if user else None + + async def get_all_users(self) -> List[Dict[str, Any]]: + """Get all users""" + users = await self.repository.get_all() + return [user.to_dict() for user in users] + + async def update_user(self, user_id: UUID, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update a user""" + # Validate input + if 'username' in data and not data['username']: + raise ValueError("Username cannot be empty") + + if 'email' in data and not data['email']: + raise ValueError("Email cannot be empty") + + # Check if username already exists (if being updated) + if 'username' in data: + existing_user = await self.repository.get_by_username(data['username']) + if existing_user and existing_user.id != user_id: + raise ValueError(f"Username '{data['username']}' is already taken") + + # Check if email already exists (if being updated) + if 'email' in data: + existing_email = await self.repository.get_by_email(data['email']) + if existing_email and existing_email.id != user_id: + raise ValueError(f"Email '{data['email']}' is already registered") + + # Update user + user = await self.repository.update(user_id, data) + return user.to_dict() if user else None + + async def delete_user(self, user_id: UUID) -> bool: + """Delete a user""" + return await self.repository.delete(user_id) From 6bc027cba6c57ce6e02aabf397f6de8a6ea7561d Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 09:54:28 +0000 Subject: [PATCH 05/32] feat: implement database module with async support - Added a new database module to manage database connections and session handling using SQLAlchemy with async capabilities. - Introduced init_db function to initialize the database schema and tables. - Created repository and service dependency functions for user, pad, and backup management. - Updated requirements.txt to include psycopg2-binary for PostgreSQL support. --- src/backend/database/__init__.py | 27 +++++++++++ src/backend/database/database.py | 80 ++++++++++++++++++++++++++++++++ src/backend/requirements.txt | 1 + 3 files changed, 108 insertions(+) create mode 100644 src/backend/database/__init__.py create mode 100644 src/backend/database/database.py diff --git a/src/backend/database/__init__.py b/src/backend/database/__init__.py new file mode 100644 index 0000000..0c1fd1e --- /dev/null +++ b/src/backend/database/__init__.py @@ -0,0 +1,27 @@ +""" +Database module for the application. + +This module provides access to all database components used in the application. +""" + +from .database import ( + init_db, + get_session, + get_user_repository, + get_pad_repository, + get_backup_repository, + get_user_service, + get_pad_service, + get_backup_service +) + +__all__ = [ + 'init_db', + 'get_session', + 'get_user_repository', + 'get_pad_repository', + 'get_backup_repository', + 'get_user_service', + 'get_pad_service', + 'get_backup_service', +] diff --git a/src/backend/database/database.py b/src/backend/database/database.py new file mode 100644 index 0000000..b4619f9 --- /dev/null +++ b/src/backend/database/database.py @@ -0,0 +1,80 @@ +""" +Database connection and session management. +""" + +import os +from typing import AsyncGenerator +from urllib.parse import quote_plus as urlquote + +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from fastapi import Depends + +from .models import Base + +# PostgreSQL connection configuration +DB_USER = os.getenv('POSTGRES_USER', 'postgres') +DB_PASSWORD = os.getenv('POSTGRES_PASSWORD', 'postgres') +DB_NAME = os.getenv('POSTGRES_DB', 'pad') +DB_HOST = os.getenv('POSTGRES_HOST', 'localhost') +DB_PORT = os.getenv('POSTGRES_PORT', '5432') + +# SQLAlchemy async database URL +DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{urlquote(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}" + +# Create async engine +engine = create_async_engine(DATABASE_URL, echo=False) + +# Create async session factory +async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False +) + +async def init_db() -> None: + """Initialize the database with required tables""" + async with engine.begin() as conn: + # Create schema if it doesn't exist + await conn.execute(f"CREATE SCHEMA IF NOT EXISTS padws") + + # Create tables + await conn.run_sync(Base.metadata.create_all) + +async def get_session() -> AsyncGenerator[AsyncSession, None]: + """Get a database session""" + async with async_session() as session: + try: + yield session + finally: + await session.close() + +# Dependency for getting repositories +async def get_user_repository(session: AsyncSession = Depends(get_session)): + """Get a user repository""" + from .repository import UserRepository + return UserRepository(session) + +async def get_pad_repository(session: AsyncSession = Depends(get_session)): + """Get a pad repository""" + from .repository import PadRepository + return PadRepository(session) + +async def get_backup_repository(session: AsyncSession = Depends(get_session)): + """Get a backup repository""" + from .repository import BackupRepository + return BackupRepository(session) + +# Dependency for getting services +async def get_user_service(session: AsyncSession = Depends(get_session)): + """Get a user service""" + from .service import UserService + return UserService(session) + +async def get_pad_service(session: AsyncSession = Depends(get_session)): + """Get a pad service""" + from .service import PadService + return PadService(session) + +async def get_backup_service(session: AsyncSession = Depends(get_session)): + """Get a backup service""" + from .service import BackupService + return BackupService(session) diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index b24ffde..cdb0307 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -9,3 +9,4 @@ requests sqlalchemy posthog redis +psycopg2-binary \ No newline at end of file From 24b7611f5e827065734061866f96671218e445a7 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 10:09:48 +0000 Subject: [PATCH 06/32] refactor: update SQLAlchemy column types in UserModel - Changed VARCHAR to String for username and email fields in UserModel to align with SQLAlchemy best practices. - This update enhances code consistency and readability within the database model. --- src/backend/database/models/user_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/backend/database/models/user_model.py b/src/backend/database/models/user_model.py index d4dd60a..cbcadf3 100644 --- a/src/backend/database/models/user_model.py +++ b/src/backend/database/models/user_model.py @@ -1,5 +1,5 @@ from typing import List -from sqlalchemy import Column, Index, VARCHAR +from sqlalchemy import Column, Index, String from sqlalchemy.orm import relationship, Mapped from .base_model import Base, BaseModel @@ -15,8 +15,8 @@ class UserModel(Base, BaseModel): ) # User-specific fields - username = Column(VARCHAR(254), nullable=False, unique=True) - email = Column(VARCHAR(254), nullable=False) + username = Column(String(254), nullable=False, unique=True) + email = Column(String(254), nullable=False) # Relationships pads: Mapped[List["PadModel"]] = relationship( From 8735bc44f825285aefd3ec4c6c5448f2e84054f6 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 10:25:36 +0000 Subject: [PATCH 07/32] refactor: update models to use schema configuration - Refactored BackupModel, PadModel, and UserModel to utilize SCHEMA_NAME for schema configuration, enhancing consistency across models. - Removed the get_schema method from BaseModel to streamline schema handling. - Updated ForeignKey references to align with the new schema approach, improving clarity and maintainability. --- src/backend/database/models/backup_model.py | 13 +++++++------ src/backend/database/models/base_model.py | 6 ------ src/backend/database/models/pad_model.py | 16 +++++++++------- src/backend/database/models/user_model.py | 12 +++++++----- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/backend/database/models/backup_model.py b/src/backend/database/models/backup_model.py index 3563628..3750660 100644 --- a/src/backend/database/models/backup_model.py +++ b/src/backend/database/models/backup_model.py @@ -1,27 +1,28 @@ -from typing import Dict, Any, Optional +from typing import Dict, Any, TYPE_CHECKING from uuid import UUID as UUIDType from sqlalchemy import Column, ForeignKey, Index from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship, Mapped -from backend.database.models.pad_model import PadModel +from .base_model import Base, BaseModel, SCHEMA_NAME -from .base_model import Base, BaseModel +if TYPE_CHECKING: + from .pad_model import PadModel class BackupModel(Base, BaseModel): """Model for backups table in app schema""" __tablename__ = "backups" __table_args__ = ( - BaseModel.get_schema(), Index("ix_backups_source_id", "source_id"), - Index("ix_backups_created_at", "created_at") + Index("ix_backups_created_at", "created_at"), + {"schema": SCHEMA_NAME} ) # Backup-specific fields source_id = Column( UUIDType(as_uuid=True), - ForeignKey(f"{BaseModel.get_schema()['schema']}.pads.id", ondelete="CASCADE"), + ForeignKey(f"{SCHEMA_NAME}.pads.id", ondelete="CASCADE"), nullable=False ) data = Column(JSONB, nullable=False) diff --git a/src/backend/database/models/base_model.py b/src/backend/database/models/base_model.py index 35ef728..db9f309 100644 --- a/src/backend/database/models/base_model.py +++ b/src/backend/database/models/base_model.py @@ -1,5 +1,4 @@ from uuid import uuid4 -from datetime import datetime from typing import Any, Dict from sqlalchemy import Column, DateTime, UUID, func @@ -24,8 +23,3 @@ class BaseModel: def to_dict(self) -> Dict[str, Any]: """Convert model instance to dictionary""" return {c.name: getattr(self, c.name) for c in self.__table__.columns} - - @classmethod - def get_schema(cls) -> Dict[str, Any]: - """Return schema configuration for the model""" - return {"schema": SCHEMA_NAME} diff --git a/src/backend/database/models/pad_model.py b/src/backend/database/models/pad_model.py index 2fc4cf7..6493371 100644 --- a/src/backend/database/models/pad_model.py +++ b/src/backend/database/models/pad_model.py @@ -1,27 +1,29 @@ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, TYPE_CHECKING from uuid import UUID as UUIDType from sqlalchemy import Column, String, ForeignKey, Index from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship, Mapped -from .base_model import Base, BaseModel -from .backup_model import BackupModel -from .user_model import UserModel +from .base_model import Base, BaseModel, SCHEMA_NAME + +if TYPE_CHECKING: + from .backup_model import BackupModel + from .user_model import UserModel class PadModel(Base, BaseModel): """Model for pads table in app schema""" __tablename__ = "pads" __table_args__ = ( - BaseModel.get_schema(), Index("ix_pads_owner_id", "owner_id"), - Index("ix_pads_display_name", "display_name") + Index("ix_pads_display_name", "display_name"), + {"schema": SCHEMA_NAME} ) # Pad-specific fields owner_id = Column( UUIDType(as_uuid=True), - ForeignKey(f"{BaseModel.get_schema()['schema']}.users.id", ondelete="CASCADE"), + ForeignKey(f"{SCHEMA_NAME}.users.id", ondelete="CASCADE"), nullable=False ) display_name = Column(String(100), nullable=False) diff --git a/src/backend/database/models/user_model.py b/src/backend/database/models/user_model.py index cbcadf3..3fb9f5c 100644 --- a/src/backend/database/models/user_model.py +++ b/src/backend/database/models/user_model.py @@ -1,17 +1,19 @@ -from typing import List +from typing import List, TYPE_CHECKING from sqlalchemy import Column, Index, String from sqlalchemy.orm import relationship, Mapped -from .base_model import Base, BaseModel -from .pad_model import PadModel +from .base_model import Base, BaseModel, SCHEMA_NAME + +if TYPE_CHECKING: + from .pad_model import PadModel class UserModel(Base, BaseModel): """Model for users table in app schema""" __tablename__ = "users" __table_args__ = ( - BaseModel.get_schema(), Index("ix_users_username", "username"), - Index("ix_users_email", "email") + Index("ix_users_email", "email"), + {"schema": SCHEMA_NAME} ) # User-specific fields From 026c6355d069f799ed043c6452ff675514d3761b Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 10:27:52 +0000 Subject: [PATCH 08/32] refactor: update UUID type usage in BackupModel and PadModel - Replaced custom UUIDType with SQLAlchemy's built-in UUID type for source_id and owner_id fields in BackupModel and PadModel, respectively. - This change enhances code clarity and aligns with SQLAlchemy best practices for UUID handling. --- src/backend/database/models/backup_model.py | 5 ++--- src/backend/database/models/pad_model.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/backend/database/models/backup_model.py b/src/backend/database/models/backup_model.py index 3750660..5e2f7d3 100644 --- a/src/backend/database/models/backup_model.py +++ b/src/backend/database/models/backup_model.py @@ -1,7 +1,6 @@ from typing import Dict, Any, TYPE_CHECKING -from uuid import UUID as UUIDType -from sqlalchemy import Column, ForeignKey, Index +from sqlalchemy import Column, ForeignKey, Index, UUID from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship, Mapped @@ -21,7 +20,7 @@ class BackupModel(Base, BaseModel): # Backup-specific fields source_id = Column( - UUIDType(as_uuid=True), + UUID(as_uuid=True), ForeignKey(f"{SCHEMA_NAME}.pads.id", ondelete="CASCADE"), nullable=False ) diff --git a/src/backend/database/models/pad_model.py b/src/backend/database/models/pad_model.py index 6493371..e718f0b 100644 --- a/src/backend/database/models/pad_model.py +++ b/src/backend/database/models/pad_model.py @@ -1,7 +1,6 @@ from typing import List, Dict, Any, TYPE_CHECKING -from uuid import UUID as UUIDType -from sqlalchemy import Column, String, ForeignKey, Index +from sqlalchemy import Column, String, ForeignKey, Index, UUID from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship, Mapped @@ -22,7 +21,7 @@ class PadModel(Base, BaseModel): # Pad-specific fields owner_id = Column( - UUIDType(as_uuid=True), + UUID(as_uuid=True), ForeignKey(f"{SCHEMA_NAME}.users.id", ondelete="CASCADE"), nullable=False ) From 1274cb37c0a7b7f586d4e6b54eb5f6a92c07a06e Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 10:38:58 +0000 Subject: [PATCH 09/32] refactor: centralize schema configuration and update database initialization - Introduced SCHEMA_NAME in base_model.py for consistent schema handling across models. - Updated init_db function to use CreateSchema for schema creation, improving clarity and maintainability. - Adjusted imports in models to reflect the new schema configuration, enhancing code organization. --- src/backend/database/database.py | 8 +++----- src/backend/database/models/__init__.py | 3 ++- src/backend/database/models/base_model.py | 13 ++++++++----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/backend/database/database.py b/src/backend/database/database.py index b4619f9..122d673 100644 --- a/src/backend/database/database.py +++ b/src/backend/database/database.py @@ -8,9 +8,10 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker +from sqlalchemy.schema import CreateSchema from fastapi import Depends -from .models import Base +from .models import Base, SCHEMA_NAME # PostgreSQL connection configuration DB_USER = os.getenv('POSTGRES_USER', 'postgres') @@ -33,10 +34,7 @@ async def init_db() -> None: """Initialize the database with required tables""" async with engine.begin() as conn: - # Create schema if it doesn't exist - await conn.execute(f"CREATE SCHEMA IF NOT EXISTS padws") - - # Create tables + await conn.execute(CreateSchema(SCHEMA_NAME, if_not_exists=True)) await conn.run_sync(Base.metadata.create_all) async def get_session() -> AsyncGenerator[AsyncSession, None]: diff --git a/src/backend/database/models/__init__.py b/src/backend/database/models/__init__.py index 040f570..82fc969 100644 --- a/src/backend/database/models/__init__.py +++ b/src/backend/database/models/__init__.py @@ -4,7 +4,7 @@ This module provides access to all database models used in the application. """ -from .base_model import Base, BaseModel +from .base_model import Base, BaseModel, SCHEMA_NAME from .user_model import UserModel from .pad_model import PadModel from .backup_model import BackupModel @@ -15,4 +15,5 @@ 'UserModel', 'PadModel', 'BackupModel', + 'SCHEMA_NAME', ] diff --git a/src/backend/database/models/base_model.py b/src/backend/database/models/base_model.py index db9f309..87f3d8c 100644 --- a/src/backend/database/models/base_model.py +++ b/src/backend/database/models/base_model.py @@ -1,14 +1,17 @@ from uuid import uuid4 from typing import Any, Dict -from sqlalchemy import Column, DateTime, UUID, func +from sqlalchemy import Column, DateTime, UUID, func, MetaData from sqlalchemy.orm import DeclarativeMeta, declarative_base -# Create a single shared Base for all models -Base: DeclarativeMeta = declarative_base() - # Define schema name in a central location -SCHEMA_NAME = "padws" +SCHEMA_NAME = "pad_ws" + +# Create metadata with schema +metadata = MetaData(schema=SCHEMA_NAME) + +# Create a single shared Base for all models with the schema-aware metadata +Base: DeclarativeMeta = declarative_base(metadata=metadata) class BaseModel: """Base model with common fields for all models""" From 5cf9698d6c6f0863209dfb4d6ebb3dd6630e6cef Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 11:22:45 +0000 Subject: [PATCH 10/32] refactor: update UserModel and UserRepository to use UUID for user IDs - Modified UserModel to use SQLAlchemy's UUID type for the primary key, aligning with Keycloak's UUID requirements. - Updated UserRepository's create method to accept a user_id parameter, allowing for explicit user ID assignment during user creation. --- src/backend/database/models/user_model.py | 5 ++++- src/backend/database/repository/user_repository.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/backend/database/models/user_model.py b/src/backend/database/models/user_model.py index 3fb9f5c..060e6bb 100644 --- a/src/backend/database/models/user_model.py +++ b/src/backend/database/models/user_model.py @@ -1,5 +1,5 @@ from typing import List, TYPE_CHECKING -from sqlalchemy import Column, Index, String +from sqlalchemy import Column, Index, String, UUID from sqlalchemy.orm import relationship, Mapped from .base_model import Base, BaseModel, SCHEMA_NAME @@ -16,6 +16,9 @@ class UserModel(Base, BaseModel): {"schema": SCHEMA_NAME} ) + # Override the default id column to use Keycloak's UUID + id = Column(UUID(as_uuid=True), primary_key=True) + # User-specific fields username = Column(String(254), nullable=False, unique=True) email = Column(String(254), nullable=False) diff --git a/src/backend/database/repository/user_repository.py b/src/backend/database/repository/user_repository.py index 8253d69..226190d 100644 --- a/src/backend/database/repository/user_repository.py +++ b/src/backend/database/repository/user_repository.py @@ -18,9 +18,9 @@ def __init__(self, session: AsyncSession): """Initialize the repository with a database session""" self.session = session - async def create(self, username: str, email: str) -> UserModel: - """Create a new user""" - user = UserModel(username=username, email=email) + async def create(self, user_id: UUID, username: str, email: str) -> UserModel: + """Create a new user with specified ID""" + user = UserModel(id=user_id, username=username, email=email) self.session.add(user) await self.session.commit() await self.session.refresh(user) From 1ed0774aaeb602fc16e55478abee4ed8cc43d4f4 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 13:55:46 +0000 Subject: [PATCH 11/32] feat: enhance user model and repository for additional user attributes - Updated UserModel to include new fields: email_verified, name, given_name, family_name, and roles, allowing for more comprehensive user data management. - Modified UserRepository's create method to accept these new fields, facilitating their inclusion during user creation. - Introduced a new user router to handle user creation and retrieval endpoints, improving API functionality and user management capabilities. --- src/backend/database/models/user_model.py | 8 +- .../database/repository/user_repository.py | 17 +++- src/backend/database/service/user_service.py | 32 ++++--- src/backend/routers/user_router.py | 83 +++++++++++++++++++ 4 files changed, 126 insertions(+), 14 deletions(-) create mode 100644 src/backend/routers/user_router.py diff --git a/src/backend/database/models/user_model.py b/src/backend/database/models/user_model.py index 060e6bb..6edb2ae 100644 --- a/src/backend/database/models/user_model.py +++ b/src/backend/database/models/user_model.py @@ -1,5 +1,6 @@ from typing import List, TYPE_CHECKING -from sqlalchemy import Column, Index, String, UUID +from sqlalchemy import Column, Index, String, UUID, Boolean +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship, Mapped from .base_model import Base, BaseModel, SCHEMA_NAME @@ -22,6 +23,11 @@ class UserModel(Base, BaseModel): # User-specific fields username = Column(String(254), nullable=False, unique=True) email = Column(String(254), nullable=False) + email_verified = Column(Boolean, nullable=False, default=False) + name = Column(String(254), nullable=True) + given_name = Column(String(254), nullable=True) + family_name = Column(String(254), nullable=True) + roles = Column(JSONB, nullable=False, default=[]) # Relationships pads: Mapped[List["PadModel"]] = relationship( diff --git a/src/backend/database/repository/user_repository.py b/src/backend/database/repository/user_repository.py index 226190d..2e25e26 100644 --- a/src/backend/database/repository/user_repository.py +++ b/src/backend/database/repository/user_repository.py @@ -18,9 +18,20 @@ def __init__(self, session: AsyncSession): """Initialize the repository with a database session""" self.session = session - async def create(self, user_id: UUID, username: str, email: str) -> UserModel: - """Create a new user with specified ID""" - user = UserModel(id=user_id, username=username, email=email) + async def create(self, user_id: UUID, username: str, email: str, email_verified: bool = False, + name: str = None, given_name: str = None, family_name: str = None, + roles: list = None) -> UserModel: + """Create a new user with specified ID and optional fields""" + user = UserModel( + id=user_id, + username=username, + email=email, + email_verified=email_verified, + name=name, + given_name=given_name, + family_name=family_name, + roles=roles or [] + ) self.session.add(user) await self.session.commit() await self.session.refresh(user) diff --git a/src/backend/database/service/user_service.py b/src/backend/database/service/user_service.py index 5b66f93..1677a7e 100644 --- a/src/backend/database/service/user_service.py +++ b/src/backend/database/service/user_service.py @@ -17,24 +17,36 @@ def __init__(self, session: AsyncSession): self.session = session self.repository = UserRepository(session) - async def create_user(self, username: str, email: str) -> Dict[str, Any]: - """Create a new user""" + async def create_user(self, user_id: UUID, username: str, email: str, + email_verified: bool = False, name: str = None, + given_name: str = None, family_name: str = None, + roles: list = None) -> Dict[str, Any]: + """Create a new user with specified ID and optional fields""" # Validate input - if not username or not email: - raise ValueError("Username and email are required") + if not user_id or not username or not email: + raise ValueError("User ID, username, and email are required") + + # Check if user_id already exists + existing_id = await self.repository.get_by_id(user_id) + if existing_id: + raise ValueError(f"User with ID '{user_id}' already exists") # Check if username already exists existing_user = await self.repository.get_by_username(username) if existing_user: raise ValueError(f"Username '{username}' is already taken") - # Check if email already exists - existing_email = await self.repository.get_by_email(email) - if existing_email: - raise ValueError(f"Email '{email}' is already registered") - # Create user - user = await self.repository.create(username, email) + user = await self.repository.create( + user_id=user_id, + username=username, + email=email, + email_verified=email_verified, + name=name, + given_name=given_name, + family_name=family_name, + roles=roles + ) return user.to_dict() async def get_user(self, user_id: UUID) -> Optional[Dict[str, Any]]: diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py new file mode 100644 index 0000000..7a5bf2d --- /dev/null +++ b/src/backend/routers/user_router.py @@ -0,0 +1,83 @@ +import os +from uuid import UUID + +import posthog +from fastapi import APIRouter, Depends, HTTPException + +from config import redis_client, OIDC_CONFIG +from database import get_user_service +from database.service import UserService +from dependencies import get_current_user, require_admin + +user_router = APIRouter() + +@user_router.post("/") +async def create_user( + user_id: UUID, + username: str, + email: str, + email_verified: bool = False, + name: str = None, + given_name: str = None, + family_name: str = None, + roles: list = None, + _: bool = Depends(require_admin), + user_service: UserService = Depends(get_user_service) +): + try: + user = await user_service.create_user( + user_id=user_id, + username=username, + email=email, + email_verified=email_verified, + name=name, + given_name=given_name, + family_name=family_name, + roles=roles + ) + return user + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@user_router.get("/") +async def get_all_users( + _: bool = Depends(require_admin), + user_service: UserService = Depends(get_user_service) +): + users = await user_service.get_all_users() + return users + + +@user_router.get("/me") +async def get_user_info( + user: dict = Depends(get_current_user), +): + + if os.getenv("VITE_PUBLIC_POSTHOG_KEY"): + telemetry = user.copy() + telemetry["$current_url"] = OIDC_CONFIG["frontend_url"] + posthog.identify(distinct_id=user["id"], properties=telemetry) + + return user + + +@user_router.get("/count") +async def get_user_count( + _: bool = Depends(require_admin), +): + session_count = len(redis_client.keys("session:*")) + return {"active_sessions": session_count } + + +@user_router.get("/{user_id}") +async def get_user( + user_id: UUID, + _: bool = Depends(require_admin), + user_service: UserService = Depends(get_user_service) +): + user = await user_service.get_user(user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + return user \ No newline at end of file From eefd4a232a9571768a2e7d677066f25534949c47 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 13:55:58 +0000 Subject: [PATCH 12/32] feat: add JWT token handling and user management dependencies - Introduced functions for decoding JWT tokens and retrieving current user information, enhancing authentication flow. - Implemented user creation logic for new users based on token data, improving user management capabilities. - Added an admin role check to enforce access control, ensuring only authorized users can access certain resources. --- src/backend/dependencies.py | 60 ++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/src/backend/dependencies.py b/src/backend/dependencies.py index cf86d0e..f0204ec 100644 --- a/src/backend/dependencies.py +++ b/src/backend/dependencies.py @@ -1,7 +1,11 @@ -from typing import Optional +import jwt +from typing import Optional, Dict, Any + from fastapi import Request, HTTPException, Depends from config import get_session, is_token_expired, refresh_token +from database import get_user_service +from database.service import UserService class SessionData: def __init__(self, access_token: str, token_data: dict): @@ -58,3 +62,57 @@ async def __call__(self, request: Request) -> Optional[SessionData]: # Create instances for use in route handlers require_auth = AuthDependency(auto_error=True) optional_auth = AuthDependency(auto_error=False) + +# JWT token handling dependencies +async def get_decoded_token( + auth: SessionData = Depends(require_auth) +) -> Dict[str, Any]: + + token_data = auth.token_data + access_token = token_data.get("access_token") + + return jwt.decode(access_token, options={"verify_signature": False}) + + +async def get_current_user( + decoded_token: Dict[str, Any] = Depends(get_decoded_token), + user_service: UserService = Depends(get_user_service), +) -> Dict[str, Any]: + + user_id = decoded_token["sub"] + user_info = await user_service.get_user(user_id) + + if not user_info: + try: + user_info = await user_service.create_user( + user_id=user_id, + username=decoded_token["preferred_username"], + email=decoded_token["email"], + email_verified=decoded_token["email_verified"], + name=decoded_token["name"], + given_name=decoded_token["given_name"], + family_name=decoded_token["family_name"], + roles=decoded_token["realm_access"]["roles"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error creating user: {e}" + ) + + return user_info + + +async def require_admin( + decoded_token: Dict[str, Any] = Depends(get_decoded_token) +) -> bool: + + roles = decoded_token.get("realm_access", {}).get("roles", []) + + if "admin" not in roles: + raise HTTPException( + status_code=403, + detail=f"Admin privileges required" + ) + + return True From 289df9ec4c44b238d086eae2bdaa8db577aa4a44 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 14:29:07 +0000 Subject: [PATCH 13/32] feat: add TemplatePadModel for managing template pads - Introduced TemplatePadModel to represent the template pads table in the app schema, enhancing the database structure. - Defined columns for name, display_name, and data, ensuring comprehensive data management for template pads. - Added an index on display_name for improved query performance. --- src/backend/database/models/pad_model.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/backend/database/models/pad_model.py b/src/backend/database/models/pad_model.py index e718f0b..d370f47 100644 --- a/src/backend/database/models/pad_model.py +++ b/src/backend/database/models/pad_model.py @@ -48,3 +48,20 @@ def to_dict(self) -> Dict[str, Any]: import json result["data"] = json.loads(result["data"]) return result + + +class TemplatePadModel(Base, BaseModel): + """Model for template pads table in app schema""" + __tablename__ = "template_pads" + __table_args__ = ( + Index("ix_template_pads_display_name", "display_name"), + Index("ix_template_pads_name", "name"), + {"schema": SCHEMA_NAME} + ) + + name = Column(String(100), nullable=False) + display_name = Column(String(100), nullable=False) + data = Column(JSONB, nullable=False) + + def __repr__(self) -> str: + return f"" \ No newline at end of file From ff53a8a88dede81171cd2268c913f9859d75cd63 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 15:08:07 +0000 Subject: [PATCH 14/32] feat: add TemplatePad repository, service, and router for template pad management - Introduced TemplatePadRepository for database operations related to template pads, including create, read, update, and delete functionalities. - Added TemplatePadService to encapsulate business logic for template pad management, ensuring data validation and error handling. - Created a new router for template pad endpoints, providing API access for creating, retrieving, updating, and deleting template pads, with admin access control. - Updated existing modules to integrate the new template pad features, enhancing overall application functionality. --- src/backend/database/__init__.py | 6 +- src/backend/database/database.py | 10 ++ src/backend/database/models/__init__.py | 3 +- src/backend/database/models/pad_model.py | 2 +- src/backend/database/repository/__init__.py | 2 + .../repository/template_pad_repository.py | 63 +++++++++++ src/backend/database/service/__init__.py | 2 + .../database/service/template_pad_service.py | 98 +++++++++++++++++ src/backend/routers/template_pad_router.py | 104 ++++++++++++++++++ 9 files changed, 287 insertions(+), 3 deletions(-) create mode 100644 src/backend/database/repository/template_pad_repository.py create mode 100644 src/backend/database/service/template_pad_service.py create mode 100644 src/backend/routers/template_pad_router.py diff --git a/src/backend/database/__init__.py b/src/backend/database/__init__.py index 0c1fd1e..6c5c0e7 100644 --- a/src/backend/database/__init__.py +++ b/src/backend/database/__init__.py @@ -10,9 +10,11 @@ get_user_repository, get_pad_repository, get_backup_repository, + get_template_pad_repository, get_user_service, get_pad_service, - get_backup_service + get_backup_service, + get_template_pad_service ) __all__ = [ @@ -21,7 +23,9 @@ 'get_user_repository', 'get_pad_repository', 'get_backup_repository', + 'get_template_pad_repository', 'get_user_service', 'get_pad_service', 'get_backup_service', + 'get_template_pad_service', ] diff --git a/src/backend/database/database.py b/src/backend/database/database.py index 122d673..7e9e01b 100644 --- a/src/backend/database/database.py +++ b/src/backend/database/database.py @@ -61,6 +61,11 @@ async def get_backup_repository(session: AsyncSession = Depends(get_session)): from .repository import BackupRepository return BackupRepository(session) +async def get_template_pad_repository(session: AsyncSession = Depends(get_session)): + """Get a template pad repository""" + from .repository import TemplatePadRepository + return TemplatePadRepository(session) + # Dependency for getting services async def get_user_service(session: AsyncSession = Depends(get_session)): """Get a user service""" @@ -76,3 +81,8 @@ async def get_backup_service(session: AsyncSession = Depends(get_session)): """Get a backup service""" from .service import BackupService return BackupService(session) + +async def get_template_pad_service(session: AsyncSession = Depends(get_session)): + """Get a template pad service""" + from .service import TemplatePadService + return TemplatePadService(session) diff --git a/src/backend/database/models/__init__.py b/src/backend/database/models/__init__.py index 82fc969..41fe372 100644 --- a/src/backend/database/models/__init__.py +++ b/src/backend/database/models/__init__.py @@ -6,7 +6,7 @@ from .base_model import Base, BaseModel, SCHEMA_NAME from .user_model import UserModel -from .pad_model import PadModel +from .pad_model import PadModel, TemplatePadModel from .backup_model import BackupModel __all__ = [ @@ -15,5 +15,6 @@ 'UserModel', 'PadModel', 'BackupModel', + 'TemplatePadModel', 'SCHEMA_NAME', ] diff --git a/src/backend/database/models/pad_model.py b/src/backend/database/models/pad_model.py index d370f47..fafddbc 100644 --- a/src/backend/database/models/pad_model.py +++ b/src/backend/database/models/pad_model.py @@ -59,7 +59,7 @@ class TemplatePadModel(Base, BaseModel): {"schema": SCHEMA_NAME} ) - name = Column(String(100), nullable=False) + name = Column(String(100), nullable=False, unique=True) display_name = Column(String(100), nullable=False) data = Column(JSONB, nullable=False) diff --git a/src/backend/database/repository/__init__.py b/src/backend/database/repository/__init__.py index 1e99327..a2433d7 100644 --- a/src/backend/database/repository/__init__.py +++ b/src/backend/database/repository/__init__.py @@ -7,9 +7,11 @@ from .user_repository import UserRepository from .pad_repository import PadRepository from .backup_repository import BackupRepository +from .template_pad_repository import TemplatePadRepository __all__ = [ 'UserRepository', 'PadRepository', 'BackupRepository', + 'TemplatePadRepository', ] diff --git a/src/backend/database/repository/template_pad_repository.py b/src/backend/database/repository/template_pad_repository.py new file mode 100644 index 0000000..4322244 --- /dev/null +++ b/src/backend/database/repository/template_pad_repository.py @@ -0,0 +1,63 @@ +""" +Template pad repository for database operations related to template pads. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update, delete + +from ..models import TemplatePadModel + +class TemplatePadRepository: + """Repository for template pad-related database operations""" + + def __init__(self, session: AsyncSession): + """Initialize the repository with a database session""" + self.session = session + + async def create(self, name: str, display_name: str, data: Dict[str, Any]) -> TemplatePadModel: + """Create a new template pad""" + template_pad = TemplatePadModel(name=name, display_name=display_name, data=data) + self.session.add(template_pad) + await self.session.commit() + await self.session.refresh(template_pad) + return template_pad + + async def get_by_id(self, template_id: UUID) -> Optional[TemplatePadModel]: + """Get a template pad by ID""" + stmt = select(TemplatePadModel).where(TemplatePadModel.id == template_id) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_by_name(self, name: str) -> Optional[TemplatePadModel]: + """Get a template pad by name""" + stmt = select(TemplatePadModel).where(TemplatePadModel.name == name) + result = await self.session.execute(stmt) + return result.scalars().first() + + async def get_all(self) -> List[TemplatePadModel]: + """Get all template pads""" + stmt = select(TemplatePadModel).order_by(TemplatePadModel.display_name) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def update(self, template_id: UUID, data: Dict[str, Any]) -> Optional[TemplatePadModel]: + """Update a template pad""" + stmt = update(TemplatePadModel).where(TemplatePadModel.id == template_id).values(**data).returning(TemplatePadModel) + result = await self.session.execute(stmt) + await self.session.commit() + return result.scalars().first() + + async def update_data(self, template_id: UUID, template_data: Dict[str, Any]) -> Optional[TemplatePadModel]: + """Update just the data field of a template pad""" + return await self.update(template_id, {"data": template_data}) + + async def delete(self, template_id: UUID) -> bool: + """Delete a template pad""" + stmt = delete(TemplatePadModel).where(TemplatePadModel.id == template_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount > 0 diff --git a/src/backend/database/service/__init__.py b/src/backend/database/service/__init__.py index 9fb7bd2..d362c0b 100644 --- a/src/backend/database/service/__init__.py +++ b/src/backend/database/service/__init__.py @@ -7,9 +7,11 @@ from .user_service import UserService from .pad_service import PadService from .backup_service import BackupService +from .template_pad_service import TemplatePadService __all__ = [ 'UserService', 'PadService', 'BackupService', + 'TemplatePadService', ] diff --git a/src/backend/database/service/template_pad_service.py b/src/backend/database/service/template_pad_service.py new file mode 100644 index 0000000..ead220b --- /dev/null +++ b/src/backend/database/service/template_pad_service.py @@ -0,0 +1,98 @@ +""" +Template pad service for business logic related to template pads. +""" + +from typing import List, Optional, Dict, Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..repository import TemplatePadRepository + +class TemplatePadService: + """Service for template pad-related business logic""" + + def __init__(self, session: AsyncSession): + """Initialize the service with a database session""" + self.session = session + self.repository = TemplatePadRepository(session) + + async def create_template(self, name: str, display_name: str, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new template pad""" + # Validate input + if not name: + raise ValueError("Name is required") + + if not display_name: + raise ValueError("Display name is required") + + if not data: + raise ValueError("Template data is required") + + # Check if template with same name already exists + existing_template = await self.repository.get_by_name(name) + if existing_template: + raise ValueError(f"Template with name '{name}' already exists") + + # Create template pad + template_pad = await self.repository.create(name, display_name, data) + return template_pad.to_dict() + + async def get_template(self, template_id: UUID) -> Optional[Dict[str, Any]]: + """Get a template pad by ID""" + template_pad = await self.repository.get_by_id(template_id) + return template_pad.to_dict() if template_pad else None + + async def get_template_by_name(self, name: str) -> Optional[Dict[str, Any]]: + """Get a template pad by name""" + template_pad = await self.repository.get_by_name(name) + return template_pad.to_dict() if template_pad else None + + async def get_all_templates(self) -> List[Dict[str, Any]]: + """Get all template pads""" + template_pads = await self.repository.get_all() + return [template_pad.to_dict() for template_pad in template_pads] + + async def update_template(self, name: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update a template pad""" + # Get the template pad to check if it exists + template_pad = await self.repository.get_by_name(name) + if not template_pad: + raise ValueError(f"Template pad with name '{name}' does not exist") + + # Validate name and display_name if they're being updated + if 'name' in data and not data['name']: + raise ValueError("Name cannot be empty") + + if 'display_name' in data and not data['display_name']: + raise ValueError("Display name cannot be empty") + + # Check if new name already exists (if being updated) + if 'name' in data and data['name'] != template_pad.name: + existing_template = await self.repository.get_by_name(data['name']) + if existing_template: + raise ValueError(f"Template with name '{data['name']}' already exists") + + # Update template pad + updated_template = await self.repository.update(name, data) + return updated_template.to_dict() if updated_template else None + + async def update_template_data(self, name: str, template_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update just the data field of a template pad""" + # Get the template pad to check if it exists + template_pad = await self.repository.get_by_name(name) + if not template_pad: + raise ValueError(f"Template pad with name '{name}' does not exist") + + # Update template pad data + updated_template = await self.repository.update_data(name, template_data) + return updated_template.to_dict() if updated_template else None + + async def delete_template(self, name: str) -> bool: + """Delete a template pad""" + # Get the template pad to check if it exists + template_pad = await self.repository.get_by_name(name) + if not template_pad: + raise ValueError(f"Template pad with name '{name}' does not exist") + + return await self.repository.delete(name) diff --git a/src/backend/routers/template_pad_router.py b/src/backend/routers/template_pad_router.py new file mode 100644 index 0000000..c009ffe --- /dev/null +++ b/src/backend/routers/template_pad_router.py @@ -0,0 +1,104 @@ +from uuid import UUID +from typing import Dict, Any + +from fastapi import APIRouter, HTTPException, Depends + +from dependencies import SessionData, require_auth, require_admin +from database import get_template_pad_service +from database.service import TemplatePadService + +template_pad_router = APIRouter() + +@template_pad_router.post("/") +async def create_template_pad( + data: Dict[str, Any], + name: str, + display_name: str, + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Create a new template pad (admin only)""" + try: + template_pad = await template_pad_service.create_template( + name=name, + display_name=display_name, + data=data + ) + return template_pad + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + +@template_pad_router.get("/") +async def get_all_template_pads( + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Get all template pads""" + try: + template_pads = await template_pad_service.get_all_templates() + return template_pads + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get template pads: {str(e)}") + +@template_pad_router.get("/{name}") +async def get_template_pad( + name: str, + _: SessionData = Depends(require_auth), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Get a specific template pad by name""" + template_pad = await template_pad_service.get_template_by_name(name) + if not template_pad: + raise HTTPException(status_code=404, detail="Template pad not found") + + return template_pad + +@template_pad_router.put("/{name}") +async def update_template_pad( + name: str, + data: Dict[str, Any], + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Update a template pad (admin only)""" + try: + updated_template = await template_pad_service.update_template(name, data) + if not updated_template: + raise HTTPException(status_code=404, detail="Template pad not found") + + return updated_template + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + +@template_pad_router.put("/{name}/data") +async def update_template_pad_data( + name: str, + data: Dict[str, Any], + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Update just the data field of a template pad (admin only)""" + try: + updated_template = await template_pad_service.update_template_data(name, data) + if not updated_template: + raise HTTPException(status_code=404, detail="Template pad not found") + + return updated_template + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + +@template_pad_router.delete("/{name}") +async def delete_template_pad( + name: str, + _: bool = Depends(require_admin), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Delete a template pad (admin only)""" + try: + success = await template_pad_service.delete_template(name) + if not success: + raise HTTPException(status_code=404, detail="Template pad not found") + + return {"status": "success", "message": "Template pad deleted successfully"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) From 287fd754ab7cb5f6316fca481522654ff26fb22e Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 15:11:02 +0000 Subject: [PATCH 15/32] feat: add docstrings to user router functions for improved clarity - Added docstrings to the create_user, get_all_users, get_user_info, get_user_count, and get_user functions to clarify their purpose and access restrictions (admin only). - This enhancement improves code documentation and aids in understanding the functionality of user management endpoints. --- src/backend/routers/user_router.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py index 7a5bf2d..02b9c85 100644 --- a/src/backend/routers/user_router.py +++ b/src/backend/routers/user_router.py @@ -24,6 +24,7 @@ async def create_user( _: bool = Depends(require_admin), user_service: UserService = Depends(get_user_service) ): + """Create a new user (admin only)""" try: user = await user_service.create_user( user_id=user_id, @@ -45,6 +46,7 @@ async def get_all_users( _: bool = Depends(require_admin), user_service: UserService = Depends(get_user_service) ): + """Get all users (admin only)""" users = await user_service.get_all_users() return users @@ -53,7 +55,7 @@ async def get_all_users( async def get_user_info( user: dict = Depends(get_current_user), ): - + """Get the current user's information""" if os.getenv("VITE_PUBLIC_POSTHOG_KEY"): telemetry = user.copy() telemetry["$current_url"] = OIDC_CONFIG["frontend_url"] @@ -66,6 +68,7 @@ async def get_user_info( async def get_user_count( _: bool = Depends(require_admin), ): + """Get the number of active sessions (admin only)""" session_count = len(redis_client.keys("session:*")) return {"active_sessions": session_count } @@ -76,6 +79,7 @@ async def get_user( _: bool = Depends(require_admin), user_service: UserService = Depends(get_user_service) ): + """Get a user by ID (admin only)""" user = await user_service.get_user(user_id) if not user: raise HTTPException(status_code=404, detail="User not found") From 0c496d9e662706e67fe26017c953439a81fa193f Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 15:11:10 +0000 Subject: [PATCH 16/32] feat: update requirements for database and file handling - Added python-multipart to requirements.txt to support file uploads in the application. - Ensured psycopg2-binary remains included for PostgreSQL database interactions, maintaining necessary dependencies for backend functionality. --- src/backend/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index cdb0307..3303e91 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -9,4 +9,5 @@ requests sqlalchemy posthog redis -psycopg2-binary \ No newline at end of file +psycopg2-binary +python-multipart \ No newline at end of file From 30e1f2742bfbd6b43bc4aaec32f26e6466ed49cf Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 15:34:03 +0000 Subject: [PATCH 17/32] feat: enhance template management and application structure - Added a new function to load templates from JSON files into the database, ensuring templates are created if they do not already exist. - Updated the lifespan of the FastAPI application to include the loading of templates during startup. - Refactored the TemplatePadRepository to update and delete templates using their name instead of ID, improving usability. - Introduced a new router for template pad endpoints, expanding API capabilities for template management. - Added a default template JSON file to provide a starting point for users, enhancing user experience. --- .../repository/template_pad_repository.py | 12 +-- src/backend/main.py | 83 +++++++++++++++++-- src/backend/routers/template_pad_router.py | 7 +- src/backend/routers/user_router.py | 1 + .../default.json} | 3 + 5 files changed, 90 insertions(+), 16 deletions(-) rename src/backend/{default_canvas.json => templates/default.json} (99%) diff --git a/src/backend/database/repository/template_pad_repository.py b/src/backend/database/repository/template_pad_repository.py index 4322244..ff85e39 100644 --- a/src/backend/database/repository/template_pad_repository.py +++ b/src/backend/database/repository/template_pad_repository.py @@ -44,20 +44,20 @@ async def get_all(self) -> List[TemplatePadModel]: result = await self.session.execute(stmt) return result.scalars().all() - async def update(self, template_id: UUID, data: Dict[str, Any]) -> Optional[TemplatePadModel]: + async def update(self, name: str, data: Dict[str, Any]) -> Optional[TemplatePadModel]: """Update a template pad""" - stmt = update(TemplatePadModel).where(TemplatePadModel.id == template_id).values(**data).returning(TemplatePadModel) + stmt = update(TemplatePadModel).where(TemplatePadModel.name == name).values(**data).returning(TemplatePadModel) result = await self.session.execute(stmt) await self.session.commit() return result.scalars().first() - async def update_data(self, template_id: UUID, template_data: Dict[str, Any]) -> Optional[TemplatePadModel]: + async def update_data(self, name: str, template_data: Dict[str, Any]) -> Optional[TemplatePadModel]: """Update just the data field of a template pad""" - return await self.update(template_id, {"data": template_data}) + return await self.update(name, {"data": template_data}) - async def delete(self, template_id: UUID) -> bool: + async def delete(self, name: str) -> bool: """Delete a template pad""" - stmt = delete(TemplatePadModel).where(TemplatePadModel.id == template_id) + stmt = delete(TemplatePadModel).where(TemplatePadModel.name == name) result = await self.session.execute(stmt) await self.session.commit() return result.rowcount > 0 diff --git a/src/backend/main.py b/src/backend/main.py index a9dc15f..b39f1f7 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -1,13 +1,28 @@ import os +import json from contextlib import asynccontextmanager from typing import Optional +import posthog from fastapi import FastAPI, Request, Depends from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from dotenv import load_dotenv -import posthog + +from db import init_db +from database import init_db as init_database +from config import STATIC_DIR, ASSETS_DIR +from dependencies import SessionData, optional_auth +from routers.auth import auth_router +from routers.canvas import canvas_router +from routers.user import user_router +from routers.user_router import user_router as user_router_v2 +from routers.workspace import workspace_router +from routers.pad_router import pad_router +from routers.template_pad_router import template_pad_router +from database.service import TemplatePadService +from database.database import async_session load_dotenv() @@ -18,18 +33,65 @@ posthog.project_api_key = POSTHOG_API_KEY posthog.host = POSTHOG_HOST -from db import init_db -from config import STATIC_DIR, ASSETS_DIR -from dependencies import SessionData, optional_auth -from routers.auth import auth_router -from routers.canvas import canvas_router -from routers.user import user_router -from routers.workspace import workspace_router +async def load_templates(): + """ + Load all templates from the templates directory into the database if they don't exist. + + This function reads all JSON files in the templates directory, extracts the display name + from the "appState.pad.displayName" field, uses the filename as the name, and stores + the entire JSON as the data. + """ + try: + # Get a session and template service + async with async_session() as session: + template_service = TemplatePadService(session) + + # Get the templates directory path + templates_dir = os.path.join(os.path.dirname(__file__), "templates") + + # Iterate through all JSON files in the templates directory + for filename in os.listdir(templates_dir): + if filename.endswith(".json"): + # Use the filename without extension as the name + name = os.path.splitext(filename)[0] + + # Check if template already exists + existing_template = await template_service.get_template_by_name(name) + + if not existing_template: + + file_path = os.path.join(templates_dir, filename) + + # Read the JSON file + with open(file_path, 'r') as f: + template_data = json.load(f) + + # Extract the display name from the JSON + display_name = template_data.get("appState", {}).get("pad", {}).get("displayName", "Untitled") + + # Create the template if it doesn't exist + await template_service.create_template( + name=name, + display_name=display_name, + data=template_data + ) + print(f"Added template: {name} ({display_name})") + else: + print(f"Template already in database: '{name}'") + + except Exception as e: + print(f"Error loading templates: {str(e)}") @asynccontextmanager async def lifespan(_: FastAPI): await init_db() + await init_database() print("Database connection established successfully") + + # Load all templates from the templates directory + await load_templates() + print("Templates loaded successfully") + yield app = FastAPI(lifespan=lifespan) @@ -37,7 +99,7 @@ async def lifespan(_: FastAPI): # CORS middleware setup app.add_middleware( CORSMiddleware, - allow_origins=["https://kc.pad.ws", "https://alex.pad.ws"], + allow_origins=["https://kc.pad.ws", "*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -54,7 +116,10 @@ async def read_root(request: Request, auth: Optional[SessionData] = Depends(opti app.include_router(auth_router, prefix="/auth") app.include_router(canvas_router, prefix="/api/canvas") app.include_router(user_router, prefix="/api/user") +app.include_router(user_router_v2, prefix="/api/users") app.include_router(workspace_router, prefix="/api/workspace") +app.include_router(pad_router, prefix="/api/pad") +app.include_router(template_pad_router, prefix="/api/templates") if __name__ == "__main__": import uvicorn diff --git a/src/backend/routers/template_pad_router.py b/src/backend/routers/template_pad_router.py index c009ffe..e2c9d35 100644 --- a/src/backend/routers/template_pad_router.py +++ b/src/backend/routers/template_pad_router.py @@ -1,4 +1,3 @@ -from uuid import UUID from typing import Dict, Any from fastapi import APIRouter, HTTPException, Depends @@ -9,6 +8,7 @@ template_pad_router = APIRouter() + @template_pad_router.post("/") async def create_template_pad( data: Dict[str, Any], @@ -28,6 +28,7 @@ async def create_template_pad( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + @template_pad_router.get("/") async def get_all_template_pads( _: bool = Depends(require_admin), @@ -40,6 +41,7 @@ async def get_all_template_pads( except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get template pads: {str(e)}") + @template_pad_router.get("/{name}") async def get_template_pad( name: str, @@ -53,6 +55,7 @@ async def get_template_pad( return template_pad + @template_pad_router.put("/{name}") async def update_template_pad( name: str, @@ -70,6 +73,7 @@ async def update_template_pad( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + @template_pad_router.put("/{name}/data") async def update_template_pad_data( name: str, @@ -87,6 +91,7 @@ async def update_template_pad_data( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + @template_pad_router.delete("/{name}") async def delete_template_pad( name: str, diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py index 02b9c85..692a9fe 100644 --- a/src/backend/routers/user_router.py +++ b/src/backend/routers/user_router.py @@ -11,6 +11,7 @@ user_router = APIRouter() + @user_router.post("/") async def create_user( user_id: UUID, diff --git a/src/backend/default_canvas.json b/src/backend/templates/default.json similarity index 99% rename from src/backend/default_canvas.json rename to src/backend/templates/default.json index a73e469..0629611 100644 --- a/src/backend/default_canvas.json +++ b/src/backend/templates/default.json @@ -1,6 +1,9 @@ { "files": {}, "appState": { + "pad": { + "displayName": "Welcome!" + }, "name": "Pad.ws", "zoom": { "value": 1 From d96ff3ceff039430859470ae75f82eac44523f91 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 16:08:06 +0000 Subject: [PATCH 18/32] feat: refactor authentication and user session management - Introduced UserSession class to unify user session handling, integrating authentication data with user information. - Updated AuthDependency to utilize UserSession, enhancing session validation and user data retrieval. - Replaced SessionData references with UserSession across various routers for consistent user session management. - Added new methods to UserSession for accessing user attributes and caching user data from the database. - Removed deprecated canvas router and introduced a new pad router for managing user pads and backups, improving API structure and functionality. --- src/backend/dependencies.py | 202 ++++++++++++--------- src/backend/main.py | 12 +- src/backend/routers/auth.py | 2 +- src/backend/routers/canvas.py | 67 ------- src/backend/routers/pad_router.py | 169 +++++++++++++++++ src/backend/routers/template_pad_router.py | 4 +- src/backend/routers/user.py | 58 ------ src/backend/routers/user_router.py | 34 +++- src/backend/routers/workspace.py | 25 +-- 9 files changed, 330 insertions(+), 243 deletions(-) delete mode 100644 src/backend/routers/canvas.py create mode 100644 src/backend/routers/pad_router.py delete mode 100644 src/backend/routers/user.py diff --git a/src/backend/dependencies.py b/src/backend/dependencies.py index f0204ec..d32e279 100644 --- a/src/backend/dependencies.py +++ b/src/backend/dependencies.py @@ -1,5 +1,6 @@ import jwt -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union +from uuid import UUID from fastapi import Request, HTTPException, Depends @@ -7,112 +8,141 @@ from database import get_user_service from database.service import UserService -class SessionData: - def __init__(self, access_token: str, token_data: dict): +class UserSession: + """ + Unified user session model that integrates authentication data with user information. + This provides a single interface for accessing both token data and user details. + """ + def __init__(self, access_token: str, token_data: dict, user_id: UUID = None): self.access_token = access_token self.token_data = token_data + self._user_data = None + + @property + def is_authenticated(self) -> bool: + """Check if the session is authenticated""" + return bool(self.access_token and self.id) + + @property + def id(self) -> UUID: + """Get user ID from token data""" + return UUID(self.token_data.get("sub")) + + @property + def email(self) -> str: + """Get user email from token data""" + return self.token_data.get("email", "") + + @property + def email_verified(self) -> bool: + """Get email verification status from token data""" + return self.token_data.get("email_verified", False) + + @property + def username(self) -> str: + """Get username from token data""" + return self.token_data.get("preferred_username", "") + + @property + def name(self) -> str: + """Get full name from token data""" + return self.token_data.get("name", "") + + @property + def given_name(self) -> str: + """Get given name from token data""" + return self.token_data.get("given_name", "") + + @property + def family_name(self) -> str: + """Get family name from token data""" + return self.token_data.get("family_name", "") + + @property + def roles(self) -> list: + """Get user roles from token data""" + return self.token_data.get("realm_access", {}).get("roles", []) + + @property + def is_admin(self) -> bool: + """Check if user has admin role""" + return "admin" in self.roles + + def to_dict(self) -> Dict[str, Any]: + """Convert user session to dictionary with common user fields""" + return { + "id": str(self.id), + "email": self.email, + "username": self.username, + "name": self.name, + "given_name": self.given_name, + "family_name": self.family_name, + "email_verified": self.email_verified, + "roles": self.roles + } + + async def get_user_data(self, user_service: UserService) -> Dict[str, Any]: + """Get user data from database, caching the result""" + if self._user_data is None and self.id: + self._user_data = await user_service.get_user(self.id) + return self._user_data class AuthDependency: - def __init__(self, auto_error: bool = True): + """ + Authentication dependency that validates session tokens and provides + a unified UserSession object for route handlers. + """ + def __init__(self, auto_error: bool = True, require_admin: bool = False): self.auto_error = auto_error + self.require_admin = require_admin - async def __call__(self, request: Request) -> Optional[SessionData]: + async def __call__(self, + request: Request, + user_service: UserService = Depends(get_user_service)) -> Optional[UserSession]: + # Get session ID from cookies session_id = request.cookies.get('session_id') + # Handle missing session ID if not session_id: - if self.auto_error: - raise HTTPException( - status_code=401, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - return None + return self._handle_auth_error("Not authenticated") + # Get session data from Redis session = get_session(session_id) if not session: - if self.auto_error: - raise HTTPException( - status_code=401, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - return None + return self._handle_auth_error("Not authenticated") - # Check if token is expired and refresh if needed + # Handle token expiration if is_token_expired(session): # Try to refresh the token success, new_session = await refresh_token(session_id, session) if not success: - # Token refresh failed, user needs to re-authenticate - if self.auto_error: - raise HTTPException( - status_code=401, - detail="Session expired", - headers={"WWW-Authenticate": "Bearer"}, - ) - return None - # Use the refreshed token data + return self._handle_auth_error("Session expired") session = new_session - - return SessionData( + + # Create user session object + user_session = UserSession( access_token=session.get('access_token'), token_data=session ) - -# Create instances for use in route handlers -require_auth = AuthDependency(auto_error=True) -optional_auth = AuthDependency(auto_error=False) - -# JWT token handling dependencies -async def get_decoded_token( - auth: SessionData = Depends(require_auth) -) -> Dict[str, Any]: - - token_data = auth.token_data - access_token = token_data.get("access_token") - - return jwt.decode(access_token, options={"verify_signature": False}) - - -async def get_current_user( - decoded_token: Dict[str, Any] = Depends(get_decoded_token), - user_service: UserService = Depends(get_user_service), -) -> Dict[str, Any]: - - user_id = decoded_token["sub"] - user_info = await user_service.get_user(user_id) - - if not user_info: - try: - user_info = await user_service.create_user( - user_id=user_id, - username=decoded_token["preferred_username"], - email=decoded_token["email"], - email_verified=decoded_token["email_verified"], - name=decoded_token["name"], - given_name=decoded_token["given_name"], - family_name=decoded_token["family_name"], - roles=decoded_token["realm_access"]["roles"], - ) - except Exception as e: + + # Check admin requirement if specified + if self.require_admin and not user_session.is_admin: + return self._handle_auth_error("Admin privileges required", status_code=403) + + return user_session + + def _handle_auth_error(self, detail: str, status_code: int = 401) -> Optional[None]: + """Handle authentication errors based on auto_error setting""" + if self.auto_error: + headers = {"WWW-Authenticate": "Bearer"} if status_code == 401 else None raise HTTPException( - status_code=500, - detail=f"Error creating user: {e}" + status_code=status_code, + detail=detail, + headers=headers, ) - - return user_info - - -async def require_admin( - decoded_token: Dict[str, Any] = Depends(get_decoded_token) -) -> bool: - - roles = decoded_token.get("realm_access", {}).get("roles", []) + return None - if "admin" not in roles: - raise HTTPException( - status_code=403, - detail=f"Admin privileges required" - ) - - return True +# Create dependency instances for use in route handlers +require_auth = AuthDependency(auto_error=True) +optional_auth = AuthDependency(auto_error=False) +require_admin = AuthDependency(auto_error=True, require_admin=True) diff --git a/src/backend/main.py b/src/backend/main.py index b39f1f7..8d99b49 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -13,11 +13,9 @@ from db import init_db from database import init_db as init_database from config import STATIC_DIR, ASSETS_DIR -from dependencies import SessionData, optional_auth +from dependencies import UserSession, optional_auth from routers.auth import auth_router -from routers.canvas import canvas_router -from routers.user import user_router -from routers.user_router import user_router as user_router_v2 +from routers.user_router import user_router from routers.workspace import workspace_router from routers.pad_router import pad_router from routers.template_pad_router import template_pad_router @@ -109,14 +107,12 @@ async def lifespan(_: FastAPI): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") @app.get("/") -async def read_root(request: Request, auth: Optional[SessionData] = Depends(optional_auth)): +async def read_root(request: Request, auth: Optional[UserSession] = Depends(optional_auth)): return FileResponse(os.path.join(STATIC_DIR, "index.html")) # Include routers in the main app with the /api prefix app.include_router(auth_router, prefix="/auth") -app.include_router(canvas_router, prefix="/api/canvas") -app.include_router(user_router, prefix="/api/user") -app.include_router(user_router_v2, prefix="/api/users") +app.include_router(user_router, prefix="/api/users") app.include_router(workspace_router, prefix="/api/workspace") app.include_router(pad_router, prefix="/api/pad") app.include_router(template_pad_router, prefix="/api/templates") diff --git a/src/backend/routers/auth.py b/src/backend/routers/auth.py index f97457f..dbc2b0f 100644 --- a/src/backend/routers/auth.py +++ b/src/backend/routers/auth.py @@ -6,7 +6,7 @@ import os from config import get_auth_url, get_token_url, OIDC_CONFIG, set_session, delete_session, STATIC_DIR, get_session -from dependencies import SessionData, require_auth +from dependencies import UserSession, require_auth from coder import CoderAPI auth_router = APIRouter() diff --git a/src/backend/routers/canvas.py b/src/backend/routers/canvas.py deleted file mode 100644 index 5f20b6e..0000000 --- a/src/backend/routers/canvas.py +++ /dev/null @@ -1,67 +0,0 @@ -import json -import jwt -from typing import Dict, Any, List -from fastapi import APIRouter, HTTPException, Depends, Request -from fastapi.responses import JSONResponse - -from dependencies import SessionData, require_auth -from db import store_canvas_data, get_canvas_data, get_recent_canvases, MAX_BACKUPS_PER_USER -import posthog - -canvas_router = APIRouter() - -def get_default_canvas_data(): - try: - with open("default_canvas.json", "r") as f: - return json.load(f) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Failed to load default canvas: {str(e)}" - ) - -@canvas_router.get("/default") -async def get_default_canvas(auth: SessionData = Depends(require_auth)): - try: - with open("default_canvas.json", "r") as f: - canvas_data = json.load(f) - return canvas_data - except Exception as e: - return JSONResponse( - status_code=500, - content={"error": f"Failed to load default canvas: {str(e)}"} - ) - -@canvas_router.post("") -async def save_canvas(data: Dict[str, Any], auth: SessionData = Depends(require_auth), request: Request = None): - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - user_id = decoded["sub"] - success = await store_canvas_data(user_id, data) - if not success: - raise HTTPException(status_code=500, detail="Failed to save canvas data") - return {"status": "success"} - -@canvas_router.get("") -async def get_canvas(auth: SessionData = Depends(require_auth)): - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - user_id = decoded["sub"] - data = await get_canvas_data(user_id) - if data is None: - return get_default_canvas_data() - return data - -@canvas_router.get("/recent") -async def get_recent_canvas_backups(limit: int = MAX_BACKUPS_PER_USER, auth: SessionData = Depends(require_auth)): - """Get the most recent canvas backups for the authenticated user""" - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) - user_id = decoded["sub"] - - # Limit the number of backups to the maximum configured value - if limit > MAX_BACKUPS_PER_USER: - limit = MAX_BACKUPS_PER_USER - - backups = await get_recent_canvases(user_id, limit) - return {"backups": backups} diff --git a/src/backend/routers/pad_router.py b/src/backend/routers/pad_router.py new file mode 100644 index 0000000..7c61d99 --- /dev/null +++ b/src/backend/routers/pad_router.py @@ -0,0 +1,169 @@ +from uuid import UUID +from typing import Dict, Any + +from fastapi import APIRouter, HTTPException, Depends, Request + +from dependencies import UserSession, require_auth +from database import get_pad_service, get_backup_service, get_template_pad_service +from database.service import PadService, BackupService, TemplatePadService + +# Default template name to use when a user doesn't have a pad +DEFAULT_TEMPLATE_NAME = "default" + +pad_router = APIRouter() + +# Constants +MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user +DEFAULT_PAD_NAME = "Untitled" # Default name for new pads + +@pad_router.post("") +async def save_canvas( + data: Dict[str, Any], + auth: UserSession = Depends(require_auth), + pad_service: PadService = Depends(get_pad_service), + backup_service: BackupService = Depends(get_backup_service), + request: Request = None +): + """Save canvas data for the authenticated user""" + # Get user ID from session + user_id = auth.user_id + + try: + # Check if user already has a pad + user_pads = await pad_service.get_pads_by_owner(user_id) + + if not user_pads: + # Create a new pad if user doesn't have one + pad = await pad_service.create_pad( + owner_id=user_id, + display_name=DEFAULT_PAD_NAME, + data=data + ) + pad_id = UUID(pad["id"]) + else: + # Update existing pad + pad = user_pads[0] # Use the first pad (assuming one pad per user for now) + pad_id = UUID(pad["id"]) + await pad_service.update_pad_data(pad_id, data) + + # Create a backup + await backup_service.create_backup(pad_id, data) + + # Manage backups (keep only the most recent ones) + await backup_service.manage_backups(pad_id, MAX_BACKUPS_PER_USER) + + return {"status": "success"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to save canvas data: {str(e)}") + +async def get_default_canvas_data(template_pad_service: TemplatePadService) -> Dict[str, Any]: + """Get default canvas data from the template pad with name 'default'""" + try: + # Get the default template pad + default_template = await template_pad_service.get_template_by_name(DEFAULT_TEMPLATE_NAME) + if not default_template: + # Return empty data if default template doesn't exist + return {} + + # Return the template data + return default_template["data"] + except Exception as e: + # Log the error but return empty data to avoid breaking the application + print(f"Error getting default canvas data: {str(e)}") + return {} + +@pad_router.get("") +async def get_canvas( + auth: UserSession = Depends(require_auth), + pad_service: PadService = Depends(get_pad_service), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Get canvas data for the authenticated user""" + # Get user ID from session + user_id = auth.user_id + + try: + # Get user's pads + user_pads = await pad_service.get_pads_by_owner(user_id) + + if not user_pads: + # Return default canvas if user doesn't have a pad + return await get_default_canvas_data(template_pad_service) + + # Return the first pad's data (assuming one pad per user for now) + return user_pads[0]["data"] + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get canvas data: {str(e)}") + +@pad_router.post("/from-template/{template_id}") +async def create_pad_from_template( + template_id: UUID, + display_name: str = DEFAULT_PAD_NAME, + auth: UserSession = Depends(require_auth), + pad_service: PadService = Depends(get_pad_service), + template_pad_service: TemplatePadService = Depends(get_template_pad_service) +): + """Create a new pad from a template""" + # Get user ID from session + user_id = auth.user_id + + try: + # Get the template + template = await template_pad_service.get_template(template_id) + if not template: + raise HTTPException(status_code=404, detail="Template not found") + + # Create a new pad using the template data + pad = await pad_service.create_pad( + owner_id=user_id, + display_name=display_name, + data=template["data"] + ) + + return pad + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create pad from template: {str(e)}") + +@pad_router.get("/recent") +async def get_recent_canvas_backups( + limit: int = MAX_BACKUPS_PER_USER, + auth: UserSession = Depends(require_auth), + pad_service: PadService = Depends(get_pad_service), + backup_service: BackupService = Depends(get_backup_service) +): + """Get the most recent canvas backups for the authenticated user""" + # Get user ID from session + user_id = auth.user_id + + # Limit the number of backups to the maximum configured value + if limit > MAX_BACKUPS_PER_USER: + limit = MAX_BACKUPS_PER_USER + + try: + # Get user's pads + user_pads = await pad_service.get_pads_by_owner(user_id) + + if not user_pads: + # Return empty list if user doesn't have a pad + return {"backups": []} + + # Get the first pad's ID (assuming one pad per user for now) + pad_id = UUID(user_pads[0]["id"]) + + # Get backups for the pad + backups_data = await backup_service.get_backups_by_source(pad_id) + + # Format backups to match the expected response format + backups = [] + for backup in backups_data[:limit]: + backups.append({ + "id": backup["id"], + "timestamp": backup["created_at"], + "data": backup["data"] + }) + + return {"backups": backups} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get canvas backups: {str(e)}") diff --git a/src/backend/routers/template_pad_router.py b/src/backend/routers/template_pad_router.py index e2c9d35..77e4928 100644 --- a/src/backend/routers/template_pad_router.py +++ b/src/backend/routers/template_pad_router.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends -from dependencies import SessionData, require_auth, require_admin +from dependencies import UserSession, require_auth, require_admin from database import get_template_pad_service from database.service import TemplatePadService @@ -45,7 +45,7 @@ async def get_all_template_pads( @template_pad_router.get("/{name}") async def get_template_pad( name: str, - _: SessionData = Depends(require_auth), + _: UserSession = Depends(require_auth), template_pad_service: TemplatePadService = Depends(get_template_pad_service) ): """Get a specific template pad by name""" diff --git a/src/backend/routers/user.py b/src/backend/routers/user.py deleted file mode 100644 index 969da69..0000000 --- a/src/backend/routers/user.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import jwt - -import posthog -from fastapi import APIRouter, Depends, Request -from dotenv import load_dotenv - -from dependencies import SessionData, require_auth -from config import redis_client - -load_dotenv() - -user_router = APIRouter() - -@user_router.get("/me") -async def get_user_info(auth: SessionData = Depends(require_auth), request: Request = None): - token_data = auth.token_data - access_token = token_data.get("access_token") - - decoded = jwt.decode(access_token, options={"verify_signature": False}) - - # Build full URL (mirroring canvas.py logic) - full_url = None - if request: - full_url = str(request.base_url).rstrip("/") + str(request.url.path) - full_url = full_url.replace("http://", "https://") - - user_data: dict = { - "id": decoded["sub"], # Unique user ID - "email": decoded.get("email", ""), - "username": decoded.get("preferred_username", ""), - "name": decoded.get("name", ""), - "given_name": decoded.get("given_name", ""), - "family_name": decoded.get("family_name", ""), - "email_verified": decoded.get("email_verified", False) - } - - if os.getenv("VITE_PUBLIC_POSTHOG_KEY"): - telemetry = user_data | {"$current_url": full_url} - posthog.identify(distinct_id=decoded["sub"], properties=telemetry) - - return user_data - -@user_router.get("/count") -async def get_user_count(auth: SessionData = Depends(require_auth)): - """ - Get the count of active sessions in Redis - - Returns: - dict: A dictionary containing the count of active sessions - """ - # Count keys that match the session pattern - session_count = len(redis_client.keys("session:*")) - - return { - "active_sessions": session_count, - "message": f"There are currently {session_count} active sessions in Redis" - } \ No newline at end of file diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py index 692a9fe..488314a 100644 --- a/src/backend/routers/user_router.py +++ b/src/backend/routers/user_router.py @@ -1,13 +1,13 @@ import os from uuid import UUID - +from typing import Dict, Any import posthog from fastapi import APIRouter, Depends, HTTPException from config import redis_client, OIDC_CONFIG from database import get_user_service from database.service import UserService -from dependencies import get_current_user, require_admin +from dependencies import UserSession, require_admin, require_auth user_router = APIRouter() @@ -54,13 +54,35 @@ async def get_all_users( @user_router.get("/me") async def get_user_info( - user: dict = Depends(get_current_user), + user: UserSession = Depends(require_auth), + user_service: UserService = Depends(get_user_service), ): """Get the current user's information""" + + user_data = await user.get_user_data(user_service) + + if not user_data: + try: + user = await user_service.create_user( + user_id=user.id, + username=user.username, + email=user.email, + email_verified=user.email_verified, + name=user.name, + given_name=user.given_name, + family_name=user.family_name, + roles=user.roles, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error creating user: {e}" + ) + if os.getenv("VITE_PUBLIC_POSTHOG_KEY"): - telemetry = user.copy() + telemetry = user_data.copy() telemetry["$current_url"] = OIDC_CONFIG["frontend_url"] - posthog.identify(distinct_id=user["id"], properties=telemetry) + posthog.identify(distinct_id=user_data["id"], properties=telemetry) return user @@ -85,4 +107,4 @@ async def get_user( if not user: raise HTTPException(status_code=404, detail="User not found") - return user \ No newline at end of file + return user diff --git a/src/backend/routers/workspace.py b/src/backend/routers/workspace.py index 56af56f..1254021 100644 --- a/src/backend/routers/workspace.py +++ b/src/backend/routers/workspace.py @@ -2,10 +2,9 @@ from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import JSONResponse from pydantic import BaseModel -import jwt import os -from dependencies import SessionData, require_auth +from dependencies import UserSession, require_auth from coder import CoderAPI workspace_router = APIRouter() @@ -19,13 +18,12 @@ class WorkspaceState(BaseModel): agent: str @workspace_router.get("/state", response_model=WorkspaceState) -async def get_workspace_state(auth: SessionData = Depends(require_auth)): +async def get_workspace_state(auth: UserSession = Depends(require_auth)): """ Get the current state of the user's workspace """ - # Get user info from token - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) + # Get user info from token data + decoded = auth.token_data username = decoded.get("preferred_username") email = decoded.get("email") @@ -56,13 +54,12 @@ async def get_workspace_state(auth: SessionData = Depends(require_auth)): ) @workspace_router.post("/start") -async def start_workspace(auth: SessionData = Depends(require_auth)): +async def start_workspace(auth: UserSession = Depends(require_auth)): """ Start a workspace for the authenticated user """ - # Get user info from token - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) + # Get user info from token data + decoded = auth.token_data email = decoded.get("email") user = coder_api.get_user_by_email(email) @@ -82,13 +79,12 @@ async def start_workspace(auth: SessionData = Depends(require_auth)): raise HTTPException(status_code=500, detail=str(e)) @workspace_router.post("/stop") -async def stop_workspace(auth: SessionData = Depends(require_auth)): +async def stop_workspace(auth: UserSession = Depends(require_auth)): """ Stop a workspace for the authenticated user """ - # Get user info from token - access_token = auth.token_data.get("access_token") - decoded = jwt.decode(access_token, options={"verify_signature": False}) + # Get user info from token data + decoded = auth.token_data email = decoded.get("email") user = coder_api.get_user_by_email(email) @@ -106,4 +102,3 @@ async def stop_workspace(auth: SessionData = Depends(require_auth)): return JSONResponse(content=response) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - From 0d143715a81ff1c8447ae58f4ed37626e29f5eee Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 16:14:03 +0000 Subject: [PATCH 19/32] refactor: streamline user session handling and update API endpoints - Simplified UserSession initialization by decoding the JWT token directly, enhancing security and clarity. - Removed unused to_dict method from UserSession to reduce code complexity. - Updated AuthDependency to eliminate unnecessary user service dependency, streamlining session validation. - Corrected API endpoint in hooks.ts from '/api/user/me' to '/api/users/me' for consistency with routing. - Cleaned up imports in user_router.py to maintain code organization. --- src/backend/dependencies.py | 24 ++++-------------------- src/backend/routers/user_router.py | 2 +- src/frontend/src/api/hooks.ts | 2 +- 3 files changed, 6 insertions(+), 22 deletions(-) diff --git a/src/backend/dependencies.py b/src/backend/dependencies.py index d32e279..b8de96c 100644 --- a/src/backend/dependencies.py +++ b/src/backend/dependencies.py @@ -1,11 +1,10 @@ import jwt -from typing import Optional, Dict, Any, Union +from typing import Optional, Dict, Any from uuid import UUID -from fastapi import Request, HTTPException, Depends +from fastapi import Request, HTTPException from config import get_session, is_token_expired, refresh_token -from database import get_user_service from database.service import UserService class UserSession: @@ -15,7 +14,7 @@ class UserSession: """ def __init__(self, access_token: str, token_data: dict, user_id: UUID = None): self.access_token = access_token - self.token_data = token_data + self.token_data = jwt.decode(access_token, options={"verify_signature": False}) self._user_data = None @property @@ -68,19 +67,6 @@ def is_admin(self) -> bool: """Check if user has admin role""" return "admin" in self.roles - def to_dict(self) -> Dict[str, Any]: - """Convert user session to dictionary with common user fields""" - return { - "id": str(self.id), - "email": self.email, - "username": self.username, - "name": self.name, - "given_name": self.given_name, - "family_name": self.family_name, - "email_verified": self.email_verified, - "roles": self.roles - } - async def get_user_data(self, user_service: UserService) -> Dict[str, Any]: """Get user data from database, caching the result""" if self._user_data is None and self.id: @@ -96,9 +82,7 @@ def __init__(self, auto_error: bool = True, require_admin: bool = False): self.auto_error = auto_error self.require_admin = require_admin - async def __call__(self, - request: Request, - user_service: UserService = Depends(get_user_service)) -> Optional[UserSession]: + async def __call__(self, request: Request) -> Optional[UserSession]: # Get session ID from cookies session_id = request.cookies.get('session_id') diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py index 488314a..2cb1bcd 100644 --- a/src/backend/routers/user_router.py +++ b/src/backend/routers/user_router.py @@ -1,6 +1,6 @@ import os from uuid import UUID -from typing import Dict, Any + import posthog from fastapi import APIRouter, Depends, HTTPException diff --git a/src/frontend/src/api/hooks.ts b/src/frontend/src/api/hooks.ts index 58d198f..f3cf285 100644 --- a/src/frontend/src/api/hooks.ts +++ b/src/frontend/src/api/hooks.ts @@ -56,7 +56,7 @@ export const api = { // User profile getUserProfile: async (): Promise => { try { - const result = await fetchApi('/api/user/me'); + const result = await fetchApi('/api/users/me'); return result; } catch (error) { throw error; From 20923d69b542dd02abb8c30ebd0b363dcb69c9f3 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 16:44:35 +0000 Subject: [PATCH 20/32] refactor: reorganize workspace routing and implement user email uniqueness - Updated import statement in main.py to reflect the new workspace_router file structure. - Modified UserModel to enforce unique email addresses for users, improving data integrity. - Introduced a new workspace_router.py file to manage workspace-related endpoints, replacing the deprecated workspace.py, streamlining the API structure. --- src/backend/database/models/user_model.py | 2 +- src/backend/main.py | 2 +- src/backend/routers/workspace.py | 104 ---------------------- src/backend/routers/workspace_router.py | 81 +++++++++++++++++ 4 files changed, 83 insertions(+), 106 deletions(-) delete mode 100644 src/backend/routers/workspace.py create mode 100644 src/backend/routers/workspace_router.py diff --git a/src/backend/database/models/user_model.py b/src/backend/database/models/user_model.py index 6edb2ae..28526b5 100644 --- a/src/backend/database/models/user_model.py +++ b/src/backend/database/models/user_model.py @@ -22,7 +22,7 @@ class UserModel(Base, BaseModel): # User-specific fields username = Column(String(254), nullable=False, unique=True) - email = Column(String(254), nullable=False) + email = Column(String(254), nullable=False, unique=True) email_verified = Column(Boolean, nullable=False, default=False) name = Column(String(254), nullable=True) given_name = Column(String(254), nullable=True) diff --git a/src/backend/main.py b/src/backend/main.py index 8d99b49..2604360 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -16,7 +16,7 @@ from dependencies import UserSession, optional_auth from routers.auth import auth_router from routers.user_router import user_router -from routers.workspace import workspace_router +from routers.workspace_router import workspace_router from routers.pad_router import pad_router from routers.template_pad_router import template_pad_router from database.service import TemplatePadService diff --git a/src/backend/routers/workspace.py b/src/backend/routers/workspace.py deleted file mode 100644 index 1254021..0000000 --- a/src/backend/routers/workspace.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Dict, Any -from fastapi import APIRouter, HTTPException, Depends -from fastapi.responses import JSONResponse -from pydantic import BaseModel -import os - -from dependencies import UserSession, require_auth -from coder import CoderAPI - -workspace_router = APIRouter() -coder_api = CoderAPI() - -class WorkspaceState(BaseModel): - state: str - workspace_id: str - username: str - base_url: str - agent: str - -@workspace_router.get("/state", response_model=WorkspaceState) -async def get_workspace_state(auth: UserSession = Depends(require_auth)): - """ - Get the current state of the user's workspace - """ - # Get user info from token data - decoded = auth.token_data - username = decoded.get("preferred_username") - email = decoded.get("email") - - # Get user's workspaces - user = coder_api.get_user_by_email(email) - username = user.get('username', None) - if not username: - raise HTTPException(status_code=404, detail="User not found") - - workspace = coder_api.get_workspace_status_for_user(username) - - if not workspace: - raise HTTPException(status_code=404, detail="No workspace found for user") - - #states can be: - #starting - #running - #stopping - #stopped - #error - - return WorkspaceState( - state=workspace.get('latest_build', {}).get('status', 'error'), - workspace_id=workspace.get('latest_build', {}).get('workspace_name', ''), - username=username, - base_url=os.getenv("CODER_URL", ""), - agent="main" - ) - -@workspace_router.post("/start") -async def start_workspace(auth: UserSession = Depends(require_auth)): - """ - Start a workspace for the authenticated user - """ - # Get user info from token data - decoded = auth.token_data - email = decoded.get("email") - - user = coder_api.get_user_by_email(email) - username = user.get('username', None) - if not username: - raise HTTPException(status_code=404, detail="User not found") - # Get user's workspace - workspace = coder_api.get_workspace_status_for_user(username) - if not workspace: - raise HTTPException(status_code=404, detail="No workspace found for user") - - # Start the workspace - try: - response = coder_api.start_workspace(workspace["id"]) - return JSONResponse(content=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@workspace_router.post("/stop") -async def stop_workspace(auth: UserSession = Depends(require_auth)): - """ - Stop a workspace for the authenticated user - """ - # Get user info from token data - decoded = auth.token_data - email = decoded.get("email") - - user = coder_api.get_user_by_email(email) - username = user.get('username', None) - if not username: - raise HTTPException(status_code=404, detail="User not found") - # Get user's workspace - workspace = coder_api.get_workspace_status_for_user(username) - if not workspace: - raise HTTPException(status_code=404, detail="No workspace found for user") - - # Stop the workspace - try: - response = coder_api.stop_workspace(workspace["id"]) - return JSONResponse(content=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/backend/routers/workspace_router.py b/src/backend/routers/workspace_router.py new file mode 100644 index 0000000..da02524 --- /dev/null +++ b/src/backend/routers/workspace_router.py @@ -0,0 +1,81 @@ +import os + +from pydantic import BaseModel +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse + +from dependencies import UserSession, require_auth +from coder import CoderAPI + +workspace_router = APIRouter() +coder_api = CoderAPI() + +class WorkspaceState(BaseModel): + id: str + state: str + name: str + username: str + base_url: str + agent: str + +@workspace_router.get("/state", response_model=WorkspaceState) +async def get_workspace_state(user: UserSession = Depends(require_auth)): + """ + Get the current state of the user's workspace + """ + + coder_user: dict = coder_api.get_user_by_email(user.email) + if not coder_user: + raise HTTPException(status_code=404, detail=f"Coder user not found for user {user.email} ({user.id})") + + coder_username: str = coder_user.get('username', None) + if not coder_username: + raise HTTPException(status_code=404, detail=f"Coder username not found for user {user.email} ({user.id})") + + workspace: dict = coder_api.get_workspace_status_for_user(coder_username) + if not workspace: + raise HTTPException(status_code=404, detail=f"Coder Workspace not found for user {user.email} ({user.id})") + + latest_build: dict = workspace.get('latest_build', {}) + latest_build_status: str = latest_build.get('status', 'error') + workspace_name: str = latest_build.get('workspace_name', None) + workspace_id: str = workspace.get('id', {}) + + return WorkspaceState( + state=latest_build_status, + id=workspace_id, + name=workspace_name, + username=coder_username, + base_url=os.getenv("CODER_URL", ""), + agent="main" + ) + + +@workspace_router.post("/start") +async def start_workspace(user: UserSession = Depends(require_auth)): + """ + Start a workspace for the authenticated user + """ + + workspace: WorkspaceState = await get_workspace_state(user) + + try: + response = coder_api.start_workspace(workspace.id) + return JSONResponse(content=response) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@workspace_router.post("/stop") +async def stop_workspace(user: UserSession = Depends(require_auth)): + """ + Stop a workspace for the authenticated user + """ + + workspace: WorkspaceState = await get_workspace_state(user) + + try: + response = coder_api.stop_workspace(workspace.id) + return JSONResponse(content=response) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file From 7b57a99ab39a1e43b8bfc75331e84ec0e5e8a583 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 17:00:11 +0000 Subject: [PATCH 21/32] refactor: update pad router and API endpoint consistency - Refactored pad router to streamline user pad management, including saving and retrieving canvas data. - Updated user session handling to use user.id instead of user_id for consistency. - Changed API endpoints in frontend hooks to align with new pad routing structure, enhancing clarity and usability. - Removed deprecated default canvas data retrieval function, simplifying the codebase. --- src/backend/routers/pad_router.py | 82 ++++++++++--------------------- src/frontend/src/api/hooks.ts | 8 +-- 2 files changed, 29 insertions(+), 61 deletions(-) diff --git a/src/backend/routers/pad_router.py b/src/backend/routers/pad_router.py index 7c61d99..5985574 100644 --- a/src/backend/routers/pad_router.py +++ b/src/backend/routers/pad_router.py @@ -7,115 +7,89 @@ from database import get_pad_service, get_backup_service, get_template_pad_service from database.service import PadService, BackupService, TemplatePadService -# Default template name to use when a user doesn't have a pad -DEFAULT_TEMPLATE_NAME = "default" - -pad_router = APIRouter() - # Constants MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user DEFAULT_PAD_NAME = "Untitled" # Default name for new pads +DEFAULT_TEMPLATE_NAME = "default" # Template name to use when a user doesn't have a pad + +pad_router = APIRouter() -@pad_router.post("") + +@pad_router.post("/") async def save_canvas( data: Dict[str, Any], - auth: UserSession = Depends(require_auth), + user: UserSession = Depends(require_auth), pad_service: PadService = Depends(get_pad_service), backup_service: BackupService = Depends(get_backup_service), - request: Request = None ): """Save canvas data for the authenticated user""" - # Get user ID from session - user_id = auth.user_id - try: # Check if user already has a pad - user_pads = await pad_service.get_pads_by_owner(user_id) + user_pads = await pad_service.get_pads_by_owner(user.id) if not user_pads: # Create a new pad if user doesn't have one pad = await pad_service.create_pad( - owner_id=user_id, + owner_id=user.id, display_name=DEFAULT_PAD_NAME, data=data ) - pad_id = UUID(pad["id"]) else: # Update existing pad pad = user_pads[0] # Use the first pad (assuming one pad per user for now) - pad_id = UUID(pad["id"]) - await pad_service.update_pad_data(pad_id, data) + await pad_service.update_pad_data(pad["id"], data) # Create a backup - await backup_service.create_backup(pad_id, data) + await backup_service.create_backup(pad["id"], data) # Manage backups (keep only the most recent ones) - await backup_service.manage_backups(pad_id, MAX_BACKUPS_PER_USER) + await backup_service.manage_backups(pad["id"], MAX_BACKUPS_PER_USER) return {"status": "success"} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to save canvas data: {str(e)}") -async def get_default_canvas_data(template_pad_service: TemplatePadService) -> Dict[str, Any]: - """Get default canvas data from the template pad with name 'default'""" - try: - # Get the default template pad - default_template = await template_pad_service.get_template_by_name(DEFAULT_TEMPLATE_NAME) - if not default_template: - # Return empty data if default template doesn't exist - return {} - - # Return the template data - return default_template["data"] - except Exception as e: - # Log the error but return empty data to avoid breaking the application - print(f"Error getting default canvas data: {str(e)}") - return {} -@pad_router.get("") +@pad_router.get("/") async def get_canvas( - auth: UserSession = Depends(require_auth), + user: UserSession = Depends(require_auth), pad_service: PadService = Depends(get_pad_service), template_pad_service: TemplatePadService = Depends(get_template_pad_service) ): """Get canvas data for the authenticated user""" - # Get user ID from session - user_id = auth.user_id - try: # Get user's pads - user_pads = await pad_service.get_pads_by_owner(user_id) + user_pads = await pad_service.get_pads_by_owner(user.id) if not user_pads: # Return default canvas if user doesn't have a pad - return await get_default_canvas_data(template_pad_service) + return await create_pad_from_template(DEFAULT_TEMPLATE_NAME, DEFAULT_PAD_NAME, user, pad_service, template_pad_service) # Return the first pad's data (assuming one pad per user for now) return user_pads[0]["data"] except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get canvas data: {str(e)}") -@pad_router.post("/from-template/{template_id}") + +@pad_router.post("/from-template/{name}") async def create_pad_from_template( - template_id: UUID, + name: str, display_name: str = DEFAULT_PAD_NAME, - auth: UserSession = Depends(require_auth), + user: UserSession = Depends(require_auth), pad_service: PadService = Depends(get_pad_service), template_pad_service: TemplatePadService = Depends(get_template_pad_service) ): """Create a new pad from a template""" - # Get user ID from session - user_id = auth.user_id - + try: # Get the template - template = await template_pad_service.get_template(template_id) + template = await template_pad_service.get_template_by_name(name) if not template: raise HTTPException(status_code=404, detail="Template not found") # Create a new pad using the template data pad = await pad_service.create_pad( - owner_id=user_id, + owner_id=user.id, display_name=display_name, data=template["data"] ) @@ -126,28 +100,22 @@ async def create_pad_from_template( except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create pad from template: {str(e)}") + @pad_router.get("/recent") async def get_recent_canvas_backups( limit: int = MAX_BACKUPS_PER_USER, - auth: UserSession = Depends(require_auth), + user: UserSession = Depends(require_auth), pad_service: PadService = Depends(get_pad_service), backup_service: BackupService = Depends(get_backup_service) ): """Get the most recent canvas backups for the authenticated user""" - # Get user ID from session - user_id = auth.user_id - # Limit the number of backups to the maximum configured value if limit > MAX_BACKUPS_PER_USER: limit = MAX_BACKUPS_PER_USER try: # Get user's pads - user_pads = await pad_service.get_pads_by_owner(user_id) - - if not user_pads: - # Return empty list if user doesn't have a pad - return {"backups": []} + user_pads = await pad_service.get_pads_by_owner(user.id) # Get the first pad's ID (assuming one pad per user for now) pad_id = UUID(user_pads[0]["id"]) diff --git a/src/frontend/src/api/hooks.ts b/src/frontend/src/api/hooks.ts index f3cf285..932cf28 100644 --- a/src/frontend/src/api/hooks.ts +++ b/src/frontend/src/api/hooks.ts @@ -96,7 +96,7 @@ export const api = { // Canvas getCanvas: async (): Promise => { try { - const result = await fetchApi('/api/canvas'); + const result = await fetchApi('/api/pad/'); return result; } catch (error) { throw error; @@ -105,7 +105,7 @@ export const api = { saveCanvas: async (data: CanvasData): Promise => { try { - const result = await fetchApi('/api/canvas', { + const result = await fetchApi('/api/pad/', { method: 'POST', body: JSON.stringify(data), }); @@ -117,7 +117,7 @@ export const api = { getDefaultCanvas: async (): Promise => { try { - const result = await fetchApi('/api/canvas/default'); + const result = await fetchApi('/api/pad/from-template/default'); return result; } catch (error) { throw error; @@ -127,7 +127,7 @@ export const api = { // Canvas Backups getCanvasBackups: async (limit: number = 10): Promise => { try { - const result = await fetchApi(`/api/canvas/recent?limit=${limit}`); + const result = await fetchApi(`/api/pad/recent?limit=${limit}`); return result; } catch (error) { throw error; From 63447ed06c47a24eb5657330413dcd5faec17f5c Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 17:03:24 +0000 Subject: [PATCH 22/32] refactor: update workspace state management in API and components - Modified WorkspaceState interface to include 'name' and 'id' properties, enhancing workspace identification. - Updated ActionButton and Terminal components to utilize 'name' instead of 'workspace_id' for URL generation, improving consistency across the application. - Ensured all relevant references to workspace identification are aligned with the new state structure, streamlining the codebase. --- src/frontend/src/api/hooks.ts | 3 ++- src/frontend/src/pad/buttons/ActionButton.tsx | 6 +++--- src/frontend/src/pad/containers/Terminal.tsx | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/frontend/src/api/hooks.ts b/src/frontend/src/api/hooks.ts index 932cf28..653d3ce 100644 --- a/src/frontend/src/api/hooks.ts +++ b/src/frontend/src/api/hooks.ts @@ -5,10 +5,11 @@ import { queryClient } from './queryClient'; // Types export interface WorkspaceState { status: 'running' | 'starting' | 'stopping' | 'stopped' | 'error'; - workspace_id: string | null; username: string | null; + name: string | null; base_url: string | null; agent: string | null; + id: string | null; error?: string; } diff --git a/src/frontend/src/pad/buttons/ActionButton.tsx b/src/frontend/src/pad/buttons/ActionButton.tsx index 91118f5..c752e05 100644 --- a/src/frontend/src/pad/buttons/ActionButton.tsx +++ b/src/frontend/src/pad/buttons/ActionButton.tsx @@ -285,7 +285,7 @@ const ActionButton: React.FC = ({ return ''; } - return `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.workspace_id}.${workspaceState.agent}/apps/code-server`; + return `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.name}.${workspaceState.agent}/apps/code-server`; }; // Placement logic has been moved to ExcalidrawElementFactory.placeInScene @@ -345,7 +345,7 @@ const ActionButton: React.FC = ({ return; } - const terminalUrl = `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.workspace_id}.${workspaceState.agent}/terminal`; + const terminalUrl = `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.name}.${workspaceState.agent}/terminal`; console.debug(`[pad.ws] Opening terminal in new tab: ${terminalUrl}`); window.open(terminalUrl, '_blank'); } else { @@ -366,7 +366,7 @@ const ActionButton: React.FC = ({ } const owner = workspaceState.username; - const workspace = workspaceState.workspace_id; + const workspace = workspaceState.name; const url = workspaceState.base_url; const agent = workspaceState.agent; diff --git a/src/frontend/src/pad/containers/Terminal.tsx b/src/frontend/src/pad/containers/Terminal.tsx index 9460bba..64195ff 100644 --- a/src/frontend/src/pad/containers/Terminal.tsx +++ b/src/frontend/src/pad/containers/Terminal.tsx @@ -56,7 +56,7 @@ export const Terminal: React.FC = ({ terminalId, baseUrl: workspaceState.base_url, username: workspaceState.username, - workspaceId: workspaceState.workspace_id, + workspaceId: workspaceState.name, agent: workspaceState.agent }; @@ -162,7 +162,7 @@ export const Terminal: React.FC = ({ return ''; } - const baseUrl = `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.workspace_id}.${workspaceState.agent}/terminal`; + const baseUrl = `${workspaceState.base_url}/@${workspaceState.username}/${workspaceState.name}.${workspaceState.agent}/terminal`; // Add reconnect parameter if terminal ID exists if (terminalId) { From de080bb1bec8ff3a515253b4492f9061b3419fbd Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 17:20:15 +0000 Subject: [PATCH 23/32] refactor: remove db.py and reorganize database initialization - Deleted db.py to streamline database management and reduce redundancy. - Updated main.py to directly import init_db from the database module, simplifying the database initialization process. - Introduced auth_router.py to handle authentication routes, enhancing modularity and organization of the codebase. - Ensured dotenv loading is handled in the database module for consistent environment variable access. --- src/backend/database/database.py | 4 + src/backend/db.py | 208 ------------------ src/backend/main.py | 6 +- .../routers/{auth.py => auth_router.py} | 3 +- 4 files changed, 7 insertions(+), 214 deletions(-) delete mode 100644 src/backend/db.py rename src/backend/routers/{auth.py => auth_router.py} (96%) diff --git a/src/backend/database/database.py b/src/backend/database/database.py index 7e9e01b..09e695d 100644 --- a/src/backend/database/database.py +++ b/src/backend/database/database.py @@ -13,6 +13,10 @@ from .models import Base, SCHEMA_NAME +from dotenv import load_dotenv + +load_dotenv() + # PostgreSQL connection configuration DB_USER = os.getenv('POSTGRES_USER', 'postgres') DB_PASSWORD = os.getenv('POSTGRES_PASSWORD', 'postgres') diff --git a/src/backend/db.py b/src/backend/db.py deleted file mode 100644 index 019f350..0000000 --- a/src/backend/db.py +++ /dev/null @@ -1,208 +0,0 @@ -import os -from typing import Optional, Dict, Any, List -from datetime import datetime -from dotenv import load_dotenv -from sqlalchemy import Column, String, JSON, DateTime, Integer, func, create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker -from sqlalchemy.future import select -from urllib.parse import quote_plus as urlquote - -load_dotenv() - -# PostgreSQL connection configuration -DB_USER = os.getenv('POSTGRES_USER', 'postgres') -DB_PASSWORD = os.getenv('POSTGRES_PASSWORD', 'postgres') -DB_NAME = os.getenv('POSTGRES_DB', 'pad') -DB_HOST = os.getenv('POSTGRES_HOST', 'localhost') -DB_PORT = os.getenv('POSTGRES_PORT', '5432') - -# SQLAlchemy async database URL -DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{urlquote(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}" - -# Create async engine -engine = create_async_engine(DATABASE_URL, echo=False) - -# Create async session factory -async_session = sessionmaker( - engine, class_=AsyncSession, expire_on_commit=False -) - -# Create base model -Base = declarative_base() - -# Canvas backup configuration -BACKUP_INTERVAL_SECONDS = 300 # 5 minutes between backups -MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user - -# In-memory dictionaries to track user activity -user_last_backup_time = {} # Tracks when each user last had a backup -user_last_activity_time = {} # Tracks when each user was last active - -# Memory management configuration -INACTIVITY_THRESHOLD_MINUTES = 30 # Remove users from memory after this many minutes of inactivity -MAX_USERS_BEFORE_CLEANUP = 1000 # Trigger cleanup when we have this many users in memory -CLEANUP_INTERVAL_SECONDS = 3600 # Run cleanup at least once per hour (1 hour) -last_cleanup_time = datetime.now() # Track when we last ran cleanup - -class CanvasData(Base): - """Model for canvas data table""" - __tablename__ = "canvas_data" - - user_id = Column(String, primary_key=True) - data = Column(JSON, nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) - - def __repr__(self): - return f"" - -class CanvasBackup(Base): - """Model for canvas backups table""" - __tablename__ = "canvas_backups" - - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(String, nullable=False) - timestamp = Column(DateTime(timezone=True), server_default=func.now()) - canvas_data = Column(JSON, nullable=False) - - def __repr__(self): - return f"" - -async def get_db_session(): - """Get a database session""" - async with async_session() as session: - yield session - -async def init_db(): - """Initialize the database with required tables""" - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - -async def cleanup_inactive_users(inactivity_threshold_minutes: int = INACTIVITY_THRESHOLD_MINUTES): - """Remove users from memory tracking if they've been inactive for the specified time""" - current_time = datetime.now() - inactive_users = [] - - for user_id, last_activity in user_last_activity_time.items(): - # Check if user has been inactive for longer than the threshold - if (current_time - last_activity).total_seconds() > (inactivity_threshold_minutes * 60): - inactive_users.append(user_id) - - # Remove inactive users from both dictionaries - for user_id in inactive_users: - if user_id in user_last_backup_time: - del user_last_backup_time[user_id] - if user_id in user_last_activity_time: - del user_last_activity_time[user_id] - - return len(inactive_users) # Return count of removed users for logging - -async def check_if_cleanup_needed(): - """Check if we should run the cleanup function""" - global last_cleanup_time - current_time = datetime.now() - time_since_last_cleanup = (current_time - last_cleanup_time).total_seconds() - - # Run cleanup if we have too many users or it's been too long - if (len(user_last_activity_time) > MAX_USERS_BEFORE_CLEANUP or - time_since_last_cleanup > CLEANUP_INTERVAL_SECONDS): - removed_count = await cleanup_inactive_users() - last_cleanup_time = current_time - print(f"[db.py] Cleanup completed: removed {removed_count} inactive users from memory") - -async def store_canvas_data(user_id: str, data: Dict[str, Any]) -> bool: - try: - # Update user's last activity time - current_time = datetime.now() - user_last_activity_time[user_id] = current_time - - # Check if cleanup is needed - await check_if_cleanup_needed() - - async with async_session() as session: - # Check if record exists - stmt = select(CanvasData).where(CanvasData.user_id == user_id) - result = await session.execute(stmt) - canvas_data = result.scalars().first() - - if canvas_data: - # Update existing record - canvas_data.data = data - else: - # Create new record - canvas_data = CanvasData(user_id=user_id, data=data) - session.add(canvas_data) - - # Check if we should create a backup - should_backup = False - if user_id not in user_last_backup_time: - # First time this user is saving, create a backup - should_backup = True - else: - # Check if backup interval has passed since last backup - time_since_last_backup = (current_time - user_last_backup_time[user_id]).total_seconds() - if time_since_last_backup >= BACKUP_INTERVAL_SECONDS: - should_backup = True - - if should_backup: - # Update the backup timestamp - user_last_backup_time[user_id] = current_time - - # Count existing backups for this user - backup_count_stmt = select(func.count()).select_from(CanvasBackup).where(CanvasBackup.user_id == user_id) - backup_count_result = await session.execute(backup_count_stmt) - backup_count = backup_count_result.scalar() - - # If user has reached the maximum number of backups, delete the oldest one - if backup_count >= MAX_BACKUPS_PER_USER: - oldest_backup_stmt = select(CanvasBackup).where(CanvasBackup.user_id == user_id).order_by(CanvasBackup.timestamp).limit(1) - oldest_backup_result = await session.execute(oldest_backup_stmt) - oldest_backup = oldest_backup_result.scalars().first() - - if oldest_backup: - await session.delete(oldest_backup) - - # Create new backup - new_backup = CanvasBackup(user_id=user_id, canvas_data=data) - session.add(new_backup) - - await session.commit() - return True - except Exception as e: - print(f"Error storing canvas data: {e}") - return False - -async def get_canvas_data(user_id: str) -> Optional[Dict[str, Any]]: - try: - # Update user's last activity time - user_last_activity_time[user_id] = datetime.now() - - async with async_session() as session: - stmt = select(CanvasData).where(CanvasData.user_id == user_id) - result = await session.execute(stmt) - canvas_data = result.scalars().first() - - if canvas_data: - return canvas_data.data - return None - except Exception as e: - print(f"Error retrieving canvas data: {e}") - return None - -async def get_recent_canvases(user_id: str, limit: int = MAX_BACKUPS_PER_USER) -> List[Dict[str, Any]]: - """Get the most recent canvas backups for a user""" - try: - # Update user's last activity time - user_last_activity_time[user_id] = datetime.now() - - async with async_session() as session: - # Get the most recent backups, limited to MAX_BACKUPS_PER_USER - stmt = select(CanvasBackup).where(CanvasBackup.user_id == user_id).order_by(CanvasBackup.timestamp.desc()).limit(limit) - result = await session.execute(stmt) - backups = result.scalars().all() - - return [{"id": backup.id, "timestamp": backup.timestamp, "data": backup.canvas_data} for backup in backups] - except Exception as e: - print(f"Error retrieving canvas backups: {e}") - return [] diff --git a/src/backend/main.py b/src/backend/main.py index 2604360..9d0cfa0 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -10,11 +10,10 @@ from fastapi.staticfiles import StaticFiles from dotenv import load_dotenv -from db import init_db -from database import init_db as init_database +from database import init_db from config import STATIC_DIR, ASSETS_DIR from dependencies import UserSession, optional_auth -from routers.auth import auth_router +from routers.auth_router import auth_router from routers.user_router import user_router from routers.workspace_router import workspace_router from routers.pad_router import pad_router @@ -83,7 +82,6 @@ async def load_templates(): @asynccontextmanager async def lifespan(_: FastAPI): await init_db() - await init_database() print("Database connection established successfully") # Load all templates from the templates directory diff --git a/src/backend/routers/auth.py b/src/backend/routers/auth_router.py similarity index 96% rename from src/backend/routers/auth.py rename to src/backend/routers/auth_router.py index dbc2b0f..6df890a 100644 --- a/src/backend/routers/auth.py +++ b/src/backend/routers/auth_router.py @@ -1,12 +1,11 @@ import secrets import jwt import httpx -from fastapi import APIRouter, Request, HTTPException, Depends +from fastapi import APIRouter, Request, HTTPException from fastapi.responses import RedirectResponse, FileResponse, JSONResponse import os from config import get_auth_url, get_token_url, OIDC_CONFIG, set_session, delete_session, STATIC_DIR, get_session -from dependencies import UserSession, require_auth from coder import CoderAPI auth_router = APIRouter() From d1756da61c66e29aee17349dea6b0fe326311d5e Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 17:36:15 +0000 Subject: [PATCH 24/32] refactor: integrate CoderAPI into authentication and workspace management - Added a new dependency function to provide a CoderAPI instance for use in routers. - Updated auth_router to utilize CoderAPI in the callback endpoint, enhancing authentication flow. - Modified workspace_router to incorporate CoderAPI in workspace state retrieval and management functions, improving workspace operations. - Streamlined dependency management across routers for better modularity and code organization. --- src/backend/dependencies.py | 9 ++++++++- src/backend/routers/auth_router.py | 11 ++++++++--- src/backend/routers/workspace_router.py | 24 ++++++++++++++++-------- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/backend/dependencies.py b/src/backend/dependencies.py index b8de96c..71c1489 100644 --- a/src/backend/dependencies.py +++ b/src/backend/dependencies.py @@ -2,10 +2,11 @@ from typing import Optional, Dict, Any from uuid import UUID -from fastapi import Request, HTTPException +from fastapi import Request, HTTPException, Depends from config import get_session, is_token_expired, refresh_token from database.service import UserService +from coder import CoderAPI class UserSession: """ @@ -130,3 +131,9 @@ def _handle_auth_error(self, detail: str, status_code: int = 401) -> Optional[No require_auth = AuthDependency(auto_error=True) optional_auth = AuthDependency(auto_error=False) require_admin = AuthDependency(auto_error=True, require_admin=True) + +def get_coder_api(): + """ + Dependency that provides a CoderAPI instance. + """ + return CoderAPI() diff --git a/src/backend/routers/auth_router.py b/src/backend/routers/auth_router.py index 6df890a..332483a 100644 --- a/src/backend/routers/auth_router.py +++ b/src/backend/routers/auth_router.py @@ -1,15 +1,15 @@ import secrets import jwt import httpx -from fastapi import APIRouter, Request, HTTPException +from fastapi import APIRouter, Request, HTTPException, Depends from fastapi.responses import RedirectResponse, FileResponse, JSONResponse import os from config import get_auth_url, get_token_url, OIDC_CONFIG, set_session, delete_session, STATIC_DIR, get_session +from dependencies import get_coder_api from coder import CoderAPI auth_router = APIRouter() -coder_api = CoderAPI() @auth_router.get("/login") async def login(request: Request, kc_idp_hint: str = None, popup: str = None): @@ -31,7 +31,12 @@ async def login(request: Request, kc_idp_hint: str = None, popup: str = None): return response @auth_router.get("/callback") -async def callback(request: Request, code: str, state: str = "default"): +async def callback( + request: Request, + code: str, + state: str = "default", + coder_api: CoderAPI = Depends(get_coder_api) +): session_id = request.cookies.get('session_id') if not session_id: raise HTTPException(status_code=400, detail="No session") diff --git a/src/backend/routers/workspace_router.py b/src/backend/routers/workspace_router.py index da02524..0511702 100644 --- a/src/backend/routers/workspace_router.py +++ b/src/backend/routers/workspace_router.py @@ -4,11 +4,10 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import JSONResponse -from dependencies import UserSession, require_auth +from dependencies import UserSession, require_auth, get_coder_api from coder import CoderAPI workspace_router = APIRouter() -coder_api = CoderAPI() class WorkspaceState(BaseModel): id: str @@ -19,7 +18,10 @@ class WorkspaceState(BaseModel): agent: str @workspace_router.get("/state", response_model=WorkspaceState) -async def get_workspace_state(user: UserSession = Depends(require_auth)): +async def get_workspace_state( + user: UserSession = Depends(require_auth), + coder_api: CoderAPI = Depends(get_coder_api) +): """ Get the current state of the user's workspace """ @@ -52,12 +54,15 @@ async def get_workspace_state(user: UserSession = Depends(require_auth)): @workspace_router.post("/start") -async def start_workspace(user: UserSession = Depends(require_auth)): +async def start_workspace( + user: UserSession = Depends(require_auth), + coder_api: CoderAPI = Depends(get_coder_api) +): """ Start a workspace for the authenticated user """ - workspace: WorkspaceState = await get_workspace_state(user) + workspace: WorkspaceState = await get_workspace_state(user, coder_api) try: response = coder_api.start_workspace(workspace.id) @@ -67,15 +72,18 @@ async def start_workspace(user: UserSession = Depends(require_auth)): @workspace_router.post("/stop") -async def stop_workspace(user: UserSession = Depends(require_auth)): +async def stop_workspace( + user: UserSession = Depends(require_auth), + coder_api: CoderAPI = Depends(get_coder_api) +): """ Stop a workspace for the authenticated user """ - workspace: WorkspaceState = await get_workspace_state(user) + workspace: WorkspaceState = await get_workspace_state(user, coder_api) try: response = coder_api.stop_workspace(workspace.id) return JSONResponse(content=response) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) From 01499cc7aa04bc66d460d11f919ebaca0fd099ad Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 17:43:50 +0000 Subject: [PATCH 25/32] refactor: migrate CoderAPI configuration to centralized config module - Replaced dotenv loading in CoderAPI with centralized configuration from config.py, enhancing consistency and maintainability. - Updated CoderAPI initialization to use environment variables directly from the config module. - Adjusted error messages to reflect the new configuration approach, improving clarity for required variables. - Streamlined imports across various modules to utilize the new configuration structure, promoting better organization. --- src/backend/coder.py | 30 ++++++++++------------- src/backend/config.py | 39 +++++++++++++++++++++--------- src/backend/main.py | 9 ++----- src/backend/routers/auth_router.py | 14 +++++------ src/backend/routers/user_router.py | 4 +-- 5 files changed, 52 insertions(+), 44 deletions(-) diff --git a/src/backend/coder.py b/src/backend/coder.py index 0927ddd..83b31ca 100644 --- a/src/backend/coder.py +++ b/src/backend/coder.py @@ -1,26 +1,22 @@ import os import requests -from dotenv import load_dotenv +from config import CODER_API_KEY, CODER_URL, CODER_TEMPLATE_ID, CODER_DEFAULT_ORGANIZATION, CODER_WORKSPACE_NAME class CoderAPI: """ - A class for interacting with the Coder API using credentials from .env file + A class for interacting with the Coder API using credentials from config """ def __init__(self): - # Load environment variables from .env file - load_dotenv() + # Get configuration from config + self.api_key = CODER_API_KEY + self.coder_url = CODER_URL + self.template_id = CODER_TEMPLATE_ID + self.default_organization_id = CODER_DEFAULT_ORGANIZATION - # Get configuration from environment variables - self.api_key = os.getenv("CODER_API_KEY") - self.coder_url = os.getenv("CODER_URL") - self.user_id = os.getenv("USER_ID") - self.template_id = os.getenv("CODER_TEMPLATE_ID") - self.default_organization_id = os.getenv("CODER_DEFAULT_ORGANIZATION") - - # Check if required environment variables are set + # Check if required configuration variables are set if not self.api_key or not self.coder_url: - raise ValueError("CODER_API_KEY and CODER_URL must be set in .env file") + raise ValueError("CODER_API_KEY and CODER_URL must be set in environment variables") # Set up common headers for API requests self.headers = { @@ -56,9 +52,9 @@ def create_workspace(self, user_id, parameter_values=None): template_id = self.template_id if not template_id: - raise ValueError("template_id must be provided or TEMPLATE_ID must be set in .env") + raise ValueError("template_id must be provided or TEMPLATE_ID must be set in environment variables") - name = os.getenv("CODER_WORKSPACE_NAME", "ubuntu") + name = CODER_WORKSPACE_NAME # Prepare the request data data = { @@ -201,7 +197,7 @@ def get_workspace_status_for_user(self, username): Returns: dict: Workspace status data if found, None otherwise """ - workspace_name = os.getenv("CODER_WORKSPACE_NAME", "ubuntu") + workspace_name = CODER_WORKSPACE_NAME endpoint = f"{self.coder_url}/api/v2/users/{username}/workspace/{workspace_name}" response = requests.get(endpoint, headers=self.headers) @@ -282,4 +278,4 @@ def stop_workspace(self, workspace_id): headers['Content-Type'] = 'application/json' response = requests.post(endpoint, headers=headers, json=data) response.raise_for_status() - return response.json() \ No newline at end of file + return response.json() diff --git a/src/backend/config.py b/src/backend/config.py index b83ee9a..98ca007 100644 --- a/src/backend/config.py +++ b/src/backend/config.py @@ -7,29 +7,46 @@ from typing import Optional, Dict, Any, Tuple from dotenv import load_dotenv +# Load environment variables once load_dotenv() +# ===== Application Configuration ===== STATIC_DIR = os.getenv("STATIC_DIR") ASSETS_DIR = os.getenv("ASSETS_DIR") +FRONTEND_URL = os.getenv('FRONTEND_URL') -OIDC_CONFIG = { - 'client_id': os.getenv('OIDC_CLIENT_ID'), - 'client_secret': os.getenv('OIDC_CLIENT_SECRET'), - 'server_url': os.getenv('OIDC_SERVER_URL'), - 'realm': os.getenv('OIDC_REALM'), - 'redirect_uri': os.getenv('REDIRECT_URI'), - 'frontend_url': os.getenv('FRONTEND_URL') -} +# ===== PostHog Configuration ===== +POSTHOG_API_KEY = os.getenv("VITE_PUBLIC_POSTHOG_KEY") +POSTHOG_HOST = os.getenv("VITE_PUBLIC_POSTHOG_HOST") + +# ===== OIDC Configuration ===== +OIDC_CLIENT_ID = os.getenv('OIDC_CLIENT_ID') +OIDC_CLIENT_SECRET = os.getenv('OIDC_CLIENT_SECRET') +OIDC_SERVER_URL = os.getenv('OIDC_SERVER_URL') +OIDC_REALM = os.getenv('OIDC_REALM') +OIDC_REDIRECT_URI = os.getenv('REDIRECT_URI') + +# ===== Redis Configuration ===== +REDIS_HOST = os.getenv('REDIS_HOST', 'localhost') +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD', None) +REDIS_PORT = int(os.getenv('REDIS_PORT', 6379)) # Redis connection redis_client = redis.Redis( - host=os.getenv('REDIS_HOST', 'localhost'), - password=os.getenv('REDIS_PASSWORD', None), - port=int(os.getenv('REDIS_PORT', 6379)), + host=REDIS_HOST, + password=REDIS_PASSWORD, + port=REDIS_PORT, db=0, decode_responses=True ) +# ===== Coder API Configuration ===== +CODER_API_KEY = os.getenv("CODER_API_KEY") +CODER_URL = os.getenv("CODER_URL") +CODER_TEMPLATE_ID = os.getenv("CODER_TEMPLATE_ID") +CODER_DEFAULT_ORGANIZATION = os.getenv("CODER_DEFAULT_ORGANIZATION") +CODER_WORKSPACE_NAME = os.getenv("CODER_WORKSPACE_NAME", "ubuntu") + # Session management functions def get_session(session_id: str) -> Optional[Dict[str, Any]]: """Get session data from Redis""" diff --git a/src/backend/main.py b/src/backend/main.py index 9d0cfa0..2736fd7 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -8,10 +8,9 @@ from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from dotenv import load_dotenv from database import init_db -from config import STATIC_DIR, ASSETS_DIR +from config import STATIC_DIR, ASSETS_DIR, POSTHOG_API_KEY, POSTHOG_HOST from dependencies import UserSession, optional_auth from routers.auth_router import auth_router from routers.user_router import user_router @@ -21,11 +20,7 @@ from database.service import TemplatePadService from database.database import async_session -load_dotenv() - -POSTHOG_API_KEY = os.environ.get("VITE_PUBLIC_POSTHOG_KEY") -POSTHOG_HOST = os.environ.get("VITE_PUBLIC_POSTHOG_HOST") - +# Initialize PostHog if API key is available if POSTHOG_API_KEY: posthog.project_api_key = POSTHOG_API_KEY posthog.host = POSTHOG_HOST diff --git a/src/backend/routers/auth_router.py b/src/backend/routers/auth_router.py index 332483a..0a2580b 100644 --- a/src/backend/routers/auth_router.py +++ b/src/backend/routers/auth_router.py @@ -5,7 +5,8 @@ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse import os -from config import get_auth_url, get_token_url, OIDC_CONFIG, set_session, delete_session, STATIC_DIR, get_session +from config import (get_auth_url, get_token_url, set_session, delete_session, get_session, + FRONTEND_URL, OIDC_CLIENT_ID, OIDC_CLIENT_SECRET, OIDC_SERVER_URL, OIDC_REALM, OIDC_REDIRECT_URI, STATIC_DIR) from dependencies import get_coder_api from coder import CoderAPI @@ -47,10 +48,10 @@ async def callback( get_token_url(), data={ 'grant_type': 'authorization_code', - 'client_id': OIDC_CONFIG['client_id'], - 'client_secret': OIDC_CONFIG['client_secret'], + 'client_id': OIDC_CLIENT_ID, + 'client_secret': OIDC_CLIENT_SECRET, 'code': code, - 'redirect_uri': OIDC_CONFIG['redirect_uri'] + 'redirect_uri': OIDC_REDIRECT_URI } ) @@ -91,9 +92,8 @@ async def logout(request: Request): delete_session(session_id) # Create the Keycloak logout URL with redirect back to our app - logout_url = f"{OIDC_CONFIG['server_url']}/realms/{OIDC_CONFIG['realm']}/protocol/openid-connect/logout" - redirect_uri = OIDC_CONFIG['frontend_url'] # Match the frontend redirect URI - full_logout_url = f"{logout_url}?id_token_hint={id_token}&post_logout_redirect_uri={redirect_uri}" + logout_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/logout" + full_logout_url = f"{logout_url}?id_token_hint={id_token}&post_logout_redirect_uri={FRONTEND_URL}" # Create a redirect response to Keycloak's logout endpoint response = JSONResponse({"status": "success", "logout_url": full_logout_url}) diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py index 2cb1bcd..bea69fb 100644 --- a/src/backend/routers/user_router.py +++ b/src/backend/routers/user_router.py @@ -4,7 +4,7 @@ import posthog from fastapi import APIRouter, Depends, HTTPException -from config import redis_client, OIDC_CONFIG +from config import redis_client, FRONTEND_URL from database import get_user_service from database.service import UserService from dependencies import UserSession, require_admin, require_auth @@ -81,7 +81,7 @@ async def get_user_info( if os.getenv("VITE_PUBLIC_POSTHOG_KEY"): telemetry = user_data.copy() - telemetry["$current_url"] = OIDC_CONFIG["frontend_url"] + telemetry["$current_url"] = FRONTEND_URL posthog.identify(distinct_id=user_data["id"], properties=telemetry) return user From c9dab93a7fc25677939956dce436b645fe8ec36e Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 17:47:04 +0000 Subject: [PATCH 26/32] refactor: update authentication URL configuration for Keycloak - Replaced direct references to OIDC_CONFIG with environment variables for OIDC_SERVER_URL, OIDC_REALM, OIDC_CLIENT_ID, and OIDC_CLIENT_SECRET, enhancing configuration management. - Improved code clarity and maintainability by centralizing authentication URL generation in the config module. --- src/backend/config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/backend/config.py b/src/backend/config.py index 98ca007..b4b4358 100644 --- a/src/backend/config.py +++ b/src/backend/config.py @@ -71,18 +71,18 @@ def delete_session(session_id: str) -> None: def get_auth_url() -> str: """Generate the authentication URL for Keycloak login""" - auth_url = f"{OIDC_CONFIG['server_url']}/realms/{OIDC_CONFIG['realm']}/protocol/openid-connect/auth" + auth_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/auth" params = { - 'client_id': OIDC_CONFIG['client_id'], + 'client_id': OIDC_CLIENT_ID, 'response_type': 'code', - 'redirect_uri': OIDC_CONFIG['redirect_uri'], + 'redirect_uri': OIDC_REDIRECT_URI, 'scope': 'openid profile email' } return f"{auth_url}?{'&'.join(f'{k}={v}' for k,v in params.items())}" def get_token_url() -> str: """Get the token endpoint URL""" - return f"{OIDC_CONFIG['server_url']}/realms/{OIDC_CONFIG['realm']}/protocol/openid-connect/token" + return f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/token" def is_token_expired(token_data: Dict[str, Any], buffer_seconds: int = 30) -> bool: """ @@ -132,8 +132,8 @@ async def refresh_token(session_id: str, token_data: Dict[str, Any]) -> Tuple[bo get_token_url(), data={ 'grant_type': 'refresh_token', - 'client_id': OIDC_CONFIG['client_id'], - 'client_secret': OIDC_CONFIG['client_secret'], + 'client_id': OIDC_CLIENT_ID, + 'client_secret': OIDC_CLIENT_SECRET, 'refresh_token': token_data['refresh_token'] } ) From 4f915320ad381dd9120295e92d1e209f0cf15703 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 19:40:15 +0000 Subject: [PATCH 27/32] refactor: enhance JWT token handling and session management - Integrated PyJWKClient for secure JWT verification, improving token validation by using signing keys from JWKS. - Updated UserSession initialization to decode tokens with verification, enhancing security and error handling. - Added a caching mechanism for the JWKS client to optimize performance. - Cleaned up token expiration checks to ensure accurate validation and error reporting. --- src/backend/config.py | 43 ++++++++++++++++++++++-------------- src/backend/dependencies.py | 19 +++++++++++++++- src/backend/requirements.txt | 3 ++- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/backend/config.py b/src/backend/config.py index b4b4358..3c70812 100644 --- a/src/backend/config.py +++ b/src/backend/config.py @@ -4,6 +4,7 @@ import httpx import redis import jwt +from jwt.jwks_client import PyJWKClient from typing import Optional, Dict, Any, Tuple from dotenv import load_dotenv @@ -47,6 +48,9 @@ CODER_DEFAULT_ORGANIZATION = os.getenv("CODER_DEFAULT_ORGANIZATION") CODER_WORKSPACE_NAME = os.getenv("CODER_WORKSPACE_NAME", "ubuntu") +# Cache for JWKS client +_jwks_client = None + # Session management functions def get_session(session_id: str) -> Optional[Dict[str, Any]]: """Get session data from Redis""" @@ -67,8 +71,6 @@ def delete_session(session_id: str) -> None: """Delete session data from Redis""" redis_client.delete(f"session:{session_id}") -provisioning_times = {} - def get_auth_url() -> str: """Generate the authentication URL for Keycloak login""" auth_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/auth" @@ -85,29 +87,28 @@ def get_token_url() -> str: return f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/token" def is_token_expired(token_data: Dict[str, Any], buffer_seconds: int = 30) -> bool: - """ - Check if the access token is expired or about to expire - - Args: - token_data: The token data containing the access token - buffer_seconds: Buffer time in seconds to refresh token before it actually expires - - Returns: - bool: True if token is expired or about to expire, False otherwise - """ if not token_data or 'access_token' not in token_data: return True try: - # Decode the JWT token without verification to get expiration time - decoded = jwt.decode(token_data['access_token'], options={"verify_signature": False}) + # Get the signing key + jwks_client = get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(token_data['access_token']) - # Get expiration time from token - exp_time = decoded.get('exp', 0) + # Decode with verification + decoded = jwt.decode( + token_data['access_token'], + signing_key.key, + algorithms=["RS256"], # Common algorithm for OIDC + audience=OIDC_CLIENT_ID, + ) - # Check if token is expired or about to expire (with buffer) + # Check expiration + exp_time = decoded.get('exp', 0) current_time = time.time() return current_time + buffer_seconds >= exp_time + except jwt.ExpiredSignatureError: + return True except Exception as e: print(f"Error checking token expiration: {str(e)}") return True @@ -153,3 +154,11 @@ async def refresh_token(session_id: str, token_data: Dict[str, Any]) -> Tuple[bo except Exception as e: print(f"Error refreshing token: {str(e)}") return False, token_data + +def get_jwks_client(): + """Get or create a PyJWKClient for token verification""" + global _jwks_client + if _jwks_client is None: + jwks_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/certs" + _jwks_client = PyJWKClient(jwks_url) + return _jwks_client \ No newline at end of file diff --git a/src/backend/dependencies.py b/src/backend/dependencies.py index 71c1489..749c26b 100644 --- a/src/backend/dependencies.py +++ b/src/backend/dependencies.py @@ -15,8 +15,25 @@ class UserSession: """ def __init__(self, access_token: str, token_data: dict, user_id: UUID = None): self.access_token = access_token - self.token_data = jwt.decode(access_token, options={"verify_signature": False}) self._user_data = None + + # Get the signing key and decode with verification + from config import get_jwks_client, OIDC_CLIENT_ID + try: + jwks_client = get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(access_token) + + self.token_data = jwt.decode( + access_token, + signing_key.key, + algorithms=["RS256"], + audience=OIDC_CLIENT_ID + ) + + except jwt.InvalidTokenError as e: + # Log the error and raise an appropriate exception + print(f"Invalid token: {str(e)}") + raise ValueError(f"Invalid authentication token: {str(e)}") @property def is_authenticated(self) -> bool: diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index 3303e91..358a51d 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -10,4 +10,5 @@ sqlalchemy posthog redis psycopg2-binary -python-multipart \ No newline at end of file +python-multipart +cryptography # Required for JWT key handling \ No newline at end of file From e69cb4e6b7440ecdf62680d305a945b6ee26408e Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 20:00:01 +0000 Subject: [PATCH 28/32] refactor: enhance Redis connection management and backup functionality - Introduced a Redis connection pool to optimize Redis client management, improving performance and resource utilization. - Updated session management functions to utilize the new Redis connection pool, ensuring efficient access to session data. - Implemented a new method in BackupService to retrieve backups for a user's first pad using a join operation, addressing the N+1 query problem. - Enhanced pad router to create backups conditionally based on time intervals, improving backup efficiency and management. - Streamlined user router to utilize the new Redis client retrieval method, promoting better code organization and maintainability. --- src/backend/config.py | 34 +++++++-- .../database/repository/backup_repository.py | 29 ++++++- .../database/service/backup_service.py | 75 ++++++++++++++++++- src/backend/main.py | 17 ++++- src/backend/routers/pad_router.py | 55 ++++++++------ src/backend/routers/user_router.py | 5 +- 6 files changed, 178 insertions(+), 37 deletions(-) diff --git a/src/backend/config.py b/src/backend/config.py index 3c70812..bce429e 100644 --- a/src/backend/config.py +++ b/src/backend/config.py @@ -3,6 +3,7 @@ import time import httpx import redis +from redis import ConnectionPool, Redis import jwt from jwt.jwks_client import PyJWKClient from typing import Optional, Dict, Any, Tuple @@ -16,6 +17,11 @@ ASSETS_DIR = os.getenv("ASSETS_DIR") FRONTEND_URL = os.getenv('FRONTEND_URL') +MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user +MIN_INTERVAL_MINUTES = 5 # Minimum interval in minutes between backups +DEFAULT_PAD_NAME = "Untitled" # Default name for new pads +DEFAULT_TEMPLATE_NAME = "default" # Template name to use when a user doesn't have a pad + # ===== PostHog Configuration ===== POSTHOG_API_KEY = os.getenv("VITE_PUBLIC_POSTHOG_KEY") POSTHOG_HOST = os.getenv("VITE_PUBLIC_POSTHOG_HOST") @@ -32,15 +38,26 @@ REDIS_PASSWORD = os.getenv('REDIS_PASSWORD', None) REDIS_PORT = int(os.getenv('REDIS_PORT', 6379)) -# Redis connection -redis_client = redis.Redis( +# Create a Redis connection pool +redis_pool = ConnectionPool( host=REDIS_HOST, password=REDIS_PASSWORD, port=REDIS_PORT, db=0, - decode_responses=True + decode_responses=True, + max_connections=10, # Adjust based on your application's needs + socket_timeout=5.0, + socket_connect_timeout=1.0, + health_check_interval=30 ) +# Create a Redis client that uses the connection pool +redis_client = Redis(connection_pool=redis_pool) + +def get_redis_client(): + """Get a Redis client from the connection pool""" + return Redis(connection_pool=redis_pool) + # ===== Coder API Configuration ===== CODER_API_KEY = os.getenv("CODER_API_KEY") CODER_URL = os.getenv("CODER_URL") @@ -54,14 +71,16 @@ # Session management functions def get_session(session_id: str) -> Optional[Dict[str, Any]]: """Get session data from Redis""" - session_data = redis_client.get(f"session:{session_id}") + client = get_redis_client() + session_data = client.get(f"session:{session_id}") if session_data: return json.loads(session_data) return None def set_session(session_id: str, data: Dict[str, Any], expiry: int) -> None: """Store session data in Redis with expiry in seconds""" - redis_client.setex( + client = get_redis_client() + client.setex( f"session:{session_id}", expiry, json.dumps(data) @@ -69,7 +88,8 @@ def set_session(session_id: str, data: Dict[str, Any], expiry: int) -> None: def delete_session(session_id: str) -> None: """Delete session data from Redis""" - redis_client.delete(f"session:{session_id}") + client = get_redis_client() + client.delete(f"session:{session_id}") def get_auth_url() -> str: """Generate the authentication URL for Keycloak login""" @@ -161,4 +181,4 @@ def get_jwks_client(): if _jwks_client is None: jwks_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/certs" _jwks_client = PyJWKClient(jwks_url) - return _jwks_client \ No newline at end of file + return _jwks_client diff --git a/src/backend/database/repository/backup_repository.py b/src/backend/database/repository/backup_repository.py index bfba95b..dcae35a 100644 --- a/src/backend/database/repository/backup_repository.py +++ b/src/backend/database/repository/backup_repository.py @@ -8,9 +8,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy import delete, func +from sqlalchemy import delete, func, join -from ..models import BackupModel +from ..models import BackupModel, PadModel class BackupRepository: """Repository for backup-related database operations""" @@ -89,3 +89,28 @@ async def count_by_source(self, source_id: UUID) -> int: stmt = select(func.count()).select_from(BackupModel).where(BackupModel.source_id == source_id) result = await self.session.execute(stmt) return result.scalar() + + async def get_backups_by_user(self, user_id: UUID, limit: int = 10) -> List[BackupModel]: + """ + Get backups for a user's first pad directly using a join operation. + This eliminates the N+1 query problem by fetching the pad and its backups in a single query. + + Args: + user_id: The user ID to get backups for + limit: Maximum number of backups to return + + Returns: + List of backup models + """ + # Create a join between PadModel and BackupModel + stmt = select(BackupModel).join( + PadModel, + BackupModel.source_id == PadModel.id + ).where( + PadModel.owner_id == user_id + ).order_by( + BackupModel.created_at.desc() + ).limit(limit) + + result = await self.session.execute(stmt) + return result.scalars().all() diff --git a/src/backend/database/service/backup_service.py b/src/backend/database/service/backup_service.py index 46f404f..e625e58 100644 --- a/src/backend/database/service/backup_service.py +++ b/src/backend/database/service/backup_service.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession -from ..repository import BackupRepository, PadRepository +from ..repository import BackupRepository, PadRepository, UserRepository class BackupService: """Service for backup-related business logic""" @@ -101,3 +101,76 @@ async def manage_backups(self, source_id: UUID, max_backups: int = 10) -> int: return await self.repository.delete_older_than(source_id, max_backups) return 0 # No backups deleted + + async def get_backups_by_user(self, user_id: UUID, limit: int = 10) -> List[Dict[str, Any]]: + """ + Get backups for a user's first pad directly using a join operation. + This eliminates the N+1 query problem by fetching the pad and its backups in a single query. + + Args: + user_id: The user ID to get backups for + limit: Maximum number of backups to return + + Returns: + List of backup dictionaries + """ + # Check if user exists + user_repository = UserRepository(self.session) + user = await user_repository.get_by_id(user_id) + if not user: + raise ValueError(f"User with ID '{user_id}' does not exist") + + # Get backups directly with a single query + backups = await self.repository.get_backups_by_user(user_id, limit) + return [backup.to_dict() for backup in backups] + + async def create_backup_if_needed(self, source_id: UUID, data: Dict[str, Any], + min_interval_minutes: int = 5, + max_backups: int = 10) -> Optional[Dict[str, Any]]: + """ + Create a backup only if needed: + - If there are no existing backups + - If the latest backup is older than the specified interval + + Args: + source_id: The ID of the source pad + data: The data to backup + min_interval_minutes: Minimum time between backups in minutes + max_backups: Maximum number of backups to keep + + Returns: + The created backup dict if a backup was created, None otherwise + """ + # Check if source pad exists + source_pad = await self.pad_repository.get_by_id(source_id) + if not source_pad: + raise ValueError(f"Pad with ID '{source_id}' does not exist") + + # Get the latest backup + latest_backup = await self.repository.get_latest_by_source(source_id) + + # Calculate the current time + current_time = datetime.now() + + # Determine if we need to create a backup + create_backup = False + + if not latest_backup: + # No backups exist yet, so create one + create_backup = True + else: + # Check if the latest backup is older than the minimum interval + backup_age = current_time - latest_backup.created_at + if backup_age.total_seconds() > (min_interval_minutes * 60): + create_backup = True + + # Create a backup if needed + if create_backup: + backup = await self.repository.create(source_id, data) + + # Manage backups (clean up old ones) + await self.manage_backups(source_id, max_backups) + + return backup.to_dict() + + return None diff --git a/src/backend/main.py b/src/backend/main.py index 2736fd7..45b19a9 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -10,7 +10,7 @@ from fastapi.staticfiles import StaticFiles from database import init_db -from config import STATIC_DIR, ASSETS_DIR, POSTHOG_API_KEY, POSTHOG_HOST +from config import STATIC_DIR, ASSETS_DIR, POSTHOG_API_KEY, POSTHOG_HOST, redis_client, redis_pool from dependencies import UserSession, optional_auth from routers.auth_router import auth_router from routers.user_router import user_router @@ -76,14 +76,29 @@ async def load_templates(): @asynccontextmanager async def lifespan(_: FastAPI): + # Initialize database await init_db() print("Database connection established successfully") + # Check Redis connection + try: + redis_client.ping() + print("Redis connection established successfully") + except Exception as e: + print(f"Warning: Redis connection failed: {str(e)}") + # Load all templates from the templates directory await load_templates() print("Templates loaded successfully") yield + + # Clean up connections when shutting down + try: + redis_pool.disconnect() + print("Redis connections closed") + except Exception as e: + print(f"Error closing Redis connections: {str(e)}") app = FastAPI(lifespan=lifespan) diff --git a/src/backend/routers/pad_router.py b/src/backend/routers/pad_router.py index 5985574..e5c4639 100644 --- a/src/backend/routers/pad_router.py +++ b/src/backend/routers/pad_router.py @@ -6,12 +6,7 @@ from dependencies import UserSession, require_auth from database import get_pad_service, get_backup_service, get_template_pad_service from database.service import PadService, BackupService, TemplatePadService - -# Constants -MAX_BACKUPS_PER_USER = 10 # Maximum number of backups to keep per user -DEFAULT_PAD_NAME = "Untitled" # Default name for new pads -DEFAULT_TEMPLATE_NAME = "default" # Template name to use when a user doesn't have a pad - +from config import MAX_BACKUPS_PER_USER, MIN_INTERVAL_MINUTES, DEFAULT_PAD_NAME, DEFAULT_TEMPLATE_NAME pad_router = APIRouter() @@ -39,11 +34,13 @@ async def save_canvas( pad = user_pads[0] # Use the first pad (assuming one pad per user for now) await pad_service.update_pad_data(pad["id"], data) - # Create a backup - await backup_service.create_backup(pad["id"], data) - - # Manage backups (keep only the most recent ones) - await backup_service.manage_backups(pad["id"], MAX_BACKUPS_PER_USER) + # Create a backup only if needed (if none exist or latest is > 5 min old) + await backup_service.create_backup_if_needed( + source_id=pad["id"], + data=data, + min_interval_minutes=MIN_INTERVAL_MINUTES, + max_backups=MAX_BACKUPS_PER_USER + ) return {"status": "success"} except Exception as e: @@ -54,7 +51,8 @@ async def save_canvas( async def get_canvas( user: UserSession = Depends(require_auth), pad_service: PadService = Depends(get_pad_service), - template_pad_service: TemplatePadService = Depends(get_template_pad_service) + template_pad_service: TemplatePadService = Depends(get_template_pad_service), + backup_service: BackupService = Depends(get_backup_service) ): """Get canvas data for the authenticated user""" try: @@ -63,7 +61,14 @@ async def get_canvas( if not user_pads: # Return default canvas if user doesn't have a pad - return await create_pad_from_template(DEFAULT_TEMPLATE_NAME, DEFAULT_PAD_NAME, user, pad_service, template_pad_service) + return await create_pad_from_template( + name=DEFAULT_TEMPLATE_NAME, + display_name=DEFAULT_PAD_NAME, + user=user, + pad_service=pad_service, + template_pad_service=template_pad_service, + backup_service=backup_service + ) # Return the first pad's data (assuming one pad per user for now) return user_pads[0]["data"] @@ -77,7 +82,8 @@ async def create_pad_from_template( display_name: str = DEFAULT_PAD_NAME, user: UserSession = Depends(require_auth), pad_service: PadService = Depends(get_pad_service), - template_pad_service: TemplatePadService = Depends(get_template_pad_service) + template_pad_service: TemplatePadService = Depends(get_template_pad_service), + backup_service: BackupService = Depends(get_backup_service) ): """Create a new pad from a template""" @@ -94,6 +100,14 @@ async def create_pad_from_template( data=template["data"] ) + # Create an initial backup for the new pad + await backup_service.create_backup_if_needed( + source_id=pad["id"], + data=template["data"], + min_interval_minutes=0, # Always create initial backup + max_backups=MAX_BACKUPS_PER_USER + ) + return pad except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -105,7 +119,6 @@ async def create_pad_from_template( async def get_recent_canvas_backups( limit: int = MAX_BACKUPS_PER_USER, user: UserSession = Depends(require_auth), - pad_service: PadService = Depends(get_pad_service), backup_service: BackupService = Depends(get_backup_service) ): """Get the most recent canvas backups for the authenticated user""" @@ -114,18 +127,12 @@ async def get_recent_canvas_backups( limit = MAX_BACKUPS_PER_USER try: - # Get user's pads - user_pads = await pad_service.get_pads_by_owner(user.id) - - # Get the first pad's ID (assuming one pad per user for now) - pad_id = UUID(user_pads[0]["id"]) - - # Get backups for the pad - backups_data = await backup_service.get_backups_by_source(pad_id) + # Get backups directly with a single query + backups_data = await backup_service.get_backups_by_user(user.id, limit) # Format backups to match the expected response format backups = [] - for backup in backups_data[:limit]: + for backup in backups_data: backups.append({ "id": backup["id"], "timestamp": backup["created_at"], diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py index bea69fb..9a76d93 100644 --- a/src/backend/routers/user_router.py +++ b/src/backend/routers/user_router.py @@ -4,7 +4,7 @@ import posthog from fastapi import APIRouter, Depends, HTTPException -from config import redis_client, FRONTEND_URL +from config import get_redis_client, FRONTEND_URL from database import get_user_service from database.service import UserService from dependencies import UserSession, require_admin, require_auth @@ -92,7 +92,8 @@ async def get_user_count( _: bool = Depends(require_admin), ): """Get the number of active sessions (admin only)""" - session_count = len(redis_client.keys("session:*")) + client = get_redis_client() + session_count = len(client.keys("session:*")) return {"active_sessions": session_count } From 8a0abd7220270c7dd507229229a86a3acc741f3d Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 20:04:52 +0000 Subject: [PATCH 29/32] refactor: standardize API endpoint paths and enhance datetime handling - Updated API endpoint paths in pad, template pad, and user routers to remove trailing slashes for consistency. - Enhanced datetime handling in BackupService to include timezone information, improving accuracy in backup creation timing. - Adjusted frontend API calls to align with the updated endpoint structure, ensuring seamless integration across the application. --- src/backend/database/service/backup_service.py | 6 +++--- src/backend/routers/pad_router.py | 4 ++-- src/backend/routers/template_pad_router.py | 4 ++-- src/backend/routers/user_router.py | 4 ++-- src/frontend/src/api/hooks.ts | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/backend/database/service/backup_service.py b/src/backend/database/service/backup_service.py index e625e58..a9e44db 100644 --- a/src/backend/database/service/backup_service.py +++ b/src/backend/database/service/backup_service.py @@ -4,7 +4,7 @@ from typing import List, Optional, Dict, Any from uuid import UUID -from datetime import datetime +from datetime import datetime, timezone from sqlalchemy.ext.asyncio import AsyncSession @@ -149,8 +149,8 @@ async def create_backup_if_needed(self, source_id: UUID, data: Dict[str, Any], # Get the latest backup latest_backup = await self.repository.get_latest_by_source(source_id) - # Calculate the current time - current_time = datetime.now() + # Calculate the current time with timezone information + current_time = datetime.now(timezone.utc) # Determine if we need to create a backup create_backup = False diff --git a/src/backend/routers/pad_router.py b/src/backend/routers/pad_router.py index e5c4639..96ff9ec 100644 --- a/src/backend/routers/pad_router.py +++ b/src/backend/routers/pad_router.py @@ -10,7 +10,7 @@ pad_router = APIRouter() -@pad_router.post("/") +@pad_router.post("") async def save_canvas( data: Dict[str, Any], user: UserSession = Depends(require_auth), @@ -47,7 +47,7 @@ async def save_canvas( raise HTTPException(status_code=500, detail=f"Failed to save canvas data: {str(e)}") -@pad_router.get("/") +@pad_router.get("") async def get_canvas( user: UserSession = Depends(require_auth), pad_service: PadService = Depends(get_pad_service), diff --git a/src/backend/routers/template_pad_router.py b/src/backend/routers/template_pad_router.py index 77e4928..a787274 100644 --- a/src/backend/routers/template_pad_router.py +++ b/src/backend/routers/template_pad_router.py @@ -9,7 +9,7 @@ template_pad_router = APIRouter() -@template_pad_router.post("/") +@template_pad_router.post("") async def create_template_pad( data: Dict[str, Any], name: str, @@ -29,7 +29,7 @@ async def create_template_pad( raise HTTPException(status_code=400, detail=str(e)) -@template_pad_router.get("/") +@template_pad_router.get("") async def get_all_template_pads( _: bool = Depends(require_admin), template_pad_service: TemplatePadService = Depends(get_template_pad_service) diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py index 9a76d93..cb2f583 100644 --- a/src/backend/routers/user_router.py +++ b/src/backend/routers/user_router.py @@ -12,7 +12,7 @@ user_router = APIRouter() -@user_router.post("/") +@user_router.post("") async def create_user( user_id: UUID, username: str, @@ -42,7 +42,7 @@ async def create_user( raise HTTPException(status_code=400, detail=str(e)) -@user_router.get("/") +@user_router.get("") async def get_all_users( _: bool = Depends(require_admin), user_service: UserService = Depends(get_user_service) diff --git a/src/frontend/src/api/hooks.ts b/src/frontend/src/api/hooks.ts index 653d3ce..7c8b0d2 100644 --- a/src/frontend/src/api/hooks.ts +++ b/src/frontend/src/api/hooks.ts @@ -97,7 +97,7 @@ export const api = { // Canvas getCanvas: async (): Promise => { try { - const result = await fetchApi('/api/pad/'); + const result = await fetchApi('/api/pad'); return result; } catch (error) { throw error; @@ -106,7 +106,7 @@ export const api = { saveCanvas: async (data: CanvasData): Promise => { try { - const result = await fetchApi('/api/pad/', { + const result = await fetchApi('/api/pad', { method: 'POST', body: JSON.stringify(data), }); From e68d32eb5d5ea29e906132992913dfda42946035 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 20:09:33 +0000 Subject: [PATCH 30/32] refactor: update CORS middleware configuration - Changed CORS middleware to allow all origins by updating the allow_origins parameter to ["*"], enhancing flexibility for cross-origin requests. --- src/backend/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/main.py b/src/backend/main.py index 45b19a9..afd8eb8 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -105,7 +105,7 @@ async def lifespan(_: FastAPI): # CORS middleware setup app.add_middleware( CORSMiddleware, - allow_origins=["https://kc.pad.ws", "*"], + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], From d9620cc3457d05f5d4755fb72a4602624c22b039 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 20:40:03 +0000 Subject: [PATCH 31/32] feat: add user synchronization with authentication token data - Implemented a new method in UserService to synchronize user data with information from the authentication token, creating or updating the user as necessary. - Updated the user router to utilize this new synchronization method, enhancing user data management and ensuring consistency between the database and authentication token data. --- src/backend/database/service/user_service.py | 58 ++++++++++++++++++++ src/backend/routers/user_router.py | 48 ++++++++-------- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/src/backend/database/service/user_service.py b/src/backend/database/service/user_service.py index 1677a7e..6dc9d8d 100644 --- a/src/backend/database/service/user_service.py +++ b/src/backend/database/service/user_service.py @@ -97,3 +97,61 @@ async def update_user(self, user_id: UUID, data: Dict[str, Any]) -> Optional[Dic async def delete_user(self, user_id: UUID) -> bool: """Delete a user""" return await self.repository.delete(user_id) + + async def sync_user_with_token_data(self, user_id: UUID, token_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Synchronize user data in the database with data from the authentication token. + If the user doesn't exist, it will be created. If it exists but has different data, + it will be updated to match the token data. + + Args: + user_id: The user's UUID + token_data: Dictionary containing user data from the authentication token + + Returns: + The user data dictionary or None if operation failed + """ + # Check if user exists + user_data = await self.get_user(user_id) + + # If user doesn't exist, create a new one + if not user_data: + try: + return await self.create_user( + user_id=user_id, + username=token_data.get("username", ""), + email=token_data.get("email", ""), + email_verified=token_data.get("email_verified", False), + name=token_data.get("name"), + given_name=token_data.get("given_name"), + family_name=token_data.get("family_name"), + roles=token_data.get("roles", []) + ) + except ValueError as e: + # Handle case where user might have been created in a race condition + if "already exists" in str(e): + user_data = await self.get_user(user_id) + else: + raise e + + # Check if user data needs to be updated + update_data = {} + fields_to_check = [ + "username", "email", "email_verified", + "name", "given_name", "family_name" + ] + + for field in fields_to_check: + token_value = token_data.get(field) + if token_value is not None and user_data.get(field) != token_value: + update_data[field] = token_value + + # Handle roles separately as they might have a different structure + if "roles" in token_data and user_data.get("roles") != token_data["roles"]: + update_data["roles"] = token_data["roles"] + + # Update user if any field has changed + if update_data: + return await self.update_user(user_id, update_data) + + return user_data diff --git a/src/backend/routers/user_router.py b/src/backend/routers/user_router.py index cb2f583..f4d84fe 100644 --- a/src/backend/routers/user_router.py +++ b/src/backend/routers/user_router.py @@ -57,34 +57,34 @@ async def get_user_info( user: UserSession = Depends(require_auth), user_service: UserService = Depends(get_user_service), ): - """Get the current user's information""" - - user_data = await user.get_user_data(user_service) - - if not user_data: - try: - user = await user_service.create_user( - user_id=user.id, - username=user.username, - email=user.email, - email_verified=user.email_verified, - name=user.name, - given_name=user.given_name, - family_name=user.family_name, - roles=user.roles, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error creating user: {e}" - ) - + """Get the current user's information and sync with token data""" + + # Create token data dictionary from UserSession properties + token_data = { + "username": user.username, + "email": user.email, + "email_verified": user.email_verified, + "name": user.name, + "given_name": user.given_name, + "family_name": user.family_name, + "roles": user.roles + } + + try: + # Sync user with token data + user_data = await user_service.sync_user_with_token_data(user.id, token_data) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error syncing user data: {e}" + ) + if os.getenv("VITE_PUBLIC_POSTHOG_KEY"): telemetry = user_data.copy() telemetry["$current_url"] = FRONTEND_URL posthog.identify(distinct_id=user_data["id"], properties=telemetry) - - return user + + return user_data @user_router.get("/count") From d773ad6403bd68dafc5779f6a28ffa59d49313b7 Mon Sep 17 00:00:00 2001 From: Alex TYRODE Date: Fri, 2 May 2025 21:18:10 +0000 Subject: [PATCH 32/32] feat: implement database migration system with Alembic - Added Alembic configuration and migration scripts to facilitate database schema changes. - Implemented a run_migrations function to execute migrations during application startup. - Created migration scripts to transfer data from the old public schema to the new pad_ws schema, ensuring data integrity and consistency. - Updated requirements.txt to include Alembic as a dependency for migration management. --- src/backend/database/alembic.ini | 123 +++++++++++ src/backend/database/database.py | 37 ++++ src/backend/database/migrations/env.py | 106 ++++++++++ .../database/migrations/script.py.mako | 28 +++ .../2025_05_02_2055-migrate_canvas_data.py | 192 ++++++++++++++++++ src/backend/main.py | 9 +- src/backend/requirements.txt | 3 +- 7 files changed, 496 insertions(+), 2 deletions(-) create mode 100644 src/backend/database/alembic.ini create mode 100644 src/backend/database/migrations/env.py create mode 100644 src/backend/database/migrations/script.py.mako create mode 100644 src/backend/database/migrations/versions/2025_05_02_2055-migrate_canvas_data.py diff --git a/src/backend/database/alembic.ini b/src/backend/database/alembic.ini new file mode 100644 index 0000000..f36d6dd --- /dev/null +++ b/src/backend/database/alembic.ini @@ -0,0 +1,123 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = migrations + +# Use a more descriptive file template that includes date and time +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# The SQLAlchemy connection URL is set in env.py from environment variables +sqlalchemy.url = postgresql://postgres:postgres@localhost/pad + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/backend/database/database.py b/src/backend/database/database.py index 09e695d..40c6d13 100644 --- a/src/backend/database/database.py +++ b/src/backend/database/database.py @@ -3,13 +3,17 @@ """ import os +import asyncio from typing import AsyncGenerator from urllib.parse import quote_plus as urlquote +from pathlib import Path from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy.schema import CreateSchema from fastapi import Depends +from alembic.config import Config +from alembic import command from .models import Base, SCHEMA_NAME @@ -35,8 +39,41 @@ engine, class_=AsyncSession, expire_on_commit=False ) +async def run_migrations() -> None: + """Run database migrations using Alembic""" + # Get the path to the alembic.ini file + alembic_ini_path = Path(__file__).parent / "alembic.ini" + + # Create Alembic configuration + alembic_cfg = Config(str(alembic_ini_path)) + + # Set the script_location to the correct path + # This ensures Alembic finds the migrations directory + alembic_cfg.set_main_option('script_location', str(Path(__file__).parent / "migrations")) + + # Define a function to run in a separate thread + def run_upgrade(): + # Import the command module here to avoid import issues + from alembic import command + + # Set attributes that env.py might need + import sys + from pathlib import Path + + # Add the database directory to sys.path + db_dir = Path(__file__).parent + if str(db_dir) not in sys.path: + sys.path.insert(0, str(db_dir)) + + # Run the upgrade command + command.upgrade(alembic_cfg, "head") + + # Run the migrations in a separate thread to avoid blocking the event loop + await asyncio.to_thread(run_upgrade) + async def init_db() -> None: """Initialize the database with required tables""" + # Create schema and tables async with engine.begin() as conn: await conn.execute(CreateSchema(SCHEMA_NAME, if_not_exists=True)) await conn.run_sync(Base.metadata.create_all) diff --git a/src/backend/database/migrations/env.py b/src/backend/database/migrations/env.py new file mode 100644 index 0000000..d753507 --- /dev/null +++ b/src/backend/database/migrations/env.py @@ -0,0 +1,106 @@ +from logging.config import fileConfig +import os +import sys +from pathlib import Path + +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from sqlalchemy.engine import URL + +from alembic import context +from dotenv import load_dotenv + +# Add the parent directory to sys.path +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +# Load environment variables from .env file +load_dotenv() + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Import the Base metadata from the models +# We need to handle imports differently to avoid module not found errors +import importlib.util +import os + +# Get the absolute path to the models module +models_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models", "__init__.py") + +# Load the module dynamically +spec = importlib.util.spec_from_file_location("models", models_path) +models = importlib.util.module_from_spec(spec) +spec.loader.exec_module(models) + +# Get Base and SCHEMA_NAME from the loaded module +Base = models.Base +SCHEMA_NAME = models.SCHEMA_NAME +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +# Get database connection details from environment variables +DB_USER = os.getenv('POSTGRES_USER', 'postgres') +DB_PASSWORD = os.getenv('POSTGRES_PASSWORD', 'postgres') +DB_NAME = os.getenv('POSTGRES_DB', 'pad') +DB_HOST = os.getenv('POSTGRES_HOST', 'localhost') +DB_PORT = os.getenv('POSTGRES_PORT', '5432') + +# Override sqlalchemy.url in alembic.ini +config.set_main_option('sqlalchemy.url', f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}") + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode.""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + include_schemas=True, + version_table_schema=SCHEMA_NAME, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + include_schemas=True, + version_table_schema=SCHEMA_NAME + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/backend/database/migrations/script.py.mako b/src/backend/database/migrations/script.py.mako new file mode 100644 index 0000000..480b130 --- /dev/null +++ b/src/backend/database/migrations/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/src/backend/database/migrations/versions/2025_05_02_2055-migrate_canvas_data.py b/src/backend/database/migrations/versions/2025_05_02_2055-migrate_canvas_data.py new file mode 100644 index 0000000..f6e9568 --- /dev/null +++ b/src/backend/database/migrations/versions/2025_05_02_2055-migrate_canvas_data.py @@ -0,0 +1,192 @@ +"""Migrate canvas_data and canvas_backups to new schema + +Revision ID: migrate_canvas_data +Revises: +Create Date: 2025-05-02 20:55:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import Session +import uuid +from datetime import datetime + +# revision identifiers, used by Alembic. +revision = 'migrate_canvas_data' +down_revision = None +branch_labels = None +depends_on = None + +# Import the schema name from the models using dynamic import +import importlib.util +import os + +# Get the absolute path to the base_model module +base_model_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models", "base_model.py") + +# Load the module dynamically +spec = importlib.util.spec_from_file_location("base_model", base_model_path) +base_model = importlib.util.module_from_spec(spec) +spec.loader.exec_module(base_model) + +# Get SCHEMA_NAME from the loaded module +SCHEMA_NAME = base_model.SCHEMA_NAME + +def upgrade() -> None: + """Migrate data from old tables to new schema""" + + # Create a connection to execute raw SQL + connection = op.get_bind() + + # Define tables for direct SQL operations + metadata = sa.MetaData() + + # Define the old tables in the public schema + canvas_data = sa.Table( + 'canvas_data', + metadata, + sa.Column('user_id', UUID(as_uuid=True), primary_key=True), + sa.Column('data', JSONB), + schema='public' + ) + + canvas_backups = sa.Table( + 'canvas_backups', + metadata, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('user_id', UUID(as_uuid=True)), + sa.Column('canvas_data', JSONB), + sa.Column('timestamp', sa.DateTime), + schema='public' + ) + + # Define the new tables in the pad_ws schema with all required columns + users = sa.Table( + 'users', + metadata, + sa.Column('id', UUID(as_uuid=True), primary_key=True), + sa.Column('username', sa.String(254)), + sa.Column('email', sa.String(254)), + sa.Column('email_verified', sa.Boolean), + sa.Column('name', sa.String(254)), + sa.Column('given_name', sa.String(254)), + sa.Column('family_name', sa.String(254)), + sa.Column('roles', JSONB), + schema=SCHEMA_NAME + ) + + pads = sa.Table( + 'pads', + metadata, + sa.Column('id', UUID(as_uuid=True), primary_key=True), + sa.Column('owner_id', UUID(as_uuid=True)), + sa.Column('display_name', sa.String(100)), + sa.Column('data', JSONB), + schema=SCHEMA_NAME + ) + + backups = sa.Table( + 'backups', + metadata, + sa.Column('id', UUID(as_uuid=True), primary_key=True), + sa.Column('source_id', UUID(as_uuid=True)), + sa.Column('data', JSONB), + sa.Column('created_at', sa.DateTime), + schema=SCHEMA_NAME + ) + + # Create a session for ORM operations + session = Session(connection) + + try: + # Step 1: Get all canvas_data records + canvas_data_records = session.execute(sa.select(canvas_data)).fetchall() + + # Dictionary to store user_id -> pad_id mapping for later use with backups + user_pad_mapping = {} + + # Step 2: For each canvas_data record, create a new pad + for record in canvas_data_records: + user_id = record.user_id + + # Check if the user exists in the new schema + user_exists = session.execute( + sa.select(users).where(users.c.id == user_id) + ).fetchone() + + if not user_exists: + print(f"User {user_id} not found in new schema, creating with placeholder data") + # Create a new user with placeholder data + # The real data will be updated when the user accesses the /me route + session.execute( + users.insert().values( + id=user_id, + username=f"migrated_user_{user_id}", + email=f"migrated_{user_id}@example.com", + email_verified=False, + name="Migrated User", + given_name="Migrated", + family_name="User", + roles=[], + ) + ) + + # Generate a new UUID for the pad + pad_id = uuid.uuid4() + + # Store the mapping for later use + user_pad_mapping[user_id] = pad_id + + # Insert the pad record + session.execute( + pads.insert().values( + id=pad_id, + owner_id=user_id, + display_name="Untitled", + data=record.data, + ) + ) + + # Step 3: Get all canvas_backups records + canvas_backup_records = session.execute(sa.select(canvas_backups)).fetchall() + + # Step 4: For each canvas_backup record, create a new backup + for record in canvas_backup_records: + user_id = record.user_id + + # Skip if we don't have a pad for this user + if user_id not in user_pad_mapping: + print(f"Warning: No pad found for user {user_id}, skipping backup") + continue + + pad_id = user_pad_mapping[user_id] + + # Insert the backup record + session.execute( + backups.insert().values( + id=uuid.uuid4(), + source_id=pad_id, + data=record.canvas_data, # Note: using canvas_data field from the record + created_at=record.timestamp, + ) + ) + + # Commit the transaction + session.commit() + + print(f"Migration complete: {len(canvas_data_records)} pads and {len(canvas_backup_records)} backups migrated") + + except Exception as e: + session.rollback() + print(f"Error during migration: {e}") + raise + finally: + session.close() + + +def downgrade() -> None: + """Downgrade is not supported for this migration""" + print("Downgrade is not supported for this data migration") diff --git a/src/backend/main.py b/src/backend/main.py index afd8eb8..563ef39 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -18,7 +18,7 @@ from routers.pad_router import pad_router from routers.template_pad_router import template_pad_router from database.service import TemplatePadService -from database.database import async_session +from database.database import async_session, run_migrations # Initialize PostHog if API key is available if POSTHOG_API_KEY: @@ -80,6 +80,13 @@ async def lifespan(_: FastAPI): await init_db() print("Database connection established successfully") + # Run database migrations + try: + await run_migrations() + print("Database migrations completed successfully") + except Exception as e: + print(f"Warning: Failed to run migrations: {str(e)}") + # Check Redis connection try: redis_client.ping() diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index 358a51d..302d8ff 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -11,4 +11,5 @@ posthog redis psycopg2-binary python-multipart -cryptography # Required for JWT key handling \ No newline at end of file +cryptography # Required for JWT key handling +alembic \ No newline at end of file