44import httpx
55import redis
66import jwt
7+ from jwt .jwks_client import PyJWKClient
78from typing import Optional , Dict , Any , Tuple
89from dotenv import load_dotenv
910
4748CODER_DEFAULT_ORGANIZATION = os .getenv ("CODER_DEFAULT_ORGANIZATION" )
4849CODER_WORKSPACE_NAME = os .getenv ("CODER_WORKSPACE_NAME" , "ubuntu" )
4950
51+ # Cache for JWKS client
52+ _jwks_client = None
53+
5054# Session management functions
5155def get_session (session_id : str ) -> Optional [Dict [str , Any ]]:
5256 """Get session data from Redis"""
@@ -67,8 +71,6 @@ def delete_session(session_id: str) -> None:
6771 """Delete session data from Redis"""
6872 redis_client .delete (f"session:{ session_id } " )
6973
70- provisioning_times = {}
71-
7274def get_auth_url () -> str :
7375 """Generate the authentication URL for Keycloak login"""
7476 auth_url = f"{ OIDC_SERVER_URL } /realms/{ OIDC_REALM } /protocol/openid-connect/auth"
@@ -85,29 +87,28 @@ def get_token_url() -> str:
8587 return f"{ OIDC_SERVER_URL } /realms/{ OIDC_REALM } /protocol/openid-connect/token"
8688
8789def is_token_expired (token_data : Dict [str , Any ], buffer_seconds : int = 30 ) -> bool :
88- """
89- Check if the access token is expired or about to expire
90-
91- Args:
92- token_data: The token data containing the access token
93- buffer_seconds: Buffer time in seconds to refresh token before it actually expires
94-
95- Returns:
96- bool: True if token is expired or about to expire, False otherwise
97- """
9890 if not token_data or 'access_token' not in token_data :
9991 return True
10092
10193 try :
102- # Decode the JWT token without verification to get expiration time
103- decoded = jwt .decode (token_data ['access_token' ], options = {"verify_signature" : False })
94+ # Get the signing key
95+ jwks_client = get_jwks_client ()
96+ signing_key = jwks_client .get_signing_key_from_jwt (token_data ['access_token' ])
10497
105- # Get expiration time from token
106- exp_time = decoded .get ('exp' , 0 )
98+ # Decode with verification
99+ decoded = jwt .decode (
100+ token_data ['access_token' ],
101+ signing_key .key ,
102+ algorithms = ["RS256" ], # Common algorithm for OIDC
103+ audience = OIDC_CLIENT_ID ,
104+ )
107105
108- # Check if token is expired or about to expire (with buffer)
106+ # Check expiration
107+ exp_time = decoded .get ('exp' , 0 )
109108 current_time = time .time ()
110109 return current_time + buffer_seconds >= exp_time
110+ except jwt .ExpiredSignatureError :
111+ return True
111112 except Exception as e :
112113 print (f"Error checking token expiration: { str (e )} " )
113114 return True
@@ -153,3 +154,11 @@ async def refresh_token(session_id: str, token_data: Dict[str, Any]) -> Tuple[bo
153154 except Exception as e :
154155 print (f"Error refreshing token: { str (e )} " )
155156 return False , token_data
157+
158+ def get_jwks_client ():
159+ """Get or create a PyJWKClient for token verification"""
160+ global _jwks_client
161+ if _jwks_client is None :
162+ jwks_url = f"{ OIDC_SERVER_URL } /realms/{ OIDC_REALM } /protocol/openid-connect/certs"
163+ _jwks_client = PyJWKClient (jwks_url )
164+ return _jwks_client
0 commit comments