From a33a4879769546a835e1dc5499c16548e22175d8 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Fri, 7 Nov 2025 18:04:54 -0400 Subject: [PATCH 1/2] feat: refactor and improve middleware pipeline --- .../id_generator/snow_flake_id_generator.rs | 2 +- .../src/hyper_servers/routes/sse_routes.rs | 6 +- .../rust-mcp-sdk/src/hyper_servers/server.rs | 21 +- crates/rust-mcp-sdk/src/lib.rs | 2 +- crates/rust-mcp-sdk/src/mcp_http.rs | 13 +- crates/rust-mcp-sdk/src/mcp_http/app_state.rs | 16 - .../{mcp_http_utils.rs => http_utils.rs} | 90 +-- .../src/mcp_http/mcp_http_handler.rs | 173 +++-- .../src/mcp_http/mcp_http_middleware.rs | 389 ----------- .../rust-mcp-sdk/src/mcp_http/middleware.rs | 486 ++++++++++++++ .../mcp_http/middleware/cors_middleware.rs | 614 ++++++++++++++++++ .../middleware/dns_rebind_protector.rs | 136 ++++ .../mcp_http/middleware/logging_middleware.rs | 36 + crates/rust-mcp-sdk/src/mcp_http/types.rs | 41 ++ .../rust-mcp-sdk/tests/common/test_client.rs | 5 +- 15 files changed, 1457 insertions(+), 573 deletions(-) rename crates/rust-mcp-sdk/src/mcp_http/{mcp_http_utils.rs => http_utils.rs} (91%) delete mode 100644 crates/rust-mcp-sdk/src/mcp_http/mcp_http_middleware.rs create mode 100644 crates/rust-mcp-sdk/src/mcp_http/middleware.rs create mode 100644 crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs create mode 100644 crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs create mode 100644 crates/rust-mcp-sdk/src/mcp_http/middleware/logging_middleware.rs create mode 100644 crates/rust-mcp-sdk/src/mcp_http/types.rs diff --git a/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs b/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs index 5ab2cb8..6f378b1 100644 --- a/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs +++ b/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs @@ -56,7 +56,7 @@ impl SnowflakeIdGenerator { .expect("invalid system time!") .as_millis() as u64; - now - *SHORTER_EPOCH + now.saturating_sub(*SHORTER_EPOCH) } fn next_id(&self) -> u64 { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index c85d81f..63e38a9 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,6 +1,7 @@ use crate::hyper_servers::error::TransportServerResult; use crate::mcp_http::{McpAppState, McpHttpHandler}; use axum::{extract::State, response::IntoResponse, routing::get, Extension, Router}; +use http::{HeaderMap, Method, Uri}; use std::sync::Arc; #[derive(Clone)] @@ -35,13 +36,16 @@ pub fn routes(sse_endpoint: &str, sse_message_endpoint: &str) -> Router` - The SSE response stream or an error pub async fn handle_sse( + headers: HeaderMap, + uri: Uri, Extension(sse_message_endpoint): Extension, Extension(http_handler): Extension>, State(state): State>, ) -> TransportServerResult { let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint; + let request = McpHttpHandler::create_request(Method::GET, uri, headers, None); let generic_response = http_handler - .handle_sse_connection(state.clone(), Some(&sse_message_endpoint)) + .handle_sse_connection(request, state.clone(), Some(&sse_message_endpoint)) .await?; let (parts, body) = generic_response.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index f3e0983..74a3d77 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -2,9 +2,10 @@ use crate::{ error::SdkResult, id_generator::{FastIdGenerator, UuidGenerator}, mcp_http::{ - utils::{ + http_utils::{ DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT, }, + middleware::dns_rebind_protector::DnsRebindProtector, McpAppState, McpHttpHandler, }, mcp_server::hyper_runtime::HyperRuntime, @@ -203,6 +204,11 @@ impl HyperServerOptions { .as_deref() .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT) } + + pub fn needs_dns_protection(&self) -> bool { + self.dns_rebinding_protection + && (self.allowed_hosts.is_some() || self.allowed_origins.is_some()) + } } /// Default implementation for HyperServerOptions @@ -270,13 +276,18 @@ impl HyperServer { ping_interval: server_options.ping_interval, transport_options: Arc::clone(&server_options.transport_options), enable_json_response: server_options.enable_json_response.unwrap_or(false), - allowed_hosts: server_options.allowed_hosts.take(), - allowed_origins: server_options.allowed_origins.take(), - dns_rebinding_protection: server_options.dns_rebinding_protection, event_store: server_options.event_store.as_ref().map(Arc::clone), }); - let http_handler = McpHttpHandler::new(); //TODO: add auth handlers + let mut http_handler = McpHttpHandler::new(); + + if server_options.needs_dns_protection() { + http_handler.add_middleware(DnsRebindProtector::new( + server_options.allowed_hosts.take(), + server_options.allowed_origins.take(), + )); + } + let app = app_routes(Arc::clone(&state), &server_options, http_handler); Self { app, diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index 9a9e0a9..0d668a0 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -3,7 +3,7 @@ pub mod error; mod hyper_servers; mod mcp_handlers; #[cfg(feature = "hyper-server")] -pub(crate) mod mcp_http; +pub mod mcp_http; mod mcp_macros; mod mcp_runtimes; mod mcp_traits; diff --git a/crates/rust-mcp-sdk/src/mcp_http.rs b/crates/rust-mcp-sdk/src/mcp_http.rs index 2e5d8fd..17c8236 100644 --- a/crates/rust-mcp-sdk/src/mcp_http.rs +++ b/crates/rust-mcp-sdk/src/mcp_http.rs @@ -1,13 +1,12 @@ mod app_state; +pub(crate) mod http_utils; mod mcp_http_handler; -pub(crate) mod mcp_http_utils; - -mod mcp_http_middleware; //TODO: +pub mod middleware; +mod types; pub use app_state::*; +pub use http_utils::*; pub use mcp_http_handler::*; -pub use mcp_http_middleware::Middleware; +pub use types::*; -pub(crate) mod utils { - pub use super::mcp_http_utils::*; -} +pub use middleware::Middleware; diff --git a/crates/rust-mcp-sdk/src/mcp_http/app_state.rs b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs index cada97d..b068612 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/app_state.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs @@ -19,23 +19,7 @@ pub struct McpAppState { pub ping_interval: Duration, pub transport_options: Arc, pub enable_json_response: bool, - /// List of allowed host header values for DNS rebinding protection. - /// If not specified, host validation is disabled. - pub allowed_hosts: Option>, - /// List of allowed origin header values for DNS rebinding protection. - /// If not specified, origin validation is disabled. - pub allowed_origins: Option>, - /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). - /// Default is false for backwards compatibility. - pub dns_rebinding_protection: bool, /// Event store for resumability support /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages pub event_store: Option>, } - -impl McpAppState { - pub fn needs_dns_protection(&self) -> bool { - self.dns_rebinding_protection - && (self.allowed_hosts.is_some() || self.allowed_origins.is_some()) - } -} diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs similarity index 91% rename from crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs rename to crates/rust-mcp-sdk/src/mcp_http/http_utils.rs index 06020d1..29cff4f 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs @@ -1,3 +1,4 @@ +use crate::mcp_http::types::GenericBody; use crate::schema::schema_utils::{ClientMessage, SdkError}; use crate::{ error::SdkResult, @@ -11,10 +12,10 @@ use crate::{ use axum::http::HeaderValue; use bytes::Bytes; use futures::stream; -use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE, HOST, ORIGIN}; +use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE}; use http_body::Frame; use http_body_util::StreamBody; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use http_body_util::{BodyExt, Full}; use hyper::{HeaderMap, StatusCode}; use rust_mcp_transport::{ EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR, @@ -32,8 +33,6 @@ pub(crate) const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; pub(crate) const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp"; const DUPLEX_BUFFER_SIZE: usize = 8192; -pub type GenericBody = BoxBody; - /// Creates an empty HTTP response body. /// /// This function constructs a `GenericBody` containing an empty `Bytes` buffer, @@ -45,6 +44,20 @@ pub fn empty_response() -> GenericBody { .boxed() } +pub fn build_response( + status_code: StatusCode, + payload: String, +) -> Result, TransportServerError> { + let body = Full::new(Bytes::from(payload)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + + http::Response::builder() + .status(status_code) + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) +} + /// Creates an initial SSE event that returns the messages endpoint /// /// Constructs an SSE event containing the messages endpoint URL with the session ID. @@ -251,7 +264,7 @@ fn is_result(json_str: &str) -> Result { } } -pub async fn create_standalone_stream( +pub(crate) async fn create_standalone_stream( session_id: SessionId, last_event_id: Option, state: Arc, @@ -287,7 +300,7 @@ pub async fn create_standalone_stream( Ok(response) } -pub async fn start_new_session( +pub(crate) async fn start_new_session( state: Arc, payload: &str, ) -> TransportServerResult> { @@ -421,7 +434,7 @@ async fn single_shot_stream( } } -pub async fn process_incoming_message_return( +pub(crate) async fn process_incoming_message_return( session_id: SessionId, state: Arc, payload: &str, @@ -446,7 +459,7 @@ pub async fn process_incoming_message_return( } } -pub async fn process_incoming_message( +pub(crate) async fn process_incoming_message( session_id: SessionId, state: Arc, payload: &str, @@ -499,11 +512,11 @@ pub async fn process_incoming_message( } } -pub fn is_empty_sse_message(sse_payload: &str) -> bool { +pub(crate) fn is_empty_sse_message(sse_payload: &str) -> bool { sse_payload.is_empty() || sse_payload.trim() == ":" } -pub async fn delete_session( +pub(crate) async fn delete_session( session_id: SessionId, state: Arc, ) -> TransportServerResult> { @@ -529,7 +542,7 @@ pub async fn delete_session( } } -pub fn acceptable_content_type(headers: &HeaderMap) -> bool { +pub(crate) fn acceptable_content_type(headers: &HeaderMap) -> bool { let accept_header = headers .get("content-type") .and_then(|val| val.to_str().ok()) @@ -539,7 +552,7 @@ pub fn acceptable_content_type(headers: &HeaderMap) -> bool { .any(|val| val.trim().starts_with("application/json")) } -pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> { +pub(crate) fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> { let protocol_version_header = headers .get(MCP_PROTOCOL_VERSION_HEADER) .and_then(|val| val.to_str().ok()) @@ -553,7 +566,7 @@ pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<() validate_mcp_protocol_version(protocol_version_header) } -pub fn accepts_event_stream(headers: &HeaderMap) -> bool { +pub(crate) fn accepts_event_stream(headers: &HeaderMap) -> bool { let accept_header = headers .get(ACCEPT) .and_then(|val| val.to_str().ok()) @@ -564,7 +577,7 @@ pub fn accepts_event_stream(headers: &HeaderMap) -> bool { .any(|val| val.trim().starts_with("text/event-stream")) } -pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool { +pub(crate) fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool { let accept_header = headers .get(ACCEPT) .and_then(|val| val.to_str().ok()) @@ -593,53 +606,6 @@ pub fn error_response( .map_err(|err| TransportServerError::HttpError(err.to_string())) } -// Protect against DNS rebinding attacks by validating Host and Origin headers. -pub(crate) async fn protect_dns_rebinding( - headers: &http::HeaderMap, - state: Arc, -) -> Result<(), SdkError> { - if !state.needs_dns_protection() { - // If protection is not needed, pass the request to the next handler - return Ok(()); - } - - if let Some(allowed_hosts) = state.allowed_hosts.as_ref() { - if !allowed_hosts.is_empty() { - let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else { - return Err(SdkError::bad_request().with_message("Invalid Host header: [unknown] ")); - }; - - if !allowed_hosts - .iter() - .any(|allowed| allowed.eq_ignore_ascii_case(host)) - { - return Err(SdkError::bad_request() - .with_message(format!("Invalid Host header: \"{host}\" ").as_str())); - } - } - } - - if let Some(allowed_origins) = state.allowed_origins.as_ref() { - if !allowed_origins.is_empty() { - let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else { - return Err( - SdkError::bad_request().with_message("Invalid Origin header: [unknown] ") - ); - }; - - if !allowed_origins - .iter() - .any(|allowed| allowed.eq_ignore_ascii_case(origin)) - { - return Err(SdkError::bad_request() - .with_message(format!("Invalid Origin header: \"{origin}\" ").as_str())); - } - } - } - - Ok(()) -} - /// Extracts the value of a query parameter from an HTTP request by key. /// /// This function parses the query string from the request URI and searches @@ -653,7 +619,7 @@ pub(crate) async fn protect_dns_rebinding( /// * `Some(String)` containing the value of the query parameter if found. /// * `None` if the query string is missing or the key is not present. /// -pub fn query_param(request: &http::Request<&str>, key: &str) -> Option { +pub(crate) fn query_param(request: &http::Request<&str>, key: &str) -> Option { request.uri().query().and_then(|query| { for pair in query.split('&') { let mut split = pair.splitn(2, '='); diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index c60b4dc..cb17689 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -1,21 +1,20 @@ #[cfg(feature = "sse")] -use super::utils::handle_sse_connection; -use crate::mcp_http::mcp_http_middleware::MiddlewareChain; -use crate::mcp_http::utils::{ - accepts_event_stream, empty_response, error_response, query_param, +use super::http_utils::{ + accepts_event_stream, empty_response, error_response, handle_sse_connection, query_param, validate_mcp_protocol_version_header, }; -use crate::mcp_http::Middleware; +use super::types::GenericBody; +use crate::mcp_http::{middleware::compose, BoxFutureResponse, Middleware, RequestHandler}; use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID; use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::SdkError; use crate::{ error::McpSdkError, mcp_http::{ - utils::{ + http_utils::{ acceptable_content_type, create_standalone_stream, delete_session, - process_incoming_message, process_incoming_message_return, protect_dns_rebinding, - start_new_session, valid_streaming_http_accept_header, GenericBody, + process_incoming_message, process_incoming_message_return, start_new_session, + valid_streaming_http_accept_header, }, McpAppState, }, @@ -26,9 +25,31 @@ use http::{self, HeaderMap, Method, StatusCode, Uri}; use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; use std::sync::Arc; +/// A helper macro to wrap an async handler method into a `RequestHandler` +/// and compose it with middlewares. +/// +/// # Example +/// ```ignore +/// let handle = with_middlewares!(self, Self::internal_handle_sse_message); +/// handle +/// ``` +#[macro_export] +macro_rules! with_middlewares { + ($self:ident, $handler:path) => {{ + let final_handler: RequestHandler = std::sync::Arc::new( + move |req: http::Request<&str>, + state: std::sync::Arc| + -> BoxFutureResponse<'_> { + Box::pin(async move { $handler(req, state).await }) + }, + ); + $crate::mcp_http::middleware::compose(&$self.middlewares, final_handler) + }}; +} + #[derive(Clone)] pub struct McpHttpHandler { - middleware_chain: MiddlewareChain, + middlewares: Vec>, } impl Default for McpHttpHandler { @@ -40,12 +61,12 @@ impl Default for McpHttpHandler { impl McpHttpHandler { pub fn new() -> Self { McpHttpHandler { - middleware_chain: MiddlewareChain::new(), + middlewares: vec![], } } - - pub fn add_middleware(&mut self, middleware: M) { - self.middleware_chain.add_middleware(middleware); + pub fn add_middleware(&mut self, middleware: M) { + let m: Arc = Arc::new(middleware); + self.middlewares.push(m); } /// An `http::Request<&str>` initialized with the specified method, URI, headers, and body. @@ -87,10 +108,17 @@ impl McpHttpHandler { #[cfg(feature = "sse")] pub async fn handle_sse_connection( &self, + request: http::Request<&str>, state: Arc, sse_message_endpoint: Option<&str>, ) -> TransportServerResult> { - handle_sse_connection(state, sse_message_endpoint).await + let sse_endpoint = Arc::from(sse_message_endpoint.map(|s| s.to_string())); + let final_handler: RequestHandler = Arc::new(move |_req, state| { + let sse_endpoint = sse_endpoint.clone(); + Box::pin(async move { handle_sse_connection(state, sse_endpoint.as_deref()).await }) + }); + let handle = compose(&self.middlewares, final_handler); + handle(request, state).await } /// Handles incoming MCP messages from the client after an SSE connection is established. @@ -113,32 +141,14 @@ impl McpHttpHandler { /// - `SessionIdInvalid`: if the session ID does not map to a valid session in the session store. /// - `StreamIoError`: if an error occurs while writing to the stream. /// - `HttpError`: if constructing the HTTP response fails. + #[cfg(feature = "sse")] pub async fn handle_sse_message( &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { - let session_id = - query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?; - - // transmit to the readable stream, that transport is reading from - let transmit = state.session_store.get(&session_id).await.ok_or( - TransportServerError::SessionIdInvalid(session_id.to_string()), - )?; - - let message = *request.body(); - transmit - .consume_payload_string(DEFAULT_STREAM_ID, message) - .await - .map_err(|err| { - tracing::trace!("{}", err); - TransportServerError::StreamIoError(err.to_string()) - })?; - - http::Response::builder() - .status(StatusCode::ACCEPTED) - .body(empty_response()) - .map_err(|err| TransportServerError::HttpError(err.to_string())) + let handle = with_middlewares!(self, Self::internal_handle_sse_message); + handle(request, state).await } /// Handles incoming MCP messages over the StreamableHTTP transport. @@ -167,25 +177,47 @@ impl McpHttpHandler { request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { - let request = self - .middleware_chain - .process_request(request) + let handle = with_middlewares!(self, Self::internal_handle_streamable_http); + handle(request, state).await + } + + async fn internal_handle_sse_message( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let session_id = + query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?; + + // transmit to the readable stream, that transport is reading from + let transmit = state.session_store.get(&session_id).await.ok_or( + TransportServerError::SessionIdInvalid(session_id.to_string()), + )?; + + let message = request.body(); + + transmit + .consume_payload_string(DEFAULT_STREAM_ID, message.as_ref()) .await - .map_err(|e| TransportServerError::HttpError(e.to_string()))?; + .map_err(|err| { + tracing::trace!("{}", err); + TransportServerError::StreamIoError(err.to_string()) + })?; - // Enforces DNS rebinding protection if required by state. - // If protection fails, respond with HTTP 403 Forbidden. - if state.needs_dns_protection() { - if let Err(error) = protect_dns_rebinding(request.headers(), state.clone()).await { - return error_response(StatusCode::FORBIDDEN, error); - } - } + http::Response::builder() + .status(StatusCode::ACCEPTED) + .body(empty_response()) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + async fn internal_handle_streamable_http( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { let method = request.method(); let response = match method { - &http::Method::GET => return self.handle_http_get(request, state).await, - &http::Method::POST => return self.handle_http_post(request, state).await, - &http::Method::DELETE => return self.handle_http_delete(request, state).await, + &http::Method::GET => return Self::handle_http_get(request, state).await, + &http::Method::POST => return Self::handle_http_post(request, state).await, + &http::Method::DELETE => return Self::handle_http_delete(request, state).await, other => { let error = SdkError::bad_request().with_message(&format!( "'{other}' is not a valid HTTP method for StreamableHTTP transport." @@ -194,24 +226,14 @@ impl McpHttpHandler { } }; - self.middleware_chain - .process_response(response?) - .await - .map_err(|e| TransportServerError::HttpError(e.to_string())) + response } /// Processes POST requests for the Streamable HTTP Protocol async fn handle_http_post( - &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { - let request = self - .middleware_chain - .process_request(request) - .await - .map_err(|e| TransportServerError::HttpError(e.to_string()))?; - let headers = request.headers(); if !valid_streaming_http_accept_header(headers) { @@ -237,7 +259,7 @@ impl McpHttpHandler { .and_then(|value| value.to_str().ok()) .map(|s| s.to_string()); - let payload = *request.body(); + let payload = request.body(); let response = match session_id { // has session-id => write to the existing stream @@ -260,24 +282,14 @@ impl McpHttpHandler { }, }; - self.middleware_chain - .process_response(response?) - .await - .map_err(|e| TransportServerError::HttpError(e.to_string())) + response } /// Processes GET requests for the Streamable HTTP Protocol async fn handle_http_get( - &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { - let request = self - .middleware_chain - .process_request(request) - .await - .map_err(|e| TransportServerError::HttpError(e.to_string()))?; - let headers = request.headers(); if !accepts_event_stream(headers) { @@ -313,24 +325,14 @@ impl McpHttpHandler { } }; - self.middleware_chain - .process_response(response?) - .await - .map_err(|e| TransportServerError::HttpError(e.to_string())) + response } /// Processes DELETE requests for the Streamable HTTP Protocol async fn handle_http_delete( - &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { - let request = self - .middleware_chain - .process_request(request) - .await - .map_err(|e| TransportServerError::HttpError(e.to_string()))?; - let headers = request.headers(); if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { @@ -352,9 +354,6 @@ impl McpHttpHandler { } }; - self.middleware_chain - .process_response(response?) - .await - .map_err(|e| TransportServerError::HttpError(e.to_string())) + response } } diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_middleware.rs deleted file mode 100644 index 22027d7..0000000 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_middleware.rs +++ /dev/null @@ -1,389 +0,0 @@ -use crate::mcp_http::utils::GenericBody; -use crate::mcp_server::error::TransportServerResult; -use http::{Request, Response}; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -/// Defines a middleware trait for processing HTTP requests and responses. -/// -/// Implementors of this trait can define custom logic to modify or inspect HTTP -/// requests before they reach the handler and HTTP responses before they are sent -/// back to the client. Middleware must be thread-safe (`Send + Sync`) and have a -/// static lifetime. -pub trait Middleware: Send + Sync + 'static { - /// Processes an incoming HTTP request. - /// - /// This method takes a request, applies middleware-specific logic, and returns - /// a future that resolves to a `TransportServerResult` containing the modified - /// request or an error. - /// - /// # Arguments - /// * `request` - The incoming HTTP request with a string body reference. - /// - /// # Returns - /// A pinned boxed future resolving to a `TransportServerResult` containing the - /// processed request. - fn process_request<'a, 'b>( - &'a self, - request: Request<&'b str>, - ) -> Pin>> + Send + 'a>> - where - 'b: 'a; // Ensure the request's lifetime outlives the future - - /// Processes an outgoing HTTP response. - /// - /// This method takes a response, applies middleware-specific logic, and returns - /// a future that resolves to a `TransportServerResult` containing the modified - /// response or an error. - /// - /// # Arguments - /// * `response` - The HTTP response with a `GenericBody`. - /// - /// # Returns - /// A pinned boxed future resolving to a `TransportServerResult` containing the - /// processed response. - fn process_response<'a, 'b>( - &'a self, - response: Response, - ) -> Pin>> + Send + 'a>> - where - 'b: 'a; // Optional, included for consistency -} - -/// A chain of middleware to process HTTP requests and responses sequentially. -/// -/// `MiddlewareChain` allows multiple middleware instances to be registered and -/// executed in order for requests (forward order) and responses (reverse order). -#[derive(Clone)] -pub struct MiddlewareChain { - middlewares: Vec>, -} - -impl MiddlewareChain { - /// Creates a new, empty middleware chain. - /// - /// # Returns - /// A new `MiddlewareChain` instance with no middleware registered. - pub fn new() -> Self { - MiddlewareChain { - middlewares: Vec::new(), - } - } - - /// Adds a middleware to the chain. - /// - /// The middleware is wrapped in an `Arc` to ensure thread-safety and shared - /// ownership. Middleware will be executed in the order they are added for - /// requests and in reverse order for responses. - /// - /// # Arguments - /// * `middleware` - The middleware to add to the chain. - pub fn add_middleware(&mut self, middleware: M) { - self.middlewares.push(Arc::new(middleware)); - } - - /// Processes an HTTP request through all registered middleware. - /// - /// Each middleware's `process_request` method is called in the order they - /// were added. If any middleware returns an error, processing stops and the - /// error is returned. - /// - /// # Arguments - /// * `request` - The HTTP request to process. - /// - /// # Returns - /// A `TransportServerResult` containing the processed request or an error. - pub async fn process_request<'a>( - &self, - request: http::Request<&'a str>, - ) -> TransportServerResult> { - let mut request = request; - for middleware in &self.middlewares { - request = middleware.process_request(request).await?; - } - Ok(request) - } - - /// Processes an HTTP response through all registered middleware. - /// - /// Each middleware's `process_response` method is called in the reverse order - /// of their addition. If any middleware returns an error, processing stops and - /// the error is returned. - /// - /// # Arguments - /// * `response` - The HTTP response to process. - /// - /// # Returns - /// A `TransportServerResult` containing the processed response or an error. - pub async fn process_response( - &self, - response: http::Response, - ) -> TransportServerResult> { - let mut response = response; - for middleware in self.middlewares.iter().rev() { - response = middleware.process_response(response).await?; - } - Ok(response) - } -} - -// Sample Middleware -pub struct LoggingMiddleware; - -impl Middleware for LoggingMiddleware { - fn process_request<'a, 'b>( - &'a self, - request: http::Request<&'b str>, - ) -> Pin>> + Send + 'a>> - where - 'b: 'a, - { - Box::pin(async move { - tracing::info!("Request: {} {}", request.method(), request.uri()); - Ok(request) - }) - } - - fn process_response<'a, 'b>( - &'a self, - response: http::Response, - ) -> Pin>> + Send + 'a>> - where - 'b: 'a, - { - Box::pin(async move { - tracing::info!("Response: {}", response.status()); - Ok(response) - }) - } -} - -#[cfg(test)] -mod tests { - use crate::{mcp_http::utils::empty_response, mcp_server::error::TransportServerError}; - - use super::*; - use async_trait::async_trait; - use bytes::Bytes; - use http::{Request, Response}; - use http_body_util::{BodyExt, Full}; - use std::sync::Mutex; - use thiserror::Error; - - /// Custom error type for test middleware. - #[derive(Error, Debug)] - enum TestMiddlewareError { - #[error("Request processing failed: {0}")] - RequestError(String), - #[error("Response processing failed: {0}")] - ResponseError(String), - } - - /// A test middleware that records its interactions with requests and responses. - struct TestMiddleware { - /// Tracks request calls with their input bodies. - request_calls: Arc>>, - /// Tracks response calls with their status codes. - response_calls: Arc>>, - /// Optional error to simulate failure in request processing. - request_error: Option, - /// Optional error to simulate failure in response processing. - response_error: Option, - } - - impl TestMiddleware { - fn new() -> Self { - TestMiddleware { - request_calls: Arc::new(Mutex::new(Vec::new())), - response_calls: Arc::new(Mutex::new(Vec::new())), - request_error: None, - response_error: None, - } - } - - fn with_errors(request_error: Option, response_error: Option) -> Self { - TestMiddleware { - request_calls: Arc::new(Mutex::new(Vec::new())), - response_calls: Arc::new(Mutex::new(Vec::new())), - request_error, - response_error, - } - } - } - - #[async_trait] - impl Middleware for TestMiddleware { - fn process_request<'a, 'b>( - &'a self, - request: Request<&'b str>, - ) -> Pin>> + Send + 'a>> - where - 'b: 'a, - { - Box::pin(async move { - if let Some(err) = &self.request_error { - return Err(TransportServerError::HttpError(err.to_string())); - } - self.request_calls - .lock() - .unwrap() - .push(request.body().to_string()); - Ok(request) - }) - } - - fn process_response<'a, 'b>( - &'a self, - response: Response, - ) -> Pin>> + Send + 'a>> - where - 'b: 'a, - { - Box::pin(async move { - if let Some(err) = &self.response_error { - return Err(TransportServerError::HttpError(err.to_string())); - } - self.response_calls - .lock() - .unwrap() - .push(response.status().as_u16()); - Ok(response) - }) - } - } - - #[tokio::test] - async fn test_empty_middleware_chain() { - let chain = MiddlewareChain::new(); - let request = Request::builder().body("test").unwrap(); - - let response = Response::builder() - .status(200) - .body(empty_response()) - .unwrap(); - - let result_request = chain.process_request(request).await.unwrap(); - let result_response = chain.process_response(response).await.unwrap(); - - assert_eq!(result_request.body().to_ascii_lowercase(), "test"); - assert_eq!(result_response.status(), 200); - } - - #[tokio::test] - async fn test_single_middleware() { - let mut chain = MiddlewareChain::new(); - let middleware = TestMiddleware::new(); - let request_calls = middleware.request_calls.clone(); - let response_calls = middleware.response_calls.clone(); - - chain.add_middleware(middleware); - - let request = Request::builder().body("test").unwrap(); - let response = Response::builder() - .status(200) - .body(empty_response()) - .unwrap(); - - let result_request = chain.process_request(request).await.unwrap(); - let result_response = chain.process_response(response).await.unwrap(); - - assert_eq!(result_request.body().to_ascii_lowercase(), "test"); - assert_eq!(result_response.status(), 200); - assert_eq!(request_calls.lock().unwrap().as_slice(), &["test"]); - assert_eq!(response_calls.lock().unwrap().as_slice(), &[200]); - } - - #[tokio::test] - async fn test_multiple_middlewares_request_order() { - let mut chain = MiddlewareChain::new(); - let middleware1 = TestMiddleware::new(); - let middleware2 = TestMiddleware::new(); - let request_calls1 = middleware1.request_calls.clone(); - let request_calls2 = middleware2.request_calls.clone(); - - chain.add_middleware(middleware1); - chain.add_middleware(middleware2); - - let request = Request::builder().body("test").unwrap(); - - let result = chain.process_request(request).await.unwrap(); - assert_eq!(result.body().to_ascii_lowercase(), "test"); - - // Check order of execution - assert_eq!(request_calls1.lock().unwrap().as_slice(), &["test"]); - assert_eq!(request_calls2.lock().unwrap().as_slice(), &["test"]); - } - - #[tokio::test] - async fn test_multiple_middlewares_response_reverse_order() { - let mut chain = MiddlewareChain::new(); - let middleware1 = TestMiddleware::new(); - let middleware2 = TestMiddleware::new(); - let response_calls1 = middleware1.response_calls.clone(); - let response_calls2 = middleware2.response_calls.clone(); - - chain.add_middleware(middleware1); - chain.add_middleware(middleware2); - - let response = Response::builder() - .status(200) - .body(empty_response()) - .unwrap(); - - let result = chain.process_response(response).await.unwrap(); - assert_eq!(result.status(), 200); - - // Check reverse order of execution - assert_eq!(response_calls2.lock().unwrap().as_slice(), &[200]); - assert_eq!(response_calls1.lock().unwrap().as_slice(), &[200]); - } - - #[tokio::test] - async fn test_middleware_request_error() { - let mut chain = MiddlewareChain::new(); - let middleware = TestMiddleware::with_errors(Some("request error".to_string()), None); - chain.add_middleware(middleware); - - let request = Request::builder().body("test").unwrap(); - - let result = chain.process_request(request).await; - assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string(), "request error"); - } - - #[tokio::test] - async fn test_middleware_response_error() { - let mut chain = MiddlewareChain::new(); - let middleware = TestMiddleware::with_errors(None, Some("response error".to_string())); - chain.add_middleware(middleware); - - let response = Response::builder() - .status(200) - .body(empty_response()) - .unwrap(); - - let result = chain.process_response(response).await; - assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string(), "response error"); - } - - #[tokio::test] - async fn test_middleware_chain_clone() { - let mut chain = MiddlewareChain::new(); - let middleware = TestMiddleware::new(); - let request_calls = middleware.request_calls.clone(); - - chain.add_middleware(middleware); - let chain_clone = chain.clone(); - - let request = Request::builder().body("test").unwrap(); - - // Process on original and clone - chain.process_request(request.clone()).await.unwrap(); - chain_clone.process_request(request).await.unwrap(); - - // Both should have processed the request - assert_eq!(request_calls.lock().unwrap().as_slice(), &["test", "test"]); - } -} diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware.rs new file mode 100644 index 0000000..bd7d2ad --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware.rs @@ -0,0 +1,486 @@ +pub mod cors_middleware; +pub(crate) mod dns_rebind_protector; +pub mod logging_middleware; + +use super::types::{BoxFutureResponse, GenericBody, RequestHandler}; +use crate::mcp_http::{McpAppState, MiddlewareNext}; +use crate::mcp_server::error::TransportServerResult; +use http::{Request, Response}; +use std::sync::Arc; + +#[async_trait::async_trait] +pub trait Middleware: Send + Sync + 'static { + async fn handle<'req>( + &self, + req: Request<&'req str>, + state: Arc, + next: MiddlewareNext<'req>, + ) -> TransportServerResult>; +} + +/// Build the final handler by folding the middlewares **in reverse**. +pub fn compose( + middlewares: &Vec>, + final_handler: RequestHandler, +) -> RequestHandler { + let mut handler = final_handler; + + for mw in middlewares.iter().rev() { + let mw = mw.clone(); + let next = handler.clone(); + + handler = Arc::new(move |req: Request<&str>, state: Arc| { + let mw = mw.clone(); + let next = next.clone(); + + Box::pin(async move { mw.handle(req, state, next).await }) as BoxFutureResponse<'_> + }); + } + + handler +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities}; + use crate::{ + id_generator::{FastIdGenerator, UuidGenerator}, + mcp_http::{ + middleware::{cors_middleware::CorsMiddleware, logging_middleware::LoggingMiddleware}, + types::GenericBodyExt, + }, + mcp_server::{error::TransportServerError, ServerHandler, ToMcpServerHandler}, + session_store::InMemorySessionStore, + }; + use async_trait::async_trait; + use http::{HeaderName, Request, Response, StatusCode}; + use http_body_util::BodyExt; + use std::{ + sync::{Arc, Mutex}, + time::Duration, + }; + struct TestHandler; + impl ServerHandler for TestHandler {} + + fn app_state() -> Arc { + let handler = TestHandler {}; + + Arc::new(McpAppState { + session_store: Arc::new(InMemorySessionStore::new()), + id_generator: Arc::new(UuidGenerator {}), + stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))), + server_details: Arc::new(InitializeResult { + capabilities: ServerCapabilities { + ..Default::default() + }, + instructions: None, + meta: None, + protocol_version: ProtocolVersion::V2025_06_18.to_string(), + server_info: Implementation { + name: "server".to_string(), + title: None, + version: "0.1.0".to_string(), + }, + }), + handler: handler.to_mcp_server_handler(), + ping_interval: Duration::from_secs(15), + transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()), + enable_json_response: false, + event_store: None, + }) + } + + /// Helper: Convert response to string + async fn response_string(res: Response) -> String { + let (_parts, body) = res.into_parts(); + let bytes = body.collect().await.unwrap().to_bytes(); + String::from_utf8(bytes.to_vec()).unwrap() + } + + /// Test Middleware – records everything, modifies req/res, supports early return + #[derive(Clone)] + struct TestMiddleware { + id: usize, + request_calls: Arc)>>>, + response_calls: Arc)>>>, + add_req_header: Option<(String, String)>, + add_res_header: Option<(String, String)>, + + // ---- early return (clone-able) ---- + early_return_status: Option, + early_return_body: Option, + + fail_request: bool, + fail_response: bool, + } + + impl TestMiddleware { + fn new(id: usize) -> Self { + Self { + id, + request_calls: Arc::new(Mutex::new(Vec::new())), + response_calls: Arc::new(Mutex::new(Vec::new())), + add_req_header: None, + add_res_header: None, + early_return_status: None, + early_return_body: None, + fail_request: false, + fail_response: false, + } + } + + fn with_req_header(mut self, name: &str, value: &str) -> Self { + self.add_req_header = Some((name.to_string(), value.to_string())); + self + } + + fn with_res_header(mut self, name: &str, value: &str) -> Self { + self.add_res_header = Some((name.to_string(), value.to_string())); + self + } + + fn early_return_200(mut self) -> Self { + self.early_return_status = Some(StatusCode::OK); + self.early_return_body = Some(format!("early-{}", self.id)); + self + } + + #[allow(unused)] + fn early_return(mut self, status: StatusCode, body: impl Into) -> Self { + self.early_return_status = Some(status); + self.early_return_body = Some(body.into()); + self + } + + fn fail_request(mut self) -> Self { + self.fail_request = true; + self + } + + fn fail_response(mut self) -> Self { + self.fail_response = true; + self + } + } + + #[async_trait] + impl Middleware for TestMiddleware { + async fn handle<'req>( + &self, + mut req: Request<&'req str>, + state: Arc, + next: MiddlewareNext<'req>, + ) -> TransportServerResult> { + // ---- record request ------------------------------------------------- + let headers = req + .headers() + .iter() + .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + self.request_calls + .lock() + .unwrap() + .push((self.id, req.body().to_string(), headers)); + + if self.fail_request { + return Err(TransportServerError::HttpError(format!( + "middleware {} failed request", + self.id + ))); + } + + // ---- add request header -------------------------------------------- + if let Some((name, value)) = &self.add_req_header { + req.headers_mut().insert( + HeaderName::from_bytes(name.as_bytes()).unwrap(), + value.parse().unwrap(), + ); + } + + // ---- early return --------------------------------------------------- + if let (Some(status), Some(body)) = (&self.early_return_status, &self.early_return_body) + { + return Ok(Response::builder() + .status(*status) + .body(GenericBody::from_string(body.to_string())) + .unwrap()); + } + + // ---- call next ------------------------------------------------------ + let mut res = next(req, state).await?; + // ---- add response header -------------------------------------------- + if let Some((name, value)) = &self.add_res_header { + res.headers_mut().insert( + HeaderName::from_bytes(name.as_bytes()).unwrap(), + value.parse().unwrap(), + ); + } + + if self.fail_response { + return Err(TransportServerError::HttpError(format!( + "middleware {} failed response", + self.id + ))); + } + + // ---- record response ------------------------------------------------ + let headers = res + .headers() + .iter() + .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + self.response_calls + .lock() + .unwrap() + .push((self.id, res.status().as_u16(), headers)); + + Ok(res) + } + } + + /// Final handler – returns a fixed response + fn final_handler(body: &'static str, status: StatusCode) -> RequestHandler { + Arc::new(move |_req, _| { + let resp = Response::builder() + .status(status) + .body(GenericBody::from_string(body.to_string())) + .unwrap(); + Box::pin(async move { Ok(resp) }) + }) + } + + // TESTS + + /// Middleware order (request → final → response) + #[tokio::test] + async fn test_middleware_order() { + let mw1 = Arc::new(TestMiddleware::new(1)); + let mw2 = Arc::new(TestMiddleware::new(2)); + let mw3 = Arc::new(TestMiddleware::new(3)); + + let middlewares: Vec> = vec![mw1.clone(), mw2.clone(), mw3.clone()]; + let handler = final_handler("final", StatusCode::OK); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("").unwrap(); + let _ = composed(req, app_state()).await.unwrap(); + + // request order: 3 → 2 → 1 → final + let rc3 = mw3.request_calls.lock().unwrap(); + let rc2 = mw2.request_calls.lock().unwrap(); + let rc1 = mw1.request_calls.lock().unwrap(); + assert_eq!(rc3[0].0, 3); + assert_eq!(rc2[0].0, 2); + assert_eq!(rc1[0].0, 1); + + // response order: 1 → 2 → 3 + let pc1 = mw1.response_calls.lock().unwrap(); + let pc2 = mw2.response_calls.lock().unwrap(); + let pc3 = mw3.response_calls.lock().unwrap(); + assert_eq!(pc1[0].0, 1); + assert_eq!(pc2[0].0, 2); + assert_eq!(pc3[0].0, 3); + } + + /// Request header added by earlier middleware is visible later + #[tokio::test] + async fn test_request_header_propagation() { + let mw1 = Arc::new(TestMiddleware::new(1).with_req_header("x-mid", "1")); + let mw2 = Arc::new(TestMiddleware::new(2)); + + let middlewares: Vec> = vec![mw1.clone(), mw2.clone()]; + let handler = final_handler("ok", StatusCode::OK); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("").unwrap(); + let _ = composed(req, app_state()).await.unwrap(); + + let rc = mw2.request_calls.lock().unwrap(); + let hdr = rc[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v); + assert_eq!(hdr, Some(&"1".to_string())); + } + + /// Response header added by later middleware is visible earlier + #[tokio::test] + async fn test_response_header_propagation() { + let mw1 = Arc::new(TestMiddleware::new(1)); + let mw2 = Arc::new(TestMiddleware::new(2).with_res_header("x-mid", "1")); + + let middlewares: Vec> = vec![mw1.clone(), mw2.clone()]; + let handler = final_handler("ok", StatusCode::OK); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("").unwrap(); + let res = composed(req, app_state()).await.unwrap(); + + let pc1 = mw1.response_calls.lock().unwrap(); + + let hdr = pc1[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v); + assert_eq!(hdr, Some(&"1".to_string())); + + assert_eq!(res.headers().get("x-mid").unwrap().to_str().unwrap(), "1"); + } + + /// Early return stops the chain + #[tokio::test] + async fn test_early_return_stops_chain() { + let mw1 = Arc::new(TestMiddleware::new(1).early_return_200()); + let mw2 = Arc::new(TestMiddleware::new(2)); + let mw3 = Arc::new(TestMiddleware::new(3)); + + let middlewares: Vec> = vec![mw1.clone(), mw2.clone(), mw3.clone()]; + let handler = final_handler("should-not-see", StatusCode::OK); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("").unwrap(); + let res = composed(req, app_state()).await.unwrap(); + + assert_eq!(response_string(res).await, "early-1"); + + assert!(mw2.request_calls.lock().unwrap().is_empty()); + assert!(mw3.request_calls.lock().unwrap().is_empty()); + } + + /// Request error stops response processing + #[tokio::test] + async fn test_request_error_stops_response_chain() { + let mw1 = Arc::new(TestMiddleware::new(1).fail_request()); + let mw2 = Arc::new(TestMiddleware::new(2)); + + let middlewares: Vec> = vec![mw1.clone(), mw2.clone()]; + let handler = final_handler("ok", StatusCode::OK); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("").unwrap(); + let result = composed(req, app_state()).await; + + assert!(result.is_err()); + assert!(mw2.request_calls.lock().unwrap().is_empty()); + assert!(mw2.response_calls.lock().unwrap().is_empty()); + } + + ///Response error after next() + #[tokio::test] + async fn test_response_error_after_next() { + let mw1 = Arc::new(TestMiddleware::new(1).fail_response()); + let mw2 = Arc::new(TestMiddleware::new(2)); + + let middlewares: Vec> = vec![mw1.clone(), mw2.clone()]; + let handler = final_handler("ok", StatusCode::OK); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("").unwrap(); + let result = composed(req, app_state()).await; + + assert!(result.is_err()); + assert!(!mw1.request_calls.lock().unwrap().is_empty()); + // response_calls is empty because we error before recording + assert!(mw1.response_calls.lock().unwrap().is_empty()); + } + + /// No middleware → direct handler + #[tokio::test] + async fn test_no_middleware() { + let middlewares: Vec> = vec![]; + let handler = final_handler("direct", StatusCode::IM_A_TEAPOT); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("").unwrap(); + let res = composed(req, app_state()).await.unwrap(); + + assert_eq!(res.status(), StatusCode::IM_A_TEAPOT); + assert_eq!(response_string(res).await, "direct"); + } + + /// Multiple headers accumulate correctly + #[tokio::test] + async fn test_multiple_headers_accumulate() { + let mw1 = Arc::new( + TestMiddleware::new(1) + .with_req_header("x-a", "1") + .with_res_header("x-b", "1"), + ); + let mw2 = Arc::new( + TestMiddleware::new(2) + .with_req_header("x-c", "2") + .with_res_header("x-d", "2"), + ); + + let mw3 = Arc::new(TestMiddleware::new(3)); + + let middlewares: Vec> = vec![mw1.clone(), mw2.clone(), mw3.clone()]; + let handler = final_handler("ok", StatusCode::OK); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("").unwrap(); + let res = composed(req, app_state()).await.unwrap(); + + let h = res.headers(); + assert_eq!(h["x-b"], "1"); + assert_eq!(h["x-d"], "2"); + + // Request headers are NOT in response + assert!(!h.contains_key("x-a")); + assert!(!h.contains_key("x-c")); + + // But they were added to the request + let req_calls_mw3 = mw3.request_calls.lock().unwrap(); + let req_headers = &req_calls_mw3[0].2; + + assert!(req_headers.iter().any(|(k, v)| k == "x-a" && v == "1")); + assert!(req_headers.iter().any(|(k, v)| k == "x-c" && v == "2")); + } + + /// Request body is passed unchanged + #[tokio::test] + async fn test_request_body_unchanged() { + let mw1 = Arc::new(TestMiddleware::new(1)); + let mw2 = Arc::new(TestMiddleware::new(2)); + + let middlewares: Vec> = vec![mw1.clone(), mw2.clone()]; + let handler: RequestHandler = Arc::new(move |req, _| { + let body = req.into_body().to_string(); + Box::pin(async move { + Ok(Response::builder() + .body(GenericBody::from_string(format!("echo:{body}"))) + .unwrap()) + }) + }); + let composed = compose(&middlewares, handler); + + let req = Request::builder().body("secret-payload").unwrap(); + let res = composed(req, app_state()).await.unwrap(); + assert_eq!(response_string(res).await, "echo:secret-payload"); + } + + // Integration: CORS + Logger (order matters) + #[tokio::test] + async fn test_cors_and_logger_integration() { + let cors = Arc::new(CorsMiddleware::permissive()); + let logger = Arc::new(LoggingMiddleware); + + // Order in the vector is the order they are *registered*. + // compose folds in reverse, so logger runs *first* (request) and *last* (response). + let middlewares: Vec> = vec![cors.clone(), logger.clone()]; + let handler = final_handler("ok", StatusCode::OK); + let composed = compose(&middlewares, handler); + + let req = Request::builder() + .method(http::Method::GET) + .uri("/api") + .header("Origin", "https://example.com") + .body("") + .unwrap(); + + let res = composed(req, app_state()).await.unwrap(); + + // CORS headers added by CorsMiddleware + assert_eq!( + res.headers()["access-control-allow-origin"], + "https://example.com" + ); + assert_eq!(res.headers()["access-control-allow-credentials"], "true"); + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs new file mode 100644 index 0000000..2f2608d --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs @@ -0,0 +1,614 @@ +//! # CORS Middleware +//! +//! A configurable CORS middleware that follows the +//! [WHATWG CORS specification](https://fetch.spec.whatwg.org/#http-cors-protocol). +//! +//! ## Features +//! - Full preflight (`OPTIONS`) handling +//! - Configurable origins: `*`, explicit list, or echo +//! - Credential support (with correct `Access-Control-Allow-Origin` behavior) +//! - Header/method validation +//! - `Access-Control-Expose-Headers` support + +use crate::{ + mcp_http::{ + http_utils::{build_response, empty_response}, + types::GenericBody, + McpAppState, Middleware, MiddlewareNext, + }, + mcp_server::error::TransportServerResult, +}; +use http::{ + header::{ + self, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS, + ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, + ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, + ACCESS_CONTROL_REQUEST_METHOD, + }, + Method, Request, Response, StatusCode, +}; +use std::{collections::HashSet, sync::Arc}; + +/// Configuration for CORS behavior. +/// +/// See [MDN CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) for details. +#[derive(Clone)] +pub struct CorsConfig { + /// Which origins are allowed to make requests. + pub allow_origins: AllowOrigins, + + /// HTTP methods allowed in preflight and actual requests. + pub allow_methods: Vec, + + /// Request headers allowed in preflight. + pub allow_headers: Vec, + + /// Whether to allow credentials (cookies, HTTP auth, etc). + /// + /// **Important**: When `true`, `allow_origins` cannot be `Any` — browsers reject `*`. + pub allow_credentials: bool, + + /// How long (in seconds) the preflight response can be cached. + pub max_age: Option, + + /// Headers that should be exposed to the client JavaScript. + pub expose_headers: Vec, +} + +impl Default for CorsConfig { + fn default() -> Self { + Self { + allow_origins: AllowOrigins::Any, + allow_methods: vec![Method::GET, Method::POST, Method::OPTIONS], + allow_headers: vec![header::CONTENT_TYPE, header::AUTHORIZATION], + allow_credentials: false, + max_age: Some(86_400), // 24 hours + expose_headers: vec![], + } + } +} + +/// Policy for allowed origins. +#[derive(Clone, Debug)] +pub enum AllowOrigins { + /// Allow any origin (`*`). + /// + /// **Cannot** be used with `allow_credentials = true`. + Any, + + /// Allow only specific origins. + List(HashSet), + + /// Echo the `Origin` header back (required when `allow_credentials = true`). + Echo, +} + +/// CORS middleware implementing the `Middleware` trait. +/// +/// Handles both **preflight** (`OPTIONS`) and **actual** requests, +/// adding appropriate CORS headers and rejecting invalid origins/methods/headers. +#[derive(Clone)] +pub struct CorsMiddleware { + config: Arc, +} + +impl CorsMiddleware { + /// Create a new CORS middleware with custom config. + pub fn new(config: CorsConfig) -> Self { + Self { + config: Arc::new(config), + } + } + + /// Create a permissive CORS config — useful for public APIs or local dev. + /// + /// Allows all common methods, credentials, and common headers. + pub fn permissive() -> Self { + Self::new(CorsConfig { + allow_origins: AllowOrigins::Any, + allow_methods: vec![ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::PATCH, + Method::OPTIONS, + Method::HEAD, + ], + allow_headers: vec![ + header::CONTENT_TYPE, + header::AUTHORIZATION, + header::ACCEPT, + header::ORIGIN, + ], + allow_credentials: true, + max_age: Some(86_400), + expose_headers: vec![], + }) + } + + // Internal: resolve allowed origin header value + fn resolve_allowed_origin(&self, origin: &str) -> Option { + match &self.config.allow_origins { + AllowOrigins::Any => { + // Only return "*" if credentials are not allowed + if self.config.allow_credentials { + // rule MDN , RFC 6454 + // If Access-Control-Allow-Credentials: true is set, + // then Access-Control-Allow-Origin CANNOT be *. + // It MUST be the exact origin (e.g., https://example.com). + Some(origin.to_string()) + } else { + Some("*".to_string()) + } + } + AllowOrigins::List(allowed) => { + if allowed.contains(origin) { + Some(origin.to_string()) + } else { + None + } + } + AllowOrigins::Echo => Some(origin.to_string()), + } + } + + // Build preflight response (204 No Content) + fn preflight_response(&self, origin: &str) -> Response { + let allowed_origin = self.resolve_allowed_origin(origin); + let mut resp = Response::builder() + .status(StatusCode::NO_CONTENT) + .body(empty_response()) + .expect("preflight response is static"); + + let headers = resp.headers_mut(); + + if let Some(origin) = allowed_origin { + headers.insert( + ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_str(&origin).expect("origin is validated"), + ); + } + + if self.config.allow_credentials { + headers.insert( + ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + + if let Some(age) = self.config.max_age { + headers.insert( + ACCESS_CONTROL_MAX_AGE, + HeaderValue::from_str(&age.to_string()).expect("u32 is valid"), + ); + } + + let methods = self + .config + .allow_methods + .iter() + .map(|m| m.as_str()) + .collect::>() + .join(", "); + headers.insert( + ACCESS_CONTROL_ALLOW_METHODS, + HeaderValue::from_str(&methods).expect("methods are static"), + ); + + let headers_list = self + .config + .allow_headers + .iter() + .map(|h| h.as_str()) + .collect::>() + .join(", "); + headers.insert( + ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_str(&headers_list).expect("headers are static"), + ); + + resp + } + + // Add CORS headers to normal response + fn add_cors_to_response( + &self, + mut resp: Response, + origin: &str, + ) -> Response { + let allowed_origin = self.resolve_allowed_origin(origin); + let headers = resp.headers_mut(); + + if let Some(origin) = allowed_origin { + headers.insert( + ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_str(&origin).expect("origin is validated"), + ); + } + + if self.config.allow_credentials { + headers.insert( + ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + + if !self.config.expose_headers.is_empty() { + let expose = self + .config + .expose_headers + .iter() + .map(|h| h.as_str()) + .collect::>() + .join(", "); + headers.insert( + ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::from_str(&expose).expect("expose headers are static"), + ); + } + + resp + } +} + +// Middleware trait implementation +#[async_trait::async_trait] +impl Middleware for CorsMiddleware { + /// Process a request, handling preflight or adding CORS headers. + /// + /// - For `OPTIONS` with `Access-Control-Request-Method`: performs preflight. + /// - For other requests: passes to `next`, then adds CORS headers. + async fn handle<'req>( + &self, + req: Request<&'req str>, + state: Arc, + next: MiddlewareNext<'req>, + ) -> TransportServerResult> { + let origin = req + .headers() + .get(header::ORIGIN) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + // Preflight: OPTIONS + Access-Control-Request-Method + if *req.method() == Method::OPTIONS { + let requested_method = req + .headers() + .get(ACCESS_CONTROL_REQUEST_METHOD) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + + let requested_headers = req + .headers() + .get(ACCESS_CONTROL_REQUEST_HEADERS) + .and_then(|v| v.to_str().ok()) + .map(|s| { + s.split(',') + .map(|h| h.trim().to_ascii_lowercase()) + .collect::>() + }) + .unwrap_or_default(); + + let origin = match origin { + Some(o) => o, + None => { + // Some tools send preflight without Origin — allow if Any + if matches!(self.config.allow_origins, AllowOrigins::Any) + && !self.config.allow_credentials + { + return Ok(self.preflight_response("*")); + } else { + let response = build_response( + StatusCode::BAD_REQUEST, + "CORS origin missing in preflight".to_string(), + ); + return response; + } + } + }; + + // Validate origin + if self.resolve_allowed_origin(&origin).is_none() { + let response = + build_response(StatusCode::FORBIDDEN, "CORS origin not allowed".to_string()); + return response; + } + + // Validate method + if let Some(m) = requested_method { + if !self.config.allow_methods.contains(&m) { + let response = build_response( + StatusCode::METHOD_NOT_ALLOWED, + "CORS method not allowed".to_string(), + ); + return response; + } + } + + // Validate headers + let allowed = self + .config + .allow_headers + .iter() + .map(|h| h.as_str().to_ascii_lowercase()) + .collect::>(); + + if !requested_headers.is_subset(&allowed) { + let response = build_response( + StatusCode::BAD_REQUEST, + "CORS header not allowed".to_string(), + ); + return response; + } + + // All good — return preflight + return Ok(self.preflight_response(&origin)); + } + + // Normal request: forward to next handler + let mut resp = next(req, state).await?; + if let Some(origin) = origin { + if self.resolve_allowed_origin(&origin).is_some() { + resp = self.add_cors_to_response(resp, &origin); + } + } + + Ok(resp) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + id_generator::{FastIdGenerator, UuidGenerator}, + mcp_http::{types::GenericBodyExt, MiddlewareNext}, + mcp_server::{ServerHandler, ToMcpServerHandler}, + schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities}, + session_store::InMemorySessionStore, + }; + use http::{header, Request, Response, StatusCode}; + use std::time::Duration; + + type TestResult = Result<(), Box>; + struct TestHandler; + impl ServerHandler for TestHandler {} + + fn app_state() -> Arc { + let handler = TestHandler {}; + + Arc::new(McpAppState { + session_store: Arc::new(InMemorySessionStore::new()), + id_generator: Arc::new(UuidGenerator {}), + stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))), + server_details: Arc::new(InitializeResult { + capabilities: ServerCapabilities { + ..Default::default() + }, + instructions: None, + meta: None, + protocol_version: ProtocolVersion::V2025_06_18.to_string(), + server_info: Implementation { + name: "server".to_string(), + title: None, + version: "0.1.0".to_string(), + }, + }), + handler: handler.to_mcp_server_handler(), + ping_interval: Duration::from_secs(15), + transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()), + enable_json_response: false, + event_store: None, + }) + } + + fn make_handler<'req>(status: StatusCode, body: &'static str) -> MiddlewareNext<'req> { + Arc::new(move |_, _| { + let resp = Response::builder() + .status(status) + .body(GenericBody::from_string(body.to_string())) + .unwrap(); + Box::pin(async { Ok(resp) }) + }) + } + + #[tokio::test] + async fn test_preflight_allowed() -> TestResult { + let cors = CorsMiddleware::permissive(); + let handler = make_handler(StatusCode::OK, "should not see"); + + let req = Request::builder() + .method(Method::OPTIONS) + .uri("/") + .header(header::ORIGIN, "https://example.com") + .header(ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + ACCESS_CONTROL_REQUEST_HEADERS, + "content-type, authorization", + ) + .body("")?; + + let resp = cors.handle(req, app_state(), handler).await?; + + assert_eq!(resp.status(), StatusCode::NO_CONTENT); + assert_eq!( + resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN], + "https://example.com" + ); + assert_eq!( + resp.headers()[ACCESS_CONTROL_ALLOW_METHODS], + "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD" + ); + Ok(()) + } + + #[tokio::test] + async fn test_preflight_disallowed_origin() -> TestResult { + let mut allowed = HashSet::new(); + allowed.insert("https://trusted.com".to_string()); + + let cors = CorsMiddleware::new(CorsConfig { + allow_origins: AllowOrigins::List(allowed), + allow_methods: vec![Method::GET], + allow_headers: vec![], + allow_credentials: false, + max_age: None, + expose_headers: vec![], + }); + + let handler = make_handler(StatusCode::OK, "irrelevant"); + + let req = Request::builder() + .method(Method::OPTIONS) + .uri("/") + .header(header::ORIGIN, "https://evil.com") + .header(ACCESS_CONTROL_REQUEST_METHOD, "GET") + .body("")?; + + let result: Response = cors.handle(req, app_state(), handler).await.unwrap(); + let (parts, _body) = result.into_parts(); + assert_eq!(parts.status, 403); + Ok(()) + } + + #[tokio::test] + async fn test_normal_request_with_origin() -> TestResult { + let cors = CorsMiddleware::permissive(); + let handler = make_handler(StatusCode::OK, "hello"); + + let req = Request::builder() + .method(Method::GET) + .uri("/") + .header(header::ORIGIN, "https://client.com") + .body("")?; + + let resp = cors.handle(req, app_state(), handler).await?; + + assert_eq!(resp.status(), StatusCode::OK); + + assert_eq!( + resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN], + "https://client.com" + ); + assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true"); + Ok(()) + } + + #[tokio::test] + async fn test_wildcard_with_no_credentials() -> TestResult { + let cors = CorsMiddleware::new(CorsConfig { + allow_origins: AllowOrigins::Any, + allow_methods: vec![Method::GET], + allow_headers: vec![], + allow_credentials: false, + max_age: None, + expose_headers: vec![], + }); + + let handler = make_handler(StatusCode::OK, "ok"); + + let req = Request::builder() + .method(Method::GET) + .uri("/") + .header(header::ORIGIN, "https://any.com") + .body("")?; + + let resp = cors.handle(req, app_state(), handler).await?; + assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN], "*"); + Ok(()) + } + + #[tokio::test] + async fn test_no_wildcard_with_credentials() -> TestResult { + let cors = CorsMiddleware::new(CorsConfig { + allow_origins: AllowOrigins::Any, + allow_methods: vec![Method::GET], + allow_headers: vec![], + allow_credentials: true, // This should prevent "*" + max_age: None, + expose_headers: vec![], + }); + + let handler = make_handler(StatusCode::OK, "ok"); + + let req = Request::builder() + .method(Method::GET) + .uri("/") + .header(header::ORIGIN, "https://any.com") + .body("")?; + + let resp = cors.handle(req, app_state(), handler).await?; + + // Should NOT have "*" even though config says Any + let origin_header = resp + .headers() + .get(ACCESS_CONTROL_ALLOW_ORIGIN) + .expect("CORS header missing"); + assert_eq!(origin_header, "https://any.com"); + + // And credentials should be allowed + assert_eq!( + resp.headers() + .get(ACCESS_CONTROL_ALLOW_CREDENTIALS) + .unwrap(), + "true" + ); + Ok(()) + } + + #[tokio::test] + async fn test_echo_origin_with_credentials() -> TestResult { + let cors = CorsMiddleware::new(CorsConfig { + allow_origins: AllowOrigins::Echo, + allow_methods: vec![Method::GET], + allow_headers: vec![], + allow_credentials: true, + max_age: None, + expose_headers: vec![], + }); + + let handler = make_handler(StatusCode::OK, "ok"); + + let req = Request::builder() + .method(Method::GET) + .uri("/") + .header(header::ORIGIN, "https://dynamic.com") + .body("")?; + + let resp = cors.handle(req, app_state(), handler).await?; + assert_eq!( + resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN], + "https://dynamic.com" + ); + assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true"); + Ok(()) + } + + #[tokio::test] + async fn test_expose_headers() -> TestResult { + let cors = CorsMiddleware::new(CorsConfig { + allow_origins: AllowOrigins::Any, + allow_methods: vec![Method::GET], + allow_headers: vec![], + allow_credentials: false, + max_age: None, + expose_headers: vec![HeaderName::from_static("x-ratelimit-remaining")], + }); + + let handler = make_handler(StatusCode::OK, "ok"); + + let req = Request::builder() + .method(Method::GET) + .uri("/") + .header(header::ORIGIN, "https://client.com") + .body("")?; + + let resp = cors.handle(req, app_state(), handler).await?; + assert_eq!( + resp.headers()[ACCESS_CONTROL_EXPOSE_HEADERS], + "x-ratelimit-remaining" + ); + Ok(()) + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs new file mode 100644 index 0000000..b351dbc --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs @@ -0,0 +1,136 @@ +//! DNS Rebinding Protection Middleware +//! +//! This module provides a middleware that protects against DNS rebinding attacks +//! by validating the `Host` and `Origin` headers against configurable allowlists. +//! +//! DNS rebinding is an attack where a malicious site tricks a client's DNS resolver +//! into resolving a domain (e.g., `attacker.com`) to a private IP (like `127.0.0.1`), +//! allowing it to bypass same-origin policy and access internal services. +//! +//! # Security Model +//! +//! - If `allowed_hosts` is `Some(vec![..])` and non-empty → `Host` header **must** match (case-insensitive) +//! - If `allowed_origins` is `Some(vec![..])` and non-empty → `Origin` header **must** match (case-insensitive) +//! - Missing or unparseable headers → treated as invalid → 403 Forbidden +//! - If allowlist is `None` or empty → that check is skipped + +use crate::{ + mcp_http::{ + error_response, middleware::BoxFutureResponse, types::GenericBody, McpAppState, Middleware, + }, + mcp_server::error::TransportServerResult, + schema::schema_utils::SdkError, +}; +use async_trait::async_trait; +use http::{ + header::{HOST, ORIGIN}, + Request, Response, StatusCode, +}; +use std::sync::Arc; + +/// DNS Rebinding Protection Middleware +/// +/// Validates `Host` and `Origin` headers against allowlists to prevent DNS rebinding attacks. +/// Returns `403 Forbidden` with a descriptive error if validation fails. +/// +/// This middleware should be placed **early** in the chain (before routing) to ensure +/// protection even for unmatched routes. +/// +/// # When to use +/// - Public-facing APIs +/// - Services accessible via custom domains +/// - Any server that should **never** be accessible via `127.0.0.1`, `localhost`, or raw IPs +/// +/// # Security Considerations +/// - Always pin exact hostnames (e.g., `app.example.com:8443`) +/// - Avoid wildcards or overly broad patterns +/// - For local development, include `localhost:PORT` explicitly +/// - Never allow raw IP addresses in production allowlists +pub(crate) struct DnsRebindProtector { + /// List of allowed host header values for DNS rebinding protection. + /// If not specified, host validation is disabled. + pub allowed_hosts: Option>, + /// List of allowed origin header values for DNS rebinding protection. + /// If not specified, origin validation is disabled. + pub allowed_origins: Option>, +} + +#[async_trait] +impl Middleware for DnsRebindProtector { + /// Processes the incoming request and applies DNS rebinding protection. + /// + /// # Arguments + /// + /// * `req` - The incoming HTTP request with `&str` body (pre-read) + /// * `state` - Shared application state + /// * `next` - The next middleware/handler in the chain + /// + /// # Returns + /// + /// * `Ok(Response)` - If validation passes, forwards to next handler + /// * `Err` via `error_response(403, ...)` - If Host/Origin validation fails + async fn handle<'req>( + &self, + req: Request<&'req str>, + state: Arc, + next: Arc< + dyn Fn(Request<&'req str>, Arc) -> BoxFutureResponse<'req> + Send + Sync, + >, + ) -> TransportServerResult> { + if let Err(error) = self.protect_dns_rebinding(req.headers()).await { + return error_response(StatusCode::FORBIDDEN, error); + } + next(req, state).await + } +} + +impl DnsRebindProtector { + pub fn new(allowed_hosts: Option>, allowed_origins: Option>) -> Self { + Self { + allowed_hosts, + allowed_origins, + } + } + + // Protect against DNS rebinding attacks by validating Host and Origin headers. + // If protection fails, respond with HTTP 403 Forbidden. + async fn protect_dns_rebinding(&self, headers: &http::HeaderMap) -> Result<(), SdkError> { + if let Some(allowed_hosts) = self.allowed_hosts.as_ref() { + if !allowed_hosts.is_empty() { + let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else { + return Err( + SdkError::bad_request().with_message("Invalid Host header: [unknown] ") + ); + }; + + if !allowed_hosts + .iter() + .any(|allowed| allowed.eq_ignore_ascii_case(host)) + { + return Err(SdkError::bad_request() + .with_message(format!("Invalid Host header: \"{host}\" ").as_str())); + } + } + } + + if let Some(allowed_origins) = self.allowed_origins.as_ref() { + if !allowed_origins.is_empty() { + let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else { + return Err( + SdkError::bad_request().with_message("Invalid Origin header: [unknown] ") + ); + }; + + if !allowed_origins + .iter() + .any(|allowed| allowed.eq_ignore_ascii_case(origin)) + { + return Err(SdkError::bad_request() + .with_message(format!("Invalid Origin header: \"{origin}\" ").as_str())); + } + } + } + + Ok(()) + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/logging_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/logging_middleware.rs new file mode 100644 index 0000000..49f2e52 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/logging_middleware.rs @@ -0,0 +1,36 @@ +//! A very simple example middleware for inspiration. +//! +//! This demonstrates how to implement a basic logging middleware +//! using the `Middleware` trait. It logs incoming requests and outgoing +//! responses. In a real-world application, you might extend this to +//! include structured logging, tracing, timing, or error reporting. +use crate::{ + mcp_http::{middleware::BoxFutureResponse, types::GenericBody, McpAppState, Middleware}, + mcp_server::error::TransportServerResult, +}; +use async_trait::async_trait; +use http::{Request, Response}; +use std::sync::Arc; + +/// A minimal middleware that logs request URIs and response statuses. +/// +/// This is just a *very, very* simple example meant for inspiration. +/// It shows how to wrap a request/response cycle inside a middleware layer. +pub struct LoggingMiddleware; + +#[async_trait] +impl Middleware for LoggingMiddleware { + async fn handle<'req>( + &self, + req: Request<&'req str>, + state: Arc, + next: Arc< + dyn Fn(Request<&'req str>, Arc) -> BoxFutureResponse<'req> + Send + Sync, + >, + ) -> TransportServerResult> { + println!("➡️ Logging request: {}", req.uri()); + let res = next(req, state).await?; + println!("⬅️ Logging response: {}", res.status()); + Ok(res) + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/types.rs b/crates/rust-mcp-sdk/src/mcp_http/types.rs new file mode 100644 index 0000000..59645d2 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/types.rs @@ -0,0 +1,41 @@ +use crate::{ + mcp_http::McpAppState, + mcp_server::error::{TransportServerError, TransportServerResult}, +}; +use bytes::Bytes; +use http::{Request, Response}; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use std::{future::Future, pin::Pin, sync::Arc}; + +pub type GenericBody = BoxBody; + +pub trait GenericBodyExt { + fn from_string(s: String) -> Self; +} + +impl GenericBodyExt for GenericBody { + fn from_string(s: String) -> Self { + Full::new(Bytes::from(s)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed() + } +} + +pub type BoxFutureResponse<'req> = + Pin>> + Send + 'req>>; + +// Define a short alias for your handler function type. +/// A handler function that processes an HTTP request and shared state, +/// returning an async response future. +pub type RequestHandlerFn = + dyn for<'req> Fn(Request<&'req str>, Arc) -> BoxFutureResponse<'req> + Send + Sync; + +/// A shared, reference-counted request handler. +pub type RequestHandler = Arc; + +// pub type RequestHandler = Arc< +// dyn for<'req> FnOnce(Request<&'req str>) -> BoxFutureResponse<'req> + Send + Sync +// >; + +pub type MiddlewareNext<'req> = + Arc, Arc) -> BoxFutureResponse<'req> + Send + Sync>; diff --git a/crates/rust-mcp-sdk/tests/common/test_client.rs b/crates/rust-mcp-sdk/tests/common/test_client.rs index 46a8525..912501a 100644 --- a/crates/rust-mcp-sdk/tests/common/test_client.rs +++ b/crates/rust-mcp-sdk/tests/common/test_client.rs @@ -131,9 +131,6 @@ pub mod test_client_common { } } -// Custom responder for SSE with 10 ping messages -struct SsePingResponder; - // Test handler pub struct TestClientHandler { message_history: Arc>>, @@ -151,7 +148,7 @@ impl ClientHandler for TestClientHandler { async fn handle_ping_request( &self, request: PingRequest, - runtime: &dyn McpClient, + _runtime: &dyn McpClient, ) -> std::result::Result { self.register_message(&request.into()).await; From 5917fe5849277d8f803bb7e7f21708e876ac7505 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Fri, 7 Nov 2025 18:19:14 -0400 Subject: [PATCH 2/2] fix: typo --- .../src/mcp_http/middleware/dns_rebind_protector.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs index b351dbc..77e7013 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs @@ -11,7 +11,7 @@ //! //! - If `allowed_hosts` is `Some(vec![..])` and non-empty → `Host` header **must** match (case-insensitive) //! - If `allowed_origins` is `Some(vec![..])` and non-empty → `Origin` header **must** match (case-insensitive) -//! - Missing or unparseable headers → treated as invalid → 403 Forbidden +//! - Missing or unparsable headers → treated as invalid → 403 Forbidden //! - If allowlist is `None` or empty → that check is skipped use crate::{