diff --git a/app/core/__init__.py b/app/core/__init__.py index a3170b40..ea153908 100644 --- a/app/core/__init__.py +++ b/app/core/__init__.py @@ -1,3 +1,5 @@ +"""Core processing utilities for MLX OpenAI server.""" + from .audio_processor import AudioProcessor from .base_processor import BaseProcessor from .image_processor import ImageProcessor diff --git a/app/core/audio_processor.py b/app/core/audio_processor.py index 814fb292..59f4c238 100644 --- a/app/core/audio_processor.py +++ b/app/core/audio_processor.py @@ -1,98 +1,207 @@ -import os -import gc +"""Audio processing utilities for MLX OpenAI server.""" + +from __future__ import annotations + import asyncio -from typing import List +import gc +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + from .base_processor import BaseProcessor class AudioProcessor(BaseProcessor): """Audio processor for handling audio files with caching and validation.""" - - def __init__(self, max_workers: int = 4, cache_size: int = 1000): + + def __init__(self, max_workers: int = 4, cache_size: int = 1000) -> None: + """ + Initialize the AudioProcessor. + + Parameters + ---------- + max_workers : int, optional + Maximum number of worker threads for processing, by default 4. + cache_size : int, optional + Maximum number of cached files to keep, by default 1000. + """ super().__init__(max_workers, cache_size) # Supported audio formats - self._supported_formats = {'.mp3', '.wav'} + self._supported_formats = {".mp3", ".wav", ".m4a", ".ogg", ".flac", ".aac"} - def _get_media_format(self, media_url: str, data: bytes = None) -> str: - """Determine audio format from URL or data.""" + def _get_media_format(self, media_url: str, _data: bytes | None = None) -> str: + """ + Determine audio format from URL or data. + + Parameters + ---------- + media_url : str + The URL or data URL of the audio file. + _data : bytes or None, optional + Audio data bytes, not used in this implementation. + + Returns + ------- + str + The audio format (e.g., 'mp3', 'wav'). + """ if media_url.startswith("data:"): # Extract format from data URL mime_type = media_url.split(";")[0].split(":")[1] if "mp3" in mime_type or "mpeg" in mime_type: return "mp3" - elif "wav" in mime_type: + if "wav" in mime_type: return "wav" - elif "m4a" in mime_type or "mp4" in mime_type: + if "m4a" in mime_type or "mp4" in mime_type: return "m4a" - elif "ogg" in mime_type: + if "ogg" in mime_type: return "ogg" - elif "flac" in mime_type: + if "flac" in mime_type: return "flac" - elif "aac" in mime_type: + if "aac" in mime_type: return "aac" else: # Extract format from file extension - ext = os.path.splitext(media_url.lower())[1] + parsed = urlparse(media_url) + if parsed.scheme: + # It's a URL, get the path part + path = parsed.path + else: + path = media_url + ext = Path(path.lower()).suffix if ext in self._supported_formats: return ext[1:] # Remove the dot - + # Default to mp3 if format cannot be determined return "mp3" def _validate_media_data(self, data: bytes) -> bool: - """Basic validation of audio data.""" + """ + Validate basic audio data. + + Parameters + ---------- + data : bytes + The audio data to validate. + + Returns + ------- + bool + True if the data appears to be valid audio, False otherwise. + """ if len(data) < 100: # Too small to be a valid audio file return False - + # Check for common audio file signatures audio_signatures = [ - b'ID3', # MP3 with ID3 tag - b'\xff\xfb', # MP3 frame header - b'\xff\xf3', # MP3 frame header - b'\xff\xf2', # MP3 frame header - b'RIFF', # WAV/AVI - b'OggS', # OGG - b'fLaC', # FLAC - b'\x00\x00\x00\x20ftypM4A', # M4A + b"ID3", # MP3 with ID3 tag + b"\xff\xfb", # MP3 frame header + b"\xff\xf3", # MP3 frame header + b"\xff\xf2", # MP3 frame header + b"RIFF", # WAV/AVI + b"OggS", # OGG + b"fLaC", # FLAC + b"\x00\x00\x00\x20ftypM4A", # M4A ] - + for sig in audio_signatures: if data.startswith(sig): return True - + # Check for WAV format (RIFF header might be at different position) - if b'WAVE' in data[:50]: + if b"WAVE" in data[:50]: return True - + return True # Allow unknown formats to pass through def _get_timeout(self) -> int: - """Get timeout for HTTP requests.""" + """ + Get timeout for HTTP requests. + + Returns + ------- + int + Timeout in seconds for audio file downloads. + """ return 60 # Longer timeout for audio files def _get_max_file_size(self) -> int: - """Get maximum file size in bytes.""" + """ + Get maximum file size in bytes. + + Returns + ------- + int + Maximum allowed file size for audio files in bytes. + """ return 500 * 1024 * 1024 # 500 MB limit for audio - def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: - """Process audio data and save to cached path.""" - with open(cached_path, 'wb') as f: + def _process_media_data(self, data: bytes, cached_path: str, **_kwargs: Any) -> str: + """ + Process audio data and save to cached path. + + Parameters + ---------- + data : bytes + The audio data to process. + cached_path : str + Path where the processed audio should be saved. + **_kwargs : Any + Additional keyword arguments (unused). + + Returns + ------- + str + The path to the cached audio file. + """ + with Path(cached_path).open("wb") as f: f.write(data) self._cleanup_old_files() return cached_path def _get_media_type_name(self) -> str: - """Get media type name for logging.""" + """ + Get media type name for logging. + + Returns + ------- + str + The string 'audio' for logging purposes. + """ return "audio" async def process_audio_url(self, audio_url: str) -> str: - """Process a single audio URL and return path to cached file.""" + """ + Process a single audio URL and return path to cached file. + + Parameters + ---------- + audio_url : str + The URL of the audio file to process. + + Returns + ------- + str + Path to the cached audio file. + """ return await self._process_single_media(audio_url) - async def process_audio_urls(self, audio_urls: List[str]) -> List[str]: - """Process multiple audio URLs and return paths to cached files.""" + async def process_audio_urls(self, audio_urls: list[str]) -> list[str | BaseException]: + """ + Process multiple audio URLs and return a list containing either file path strings or BaseException instances for failed items. + + Parameters + ---------- + audio_urls : list[str] + List of audio URLs to process. + + Returns + ------- + list[str | BaseException] + List where each element is either a path to a cached audio file (str) or a BaseException for failed processing. + """ tasks = [self.process_audio_url(url) for url in audio_urls] results = await asyncio.gather(*tasks, return_exceptions=True) # Force garbage collection after batch processing gc.collect() - return results \ No newline at end of file + return results diff --git a/app/core/base_processor.py b/app/core/base_processor.py index 52833fcf..f16354c1 100644 --- a/app/core/base_processor.py +++ b/app/core/base_processor.py @@ -1,30 +1,38 @@ +"""Base processor classes for media processing with caching and validation.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod import base64 +from concurrent.futures import ThreadPoolExecutor +import gc import hashlib -import os +from pathlib import Path import tempfile -import aiohttp import time -import gc +from types import TracebackType +from typing import Any, Self + +import aiofiles +import aiohttp from loguru import logger -from typing import Dict, Optional, Any -from concurrent.futures import ThreadPoolExecutor -from abc import ABC, abstractmethod class BaseProcessor(ABC): """Base class for media processors with common caching and session management.""" - - def __init__(self, max_workers: int = 4, cache_size: int = 1000): + + def __init__(self, max_workers: int = 4, cache_size: int = 1000) -> None: # Use tempfile for macOS-efficient temporary file handling self.temp_dir = tempfile.TemporaryDirectory() - self._session: Optional[aiohttp.ClientSession] = None + self._session: aiohttp.ClientSession | None = None self.executor = ThreadPoolExecutor(max_workers=max_workers) self._cache_size = cache_size self._last_cleanup = time.time() self._cleanup_interval = 3600 # 1 hour # Replace lru_cache with manual cache for better control - self._hash_cache: Dict[str, str] = {} - self._cache_access_times: Dict[str, float] = {} + self._hash_cache: dict[str, str] = {} + self._cache_access_times: dict[str, float] = {} + self._cleaned: bool = False def _get_media_hash(self, media_url: str) -> str: """Get hash for media URL with manual caching that can be cleared.""" @@ -32,176 +40,205 @@ def _get_media_hash(self, media_url: str) -> str: if media_url in self._hash_cache: self._cache_access_times[media_url] = time.time() return self._hash_cache[media_url] - + # Generate hash if media_url.startswith("data:"): _, encoded = media_url.split(",", 1) data = base64.b64decode(encoded) else: - data = media_url.encode('utf-8') - + data = media_url.encode("utf-8") + hash_value = hashlib.md5(data).hexdigest() - + # Add to cache with size management if len(self._hash_cache) >= self._cache_size: self._evict_oldest_cache_entries() - + self._hash_cache[media_url] = hash_value self._cache_access_times[media_url] = time.time() return hash_value - def _evict_oldest_cache_entries(self): + def _evict_oldest_cache_entries(self) -> None: """Remove oldest 20% of cache entries to make room.""" if not self._cache_access_times: return - + # Sort by access time and remove oldest 20% sorted_items = sorted(self._cache_access_times.items(), key=lambda x: x[1]) to_remove = len(sorted_items) // 5 # Remove 20% - + for url, _ in sorted_items[:to_remove]: self._hash_cache.pop(url, None) self._cache_access_times.pop(url, None) - + # Force garbage collection after cache eviction gc.collect() @abstractmethod - def _get_media_format(self, media_url: str, data: bytes = None) -> str: + def _get_media_format(self, media_url: str, data: bytes | None = None) -> str: """Determine media format from URL or data. Must be implemented by subclasses.""" - pass @abstractmethod def _validate_media_data(self, data: bytes) -> bool: """Validate media data. Must be implemented by subclasses.""" - pass @abstractmethod def _get_timeout(self) -> int: """Get timeout for HTTP requests. Must be implemented by subclasses.""" - pass @abstractmethod def _get_max_file_size(self) -> int: """Get maximum file size in bytes. Must be implemented by subclasses.""" - pass @abstractmethod - def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> Dict[str, Any]: - """Process media data and save to cached path. Must be implemented by subclasses.""" - pass + def _process_media_data(self, data: bytes, cached_path: str, **kwargs: Any) -> str: + """Process media data and save to cached path and return the cached file path.""" @abstractmethod def _get_media_type_name(self) -> str: """Get media type name for logging. Must be implemented by subclasses.""" - pass async def _get_session(self) -> aiohttp.ClientSession: if self._session is None or self._session.closed: self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self._get_timeout()), - headers={"User-Agent": "mlx-server-OAI-compat/1.0"} + headers={"User-Agent": "mlx-server-OAI-compat/1.0"}, ) return self._session - def _cleanup_old_files(self): + def _cleanup_old_files(self) -> None: current_time = time.time() if current_time - self._last_cleanup > self._cleanup_interval: try: - for file in os.listdir(self.temp_dir.name): - file_path = os.path.join(self.temp_dir.name, file) - if os.path.getmtime(file_path) < current_time - self._cleanup_interval: - os.remove(file_path) + temp_dir_path = Path(self.temp_dir.name) + for file_path in temp_dir_path.iterdir(): + if file_path.stat().st_mtime < current_time - self._cleanup_interval: + file_path.unlink() self._last_cleanup = current_time # Also clean up cache periodically if len(self._hash_cache) > self._cache_size * 0.8: self._evict_oldest_cache_entries() gc.collect() # Force garbage collection after cleanup except Exception as e: - logger.warning(f"Failed to clean up old {self._get_media_type_name()} files: {str(e)}") + logger.warning( + f"Failed to clean up old {self._get_media_type_name()} files. {type(e).__name__}: {e}" + ) - async def _process_single_media(self, media_url: str, **kwargs) -> str: + async def _process_single_media(self, media_url: str, **kwargs: Any) -> str: try: media_hash = self._get_media_hash(media_url) media_format = self._get_media_format(media_url) - cached_path = os.path.join(self.temp_dir.name, f"{media_hash}.{media_format}") + cached_path = str(Path(self.temp_dir.name) / f"{media_hash}.{media_format}") - if os.path.exists(cached_path): + if Path(cached_path).exists(): logger.debug(f"Using cached {self._get_media_type_name()}: {cached_path}") return cached_path - if os.path.exists(media_url): + if Path(media_url).exists(): + # Check file size before opening + file_size = Path(media_url).stat().st_size + if file_size > self._get_max_file_size(): + raise ValueError( + f"Local {self._get_media_type_name()} file exceeds size limit: {file_size} > {self._get_max_file_size()}" + ) # Copy local file to cache - with open(media_url, 'rb') as f: - data = f.read() - + async with aiofiles.open(media_url, "rb") as f: + data = await f.read() + + # Validate size after reading (in case file changed) + if len(data) > self._get_max_file_size(): + raise ValueError( + f"Read {self._get_media_type_name()} data exceeds size limit: {len(data)} > {self._get_max_file_size()}" + ) + if not self._validate_media_data(data): raise ValueError(f"Invalid {self._get_media_type_name()} file format") - + return self._process_media_data(data, cached_path, **kwargs) - elif media_url.startswith("data:"): + if media_url.startswith("data:"): _, encoded = media_url.split(",", 1) estimated_size = len(encoded) * 3 / 4 if estimated_size > self._get_max_file_size(): - raise ValueError(f"Base64-encoded {self._get_media_type_name()} exceeds size limit") + raise ValueError( + f"Base64-encoded {self._get_media_type_name()} exceeds size limit" + ) data = base64.b64decode(encoded) - + if not self._validate_media_data(data): raise ValueError(f"Invalid {self._get_media_type_name()} file format") - + + return self._process_media_data(data, cached_path, **kwargs) + session = await self._get_session() + async with session.get(media_url) as response: + response.raise_for_status() + # Check Content-Length if available + content_length = response.headers.get("Content-Length") + if content_length: + try: + size = int(content_length) + if size > self._get_max_file_size(): + raise ValueError( + f"HTTP {self._get_media_type_name()} Content-Length exceeds size limit: {size} > {self._get_max_file_size()}" + ) + except ValueError: + logger.warning(f"Invalid Content-Length header: {content_length}") + data = await response.read() + + # Validate size after reading + if len(data) > self._get_max_file_size(): + raise ValueError( + f"Downloaded {self._get_media_type_name()} data exceeds size limit: {len(data)} > {self._get_max_file_size()}" + ) + + if not self._validate_media_data(data): + raise ValueError(f"Invalid {self._get_media_type_name()} file format") + return self._process_media_data(data, cached_path, **kwargs) - else: - session = await self._get_session() - async with session.get(media_url) as response: - response.raise_for_status() - data = await response.read() - - if not self._validate_media_data(data): - raise ValueError(f"Invalid {self._get_media_type_name()} file format") - - return self._process_media_data(data, cached_path, **kwargs) except Exception as e: - logger.error(f"Failed to process {self._get_media_type_name()}: {str(e)}") - raise ValueError(f"Failed to process {self._get_media_type_name()}: {str(e)}") + logger.error(f"Failed to process {self._get_media_type_name()} {type(e).__name__}: {e}") + raise ValueError(f"Failed to process {self._get_media_type_name()}: {e}") from e finally: gc.collect() - def clear_cache(self): + def clear_cache(self) -> None: """Manually clear the hash cache to free memory.""" self._hash_cache.clear() self._cache_access_times.clear() gc.collect() - async def cleanup(self): - if hasattr(self, '_cleaned') and self._cleaned: + async def cleanup(self) -> None: + """Clean up resources and caches.""" + if hasattr(self, "_cleaned") and self._cleaned: return self._cleaned = True try: # Clear caches before cleanup self.clear_cache() - + if self._session and not self._session.closed: await self._session.close() except Exception as e: - logger.warning(f"Exception closing aiohttp session: {str(e)}") + logger.warning(f"Exception closing aiohttp session. {type(e).__name__}: {e}") try: self.executor.shutdown(wait=True) except Exception as e: - logger.warning(f"Exception shutting down executor: {str(e)}") + logger.warning(f"Exception shutting down executor. {type(e).__name__}: {e}") try: self.temp_dir.cleanup() except Exception as e: - logger.warning(f"Exception cleaning up temp directory: {str(e)}") + logger.warning(f"Exception cleaning up temp directory. {type(e).__name__}: {e}") - async def __aenter__(self): + async def __aenter__(self) -> Self: + """Enter async context manager.""" return self - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Exit async context manager and cleanup.""" await self.cleanup() - - def __del__(self): - # Async cleanup cannot be reliably performed in __del__ - # Please use 'async with Processor()' or call 'await cleanup()' explicitly. - pass \ No newline at end of file diff --git a/app/core/image_processor.py b/app/core/image_processor.py index c38ccc65..f83bbd5c 100644 --- a/app/core/image_processor.py +++ b/app/core/image_processor.py @@ -1,49 +1,56 @@ -import gc +"""Image processing utilities for MLX OpenAI server.""" + +from __future__ import annotations + import asyncio -from PIL import Image -from loguru import logger +import contextlib +import gc from io import BytesIO -from typing import List +from typing import Any + +from loguru import logger +from PIL import Image + from .base_processor import BaseProcessor class ImageProcessor(BaseProcessor): """Image processor for handling image files with caching, validation, and processing.""" - - def __init__(self, max_workers: int = 4, cache_size: int = 1000): + + def __init__(self, max_workers: int = 4, cache_size: int = 1000) -> None: super().__init__(max_workers, cache_size) Image.MAX_IMAGE_PIXELS = 100000000 # Limit to 100 megapixels - def _get_media_format(self, media_url: str, data: bytes = None) -> str: + def _get_media_format(self, _media_url: str, _data: bytes | None = None) -> str: """Determine image format from URL or data.""" - # For images, we always save as JPEG for consistency - return "jpg" + # For images, we always save as PNG for consistency + return "png" def _validate_media_data(self, data: bytes) -> bool: - """Basic validation of image data.""" + """Validate basic image data.""" if len(data) < 100: # Too small to be a valid image file return False - + # Check for common image file signatures image_signatures = [ - b'\xff\xd8\xff', # JPEG - b'\x89PNG\r\n\x1a\n', # PNG - b'GIF87a', # GIF87a - b'GIF89a', # GIF89a - b'BM', # BMP - b'II*\x00', # TIFF (little endian) - b'MM\x00*', # TIFF (big endian) - b'RIFF', # WebP (part of RIFF) + b"\xff\xd8\xff", # JPEG + b"\x89PNG\r\n\x1a\n", # PNG + b"GIF87a", # GIF87a + b"GIF89a", # GIF89a + b"BM", # BMP + b"II*\x00", # TIFF (little endian) + b"MM\x00*", # TIFF (big endian) + b"RIFF", # WebP (part of RIFF) ] - + for sig in image_signatures: if data.startswith(sig): return True - + # Additional check for WebP - if data.startswith(b'RIFF') and b'WEBP' in data[:20]: + if data.startswith(b"RIFF") and b"WEBP" in data[:20]: return True - + return False def _get_timeout(self) -> int: @@ -58,7 +65,9 @@ def _get_media_type_name(self) -> str: """Get media type name for logging.""" return "image" - def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 448) -> Image.Image: + def _resize_image_keep_aspect_ratio( + self, image: Image.Image, max_size: int = 448 + ) -> Image.Image: width, height = image.size if width <= max_size and height <= max_size: return image @@ -75,46 +84,46 @@ def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 44 return image def _prepare_image_for_saving(self, image: Image.Image) -> Image.Image: - if image.mode in ('RGBA', 'LA'): - background = Image.new('RGB', image.size, (255, 255, 255)) - if image.mode == 'RGBA': + if image.mode in ("RGBA", "LA"): + background = Image.new("RGB", image.size, (255, 255, 255)) + if image.mode == "RGBA": background.paste(image, mask=image.split()[3]) else: background.paste(image, mask=image.split()[1]) return background - elif image.mode != 'RGB': - return image.convert('RGB') + if image.mode != "RGB": + return image.convert("RGB") return image - def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: + def _process_media_data(self, data: bytes, cached_path: str, **kwargs: Any) -> str: """Process image data and save to cached path.""" image = None resize = kwargs.get("resize", True) try: - with Image.open(BytesIO(data), mode='r') as image: + with Image.open(BytesIO(data), mode="r") as image: if resize: image = self._resize_image_keep_aspect_ratio(image) image = self._prepare_image_for_saving(image) - image.save(cached_path, 'PNG', quality=100, optimize=True) - + image.save(cached_path, "PNG", quality=100, optimize=True) + self._cleanup_old_files() return cached_path finally: # Ensure image object is closed to free memory if image: - try: + with contextlib.suppress(Exception): image.close() - except: - pass async def process_image_url(self, image_url: str, resize: bool = True) -> str: """Process a single image URL and return path to cached file.""" return await self._process_single_media(image_url, resize=resize) - async def process_image_urls(self, image_urls: List[str], resize: bool = True) -> List[str]: - """Process multiple image URLs and return paths to cached files.""" + async def process_image_urls( + self, image_urls: list[str], resize: bool = True + ) -> list[str | BaseException]: + """Process multiple image URLs and return paths to cached files (exceptions may be BaseException).""" tasks = [self.process_image_url(url, resize=resize) for url in image_urls] results = await asyncio.gather(*tasks, return_exceptions=True) # Force garbage collection after batch processing gc.collect() - return results \ No newline at end of file + return results diff --git a/app/core/model_registry.py b/app/core/model_registry.py index b47436bb..212c0251 100644 --- a/app/core/model_registry.py +++ b/app/core/model_registry.py @@ -1,8 +1,10 @@ """Model registry for managing multiple model handlers.""" +from __future__ import annotations + import asyncio import time -from typing import Any, Dict, List, Optional +from typing import Any from loguru import logger @@ -17,16 +19,17 @@ class ModelRegistry: In Phase 1, this wraps the existing single-model flow. Future phases will extend this to support multi-model loading and hot-swapping. - Attributes: + Attributes + ---------- _handlers: Dict mapping model_id to handler instance _metadata: Dict mapping model_id to ModelMetadata _lock: Async lock for thread-safe operations """ - def __init__(self): + def __init__(self) -> None: """Initialize empty model registry.""" - self._handlers: Dict[str, Any] = {} - self._metadata: Dict[str, ModelMetadata] = {} + self._handlers: dict[str, Any] = {} + self._metadata: dict[str, ModelMetadata] = {} self._lock = asyncio.Lock() logger.info("Model registry initialized") @@ -35,7 +38,7 @@ async def register_model( model_id: str, handler: Any, model_type: str, - context_length: Optional[int] = None, + context_length: int | None = None, ) -> None: """ Register a model handler with metadata. @@ -46,7 +49,8 @@ async def register_model( model_type: Type of model (lm, multimodal, embeddings, etc.) context_length: Maximum context length (if applicable) - Raises: + Raises + ------ ValueError: If model_id already registered """ async with self._lock: @@ -66,8 +70,7 @@ async def register_model( self._metadata[model_id] = metadata logger.info( - f"Registered model: {model_id} (type={model_type}, " - f"context_length={context_length})" + f"Registered model: {model_id} (type={model_type}, context_length={context_length})" ) def get_handler(self, model_id: str) -> Any: @@ -77,21 +80,24 @@ def get_handler(self, model_id: str) -> Any: Args: model_id: Model identifier - Returns: + Returns + ------- Handler instance - Raises: + Raises + ------ KeyError: If model_id not found """ if model_id not in self._handlers: raise KeyError(f"Model '{model_id}' not found in registry") return self._handlers[model_id] - def list_models(self) -> List[Dict[str, Any]]: + def list_models(self) -> list[dict[str, Any]]: """ List all registered models with metadata. - Returns: + Returns + ------- List of model metadata dicts in OpenAI API format """ return [ @@ -111,10 +117,12 @@ def get_metadata(self, model_id: str) -> ModelMetadata: Args: model_id: Model identifier - Returns: + Returns + ------- ModelMetadata instance - Raises: + Raises + ------ KeyError: If model_id not found """ if model_id not in self._metadata: @@ -131,7 +139,8 @@ async def unregister_model(self, model_id: str) -> None: Args: model_id: Model identifier - Raises: + Raises + ------ KeyError: If model_id not found """ async with self._lock: @@ -151,7 +160,8 @@ def has_model(self, model_id: str) -> bool: Args: model_id: Model identifier - Returns: + Returns + ------- True if model is registered, False otherwise """ return model_id in self._handlers @@ -160,7 +170,8 @@ def get_model_count(self) -> int: """ Get count of registered models. - Returns: + Returns + ------- Number of registered models """ return len(self._handlers) diff --git a/app/core/queue.py b/app/core/queue.py index bbc8bf2d..6e7e9f0e 100644 --- a/app/core/queue.py +++ b/app/core/queue.py @@ -1,43 +1,52 @@ +"""Asynchronous request queue with concurrency control.""" + +from __future__ import annotations + import asyncio -import time -from typing import Any, Dict, Optional, Callable, Awaitable, TypeVar, Generic +from collections.abc import Awaitable, Callable +import contextlib import gc +import time +from typing import Any, Generic, TypeVar + from loguru import logger -T = TypeVar('T') +T = TypeVar("T") + class RequestItem(Generic[T]): - """ - Represents a single request in the queue. - """ - def __init__(self, request_id: str, data: Any): + """Represents a single request in the queue.""" + + def __init__(self, request_id: str, data: T) -> None: self.request_id = request_id self.data = data self.created_at = time.time() - self.future = asyncio.Future() - + self.future: asyncio.Future[T] = asyncio.Future() + def set_result(self, result: T) -> None: """Set the result for this request.""" if not self.future.done(): self.future.set_result(result) - + def set_exception(self, exc: Exception) -> None: """Set an exception for this request.""" if not self.future.done(): self.future.set_exception(exc) - + async def get_result(self) -> T: """Wait for and return the result of this request.""" return await self.future + class RequestQueue: - """ - A simple asynchronous request queue with configurable concurrency. - """ - def __init__(self, max_concurrency: int = 2, timeout: float = 300.0, queue_size: int = 100): + """A simple asynchronous request queue with configurable concurrency.""" + + def __init__( + self, max_concurrency: int = 2, timeout: float = 300.0, queue_size: int = 100 + ) -> None: """ Initialize the request queue. - + Args: max_concurrency (int): Maximum number of concurrent requests to process. timeout (float): Timeout in seconds for request processing. @@ -47,40 +56,48 @@ def __init__(self, max_concurrency: int = 2, timeout: float = 300.0, queue_size: self.timeout = timeout self.queue_size = queue_size self.semaphore = asyncio.Semaphore(max_concurrency) - self.queue = asyncio.Queue(maxsize=queue_size) - self.active_requests: Dict[str, RequestItem] = {} - self._worker_task = None + self.queue: asyncio.Queue[RequestItem[Any]] = asyncio.Queue(maxsize=queue_size) + self.active_requests: dict[str, RequestItem[Any]] = {} + self._worker_task: asyncio.Task[None] | None = None self._running = False - - async def start(self, processor: Callable[[Any], Awaitable[Any]]): + self._tasks: set[asyncio.Task[None]] = set() + + async def start(self, processor: Callable[[Any], Awaitable[Any]]) -> None: """ Start the queue worker. - + Args: processor: Async function that processes queue items. """ if self._running: return - + self._running = True self._worker_task = asyncio.create_task(self._worker_loop(processor)) logger.info(f"Started request queue with max concurrency: {self.max_concurrency}") - - async def stop(self): + + async def stop(self) -> None: """Stop the queue worker.""" if not self._running: return - + self._running = False - + # Cancel the worker task if self._worker_task and not self._worker_task.done(): self._worker_task.cancel() - try: + with contextlib.suppress(asyncio.CancelledError): await self._worker_task - except asyncio.CancelledError: - pass - + + # Cancel all in-flight tasks + tasks_snapshot = list(self._tasks) + for task in tasks_snapshot: + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*tasks_snapshot, return_exceptions=True) + self._tasks.clear() + # Cancel all pending requests pending_requests = list(self.active_requests.values()) for request in pending_requests: @@ -88,28 +105,28 @@ async def stop(self): request.future.cancel() # Clean up request data try: - if hasattr(request, 'data'): + if hasattr(request, "data"): del request.data - except: - pass - + except Exception as e: + logger.opt(exception=e).debug("Failed to remove request.data") + self.active_requests.clear() - + # Clear the queue while not self.queue.empty(): try: self.queue.get_nowait() except asyncio.QueueEmpty: break - + # Force garbage collection after cleanup gc.collect() logger.info("Stopped request queue") - - async def _worker_loop(self, processor: Callable[[Any], Awaitable[Any]]): + + async def _worker_loop(self, processor: Callable[[Any], Awaitable[Any]]) -> None: """ - Main worker loop that processes queue items. - + Process queue items in main worker loop. + Args: processor: Async function that processes queue items. """ @@ -117,19 +134,23 @@ async def _worker_loop(self, processor: Callable[[Any], Awaitable[Any]]): try: # Get the next item from the queue request = await self.queue.get() - + # Process the request with concurrency control - asyncio.create_task(self._process_request(request, processor)) - + task = asyncio.create_task(self._process_request(request, processor)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + except asyncio.CancelledError: break except Exception as e: - logger.error(f"Error in worker loop: {str(e)}") - - async def _process_request(self, request: RequestItem, processor: Callable[[Any], Awaitable[Any]]): + logger.error(f"Error in worker loop. {type(e).__name__}: {e}") + + async def _process_request( + self, request: RequestItem[Any], processor: Callable[[Any], Awaitable[Any]] + ) -> None: """ Process a single request with timeout and error handling. - + Args: request: The request to process. processor: Async function that processes the request. @@ -139,97 +160,110 @@ async def _process_request(self, request: RequestItem, processor: Callable[[Any] try: # Process with timeout processing_start = time.time() - result = await asyncio.wait_for( - processor(request.data), - timeout=self.timeout - ) + result = await asyncio.wait_for(processor(request.data), timeout=self.timeout) processing_time = time.time() - processing_start - + # Set the result request.set_result(result) logger.info(f"Request {request.request_id} processed in {processing_time:.2f}s") - - except asyncio.TimeoutError: - request.set_exception(TimeoutError(f"Request processing timed out after {self.timeout}s")) + + except TimeoutError: + request.set_exception( + TimeoutError(f"Request processing timed out after {self.timeout}s") + ) logger.warning(f"Request {request.request_id} timed out after {self.timeout}s") - + except asyncio.CancelledError as e: + # Propagate cancellation but ensure the future is not left hanging + if not request.future.done(): + request.future.set_exception(e) + logger.info(f"Request {request.request_id} was cancelled") + raise except Exception as e: request.set_exception(e) - logger.error(f"Error processing request {request.request_id}: {str(e)}") - + logger.error( + f"Error processing request {request.request_id}. {type(e).__name__}: {e}" + ) + finally: # Always remove from active requests, even if an error occurred removed_request = self.active_requests.pop(request.request_id, None) if removed_request: # Clean up the request object try: - if hasattr(removed_request, 'data'): + if hasattr(removed_request, "data"): del removed_request.data - except: - pass + except Exception as e: + logger.opt(exception=e).debug("Failed to remove request.data") # Force garbage collection periodically to prevent memory buildup if len(self.active_requests) % 10 == 0: # Every 10 requests gc.collect() - - async def enqueue(self, request_id: str, data: Any) -> RequestItem: + + async def enqueue(self, request_id: str, data: Any) -> RequestItem[Any]: """ Add a request to the queue. - + Args: request_id: Unique ID for the request. data: The request data to process. - - Returns: + + Returns + ------- RequestItem: The queued request item. - - Raises: + + Raises + ------ asyncio.QueueFull: If the queue is full. """ if not self._running: raise RuntimeError("Queue is not running") - + # Create request item request = RequestItem(request_id, data) - + # Add to active requests and queue self.active_requests[request_id] = request - + try: # This will raise QueueFull if the queue is full await asyncio.wait_for( self.queue.put(request), - timeout=1.0 # Short timeout for queue put + timeout=1.0, # Short timeout for queue put ) + except TimeoutError: + self.active_requests.pop(request_id, None) + raise asyncio.QueueFull( + "Request queue is full and timed out waiting for space" + ) from None + else: queue_time = time.time() - request.created_at logger.info(f"Request {request_id} queued (wait: {queue_time:.2f}s)") return request - - except asyncio.TimeoutError: - self.active_requests.pop(request_id, None) - raise asyncio.QueueFull("Request queue is full and timed out waiting for space") - + async def submit(self, request_id: str, data: Any) -> Any: """ Submit a request and wait for its result. - + Args: request_id: Unique ID for the request. data: The request data to process. - - Returns: + + Returns + ------- The result of processing the request. - - Raises: + + Raises + ------ Various exceptions that may occur during processing. """ request = await self.enqueue(request_id, data) return await request.get_result() - - def get_queue_stats(self) -> Dict[str, Any]: + + def get_queue_stats(self) -> dict[str, Any]: """ Get queue statistics. - - Returns: + + Returns + ------- Dict with queue statistics. """ return { @@ -237,10 +271,10 @@ def get_queue_stats(self) -> Dict[str, Any]: "queue_size": self.queue.qsize(), "max_queue_size": self.queue_size, "active_requests": len(self.active_requests), - "max_concurrency": self.max_concurrency + "max_concurrency": self.max_concurrency, } # Alias for the async stop method to maintain consistency in cleanup interfaces - async def stop_async(self): + async def stop_async(self) -> None: """Alias for stop - stops the queue worker asynchronously.""" - await self.stop() \ No newline at end of file + await self.stop() diff --git a/app/core/video_processor.py b/app/core/video_processor.py index 92d7141d..5a1120a0 100644 --- a/app/core/video_processor.py +++ b/app/core/video_processor.py @@ -1,87 +1,101 @@ -import os -import gc +"""Video processing utilities for MLX OpenAI server.""" + +from __future__ import annotations + import asyncio +import gc +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + from loguru import logger -from typing import List + from .base_processor import BaseProcessor class VideoProcessor(BaseProcessor): """Video processor for handling video files with caching, validation, and processing.""" - - def __init__(self, max_workers: int = 4, cache_size: int = 1000): + + def __init__(self, max_workers: int = 4, cache_size: int = 1000) -> None: super().__init__(max_workers, cache_size) # Supported video formats - self._supported_formats = {'.mp4', '.avi', '.mov'} + self._supported_formats = {".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"} - def _get_media_format(self, media_url: str, data: bytes = None) -> str: + def _get_media_format(self, media_url: str, _data: bytes | None = None) -> str: """Determine video format from URL or data.""" if media_url.startswith("data:"): # Extract format from data URL mime_type = media_url.split(";")[0].split(":")[1] if "mp4" in mime_type: return "mp4" - elif "quicktime" in mime_type or "mov" in mime_type: + if "quicktime" in mime_type or "mov" in mime_type: return "mov" - elif "x-msvideo" in mime_type or "avi" in mime_type: + if "x-msvideo" in mime_type or "avi" in mime_type: return "avi" + if "webm" in mime_type: + return "webm" + if "x-matroska" in mime_type or "matroska" in mime_type: + return "mkv" + if "x-flv" in mime_type or "flv" in mime_type: + return "flv" else: # Extract format from file extension - ext = os.path.splitext(media_url.lower())[1] + parsed = urlparse(media_url) + if parsed.scheme: + # It's a URL, get the path part + path = parsed.path + else: + path = media_url + ext = Path(path.lower()).suffix if ext in self._supported_formats: return ext[1:] # Remove the dot - + # Default to mp4 if format cannot be determined return "mp4" def _validate_media_data(self, data: bytes) -> bool: - """Basic validation of video data.""" + """Validate basic video data.""" if len(data) < 100: # Too small to be a valid video file return False - + # Check for common video file signatures video_signatures = [ # MP4/M4V/MOV (ISO Base Media File Format) - (b'\x00\x00\x00\x14ftypisom', 0), # MP4 - (b'\x00\x00\x00\x18ftyp', 0), # MP4/MOV - (b'\x00\x00\x00\x1cftyp', 0), # MP4/MOV - (b'\x00\x00\x00\x20ftyp', 0), # MP4/MOV - (b'ftyp', 4), # MP4/MOV (ftyp at offset 4) - + (b"\x00\x00\x00\x14ftypisom", 0), # MP4 + (b"\x00\x00\x00\x18ftyp", 0), # MP4/MOV + (b"\x00\x00\x00\x1cftyp", 0), # MP4/MOV + (b"\x00\x00\x00\x20ftyp", 0), # MP4/MOV + (b"ftyp", 4), # MP4/MOV (ftyp at offset 4) # AVI - (b'RIFF', 0), # AVI (also check for 'AVI ' at offset 8) - + (b"RIFF", 0), # AVI (also check for 'AVI ' at offset 8) # WebM/MKV (Matroska) - (b'\x1a\x45\xdf\xa3', 0), # Matroska/WebM - + (b"\x1a\x45\xdf\xa3", 0), # Matroska/WebM # FLV - (b'FLV\x01', 0), # Flash Video - + (b"FLV\x01", 0), # Flash Video # MPEG - (b'\x00\x00\x01\xba', 0), # MPEG PS - (b'\x00\x00\x01\xb3', 0), # MPEG PS - + (b"\x00\x00\x01\xba", 0), # MPEG PS + (b"\x00\x00\x01\xb3", 0), # MPEG PS # QuickTime - (b'moov', 0), # QuickTime - (b'mdat', 0), # QuickTime + (b"moov", 0), # QuickTime + (b"mdat", 0), # QuickTime ] - + for sig, offset in video_signatures: - if len(data) > offset + len(sig): - if data[offset:offset+len(sig)] == sig: + if len(data) >= offset + len(sig): + if data[offset : offset + len(sig)] == sig: # Additional validation for AVI - if sig == b'RIFF' and len(data) > 12: - if data[8:12] == b'AVI ': + if sig == b"RIFF" and len(data) > 12: + if data[8:12] == b"AVI ": return True - elif sig == b'RIFF': + elif sig == b"RIFF": continue # Not AVI, might be WAV else: return True - + # Check for ftyp box anywhere in first 32 bytes (MP4/MOV) - if b'ftyp' in data[:32]: + if b"ftyp" in data[:32]: return True - + # Allow unknown formats to pass through for flexibility return True @@ -93,18 +107,18 @@ def _get_max_file_size(self) -> int: """Get maximum file size in bytes.""" return 1024 * 1024 * 1024 # 1 GB limit for videos - def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: + def _process_media_data(self, data: bytes, cached_path: str, **_kwargs: Any) -> str: """Process video data and save to cached path.""" try: - with open(cached_path, 'wb') as f: + with Path(cached_path).open("wb") as f: f.write(data) - + except Exception as e: + logger.error(f"Failed to save video data. {type(e).__name__}: {e}") + raise + else: logger.info(f"Saved video to {cached_path} ({len(data)} bytes)") self._cleanup_old_files() return cached_path - except Exception as e: - logger.error(f"Failed to save video data: {str(e)}") - raise def _get_media_type_name(self) -> str: """Get media type name for logging.""" @@ -113,29 +127,31 @@ def _get_media_type_name(self) -> str: async def process_video_url(self, video_url: str) -> str: """ Process a single video URL and return path to cached file. - + Supports: - HTTP/HTTPS URLs (downloads video) - Local file paths (copies to cache) - Data URLs (base64 encoded videos) - + Args: video_url: URL, file path, or data URL of the video - - Returns: + + Returns + ------- Path to the cached video file in temp directory """ return await self._process_single_media(video_url) - async def process_video_urls(self, video_urls: List[str]) -> List[str]: + async def process_video_urls(self, video_urls: list[str]) -> list[str | BaseException]: """ Process multiple video URLs and return paths to cached files. - + Args: video_urls: List of URLs, file paths, or data URLs of videos - - Returns: - List of paths to cached video files + + Returns + ------- + List of cached file paths or BaseException instances for failed items """ tasks = [self.process_video_url(url) for url in video_urls] results = await asyncio.gather(*tasks, return_exceptions=True)