Skip to content

Commit f5e2f09

Browse files
committed
Forward auth token to MLM calls
1 parent c22a433 commit f5e2f09

File tree

1 file changed

+80
-21
lines changed

1 file changed

+80
-21
lines changed

src/mcp_server_uyuni/server.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from datetime import datetime, timezone
2121
from pydantic import BaseModel
2222

23+
from fastmcp.server.middleware import Middleware, MiddlewareContext
2324
from fastmcp import FastMCP, Context
2425
from mcp import LoggingLevel, ServerSession, types
2526
from mcp_server_uyuni.logging_config import get_logger, Transport
@@ -64,6 +65,30 @@ class ActivationKeySchema(BaseModel):
6465
# Sentinel object to indicate an expected timeout for long-running actions
6566
TIMEOUT_HAPPENED = object()
6667

68+
class AuthTokenMiddleware(Middleware):
69+
async def on_call_tool(self, ctx: MiddlewareContext, call_next):
70+
"""
71+
Extracts the JWT token from the Authorization header (if present)
72+
and injects it into the context state for other tools to use.
73+
"""
74+
fastmcp_ctx = ctx.fastmcp_context
75+
auth_header = fastmcp_ctx.request_context.request.headers['authorization']
76+
token = None
77+
if auth_header:
78+
# Expecting "Authorization: Bearer <token>"
79+
parts = auth_header.split()
80+
if len(parts) == 2 and parts[0] == "Bearer":
81+
token = parts[1]
82+
logger.debug("Successfully extracted token from header.")
83+
else:
84+
logger.warning(f"Malformed Authorization header received: {auth_header}")
85+
else:
86+
logger.debug("No Authorization header found in the request.")
87+
88+
fastmcp_ctx.set_state('token', token)
89+
result = await call_next(ctx)
90+
return result
91+
6792
def write_tool(*decorator_args, **decorator_kwargs):
6893
"""
6994
A decorator that registers a function as an MCP tool only if write
@@ -74,11 +99,11 @@ def decorator(func):
7499
if UYUNI_MCP_WRITE_TOOLS_ENABLED:
75100
# 3a. If enabled, it applies the @mcp.tool() decorator, registering the function.
76101
return mcp.tool(*decorator_args, **decorator_kwargs)(func)
77-
102+
78103
# 3b. If disabled, it does nothing and just returns the original,
79104
# un-decorated function. It is never registered.
80105
return func
81-
106+
82107
# 1. The factory returns the decorator.
83108
return decorator
84109

@@ -87,6 +112,7 @@ async def _call_uyuni_api(
87112
method: str,
88113
api_path: str,
89114
error_context: str,
115+
token: Optional[str] = None,
90116
params: Dict[str, Any] = None,
91117
json_body: Dict[str, Any] = None,
92118
perform_login: bool = True,
@@ -108,9 +134,15 @@ async def _call_uyuni_api(
108134
return error_msg
109135

110136
if perform_login:
111-
login_data = {"login": UYUNI_USER, "password": UYUNI_PASS}
112137
try:
113-
login_response = await client.post(UYUNI_SERVER + '/rhn/manager/api/login', json=login_data)
138+
if token:
139+
login_response = await client.post(
140+
UYUNI_SERVER + '/rhn/manager/api/oidcLogin',
141+
headers={"Authorization": f"Bearer {token}"})
142+
elif UYUNI_USER and UYUNI_PASS:
143+
login_response = await client.post(
144+
UYUNI_SERVER + '/rhn/manager/api/auth/login',
145+
json={"login": UYUNI_USER, "password": UYUNI_PASS})
114146
login_response.raise_for_status()
115147
except httpx.HTTPStatusError as e:
116148
logger.error(f"HTTP error during login for {error_context}: {e.request.url} - {e.response.status_code} - {e.response.text}")
@@ -188,16 +220,17 @@ async def get_list_of_active_systems(ctx: Context) -> List[Dict[str, Any]]:
188220
logger.info(log_string)
189221
await ctx.info(log_string)
190222

191-
return await _get_list_of_active_systems()
223+
return await _get_list_of_active_systems(ctx.get_state('token'))
192224

193-
async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:
225+
async def _get_list_of_active_systems(token: str) -> List[Dict[str, Union[str, int]]]:
194226

195227
async with httpx.AsyncClient(verify=UYUNI_MCP_SSL_VERIFY) as client:
196228
systems_data_result = await _call_uyuni_api(
197229
client=client,
198230
method="GET",
199231
api_path="/rhn/manager/api/system/listSystems",
200232
error_context="fetching active systems",
233+
token=token,
201234
default_on_error=[]
202235
)
203236

@@ -213,7 +246,7 @@ async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:
213246

214247
return filtered_systems
215248

216-
async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str]:
249+
async def _resolve_system_id(system_identifier: Union[str, int], token: str) -> Optional[str]:
217250
"""
218251
Resolves a system identifier, which can be a name or an ID, to a numeric system ID string.
219252
@@ -241,6 +274,7 @@ async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str
241274
api_path="/rhn/manager/api/system/getId",
242275
params={'name': system_name},
243276
error_context=f"resolving system ID for name '{system_name}'",
277+
token=token,
244278
default_on_error=[] # Return an empty list on failure
245279
)
246280

@@ -283,10 +317,10 @@ async def get_cpu_of_a_system(system_identifier: Union[str, int], ctx: Context)
283317
log_string = f"Getting CPU information of system with id {system_identifier}"
284318
logger.info(log_string)
285319
await ctx.info(log_string)
286-
return await _get_cpu_of_a_system(system_identifier)
320+
return await _get_cpu_of_a_system(system_identifier, ctx.get_state('token'))
287321

288-
async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str, Any]:
289-
system_id = await _resolve_system_id(system_identifier)
322+
async def _get_cpu_of_a_system(system_identifier: Union[str, int], token: str) -> Dict[str, Any]:
323+
system_id = await _resolve_system_id(system_identifier, token)
290324
if not system_id:
291325
return {} # Helper function already logged the reason for failure.
292326

@@ -297,6 +331,7 @@ async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str,
297331
api_path="/rhn/manager/api/system/getCpu",
298332
params={'sid': system_id},
299333
error_context=f"fetching CPU data for system {system_identifier}",
334+
token=token,
300335
default_on_error={}
301336
)
302337

@@ -332,7 +367,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
332367
await ctx.info(log_string)
333368

334369
all_systems_cpu_data = []
335-
active_systems = await _get_list_of_active_systems() # Calls your existing tool
370+
active_systems = await _get_list_of_active_systems(ctx.get_state('token'))
336371

337372
if not active_systems:
338373
print("Warning: No active systems found or failed to retrieve system list.")
@@ -347,7 +382,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
347382
continue
348383

349384
print(f"Fetching CPU info for system: {system_name} (ID: {system_id})")
350-
cpu_info = await _get_cpu_of_a_system(str(system_id)) # Calls your other existing tool
385+
cpu_info = await _get_cpu_of_a_system(str(system_id), ctx.get_state('token'))
351386

352387
all_systems_cpu_data.append({
353388
'system_name': system_name,
@@ -429,7 +464,8 @@ async def check_system_updates(system_identifier: Union[str, int], ctx: Context)
429464
return await _check_system_updates(system_identifier, ctx)
430465

431466
async def _check_system_updates(system_identifier: Union[str, int], ctx: Context) -> Dict[str, Any]:
432-
system_id = await _resolve_system_id(system_identifier)
467+
token = ctx.get_state('token')
468+
system_id = await _resolve_system_id(system_identifier, token)
433469
default_error_response = {
434470
'system_identifier': system_identifier,
435471
'has_pending_updates': False,
@@ -449,6 +485,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context
449485
api_path="/rhn/manager/api/system/getRelevantErrata",
450486
params={'sid': system_id},
451487
error_context=f"checking updates for system {system_identifier}",
488+
token=token,
452489
default_on_error=None # Distinguish API error from empty list
453490
)
454491

@@ -458,6 +495,7 @@ async def _check_system_updates(system_identifier: Union[str, int], ctx: Context
458495
api_path="/rhn/manager/api/system/getUnscheduledErrata",
459496
params={'sid': str(system_id)},
460497
error_context=f"checking unscheduled errata for system ID {system_id}",
498+
token=token,
461499
default_on_error=[] # Return empty list on failure
462500
)
463501

@@ -549,7 +587,7 @@ async def check_all_systems_for_updates(ctx: Context) -> List[Dict[str, Any]]:
549587
await ctx.info(log_string)
550588

551589
systems_with_updates = []
552-
active_systems = await _get_list_of_active_systems() # Get the list of all systems
590+
active_systems = await get_list_of_active_systems(ctx) # Get the list of all systems
553591

554592
if not active_systems:
555593
print("Warning: No active systems found or failed to retrieve system list.")
@@ -614,6 +652,8 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
614652
if not is_confirmed:
615653
return f"CONFIRMATION REQUIRED: This will apply pending updates to the system {system_identifier}. Do you confirm?"
616654

655+
token = ctx.get_state('token')
656+
617657
# 1. Use check_system_updates to get relevant errata
618658
update_info = await _check_system_updates(system_identifier, ctx)
619659

@@ -634,7 +674,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
634674
print(f"Could not extract any valid errata IDs for system {system_identifier} from the update information: {errata_list}")
635675
return ""
636676

637-
system_id = await _resolve_system_id(system_identifier)
677+
system_id = await _resolve_system_id(system_identifier, token)
638678
if not system_id:
639679
return "" # Helper function already logged the reason for failure.
640680

@@ -649,6 +689,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
649689
api_path="/rhn/manager/api/system/scheduleApplyErrata",
650690
json_body=payload,
651691
error_context=f"scheduling errata application for system {system_identifier}",
692+
token=token,
652693
default_on_error=None # Helper will return None on error
653694
)
654695

@@ -693,7 +734,8 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
693734
return f"Invalid errata ID '{errata_id}'. The ID must be an integer."
694735

695736

696-
system_id = await _resolve_system_id(system_identifier)
737+
token = ctx.get_state('token')
738+
system_id = await _resolve_system_id(system_identifier, token)
697739
if not system_id:
698740
return "" # Helper function already logged the reason for failure.
699741

@@ -711,6 +753,7 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
711753
api_path="/rhn/manager/api/system/scheduleApplyErrata",
712754
json_body=payload,
713755
error_context=f"scheduling specific update (errata ID: {errata_id_int}) for system {system_identifier}",
756+
token=token,
714757
default_on_error=None # Helper returns None on error
715758
)
716759

@@ -790,8 +833,10 @@ async def add_system(
790833
elif not activation_key: # Fallback if elicitation is not supported
791834
return "You need to provide an activation key."
792835

836+
token = ctx.get_state('token')
837+
793838
# Check if the system already exists
794-
active_systems = await _get_list_of_active_systems()
839+
active_systems = await _get_list_of_active_systems(token)
795840
for system in active_systems:
796841
if system.get('system_name') == host:
797842
message = f"System '{host}' already exists in Uyuni. No action taken."
@@ -834,6 +879,7 @@ async def add_system(
834879
api_path="/rhn/manager/api/system/bootstrapWithPrivateSshKey",
835880
json_body=payload,
836881
error_context=f"adding system {host}",
882+
token=token,
837883
default_on_error=None,
838884
expect_timeout=True,
839885
)
@@ -881,12 +927,13 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu
881927

882928
is_confirmed = _to_bool(confirm)
883929

884-
system_id = await _resolve_system_id(system_identifier)
930+
token = ctx.get_state('token')
931+
system_id = await _resolve_system_id(system_identifier, token)
885932
if not system_id:
886933
return "" # Helper function already logged the reason for failure.
887934

888935
# Check if the system exists before proceeding
889-
active_systems = await _get_list_of_active_systems()
936+
active_systems = await _get_list_of_active_systems(token)
890937
if not any(s.get('system_id') == int(system_id) for s in active_systems):
891938
message = f"System with ID {system_id} not found."
892939
logger.warning(message)
@@ -905,6 +952,7 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu
905952
api_path="/rhn/manager/api/system/deleteSystem",
906953
json_body={"sid": system_id, "cleanupType": cleanup_type},
907954
error_context=f"removing system ID {system_id}",
955+
token=token,
908956
default_on_error=None
909957
)
910958

@@ -949,6 +997,8 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
949997
find_by_cve_path = '/rhn/manager/api/errata/findByCve'
950998
list_affected_systems_path = '/rhn/manager/api/errata/listAffectedSystems'
951999

1000+
token = ctx.get_state('token')
1001+
9521002
async with httpx.AsyncClient(verify=UYUNI_MCP_SSL_VERIFY) as client:
9531003
# 1. Call findByCve (login will be handled by the helper)
9541004
print(f"Searching for errata related to CVE: {cve_identifier}")
@@ -958,6 +1008,7 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
9581008
api_path=find_by_cve_path,
9591009
params={'cveName': cve_identifier},
9601010
error_context=f"finding errata for CVE {cve_identifier}",
1011+
token=token,
9611012
default_on_error=None # Distinguish API error from empty list
9621013
)
9631014

@@ -1046,6 +1097,7 @@ async def get_systems_needing_reboot(ctx: Context) -> List[Dict[str, Any]]: # No
10461097
method="GET",
10471098
api_path=list_reboot_path,
10481099
error_context="fetching systems needing reboot",
1100+
token=ctx.get_state('token'),
10491101
default_on_error=[] # Return empty list on error
10501102
)
10511103

@@ -1093,7 +1145,8 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context
10931145

10941146
is_confirmed = _to_bool(confirm)
10951147

1096-
system_id = await _resolve_system_id(system_identifier)
1148+
token = ctx.get_state('token')
1149+
system_id = await _resolve_system_id(system_identifier, token)
10971150
if not system_id:
10981151
return "" # Helper function already logged the reason for failure.
10991152

@@ -1113,6 +1166,7 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context
11131166
api_path=schedule_reboot_path,
11141167
json_body=payload,
11151168
error_context=f"scheduling reboot for system {system_identifier}",
1169+
token=token,
11161170
default_on_error=None # Helper returns None on error
11171171
)
11181172

@@ -1159,6 +1213,7 @@ async def list_all_scheduled_actions(ctx: Context) -> List[Dict[str, Any]]:
11591213
method="GET",
11601214
api_path=list_actions_path,
11611215
error_context="listing all scheduled actions",
1216+
token=ctx.get_state('token'),
11621217
default_on_error=[] # Return empty list on error
11631218
)
11641219

@@ -1221,6 +1276,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str]
12211276
api_path=cancel_actions_path,
12221277
json_body=payload,
12231278
error_context=f"canceling action {action_id}",
1279+
token=ctx.get_state('token'),
12241280
default_on_error=0 # API returns 1 on success, so 0 can signify an error or unexpected response from helper
12251281
)
12261282
if api_result == 1:
@@ -1230,7 +1286,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: Union[bool, str]
12301286
return f"Failed to cancel action: {action_id}. The API did not return success (expected 1, got {api_result}). Check server logs for details."
12311287

12321288
@mcp.tool()
1233-
async def list_activation_keys() -> List[Dict[str, str]]:
1289+
async def list_activation_keys(ctx: Context) -> List[Dict[str, str]]:
12341290
"""
12351291
Fetches a list of activation keys from the Uyuni server.
12361292
@@ -1252,6 +1308,7 @@ async def list_activation_keys() -> List[Dict[str, str]]:
12521308
method="GET",
12531309
api_path=list_keys_path,
12541310
error_context="listing activation keys",
1311+
token=ctx.get_state('token'),
12551312
default_on_error=[]
12561313
)
12571314

@@ -1293,6 +1350,7 @@ async def get_unscheduled_errata(system_id: int, ctx: Context) -> List[Dict[str,
12931350
api_path=get_unscheduled_errata,
12941351
params=payload,
12951352
error_context=f"fetching unscheduled errata for system ID {system_id}",
1353+
token=ctx.get_state('token'),
12961354
default_on_error=None
12971355
)
12981356

@@ -1312,6 +1370,7 @@ def main_cli():
13121370
logger.info("Running Uyuni MCP server.")
13131371

13141372
if UYUNI_MCP_TRANSPORT == Transport.HTTP.value:
1373+
mcp.add_middleware(AuthTokenMiddleware())
13151374
mcp.run(transport="streamable-http", host=UYUNI_MCP_HOST, port=UYUNI_MCP_PORT)
13161375
elif UYUNI_MCP_TRANSPORT == Transport.STDIO.value:
13171376
mcp.run(transport="stdio")

0 commit comments

Comments
 (0)