diff --git a/Makefile b/Makefile index 0398c32..519efd2 100644 --- a/Makefile +++ b/Makefile @@ -7,4 +7,4 @@ run: --queue-size 100 install: - pip install -e . \ No newline at end of file + pip install -e . diff --git a/README.md b/README.md index a2243f9..df8b01b 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,9 @@ The server supports six types of MLX models: ### Flux-Series Image Models -The server supports multiple Flux and Qwen model configurations for advanced image generation and editing: +> **⚠️ Note:** Image generation and editing capabilities require installation of `mflux`: `pip install mlx-openai-server[image-generation]` or `pip install git+https://github.com/cubist38/mflux.git` + +The server supports multiple Flux model configurations for advanced image generation and editing: #### Image Generation Models - **`flux-schnell`** - Fast generation with 4 default steps, no guidance (best for quick iterations) @@ -202,6 +204,9 @@ Follow these steps to set up the MLX-powered server: git clone https://github.com/cubist38/mlx-openai-server.git cd mlx-openai-server pip install -e . + + # Optional: For image generation/editing support + pip install -e .[image-generation] ``` ### Using Conda (Recommended) @@ -236,6 +241,9 @@ For better environment management and to avoid architecture issues, we recommend git clone https://github.com/cubist38/mlx-openai-server.git cd mlx-openai-server pip install -e . + + # Optional: For image generation/editing support + pip install -e .[image-generation] ``` ### Optional Dependencies @@ -253,15 +261,44 @@ pip install mlx-openai-server - All core API endpoints and functionality #### Image Generation & Editing Support -The server includes support for image generation and editing capabilities: +For image generation and editing capabilities, install with the image-generation extra: + +```bash +# Install with image generation support +pip install mlx-openai-server[image-generation] +``` + +Or install manually: +```bash +# First install the base server +pip install mlx-openai-server + +# Then install mflux for image generation/editing support +pip install git+https://github.com/cubist38/mflux.git +``` -**Additional features:** +**Additional features with mflux:** - Image generation models (`--model-type image-generation`) - Image editing models (`--model-type image-edit`) - MLX Flux-series model support - Qwen Image model support - LoRA adapter support for fine-tuned generation and editing +#### Enhanced Caching Support +For enhanced caching and performance when working with complex ML models and objects, install with the enhanced-caching extra: + +```bash +# Install with enhanced caching support +pip install mlx-openai-server[enhanced-caching] +``` + +This enables better serialization and caching of objects from: +- spaCy (NLP processing) +- regex (regular expressions) +- tiktoken (tokenization) +- torch (PyTorch tensors and models) +- transformers (Hugging Face models) + #### Whisper Models Support For whisper models to work properly, you need to install ffmpeg: diff --git a/app/__init__.py b/app/__init__.py index 0b90d88..6375959 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,7 +1 @@ -import os -from .version import __version__ - -# Suppress transformers warnings -os.environ['TRANSFORMERS_VERBOSITY'] = 'error' - -__all__ = ["__version__"] \ No newline at end of file +"""MLX OpenAI Server package.""" diff --git a/app/cli.py b/app/cli.py index 5b740cb..e440b1e 100644 --- a/app/cli.py +++ b/app/cli.py @@ -5,6 +5,8 @@ the ASGI server. """ +from __future__ import annotations + import asyncio import sys @@ -17,7 +19,7 @@ from .version import __version__ -class UpperChoice(click.Choice): +class UpperChoice(click.Choice[str]): """Case-insensitive choice type that returns uppercase values. This small convenience subclass normalizes user input in a @@ -26,7 +28,7 @@ class UpperChoice(click.Choice): where the internal representation is uppercased. """ - def normalize_choice(self, choice, ctx): + def normalize_choice(self, choice: str | None, ctx: click.Context | None) -> str | None: # type: ignore[override] """Return the canonical uppercase choice or raise BadParameter. Parameters @@ -75,20 +77,19 @@ def normalize_choice(self, choice, ctx): 🚀 Version: %(version)s """, ) -def cli(): +def cli() -> None: """Top-level Click command group for the MLX server CLI. Subcommands (such as ``launch``) are registered on this group and invoked by the console entry point. """ - pass -@cli.command() +@cli.command(help="Start the MLX OpenAI Server with the supplied flags") @click.option( "--model-path", required=True, - help="Path to the model (required for lm, multimodal, embeddings, image-generation, image-edit, whisper model types). With `image-generation` or `image-edit` model types, it should be the local path to the model.", + help="Path to the model (required for lm, multimodal, embeddings, image-generation, image-edit, whisper model types). Can be a local path or Hugging Face repository ID (e.g., 'blackforestlabs/FLUX.1-dev').", ) @click.option( "--model-type", @@ -186,35 +187,77 @@ def cli(): help="Path to a custom chat template file. Only works with language models (lm) and multimodal models.", ) def launch( - model_path, - model_type, - context_length, - port, - host, - max_concurrency, - queue_timeout, - queue_size, - quantize, - config_name, - lora_paths, - lora_scales, - disable_auto_resize, - log_file, - no_log_file, - log_level, - enable_auto_tool_choice, - tool_call_parser, - reasoning_parser, - trust_remote_code, - chat_template_file, + model_path: str, + model_type: str, + context_length: int, + port: int, + host: str, + max_concurrency: int, + queue_timeout: int, + queue_size: int, + quantize: int, + config_name: str | None, + lora_paths: str | None, + lora_scales: str | None, + disable_auto_resize: bool, + log_file: str | None, + no_log_file: bool, + log_level: str, + enable_auto_tool_choice: bool, + tool_call_parser: str | None, + reasoning_parser: str | None, + trust_remote_code: bool, + chat_template_file: str | None, ) -> None: """Start the FastAPI/Uvicorn server with the supplied flags. The command builds a server configuration object using ``MLXServerConfig`` and then calls the async ``start`` routine which handles the event loop and server lifecycle. - """ + Parameters + ---------- + model_path : str + Path to the model (required for lm, multimodal, embeddings, image-generation, image-edit, whisper model types). + model_type : str + Type of model to run (lm, multimodal, image-generation, image-edit, embeddings, whisper). + context_length : int + Context length for language models. + port : int + Port to run the server on. + host : str + Host to run the server on. + max_concurrency : int + Maximum number of concurrent requests. + queue_timeout : int + Request timeout in seconds. + queue_size : int + Maximum queue size for pending requests. + quantize : int + Quantization level for the model. + config_name : str or None + Config name of the model. + lora_paths : str or None + Path to the LoRA file(s). + lora_scales : str or None + Scale factor for the LoRA file(s). + disable_auto_resize : bool + Disable automatic model resizing. + log_file : str or None + Path to log file. + no_log_file : bool + Disable file logging entirely. + log_level : str + Set the logging level. + enable_auto_tool_choice : bool + Enable automatic tool choice. + tool_call_parser : str or None + Specify tool call parser to use. + reasoning_parser : str or None + Specify reasoning parser to use. + trust_remote_code : bool + Enable trust_remote_code when loading models. + """ args = MLXServerConfig( model_path=model_path, model_type=model_type, diff --git a/app/config.py b/app/config.py index 576e3ea..7eb039c 100644 --- a/app/config.py +++ b/app/config.py @@ -47,7 +47,7 @@ class MLXServerConfig: lora_paths_str: str | None = None lora_scales_str: str | None = None - def __post_init__(self): + def __post_init__(self) -> None: """Normalize certain CLI fields after instantiation. - Convert comma-separated ``lora_paths`` and ``lora_scales`` into @@ -55,7 +55,6 @@ def __post_init__(self): - Apply small model-type-specific defaults for ``config_name`` and emit warnings when values appear inconsistent. """ - # Process comma-separated LoRA paths and scales into lists (or None) if self.lora_paths_str: self.lora_paths = [p.strip() for p in self.lora_paths_str.split(",") if p.strip()] @@ -74,11 +73,9 @@ def __post_init__(self): # image-edit model types. If missing for those types, set defaults. if self.config_name and self.model_type not in ["image-generation", "image-edit"]: logger.warning( - "Config name parameter '%s' provided but model type is '%s'. " + f"Config name parameter '{self.config_name}' provided but model type is '{self.model_type}'. " "Config name is only used with image-generation " - "and image-edit models.", - self.config_name, - self.model_type, + "and image-edit models." ) elif self.model_type == "image-generation" and not self.config_name: logger.warning( diff --git a/app/main.py b/app/main.py index 7f05d79..278d3c0 100644 --- a/app/main.py +++ b/app/main.py @@ -27,13 +27,19 @@ from .version import __version__ -def print_startup_banner(config_args): - """Log a compact startup banner describing the selected config. +def print_startup_banner(config_args: MLXServerConfig) -> None: + """ + Log a compact startup banner describing the selected config. The function emits human-friendly log messages that summarize the runtime configuration (model path/type, host/port, concurrency, LoRA settings, and logging options). Intended for the user-facing startup output only. + + Parameters + ---------- + config_args : MLXServerConfig + Configuration object containing runtime settings to display. """ logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") logger.info(f"✨ MLX Server v{__version__} Starting ✨") @@ -78,12 +84,18 @@ def print_startup_banner(config_args): async def start(config: MLXServerConfig) -> None: - """Run the ASGI server using the provided configuration. + """ + Run the ASGI server using the provided configuration. This coroutine wires the configuration into the server setup routine, logs progress, and starts the Uvicorn server. It handles KeyboardInterrupt and logs any startup failures before exiting the process with a non-zero code. + + Parameters + ---------- + config : MLXServerConfig + Configuration object for server setup. """ try: # Display startup information @@ -98,19 +110,20 @@ async def start(config: MLXServerConfig) -> None: except KeyboardInterrupt: logger.info("Server shutdown requested by user. Exiting...") except Exception as e: - logger.error(f"Server startup failed: {str(e)}") + logger.error(f"Server startup failed. {type(e).__name__}: {e}") sys.exit(1) -def main(): - """Normalize process args and dispatch to the Click CLI. +def main() -> None: + """ + Normalize process args and dispatch to the Click CLI. This helper gathers command-line arguments, inserts the "launch" subcommand when a subcommand is omitted for backwards compatibility, and delegates execution to :func:`app.cli.cli` through ``cli.main``. """ - from .cli import cli + from .cli import cli # noqa: PLC0415 args = [str(x) for x in sys.argv[1:]] # Keep backwards compatibility: Add 'launch' subcommand if none is provided diff --git a/app/server.py b/app/server.py index 0655aa2..779173e 100644 --- a/app/server.py +++ b/app/server.py @@ -12,9 +12,14 @@ ready to run. """ +from __future__ import annotations + import gc +import sys import time -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from http import HTTPStatus import mlx.core as mx import uvicorn @@ -22,6 +27,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from loguru import logger +from starlette.responses import Response from .api.endpoints import router from .config import MLXServerConfig @@ -59,7 +65,7 @@ def configure_logging( # Add console handler logger.add( - lambda msg: print(msg), + sys.stdout, level=log_level, format="{time:YYYY-MM-DD HH:mm:ss} | " "{level: <8} | " @@ -97,11 +103,12 @@ def get_model_identifier(config_args: MLXServerConfig) -> str: str Value that identifies the model for handler initialization. """ - return config_args.model_path -def create_lifespan(config_args: MLXServerConfig): +def create_lifespan( + config_args: MLXServerConfig, +) -> Callable[[FastAPI], AbstractAsyncContextManager[None]]: """Create an async FastAPI lifespan context manager bound to configuration. The returned context manager performs the following actions during @@ -117,17 +124,21 @@ def create_lifespan(config_args: MLXServerConfig): During shutdown the lifespan will attempt to call the handler's ``cleanup`` method and perform final memory cleanup. - Args: - config_args: Object containing CLI configuration attributes used - to initialize handlers (e.g., model_type, model_path, - max_concurrency, queue_timeout, etc.). + Parameters + ---------- + config_args : MLXServerConfig + Object containing CLI configuration attributes used + to initialize handlers (e.g., model_type, model_path, + max_concurrency, queue_timeout, etc.). - Returns: - Callable: An asynccontextmanager usable as FastAPI ``lifespan``. + Returns + ------- + Callable + An asynccontextmanager usable as FastAPI ``lifespan``. """ @asynccontextmanager - async def lifespan(app: FastAPI) -> None: + async def lifespan(app: FastAPI) -> AsyncIterator[None]: """FastAPI lifespan callable that initializes MLX handlers. On startup this function selects and initializes the correct @@ -142,6 +153,13 @@ async def lifespan(app: FastAPI) -> None: FastAPI application instance being started. """ try: + handler: ( + MLXVLMHandler + | MLXFluxHandler + | MLXEmbeddingsHandler + | MLXWhisperHandler + | MLXLMHandler + ) model_identifier = get_model_identifier(config_args) if config_args.model_type == "image-generation": logger.info(f"Initializing MLX handler with model name: {model_identifier}") @@ -217,7 +235,7 @@ async def lifespan(app: FastAPI) -> None: app.state.handler = handler except Exception as e: - logger.error(f"Failed to initialize MLX handler: {str(e)}") + logger.exception(f"Failed to initialize MLX handler. {type(e).__name__}: {e}") raise # Initial memory cleanup @@ -235,7 +253,7 @@ async def lifespan(app: FastAPI) -> None: await app.state.handler.cleanup() logger.info("Resources cleaned up successfully") except Exception as e: - logger.error(f"Error during shutdown: {str(e)}") + logger.exception(f"Error during shutdown. {type(e).__name__}: {e}") # Final memory cleanup mx.clear_cache() @@ -244,31 +262,29 @@ async def lifespan(app: FastAPI) -> None: return lifespan -# App instance will be created during setup with the correct lifespan -app = None +# FastAPI app instance is created during setup with the correct lifespan def setup_server(config_args: MLXServerConfig) -> uvicorn.Config: - global app - """Create and configure the FastAPI app and return a Uvicorn config. This function sets up logging, constructs the FastAPI application with a configured lifespan, registers routes and middleware, and returns a :class:`uvicorn.Config` ready to be used to run the server. - Note: This function mutates the module-level ``app`` global variable. - - Args: - args: Configuration object usually produced by the CLI. Expected - to have attributes like ``host``, ``port``, ``log_level``, - and logging-related fields. + Parameters + ---------- + config_args : MLXServerConfig + Configuration object usually produced by the CLI. Expected + to have attributes like ``host``, ``port``, ``log_level``, + and logging-related fields. - Returns: - uvicorn.Config: A configuration object that can be passed to + Returns + ------- + uvicorn.Config + A configuration object that can be passed to ``uvicorn.Server(config).run()`` to start the application. """ - # Configure logging based on CLI parameters configure_logging( log_file=config_args.log_file, @@ -296,12 +312,26 @@ def setup_server(config_args: MLXServerConfig) -> uvicorn.Config: ) @app.middleware("http") - async def add_process_time_header(request: Request, call_next): + async def add_process_time_header( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: """Middleware to add processing time header and run cleanup. Measures request processing time, appends an ``X-Process-Time`` header, and increments a simple request counter used to trigger periodic memory cleanup for long-running processes. + + Parameters + ---------- + request : Request + The incoming HTTP request. + call_next : Callable[[Request], Awaitable[Response]] + The next middleware or endpoint in the chain. + + Returns + ------- + Response + The HTTP response with added headers. """ start_time = time.time() response = await call_next(request) @@ -325,16 +355,28 @@ async def add_process_time_header(request: Request, call_next): return response @app.exception_handler(Exception) - async def global_exception_handler(request: Request, exc: Exception): + async def global_exception_handler(_request: Request, exc: Exception) -> JSONResponse: """Global exception handler that logs and returns a 500 payload. Logs the exception (with traceback) and returns a generic JSON response with a 500 status code so internal errors do not leak implementation details to clients. + + Parameters + ---------- + _request : Request + The incoming HTTP request (unused). + exc : Exception + The exception that was raised. + + Returns + ------- + JSONResponse + A JSON response with error details. """ - logger.error(f"Global exception handler caught: {str(exc)}", exc_info=True) + logger.exception(f"Global exception handler caught. {type(exc).__name__}: {exc!s}") return JSONResponse( - status_code=500, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"error": {"message": "Internal server error", "type": "internal_error"}}, ) diff --git a/app/version.py b/app/version.py index e590a12..518a5d6 100644 --- a/app/version.py +++ b/app/version.py @@ -1,6 +1,7 @@ +"""Version information for the mlx-openai-server package.""" # Version number format: MAJOR.MINOR.PATCH # Major: Major version number (increments when breaking changes are introduced) # Minor: Minor version number (increments when new features are added) # Patch: Patch version number (increments when bug fixes are made) -__version__ = "1.4.1" \ No newline at end of file +__version__ = "1.4.1" diff --git a/configure_mlx.sh b/configure_mlx.sh index f1cfe6e..bb6f968 100644 --- a/configure_mlx.sh +++ b/configure_mlx.sh @@ -4,12 +4,12 @@ TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024)) # Calculate 80% and TOTAL_MEM_GB-5GB in MB -EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100)) -MINUS_5GB=$((($TOTAL_MEM_MB - 5120))) +EIGHTY_PERCENT=$((TOTAL_MEM_MB * 80 / 100)) +MINUS_5GB=$((TOTAL_MEM_MB - 5120)) # Calculate 70% and TOTAL_MEM_GB-8GB in MB -SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100)) -MINUS_8GB=$((($TOTAL_MEM_MB - 8192))) +SEVENTY_PERCENT=$((TOTAL_MEM_MB * 70 / 100)) +MINUS_8GB=$((TOTAL_MEM_MB - 8192)) # Set WIRED_LIMIT_MB to higher value if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then @@ -40,4 +40,4 @@ else sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \ sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB -fi \ No newline at end of file +fi