11use crate :: generated_schema:: * ;
22use serde:: ser:: SerializeStruct ;
33use serde_json:: { json, Value } ;
4+ use std:: hash:: { Hash , Hasher } ;
45use std:: { fmt:: Display , str:: FromStr } ;
56
67#[ derive( Debug ) ]
@@ -37,6 +38,47 @@ fn detect_message_type(value: &serde_json::Value) -> MessageTypes {
3738 MessageTypes :: Request
3839}
3940
41+ pub trait MCPMessage {
42+ fn is_response ( & self ) -> bool ;
43+ fn is_request ( & self ) -> bool ;
44+ fn is_notification ( & self ) -> bool ;
45+ fn is_error ( & self ) -> bool ;
46+ fn request_id ( & self ) -> Option < & RequestId > ;
47+ }
48+
49+ //*******************************//
50+ //** RequestId Implementations **//
51+ //*******************************//
52+
53+ // Implement PartialEq and Eq for RequestId
54+ impl PartialEq for RequestId {
55+ fn eq ( & self , other : & Self ) -> bool {
56+ match ( self , other) {
57+ ( RequestId :: String ( a) , RequestId :: String ( b) ) => a == b,
58+ ( RequestId :: Integer ( a) , RequestId :: Integer ( b) ) => a == b,
59+ _ => false , // Different variants are never equal
60+ }
61+ }
62+ }
63+
64+ impl Eq for RequestId { }
65+
66+ // Implement Hash for RequestId, so we can store it in HashMaps, HashSets, etc.
67+ impl Hash for RequestId {
68+ fn hash < H : Hasher > ( & self , state : & mut H ) {
69+ match self {
70+ RequestId :: String ( s) => {
71+ 0u8 . hash ( state) ; // Prefix with 0 for String variant
72+ s. hash ( state) ;
73+ }
74+ RequestId :: Integer ( i) => {
75+ 1u8 . hash ( state) ; // Prefix with 1 for Integer variant
76+ i. hash ( state) ;
77+ }
78+ }
79+ }
80+ }
81+
4082//*******************//
4183//** ClientMessage **//
4284//*******************//
@@ -52,6 +94,43 @@ pub enum ClientMessage {
5294 Error ( JsonrpcError ) ,
5395}
5496
97+ // Implementing the `MCPMessage` trait for `ClientMessage`
98+ impl MCPMessage for ClientMessage {
99+ // Returns true if the message is a response type
100+ fn is_response ( & self ) -> bool {
101+ matches ! ( self , ClientMessage :: Response ( _) )
102+ }
103+
104+ // Returns true if the message is a request type
105+ fn is_request ( & self ) -> bool {
106+ matches ! ( self , ClientMessage :: Request ( _) )
107+ }
108+
109+ // Returns true if the message is a notification type (i.e., does not expect a response)
110+ fn is_notification ( & self ) -> bool {
111+ matches ! ( self , ClientMessage :: Notification ( _) )
112+ }
113+
114+ // Returns true if the message represents an error
115+ fn is_error ( & self ) -> bool {
116+ matches ! ( self , ClientMessage :: Error ( _) )
117+ }
118+
119+ // Retrieves the request ID associated with the message, if applicable
120+ fn request_id ( & self ) -> Option < & RequestId > {
121+ match self {
122+ // If the message is a request, return the associated request ID
123+ ClientMessage :: Request ( client_jsonrpc_request) => Some ( & client_jsonrpc_request. id ) ,
124+ // Notifications do not have request IDs
125+ ClientMessage :: Notification ( _) => None ,
126+ // If the message is a response, return the associated request ID
127+ ClientMessage :: Response ( client_jsonrpc_response) => Some ( & client_jsonrpc_response. id ) ,
128+ // If the message is an error, return the associated request ID
129+ ClientMessage :: Error ( jsonrpc_error) => Some ( & jsonrpc_error. id ) ,
130+ }
131+ }
132+ }
133+
55134//**************************//
56135//** ClientJsonrpcRequest **//
57136//**************************//
@@ -385,6 +464,43 @@ pub enum ServerMessage {
385464 Error ( JsonrpcError ) ,
386465}
387466
467+ // Implementing the `MCPMessage` trait for `ServerMessage`
468+ impl MCPMessage for ServerMessage {
469+ // Returns true if the message is a response type
470+ fn is_response ( & self ) -> bool {
471+ matches ! ( self , ServerMessage :: Response ( _) )
472+ }
473+
474+ // Returns true if the message is a request type
475+ fn is_request ( & self ) -> bool {
476+ matches ! ( self , ServerMessage :: Request ( _) )
477+ }
478+
479+ // Returns true if the message is a notification type (i.e., does not expect a response)
480+ fn is_notification ( & self ) -> bool {
481+ matches ! ( self , ServerMessage :: Notification ( _) )
482+ }
483+
484+ // Returns true if the message represents an error
485+ fn is_error ( & self ) -> bool {
486+ matches ! ( self , ServerMessage :: Error ( _) )
487+ }
488+
489+ // Retrieves the request ID associated with the message, if applicable
490+ fn request_id ( & self ) -> Option < & RequestId > {
491+ match self {
492+ // If the message is a request, return the associated request ID
493+ ServerMessage :: Request ( client_jsonrpc_request) => Some ( & client_jsonrpc_request. id ) ,
494+ // Notifications do not have request IDs
495+ ServerMessage :: Notification ( _) => None ,
496+ // If the message is a response, return the associated request ID
497+ ServerMessage :: Response ( client_jsonrpc_response) => Some ( & client_jsonrpc_response. id ) ,
498+ // If the message is an error, return the associated request ID
499+ ServerMessage :: Error ( jsonrpc_error) => Some ( & jsonrpc_error. id ) ,
500+ }
501+ }
502+ }
503+
388504impl FromStr for ServerMessage {
389505 type Err = JsonrpcErrorError ;
390506
0 commit comments