Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 81 additions & 24 deletions src/mcp_server_uyuni/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from datetime import datetime, timezone
from pydantic import BaseModel

from fastmcp.server.middleware import Middleware, MiddlewareContext
from fastmcp import FastMCP, Context
from mcp import LoggingLevel, ServerSession, types
from mcp_server_uyuni.logging_config import get_logger, Transport
Expand All @@ -30,9 +31,7 @@ class ActivationKeySchema(BaseModel):
activation_key: str

REQUIRED_VARS = [
"UYUNI_SERVER",
"UYUNI_USER",
"UYUNI_PASS",
"UYUNI_SERVER"
]

missing_vars = [key for key in REQUIRED_VARS if key not in os.environ]
Expand Down Expand Up @@ -64,6 +63,30 @@ class ActivationKeySchema(BaseModel):
# Sentinel object to indicate an expected timeout for long-running actions
TIMEOUT_HAPPENED = object()

class AuthTokenMiddleware(Middleware):
async def on_call_tool(self, ctx: MiddlewareContext, call_next):
"""
Extracts the JWT token from the Authorization header (if present)
and injects it into the context state for other tools to use.
"""
fastmcp_ctx = ctx.fastmcp_context
auth_header = fastmcp_ctx.request_context.request.headers['authorization']
token = None
if auth_header:
# Expecting "Authorization: Bearer <token>"
parts = auth_header.split()
if len(parts) == 2 and parts[0] == "Bearer":
token = parts[1]
logger.debug("Successfully extracted token from header.")
else:
logger.warning(f"Malformed Authorization header received: {auth_header}")
else:
logger.debug("No Authorization header found in the request.")

fastmcp_ctx.set_state('token', token)
result = await call_next(ctx)
return result

def write_tool(*decorator_args, **decorator_kwargs):
"""
A decorator that registers a function as an MCP tool only if write
Expand All @@ -74,11 +97,11 @@ def decorator(func):
if UYUNI_MCP_WRITE_TOOLS_ENABLED:
# 3a. If enabled, it applies the @mcp.tool() decorator, registering the function.
return mcp.tool(*decorator_args, **decorator_kwargs)(func)

# 3b. If disabled, it does nothing and just returns the original,
# un-decorated function. It is never registered.
return func

# 1. The factory returns the decorator.
return decorator

Expand All @@ -87,6 +110,7 @@ async def _call_uyuni_api(
method: str,
api_path: str,
error_context: str,
token: Optional[str] = None,
params: Dict[str, Any] = None,
json_body: Dict[str, Any] = None,
perform_login: bool = True,
Expand All @@ -108,9 +132,15 @@ async def _call_uyuni_api(
return error_msg

if perform_login:
login_data = {"login": UYUNI_USER, "password": UYUNI_PASS}
try:
login_response = await client.post(UYUNI_SERVER + '/rhn/manager/api/login', json=login_data)
if token:
login_response = await client.get(
UYUNI_SERVER + '/rhn/manager/api/auth/oidcLogin',
headers={"Authorization": f"Bearer {token}"})
elif UYUNI_USER and UYUNI_PASS:
login_response = await client.post(
UYUNI_SERVER + '/rhn/manager/api/auth/login',
json={"login": UYUNI_USER, "password": UYUNI_PASS})
login_response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error during login for {error_context}: {e.request.url} - {e.response.status_code} - {e.response.text}")
Expand Down Expand Up @@ -188,16 +218,17 @@ async def get_list_of_active_systems(ctx: Context) -> List[Dict[str, Any]]:
logger.info(log_string)
await ctx.info(log_string)

return await _get_list_of_active_systems()
return await _get_list_of_active_systems(ctx.get_state('token'))

async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:
async def _get_list_of_active_systems(token: str) -> List[Dict[str, Union[str, int]]]:

async with httpx.AsyncClient(verify=UYUNI_MCP_SSL_VERIFY) as client:
systems_data_result = await _call_uyuni_api(
client=client,
method="GET",
api_path="/rhn/manager/api/system/listSystems",
error_context="fetching active systems",
token=token,
default_on_error=[]
)

Expand All @@ -213,7 +244,7 @@ async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:

return filtered_systems

async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str]:
async def _resolve_system_id(system_identifier: Union[str, int], token: str) -> Optional[str]:
"""
Resolves a system identifier, which can be a name or an ID, to a numeric system ID string.

Expand Down Expand Up @@ -241,6 +272,7 @@ async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str
api_path="/rhn/manager/api/system/getId",
params={'name': system_name},
error_context=f"resolving system ID for name '{system_name}'",
token=token,
default_on_error=[] # Return an empty list on failure
)

Expand Down Expand Up @@ -283,10 +315,10 @@ async def get_cpu_of_a_system(system_identifier: Union[str, int], ctx: Context)
log_string = f"Getting CPU information of system with id {system_identifier}"
logger.info(log_string)
await ctx.info(log_string)
return await _get_cpu_of_a_system(system_identifier)
return await _get_cpu_of_a_system(system_identifier, ctx.get_state('token'))

async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str, Any]:
system_id = await _resolve_system_id(system_identifier)
async def _get_cpu_of_a_system(system_identifier: Union[str, int], token: str) -> Dict[str, Any]:
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return {} # Helper function already logged the reason for failure.

Expand All @@ -297,6 +329,7 @@ async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str,
api_path="/rhn/manager/api/system/getCpu",
params={'sid': system_id},
error_context=f"fetching CPU data for system {system_identifier}",
token=token,
default_on_error={}
)

Expand Down Expand Up @@ -332,7 +365,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
await ctx.info(log_string)

all_systems_cpu_data = []
active_systems = await _get_list_of_active_systems() # Calls your existing tool
active_systems = await _get_list_of_active_systems(ctx.get_state('token'))

if not active_systems:
print("Warning: No active systems found or failed to retrieve system list.")
Expand All @@ -347,7 +380,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
continue

print(f"Fetching CPU info for system: {system_name} (ID: {system_id})")
cpu_info = await _get_cpu_of_a_system(str(system_id)) # Calls your other existing tool
cpu_info = await _get_cpu_of_a_system(str(system_id), ctx.get_state('token'))

all_systems_cpu_data.append({
'system_name': system_name,
Expand Down Expand Up @@ -429,7 +462,8 @@ async def check_system_updates(system_identifier: Union[str, int], ctx: Context)
return await _check_system_updates(system_identifier, ctx)

async def _check_system_updates(system_identifier: Union[str, int], ctx: Context) -> Dict[str, Any]:
system_id = await _resolve_system_id(system_identifier)
token = ctx.get_state('token')
system_id = await _resolve_system_id(system_identifier, token)
default_error_response = {
'system_identifier': system_identifier,
'has_pending_updates': False,
Expand All @@ -449,6 +483,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context
api_path="/rhn/manager/api/system/getRelevantErrata",
params={'sid': system_id},
error_context=f"checking updates for system {system_identifier}",
token=token,
default_on_error=None # Distinguish API error from empty list
)

Expand All @@ -458,6 +493,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context
api_path="/rhn/manager/api/system/getUnscheduledErrata",
params={'sid': str(system_id)},
error_context=f"checking unscheduled errata for system ID {system_id}",
token=token,
default_on_error=[] # Return empty list on failure
)

Expand Down Expand Up @@ -549,7 +585,7 @@ async def check_all_systems_for_updates(ctx: Context) -> List[Dict[str, Any]]:
await ctx.info(log_string)

systems_with_updates = []
active_systems = await _get_list_of_active_systems() # Get the list of all systems
active_systems = await get_list_of_active_systems(ctx) # Get the list of all systems

if not active_systems:
print("Warning: No active systems found or failed to retrieve system list.")
Expand Down Expand Up @@ -614,6 +650,8 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
if not is_confirmed:
return f"CONFIRMATION REQUIRED: This will apply pending updates to the system {system_identifier}. Do you confirm?"

token = ctx.get_state('token')

# 1. Use check_system_updates to get relevant errata
update_info = await _check_system_updates(system_identifier, ctx)

Expand All @@ -634,7 +672,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
print(f"Could not extract any valid errata IDs for system {system_identifier} from the update information: {errata_list}")
return ""

system_id = await _resolve_system_id(system_identifier)
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return "" # Helper function already logged the reason for failure.

Expand All @@ -649,6 +687,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
api_path="/rhn/manager/api/system/scheduleApplyErrata",
json_body=payload,
error_context=f"scheduling errata application for system {system_identifier}",
token=token,
default_on_error=None # Helper will return None on error
)

Expand Down Expand Up @@ -693,7 +732,8 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
return f"Invalid errata ID '{errata_id}'. The ID must be an integer."


system_id = await _resolve_system_id(system_identifier)
token = ctx.get_state('token')
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return "" # Helper function already logged the reason for failure.

Expand All @@ -711,6 +751,7 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
api_path="/rhn/manager/api/system/scheduleApplyErrata",
json_body=payload,
error_context=f"scheduling specific update (errata ID: {errata_id_int}) for system {system_identifier}",
token=token,
default_on_error=None # Helper returns None on error
)

Expand Down Expand Up @@ -790,8 +831,10 @@ async def add_system(
elif not activation_key: # Fallback if elicitation is not supported
return "You need to provide an activation key."

token = ctx.get_state('token')

# Check if the system already exists
active_systems = await _get_list_of_active_systems()
active_systems = await _get_list_of_active_systems(token)
for system in active_systems:
if system.get('system_name') == host:
message = f"System '{host}' already exists in Uyuni. No action taken."
Expand Down Expand Up @@ -834,6 +877,7 @@ async def add_system(
api_path="/rhn/manager/api/system/bootstrapWithPrivateSshKey",
json_body=payload,
error_context=f"adding system {host}",
token=token,
default_on_error=None,
expect_timeout=True,
)
Expand Down Expand Up @@ -881,12 +925,13 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu

is_confirmed = _to_bool(confirm)

system_id = await _resolve_system_id(system_identifier)
token = ctx.get_state('token')
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return "" # Helper function already logged the reason for failure.

# Check if the system exists before proceeding
active_systems = await _get_list_of_active_systems()
active_systems = await _get_list_of_active_systems(token)
if not any(s.get('system_id') == int(system_id) for s in active_systems):
message = f"System with ID {system_id} not found."
logger.warning(message)
Expand All @@ -905,6 +950,7 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu
api_path="/rhn/manager/api/system/deleteSystem",
json_body={"sid": system_id, "cleanupType": cleanup_type},
error_context=f"removing system ID {system_id}",
token=token,
default_on_error=None
)

Expand Down Expand Up @@ -949,6 +995,8 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
find_by_cve_path = '/rhn/manager/api/errata/findByCve'
list_affected_systems_path = '/rhn/manager/api/errata/listAffectedSystems'

token = ctx.get_state('token')

async with httpx.AsyncClient(verify=UYUNI_MCP_SSL_VERIFY) as client:
# 1. Call findByCve (login will be handled by the helper)
print(f"Searching for errata related to CVE: {cve_identifier}")
Expand All @@ -958,6 +1006,7 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
api_path=find_by_cve_path,
params={'cveName': cve_identifier},
error_context=f"finding errata for CVE {cve_identifier}",
token=token,
default_on_error=None # Distinguish API error from empty list
)

Expand Down Expand Up @@ -1046,6 +1095,7 @@ async def get_systems_needing_reboot(ctx: Context) -> List[Dict[str, Any]]: # No
method="GET",
api_path=list_reboot_path,
error_context="fetching systems needing reboot",
token=ctx.get_state('token'),
default_on_error=[] # Return empty list on error
)

Expand Down Expand Up @@ -1093,7 +1143,8 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context

is_confirmed = _to_bool(confirm)

system_id = await _resolve_system_id(system_identifier)
token = ctx.get_state('token')
system_id = await _resolve_system_id(system_identifier, token)
if not system_id:
return "" # Helper function already logged the reason for failure.

Expand All @@ -1113,6 +1164,7 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context
api_path=schedule_reboot_path,
json_body=payload,
error_context=f"scheduling reboot for system {system_identifier}",
token=token,
default_on_error=None # Helper returns None on error
)

Expand Down Expand Up @@ -1159,6 +1211,7 @@ async def list_all_scheduled_actions(ctx: Context) -> List[Dict[str, Any]]:
method="GET",
api_path=list_actions_path,
error_context="listing all scheduled actions",
token=ctx.get_state('token'),
default_on_error=[] # Return empty list on error
)

Expand Down Expand Up @@ -1221,6 +1274,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str]
api_path=cancel_actions_path,
json_body=payload,
error_context=f"canceling action {action_id}",
token=ctx.get_state('token'),
default_on_error=0 # API returns 1 on success, so 0 can signify an error or unexpected response from helper
)
if api_result == 1:
Expand All @@ -1230,7 +1284,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str]
return f"Failed to cancel action: {action_id}. The API did not return success (expected 1, got {api_result}). Check server logs for details."

@mcp.tool()
async def list_activation_keys() -> List[Dict[str, str]]:
async def list_activation_keys(ctx: Context) -> List[Dict[str, str]]:
"""
Fetches a list of activation keys from the Uyuni server.

Expand All @@ -1252,6 +1306,7 @@ async def list_activation_keys() -> List[Dict[str, str]]:
method="GET",
api_path=list_keys_path,
error_context="listing activation keys",
token=ctx.get_state('token'),
default_on_error=[]
)

Expand Down Expand Up @@ -1293,6 +1348,7 @@ async def get_unscheduled_errata(system_id: int, ctx: Context) -> List[Dict[str,
api_path=get_unscheduled_errata,
params=payload,
error_context=f"fetching unscheduled errata for system ID {system_id}",
token=ctx.get_state('token'),
default_on_error=None
)

Expand All @@ -1312,6 +1368,7 @@ def main_cli():
logger.info("Running Uyuni MCP server.")

if UYUNI_MCP_TRANSPORT == Transport.HTTP.value:
mcp.add_middleware(AuthTokenMiddleware())
mcp.run(transport="streamable-http", host=UYUNI_MCP_HOST, port=UYUNI_MCP_PORT)
elif UYUNI_MCP_TRANSPORT == Transport.STDIO.value:
mcp.run(transport="stdio")
Expand Down