Skip to content

Commit 57f9883

Browse files
committed
Forward auth token to MLM calls
1 parent b893dee commit 57f9883

File tree

1 file changed

+80
-21
lines changed

1 file changed

+80
-21
lines changed

src/mcp_server_uyuni/server.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from datetime import datetime, timezone
2121
from pydantic import BaseModel
2222

23+
from fastmcp.server.middleware import Middleware, MiddlewareContext
2324
from fastmcp import FastMCP, Context
2425
from mcp import LoggingLevel, ServerSession, types
2526
from mcp_server_uyuni.logging_config import get_logger, Transport
@@ -64,6 +65,30 @@ class ActivationKeySchema(BaseModel):
6465
# Sentinel object to indicate an expected timeout for long-running actions
6566
TIMEOUT_HAPPENED = object()
6667

68+
class AuthTokenMiddleware(Middleware):
69+
async def on_call_tool(self, ctx: MiddlewareContext, call_next):
70+
"""
71+
Extracts the JWT token from the Authorization header (if present)
72+
and injects it into the context state for other tools to use.
73+
"""
74+
fastmcp_ctx = ctx.fastmcp_context
75+
auth_header = fastmcp_ctx.request_context.request.headers['authorization']
76+
token = None
77+
if auth_header:
78+
# Expecting "Authorization: Bearer <token>"
79+
parts = auth_header.split()
80+
if len(parts) == 2 and parts[0] == "Bearer":
81+
token = parts[1]
82+
logger.debug("Successfully extracted token from header.")
83+
else:
84+
logger.warning(f"Malformed Authorization header received: {auth_header}")
85+
else:
86+
logger.debug("No Authorization header found in the request.")
87+
88+
fastmcp_ctx.set_state('token', token)
89+
result = await call_next(ctx)
90+
return result
91+
6792
def write_tool(*decorator_args, **decorator_kwargs):
6893
"""
6994
A decorator that registers a function as an MCP tool only if write
@@ -74,11 +99,11 @@ def decorator(func):
7499
if UYUNI_MCP_WRITE_TOOLS_ENABLED:
75100
# 3a. If enabled, it applies the @mcp.tool() decorator, registering the function.
76101
return mcp.tool(*decorator_args, **decorator_kwargs)(func)
77-
102+
78103
# 3b. If disabled, it does nothing and just returns the original,
79104
# un-decorated function. It is never registered.
80105
return func
81-
106+
82107
# 1. The factory returns the decorator.
83108
return decorator
84109

@@ -87,6 +112,7 @@ async def _call_uyuni_api(
87112
method: str,
88113
api_path: str,
89114
error_context: str,
115+
token: Optional[str] = None,
90116
params: Dict[str, Any] = None,
91117
json_body: Dict[str, Any] = None,
92118
perform_login: bool = True,
@@ -108,9 +134,15 @@ async def _call_uyuni_api(
108134
return error_msg
109135

110136
if perform_login:
111-
login_data = {"login": UYUNI_USER, "password": UYUNI_PASS}
112137
try:
113-
login_response = await client.post(UYUNI_SERVER + '/rhn/manager/api/login', json=login_data)
138+
if token:
139+
login_response = await client.post(
140+
UYUNI_SERVER + '/rhn/manager/api/oidcLogin',
141+
headers={"Authorization": f"Bearer {token}"})
142+
elif UYUNI_USER and UYUNI_PASS:
143+
login_response = await client.post(
144+
UYUNI_SERVER + '/rhn/manager/api/auth/login',
145+
json={"login": UYUNI_USER, "password": UYUNI_PASS})
114146
login_response.raise_for_status()
115147
except httpx.HTTPStatusError as e:
116148
logger.error(f"HTTP error during login for {error_context}: {e.request.url} - {e.response.status_code} - {e.response.text}")
@@ -181,16 +213,17 @@ async def get_list_of_active_systems(ctx: Context) -> List[Dict[str, Any]]:
181213
logger.info(log_string)
182214
await ctx.info(log_string)
183215

184-
return await _get_list_of_active_systems()
216+
return await _get_list_of_active_systems(ctx.get_state('token'))
185217

186-
async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:
218+
async def _get_list_of_active_systems(token: str) -> List[Dict[str, Union[str, int]]]:
187219

188220
async with httpx.AsyncClient(verify=UYUNI_MCP_SSL_VERIFY) as client:
189221
systems_data_result = await _call_uyuni_api(
190222
client=client,
191223
method="GET",
192224
api_path="/rhn/manager/api/system/listSystems",
193225
error_context="fetching active systems",
226+
token=token,
194227
default_on_error=[]
195228
)
196229

@@ -206,7 +239,7 @@ async def _get_list_of_active_systems() -> List[Dict[str, Union[str, int]]]:
206239

207240
return filtered_systems
208241

209-
async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str]:
242+
async def _resolve_system_id(system_identifier: Union[str, int], token: str) -> Optional[str]:
210243
"""
211244
Resolves a system identifier, which can be a name or an ID, to a numeric system ID string.
212245
@@ -234,6 +267,7 @@ async def _resolve_system_id(system_identifier: Union[str, int]) -> Optional[str
234267
api_path="/rhn/manager/api/system/getId",
235268
params={'name': system_name},
236269
error_context=f"resolving system ID for name '{system_name}'",
270+
token=token,
237271
default_on_error=[] # Return an empty list on failure
238272
)
239273

@@ -276,10 +310,10 @@ async def get_cpu_of_a_system(system_identifier: Union[str, int], ctx: Context)
276310
log_string = f"Getting CPU information of system with id {system_identifier}"
277311
logger.info(log_string)
278312
await ctx.info(log_string)
279-
return await _get_cpu_of_a_system(system_identifier)
313+
return await _get_cpu_of_a_system(system_identifier, ctx.get_state('token'))
280314

281-
async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str, Any]:
282-
system_id = await _resolve_system_id(system_identifier)
315+
async def _get_cpu_of_a_system(system_identifier: Union[str, int], token: str) -> Dict[str, Any]:
316+
system_id = await _resolve_system_id(system_identifier, token)
283317
if not system_id:
284318
return {} # Helper function already logged the reason for failure.
285319

@@ -290,6 +324,7 @@ async def _get_cpu_of_a_system(system_identifier: Union[str, int]) -> Dict[str,
290324
api_path="/rhn/manager/api/system/getCpu",
291325
params={'sid': system_id},
292326
error_context=f"fetching CPU data for system {system_identifier}",
327+
token=token,
293328
default_on_error={}
294329
)
295330

@@ -325,7 +360,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
325360
await ctx.info(log_string)
326361

327362
all_systems_cpu_data = []
328-
active_systems = await _get_list_of_active_systems() # Calls your existing tool
363+
active_systems = await _get_list_of_active_systems(ctx.get_state('token'))
329364

330365
if not active_systems:
331366
print("Warning: No active systems found or failed to retrieve system list.")
@@ -340,7 +375,7 @@ async def get_all_systems_cpu_info(ctx: Context) -> List[Dict[str, Any]]:
340375
continue
341376

342377
print(f"Fetching CPU info for system: {system_name} (ID: {system_id})")
343-
cpu_info = await _get_cpu_of_a_system(str(system_id)) # Calls your other existing tool
378+
cpu_info = await _get_cpu_of_a_system(str(system_id), ctx.get_state('token'))
344379

345380
all_systems_cpu_data.append({
346381
'system_name': system_name,
@@ -419,7 +454,8 @@ async def check_system_updates(system_identifier: Union[str, int], ctx: Context)
419454
log_string = f"Checking pending updates for system {system_identifier}"
420455
logger.info(log_string)
421456
await ctx.info(log_string)
422-
system_id = await _resolve_system_id(system_identifier)
457+
token = ctx.get_state('token')
458+
system_id = await _resolve_system_id(system_identifier, token)
423459
default_error_response = {
424460
'system_identifier': system_identifier,
425461
'has_pending_updates': False,
@@ -439,6 +475,7 @@ async def check_system_updates(system_identifier: Union[str, int], ctx: Context)
439475
api_path="/rhn/manager/api/system/getRelevantErrata",
440476
params={'sid': system_id},
441477
error_context=f"checking updates for system {system_identifier}",
478+
token=token,
442479
default_on_error=None # Distinguish API error from empty list
443480
)
444481

@@ -448,6 +485,7 @@ async def check_system_updates(system_identifier: Union[str, int], ctx: Context)
448485
api_path="/rhn/manager/api/system/getUnscheduledErrata",
449486
params={'sid': str(system_id)},
450487
error_context=f"checking unscheduled errata for system ID {system_id}",
488+
token=token,
451489
default_on_error=[] # Return empty list on failure
452490
)
453491

@@ -539,7 +577,7 @@ async def check_all_systems_for_updates(ctx: Context) -> List[Dict[str, Any]]:
539577
await ctx.info(log_string)
540578

541579
systems_with_updates = []
542-
active_systems = await _get_list_of_active_systems() # Get the list of all systems
580+
active_systems = await get_list_of_active_systems(ctx) # Get the list of all systems
543581

544582
if not active_systems:
545583
print("Warning: No active systems found or failed to retrieve system list.")
@@ -598,6 +636,8 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
598636
if not confirm:
599637
return f"CONFIRMATION REQUIRED: This will apply pending updates to the system {system_identifier}. Do you confirm?"
600638

639+
token = ctx.get_state('token')
640+
601641
# 1. Use check_system_updates to get relevant errata
602642
update_info = await check_system_updates(system_identifier, ctx)
603643

@@ -618,7 +658,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
618658
print(f"Could not extract any valid errata IDs for system {system_identifier} from the update information: {errata_list}")
619659
return ""
620660

621-
system_id = await _resolve_system_id(system_identifier)
661+
system_id = await _resolve_system_id(system_identifier, token)
622662
if not system_id:
623663
return "" # Helper function already logged the reason for failure.
624664

@@ -633,6 +673,7 @@ async def schedule_apply_pending_updates_to_system(system_identifier: Union[str,
633673
api_path="/rhn/manager/api/system/scheduleApplyErrata",
634674
json_body=payload,
635675
error_context=f"scheduling errata application for system {system_identifier}",
676+
token=token,
636677
default_on_error=None # Helper will return None on error
637678
)
638679

@@ -664,7 +705,8 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
664705
log_string = f"Attempting to apply specific update (errata ID: {errata_id}) to system ID: {system_identifier}"
665706
logger.info(log_string)
666707
await ctx.info(log_string)
667-
system_id = await _resolve_system_id(system_identifier)
708+
token = ctx.get_state('token')
709+
system_id = await _resolve_system_id(system_identifier, token)
668710
if not system_id:
669711
return "" # Helper function already logged the reason for failure.
670712

@@ -682,6 +724,7 @@ async def schedule_apply_specific_update(system_identifier: Union[str, int], err
682724
api_path="/rhn/manager/api/system/scheduleApplyErrata",
683725
json_body=payload,
684726
error_context=f"scheduling specific update (errata ID: {errata_id}) for system {system_identifier}",
727+
token=token,
685728
default_on_error=None # Helper returns None on error
686729
)
687730

@@ -757,8 +800,10 @@ async def add_system(
757800
elif not activation_key: # Fallback if elicitation is not supported
758801
return "You need to provide an activation key."
759802

803+
token = ctx.get_state('token')
804+
760805
# Check if the system already exists
761-
active_systems = await _get_list_of_active_systems()
806+
active_systems = await _get_list_of_active_systems(token)
762807
for system in active_systems:
763808
if system.get('system_name') == host:
764809
message = f"System '{host}' already exists in Uyuni. No action taken."
@@ -801,6 +846,7 @@ async def add_system(
801846
api_path="/rhn/manager/api/system/bootstrapWithPrivateSshKey",
802847
json_body=payload,
803848
error_context=f"adding system {host}",
849+
token=token,
804850
default_on_error=None,
805851
expect_timeout=True,
806852
)
@@ -842,12 +888,13 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu
842888
log_string = f"Attempting to remove system with id {system_identifier}"
843889
logger.info(log_string)
844890
await ctx.info(log_string)
845-
system_id = await _resolve_system_id(system_identifier)
891+
token = ctx.get_state('token')
892+
system_id = await _resolve_system_id(system_identifier, token)
846893
if not system_id:
847894
return "" # Helper function already logged the reason for failure.
848895

849896
# Check if the system exists before proceeding
850-
active_systems = await _get_list_of_active_systems()
897+
active_systems = await _get_list_of_active_systems(token)
851898
if not any(s.get('system_id') == int(system_id) for s in active_systems):
852899
message = f"System with ID {system_id} not found."
853900
logger.warning(message)
@@ -866,6 +913,7 @@ async def remove_system(system_identifier: Union[str, int], ctx: Context, cleanu
866913
api_path="/rhn/manager/api/system/deleteSystem",
867914
json_body={"sid": system_id, "cleanupType": cleanup_type},
868915
error_context=f"removing system ID {system_id}",
916+
token=token,
869917
default_on_error=None
870918
)
871919

@@ -910,6 +958,8 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
910958
find_by_cve_path = '/rhn/manager/api/errata/findByCve'
911959
list_affected_systems_path = '/rhn/manager/api/errata/listAffectedSystems'
912960

961+
token = ctx.get_state('token')
962+
913963
async with httpx.AsyncClient(verify=UYUNI_MCP_SSL_VERIFY) as client:
914964
# 1. Call findByCve (login will be handled by the helper)
915965
print(f"Searching for errata related to CVE: {cve_identifier}")
@@ -919,6 +969,7 @@ async def get_systems_needing_security_update_for_cve(cve_identifier: str, ctx:
919969
api_path=find_by_cve_path,
920970
params={'cveName': cve_identifier},
921971
error_context=f"finding errata for CVE {cve_identifier}",
972+
token=token,
922973
default_on_error=None # Distinguish API error from empty list
923974
)
924975

@@ -1007,6 +1058,7 @@ async def get_systems_needing_reboot(ctx: Context) -> List[Dict[str, Any]]:
10071058
method="GET",
10081059
api_path=list_reboot_path,
10091060
error_context="fetching systems needing reboot",
1061+
token=ctx.get_state('token'),
10101062
default_on_error=[] # Return empty list on error
10111063
)
10121064

@@ -1047,7 +1099,8 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context
10471099
log_string = f"Schedule system reboot for system {system_identifier}"
10481100
logger.info(log_string)
10491101
await ctx.info(log_string)
1050-
system_id = await _resolve_system_id(system_identifier)
1102+
token = ctx.get_state('token')
1103+
system_id = await _resolve_system_id(system_identifier, token)
10511104
if not system_id:
10521105
return "" # Helper function already logged the reason for failure.
10531106

@@ -1067,6 +1120,7 @@ async def schedule_system_reboot(system_identifier: Union[str, int], ctx:Context
10671120
api_path=schedule_reboot_path,
10681121
json_body=payload,
10691122
error_context=f"scheduling reboot for system {system_identifier}",
1123+
token=token,
10701124
default_on_error=None # Helper returns None on error
10711125
)
10721126

@@ -1113,6 +1167,7 @@ async def list_all_scheduled_actions(ctx: Context) -> List[Dict[str, Any]]:
11131167
method="GET",
11141168
api_path=list_actions_path,
11151169
error_context="listing all scheduled actions",
1170+
token=ctx.get_state('token'),
11161171
default_on_error=[] # Return empty list on error
11171172
)
11181173

@@ -1169,6 +1224,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: bool = False) ->
11691224
api_path=cancel_actions_path,
11701225
json_body=payload,
11711226
error_context=f"canceling action {action_id}",
1227+
token=ctx.get_state('token'),
11721228
default_on_error=0 # API returns 1 on success, so 0 can signify an error or unexpected response from helper
11731229
)
11741230
if api_result == 1:
@@ -1178,7 +1234,7 @@ async def cancel_action(action_id: int, ctx: Context, confirm: bool = False) ->
11781234
return f"Failed to cancel action: {action_id}. The API did not return success (expected 1, got {api_result}). Check server logs for details."
11791235

11801236
@mcp.tool()
1181-
async def list_activation_keys() -> List[Dict[str, str]]:
1237+
async def list_activation_keys(ctx: Context) -> List[Dict[str, str]]:
11821238
"""
11831239
Fetches a list of activation keys from the Uyuni server.
11841240
@@ -1200,6 +1256,7 @@ async def list_activation_keys() -> List[Dict[str, str]]:
12001256
method="GET",
12011257
api_path=list_keys_path,
12021258
error_context="listing activation keys",
1259+
token=ctx.get_state('token'),
12031260
default_on_error=[]
12041261
)
12051262

@@ -1241,6 +1298,7 @@ async def get_unscheduled_errata(system_id: int, ctx: Context) -> List[Dict[str,
12411298
api_path=get_unscheduled_errata,
12421299
params=payload,
12431300
error_context=f"fetching unscheduled errata for system ID {system_id}",
1301+
token=ctx.get_state('token'),
12441302
default_on_error=None
12451303
)
12461304

@@ -1260,6 +1318,7 @@ def main_cli():
12601318
logger.info("Running Uyuni MCP server.")
12611319

12621320
if UYUNI_MCP_TRANSPORT == Transport.HTTP.value:
1321+
mcp.add_middleware(AuthTokenMiddleware())
12631322
mcp.run(transport="streamable-http", host=UYUNI_MCP_HOST, port=UYUNI_MCP_PORT)
12641323
elif UYUNI_MCP_TRANSPORT == Transport.STDIO.value:
12651324
mcp.run(transport="stdio")

0 commit comments

Comments
 (0)