2424from cryptography .hazmat .primitives .asymmetric import rsa
2525from cryptography .hazmat .backends import default_backend
2626
27+ REQUESTS_TIMEOUT = 30
2728class CognitoHelper :
2829 """Handles user authentication with AWS Cognito."""
2930
@@ -70,7 +71,9 @@ def jwt_to_pem(self, n, e):
7071 # https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
7172 def get_cognito_jwk (self , kid ):
7273 url = f"https://cognito-idp.{ self .region } .amazonaws.com/{ self .user_pool_id } /.well-known/jwks.json"
73- jwks = requests .get (url ).json ()
74+ response = requests .get (url , timeout = REQUESTS_TIMEOUT )
75+ response .raise_for_status ()
76+ jwks = response .json ()
7477 # Extract the specific key from jwks for verification
7578 for jwk in jwks ["keys" ]:
7679 if jwk ["kid" ] == kid :
@@ -82,7 +85,7 @@ def decode_id_token(self, id_token = None):
8285 if id_token is None :
8386 id_token = st .session_state .get ("id_token" , "" )
8487
85- if id_token != "" :
88+ if id_token != "" : # nosec B105
8689 jwt_headers = jwt .get_unverified_header (id_token )
8790 jwk = self .get_cognito_jwk (jwt_headers ["kid" ])
8891 public_key = self .jwt_to_pem (jwk ["n" ], jwk ["e" ])
@@ -93,15 +96,15 @@ def decode_id_token(self, id_token = None):
9396 def get_user_tokens (self , auth_code = None ):
9497 """Gets user access and ID tokens using auth code."""
9598
96- access_token = ""
97- id_token = ""
99+ access_token = "" # nosec B105
100+ id_token = "" # nosec B105
98101
99102 # if auth_code is not provided, try to get credentianls from the session state.
100103 if not auth_code :
101104 access_token = st .session_state .get ("access_token" , "" )
102105 id_token = st .session_state .get ("id_token" , "" )
103106
104- if access_token != "" and id_token != "" :
107+ if access_token != "" and id_token != "" : # nosec B105
105108 return access_token , id_token
106109
107110 try :
@@ -118,13 +121,14 @@ def get_user_tokens(self, auth_code = None):
118121 "redirect_uri" : self .app_uri ,
119122 }
120123
121- token_response = requests .post (self .token_url , headers = headers , data = body )
124+ token_response = requests .post (self .token_url , headers = headers , data = body , timeout = REQUESTS_TIMEOUT )
125+ token_response .raise_for_status ()
122126 access_token = token_response .json ()["access_token" ]
123127 id_token = token_response .json ()["id_token" ]
124128
125129 except (KeyError , TypeError ):
126- access_token = ""
127- id_token = ""
130+ access_token = "" # nosec B105
131+ id_token = "" # nosec B105
128132
129133 return access_token , id_token
130134
@@ -175,11 +179,11 @@ def set_session_state(self):
175179 auth_code = auth_query_params ["code" ]
176180 access_token , id_token = self .get_user_tokens (auth_code )
177181
178- if access_token != "" :
182+ if access_token != "" : # nosec B105
179183 st .session_state ["auth_code" ] = auth_code
180184 st .session_state ["access_token" ] = access_token
181185
182- if id_token != "" :
186+ if id_token != "" : # nosec B105
183187 st .session_state ["id_token" ] = id_token
184188 credentials = self .get_user_temporary_credentials (id_token )
185189 st .session_state ["access_key_id" ] = credentials ["AccessKeyId" ]
@@ -198,7 +202,7 @@ def is_authenticated(self):
198202 session_token = st .session_state .get ("session_token" , "" )
199203 expiration = st .session_state .get ("expiration" )
200204
201- is_valid_session = (access_key_id != "" and secret_access_key != "" and session_token != "" )
205+ is_valid_session = (access_key_id != "" and secret_access_key != "" and session_token != "" ) # nosec B105
202206 # +5 seconds to consider a expiry buffer. If the session is about to expire, we need to renew it.
203207 has_not_expired = (expiration .timestamp () > (time .time () + 5 )) if expiration else True
204208
0 commit comments