Skip to content

Commit 4f91532

Browse files
committed
refactor: enhance JWT token handling and session management
- Integrated PyJWKClient for secure JWT verification, improving token validation by using signing keys from JWKS. - Updated UserSession initialization to decode tokens with verification, enhancing security and error handling. - Added a caching mechanism for the JWKS client to optimize performance. - Cleaned up token expiration checks to ensure accurate validation and error reporting.
1 parent c9dab93 commit 4f91532

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

src/backend/config.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import httpx
55
import redis
66
import jwt
7+
from jwt.jwks_client import PyJWKClient
78
from typing import Optional, Dict, Any, Tuple
89
from dotenv import load_dotenv
910

@@ -47,6 +48,9 @@
4748
CODER_DEFAULT_ORGANIZATION = os.getenv("CODER_DEFAULT_ORGANIZATION")
4849
CODER_WORKSPACE_NAME = os.getenv("CODER_WORKSPACE_NAME", "ubuntu")
4950

51+
# Cache for JWKS client
52+
_jwks_client = None
53+
5054
# Session management functions
5155
def get_session(session_id: str) -> Optional[Dict[str, Any]]:
5256
"""Get session data from Redis"""
@@ -67,8 +71,6 @@ def delete_session(session_id: str) -> None:
6771
"""Delete session data from Redis"""
6872
redis_client.delete(f"session:{session_id}")
6973

70-
provisioning_times = {}
71-
7274
def get_auth_url() -> str:
7375
"""Generate the authentication URL for Keycloak login"""
7476
auth_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/auth"
@@ -85,29 +87,28 @@ def get_token_url() -> str:
8587
return f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/token"
8688

8789
def is_token_expired(token_data: Dict[str, Any], buffer_seconds: int = 30) -> bool:
88-
"""
89-
Check if the access token is expired or about to expire
90-
91-
Args:
92-
token_data: The token data containing the access token
93-
buffer_seconds: Buffer time in seconds to refresh token before it actually expires
94-
95-
Returns:
96-
bool: True if token is expired or about to expire, False otherwise
97-
"""
9890
if not token_data or 'access_token' not in token_data:
9991
return True
10092

10193
try:
102-
# Decode the JWT token without verification to get expiration time
103-
decoded = jwt.decode(token_data['access_token'], options={"verify_signature": False})
94+
# Get the signing key
95+
jwks_client = get_jwks_client()
96+
signing_key = jwks_client.get_signing_key_from_jwt(token_data['access_token'])
10497

105-
# Get expiration time from token
106-
exp_time = decoded.get('exp', 0)
98+
# Decode with verification
99+
decoded = jwt.decode(
100+
token_data['access_token'],
101+
signing_key.key,
102+
algorithms=["RS256"], # Common algorithm for OIDC
103+
audience=OIDC_CLIENT_ID,
104+
)
107105

108-
# Check if token is expired or about to expire (with buffer)
106+
# Check expiration
107+
exp_time = decoded.get('exp', 0)
109108
current_time = time.time()
110109
return current_time + buffer_seconds >= exp_time
110+
except jwt.ExpiredSignatureError:
111+
return True
111112
except Exception as e:
112113
print(f"Error checking token expiration: {str(e)}")
113114
return True
@@ -153,3 +154,11 @@ async def refresh_token(session_id: str, token_data: Dict[str, Any]) -> Tuple[bo
153154
except Exception as e:
154155
print(f"Error refreshing token: {str(e)}")
155156
return False, token_data
157+
158+
def get_jwks_client():
159+
"""Get or create a PyJWKClient for token verification"""
160+
global _jwks_client
161+
if _jwks_client is None:
162+
jwks_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/certs"
163+
_jwks_client = PyJWKClient(jwks_url)
164+
return _jwks_client

src/backend/dependencies.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,25 @@ class UserSession:
1515
"""
1616
def __init__(self, access_token: str, token_data: dict, user_id: UUID = None):
1717
self.access_token = access_token
18-
self.token_data = jwt.decode(access_token, options={"verify_signature": False})
1918
self._user_data = None
19+
20+
# Get the signing key and decode with verification
21+
from config import get_jwks_client, OIDC_CLIENT_ID
22+
try:
23+
jwks_client = get_jwks_client()
24+
signing_key = jwks_client.get_signing_key_from_jwt(access_token)
25+
26+
self.token_data = jwt.decode(
27+
access_token,
28+
signing_key.key,
29+
algorithms=["RS256"],
30+
audience=OIDC_CLIENT_ID
31+
)
32+
33+
except jwt.InvalidTokenError as e:
34+
# Log the error and raise an appropriate exception
35+
print(f"Invalid token: {str(e)}")
36+
raise ValueError(f"Invalid authentication token: {str(e)}")
2037

2138
@property
2239
def is_authenticated(self) -> bool:

src/backend/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ sqlalchemy
1010
posthog
1111
redis
1212
psycopg2-binary
13-
python-multipart
13+
python-multipart
14+
cryptography # Required for JWT key handling

0 commit comments

Comments
 (0)