diff --git a/src/mcp_server_uyuni/server.py b/src/mcp_server_uyuni/server.py index ec3d2ce..24f18f4 100644 --- a/src/mcp_server_uyuni/server.py +++ b/src/mcp_server_uyuni/server.py @@ -20,6 +20,7 @@ from datetime import datetime, timezone from pydantic import BaseModel +from fastmcp.server.middleware import Middleware, MiddlewareContext from fastmcp import FastMCP, Context from mcp import LoggingLevel, ServerSession, types from mcp_server_uyuni.logging_config import get_logger, Transport @@ -30,9 +31,7 @@ class ActivationKeySchema(BaseModel): activation_key: str REQUIRED_VARS = [ - "UYUNI_SERVER", - "UYUNI_USER", - "UYUNI_PASS", + "UYUNI_SERVER" ] missing_vars = [key for key in REQUIRED_VARS if key not in os.environ] @@ -64,6 +63,30 @@ class ActivationKeySchema(BaseModel): # Sentinel object to indicate an expected timeout for long-running actions TIMEOUT_HAPPENED = object() +class AuthTokenMiddleware(Middleware): + async def on_call_tool(self, ctx: MiddlewareContext, call_next): + """ + Extracts the JWT token from the Authorization header (if present) + and injects it into the context state for other tools to use. + """ + fastmcp_ctx = ctx.fastmcp_context + auth_header = fastmcp_ctx.request_context.request.headers['authorization'] + token = None + if auth_header: + # Expecting "Authorization: Bearer " + parts = auth_header.split() + if len(parts) == 2 and parts[0] == "Bearer": + token = parts[1] + logger.debug("Successfully extracted token from header.") + else: + logger.warning(f"Malformed Authorization header received: {auth_header}") + else: + logger.debug("No Authorization header found in the request.") + + fastmcp_ctx.set_state('token', token) + result = await call_next(ctx) + return result + def write_tool(*decorator_args, **decorator_kwargs): """ A decorator that registers a function as an MCP tool only if write @@ -74,11 +97,11 @@ def decorator(func): if UYUNI_MCP_WRITE_TOOLS_ENABLED: # 3a. If enabled, it applies the @mcp.tool() decorator, registering the function. return mcp.tool(*decorator_args, **decorator_kwargs)(func) - + # 3b. If disabled, it does nothing and just returns the original, # un-decorated function. It is never registered. return func - + # 1. The factory returns the decorator. return decorator @@ -87,6 +110,7 @@ async def _call_uyuni_api( method: str, api_path: str, error_context: str, + token: Optional[str] = None, params: Dict[str, Any] = None, json_body: Dict[str, Any] = None, perform_login: bool = True, @@ -108,9 +132,15 @@ async def _call_uyuni_api( return error_msg if perform_login: - login_data = {"login": UYUNI_USER, "password": UYUNI_PASS} try: - login_response = await client.post(UYUNI_SERVER + '/rhn/manager/api/login', json=login_data) + if token: + login_response = await client.get( + UYUNI_SERVER + '/rhn/manager/api/auth/oidcLogin', + headers={"Authorization": f"Bearer {token}"}) + elif UYUNI_USER and UYUNI_PASS: + login_response = await client.post( + UYUNI_SERVER + '/rhn/manager/api/auth/login', + json={"login": UYUNI_USER, "password": UYUNI_PASS}) login_response.raise_for_status() except httpx.HTTPStatusError as e: logger.error(f"HTTP error during login for {error_context}: {e.request.url} - {e.response.status_code} - {e.response.text}") @@ -188,9 +218,9 @@ async def get_list_of_active_systems(ctx: Context) -> List[Dict[str, Any]]: logger.info(log_string) await ctx.info(log_string) - return await _get_list_of_active_systems() + return await _get_list_of_active_systems(ctx.get_state('token')) -async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]: +async def _get_list_of_active_systems(token: str) -> List[Dict[str, Union[str, int]]]: async with httpx.AsyncClient(verify=UYUNI_MCP_SSL_VERIFY) as client: systems_data_result = await _call_uyuni_api( @@ -198,6 +228,7 @@ async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]: method="GET", api_path="/rhn/manager/api/system/listSystems", error_context="fetching active systems", + token=token, default_on_error=[] ) @@ -213,7 +244,7 @@ async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]: return filtered_systems -async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str]: +async def _resolve_system_id(system_identifier: Union[str, int], token: str) -> Optional[str]: """ Resolves a system identifier, which can be a name or an ID, to a numeric system ID string. @@ -241,6 +272,7 @@ async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str api_path="/rhn/manager/api/system/getId", params={'name': system_name}, error_context=f"resolving system ID for name '{system_name}'", + token=token, default_on_error=[] # Return an empty list on failure ) @@ -283,10 +315,10 @@ async def get_cpu_of_a_system(system_identifier: Union[str, int], ctx: Context) log_string = f"Getting CPU information of system with id {system_identifier}" logger.info(log_string) await ctx.info(log_string) - return await _get_cpu_of_a_system(system_identifier) + return await _get_cpu_of_a_system(system_identifier, ctx.get_state('token')) -async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str, Any]: - system_id = await _resolve_system_id(system_identifier) +async def _get_cpu_of_a_system(system_identifier: Union[str, int], token: str) -> Dict[str, Any]: + system_id = await _resolve_system_id(system_identifier, token) if not system_id: return {} # Helper function already logged the reason for failure. @@ -297,6 +329,7 @@ async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str, api_path="/rhn/manager/api/system/getCpu", params={'sid': system_id}, error_context=f"fetching CPU data for system {system_identifier}", + token=token, default_on_error={} ) @@ -332,7 +365,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]: await ctx.info(log_string) all_systems_cpu_data = [] - active_systems = await _get_list_of_active_systems() # Calls your existing tool + active_systems = await _get_list_of_active_systems(ctx.get_state('token')) if not active_systems: print("Warning: No active systems found or failed to retrieve system list.") @@ -347,7 +380,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]: continue print(f"Fetching CPU info for system: {system_name} (ID: {system_id})") - cpu_info = await _get_cpu_of_a_system(str(system_id)) # Calls your other existing tool + cpu_info = await _get_cpu_of_a_system(str(system_id), ctx.get_state('token')) all_systems_cpu_data.append({ 'system_name': system_name, @@ -429,7 +462,8 @@ async def check_system_updates(system_identifier: Union[str, int], ctx: Context) return await _check_system_updates(system_identifier, ctx) async def _check_system_updates(system_identifier: Union[str, int], ctx: Context) -> Dict[str, Any]: - system_id = await _resolve_system_id(system_identifier) + token = ctx.get_state('token') + system_id = await _resolve_system_id(system_identifier, token) default_error_response = { 'system_identifier': system_identifier, 'has_pending_updates': False, @@ -449,6 +483,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context api_path="/rhn/manager/api/system/getRelevantErrata", params={'sid': system_id}, error_context=f"checking updates for system {system_identifier}", + token=token, default_on_error=None # Distinguish API error from empty list ) @@ -458,6 +493,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context api_path="/rhn/manager/api/system/getUnscheduledErrata", params={'sid': str(system_id)}, error_context=f"checking unscheduled errata for system ID {system_id}", + token=token, default_on_error=[] # Return empty list on failure ) @@ -549,7 +585,7 @@ async def check_all_systems_for_updates(ctx: Context) -> List[Dict[str, Any]]: await ctx.info(log_string) systems_with_updates = [] - active_systems = await _get_list_of_active_systems() # Get the list of all systems + active_systems = await get_list_of_active_systems(ctx) # Get the list of all systems if not active_systems: print("Warning: No active systems found or failed to retrieve system list.") @@ -614,6 +650,8 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str, if not is_confirmed: return f"CONFIRMATION REQUIRED: This will apply pending updates to the system {system_identifier}. Do you confirm?" + token = ctx.get_state('token') + # 1. Use check_system_updates to get relevant errata update_info = await _check_system_updates(system_identifier, ctx) @@ -634,7 +672,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str, print(f"Could not extract any valid errata IDs for system {system_identifier} from the update information: {errata_list}") return "" - system_id = await _resolve_system_id(system_identifier) + system_id = await _resolve_system_id(system_identifier, token) if not system_id: return "" # Helper function already logged the reason for failure. @@ -649,6 +687,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str, api_path="/rhn/manager/api/system/scheduleApplyErrata", json_body=payload, error_context=f"scheduling errata application for system {system_identifier}", + token=token, default_on_error=None # Helper will return None on error ) @@ -693,7 +732,8 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err return f"Invalid errata ID '{errata_id}'. The ID must be an integer." - system_id = await _resolve_system_id(system_identifier) + token = ctx.get_state('token') + system_id = await _resolve_system_id(system_identifier, token) if not system_id: return "" # Helper function already logged the reason for failure. @@ -711,6 +751,7 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err api_path="/rhn/manager/api/system/scheduleApplyErrata", json_body=payload, error_context=f"scheduling specific update (errata ID: {errata_id_int}) for system {system_identifier}", + token=token, default_on_error=None # Helper returns None on error ) @@ -790,8 +831,10 @@ async def add_system( elif not activation_key: # Fallback if elicitation is not supported return "You need to provide an activation key." + token = ctx.get_state('token') + # Check if the system already exists - active_systems = await _get_list_of_active_systems() + active_systems = await _get_list_of_active_systems(token) for system in active_systems: if system.get('system_name') == host: message = f"System '{host}' already exists in Uyuni. No action taken." @@ -834,6 +877,7 @@ async def add_system( api_path="/rhn/manager/api/system/bootstrapWithPrivateSshKey", json_body=payload, error_context=f"adding system {host}", + token=token, default_on_error=None, expect_timeout=True, ) @@ -881,12 +925,13 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu is_confirmed = _to_bool(confirm) - system_id = await _resolve_system_id(system_identifier) + token = ctx.get_state('token') + system_id = await _resolve_system_id(system_identifier, token) if not system_id: return "" # Helper function already logged the reason for failure. # Check if the system exists before proceeding - active_systems = await _get_list_of_active_systems() + active_systems = await _get_list_of_active_systems(token) if not any(s.get('system_id') == int(system_id) for s in active_systems): message = f"System with ID {system_id} not found." logger.warning(message) @@ -905,6 +950,7 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu api_path="/rhn/manager/api/system/deleteSystem", json_body={"sid": system_id, "cleanupType": cleanup_type}, error_context=f"removing system ID {system_id}", + token=token, default_on_error=None ) @@ -949,6 +995,8 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx: find_by_cve_path = '/rhn/manager/api/errata/findByCve' list_affected_systems_path = '/rhn/manager/api/errata/listAffectedSystems' + token = ctx.get_state('token') + async with httpx.AsyncClient(verify=UYUNI_MCP_SSL_VERIFY) as client: # 1. Call findByCve (login will be handled by the helper) print(f"Searching for errata related to CVE: {cve_identifier}") @@ -958,6 +1006,7 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx: api_path=find_by_cve_path, params={'cveName': cve_identifier}, error_context=f"finding errata for CVE {cve_identifier}", + token=token, default_on_error=None # Distinguish API error from empty list ) @@ -1046,6 +1095,7 @@ async def get_systems_needing_reboot(ctx: Context) -> List[Dict[str, Any]]: # No method="GET", api_path=list_reboot_path, error_context="fetching systems needing reboot", + token=ctx.get_state('token'), default_on_error=[] # Return empty list on error ) @@ -1093,7 +1143,8 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context is_confirmed = _to_bool(confirm) - system_id = await _resolve_system_id(system_identifier) + token = ctx.get_state('token') + system_id = await _resolve_system_id(system_identifier, token) if not system_id: return "" # Helper function already logged the reason for failure. @@ -1113,6 +1164,7 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context api_path=schedule_reboot_path, json_body=payload, error_context=f"scheduling reboot for system {system_identifier}", + token=token, default_on_error=None # Helper returns None on error ) @@ -1159,6 +1211,7 @@ async def list_all_scheduled_actions(ctx: Context) -> List[Dict[str, Any]]: method="GET", api_path=list_actions_path, error_context="listing all scheduled actions", + token=ctx.get_state('token'), default_on_error=[] # Return empty list on error ) @@ -1221,6 +1274,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str] api_path=cancel_actions_path, json_body=payload, error_context=f"canceling action {action_id}", + token=ctx.get_state('token'), default_on_error=0 # API returns 1 on success, so 0 can signify an error or unexpected response from helper ) if api_result == 1: @@ -1230,7 +1284,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str] return f"Failed to cancel action: {action_id}. The API did not return success (expected 1, got {api_result}). Check server logs for details." @mcp.tool() -async def list_activation_keys() -> List[Dict[str, str]]: +async def list_activation_keys(ctx: Context) -> List[Dict[str, str]]: """ Fetches a list of activation keys from the Uyuni server. @@ -1252,6 +1306,7 @@ async def list_activation_keys() -> List[Dict[str, str]]: method="GET", api_path=list_keys_path, error_context="listing activation keys", + token=ctx.get_state('token'), default_on_error=[] ) @@ -1293,6 +1348,7 @@ async def get_unscheduled_errata(system_id: int, ctx: Context) -> List[Dict[str, api_path=get_unscheduled_errata, params=payload, error_context=f"fetching unscheduled errata for system ID {system_id}", + token=ctx.get_state('token'), default_on_error=None ) @@ -1312,6 +1368,7 @@ def main_cli(): logger.info("Running Uyuni MCP server.") if UYUNI_MCP_TRANSPORT == Transport.HTTP.value: + mcp.add_middleware(AuthTokenMiddleware()) mcp.run(transport="streamable-http", host=UYUNI_MCP_HOST, port=UYUNI_MCP_PORT) elif UYUNI_MCP_TRANSPORT == Transport.STDIO.value: mcp.run(transport="stdio")