Skip to content

Commit d0b6a2a

Browse files
committed
updating for flake8 and mypy errors
1 parent 1a42368 commit d0b6a2a

File tree

5 files changed

+395
-250
lines changed
  • aws_sra_examples/solutions/genai/bedrock_org/lambda/rules
    • sra_bedrock_check_kb_ingestion_encryption
    • sra_bedrock_check_kb_logging
    • sra_bedrock_check_kb_opensearch_encryption
    • sra_bedrock_check_kb_s3_bucket
    • sra_bedrock_check_kb_vector_store_secret

5 files changed

+395
-250
lines changed

aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_kb_ingestion_encryption/app.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,43 @@
2828
bedrock_agent_client = boto3.client("bedrock-agent", region_name=AWS_REGION)
2929
config_client = boto3.client("config", region_name=AWS_REGION)
3030

31-
def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]:
31+
32+
def check_data_sources(kb_id: str, kb_name: str) -> str | None: # noqa: CFQ004
33+
"""Check if a knowledge base's data sources are encrypted.
34+
35+
Args:
36+
kb_id (str): Knowledge base ID
37+
kb_name (str): Knowledge base name
38+
39+
Raises:
40+
ClientError: If there is an error checking the knowledge base
41+
42+
Returns:
43+
str | None: Error message if non-compliant, None if compliant
44+
"""
45+
try:
46+
data_sources = bedrock_agent_client.list_data_sources(knowledgeBaseId=kb_id)
47+
if not isinstance(data_sources, dict):
48+
return f"{kb_name} (invalid data sources response)"
49+
unencrypted_sources = []
50+
for source in data_sources.get("dataSourceSummaries", []):
51+
if not isinstance(source, dict):
52+
continue
53+
encryption_config = source.get("serverSideEncryptionConfiguration", {})
54+
if not isinstance(encryption_config, dict) or not encryption_config.get("kmsKeyArn"):
55+
unencrypted_sources.append(source.get("name", source["dataSourceId"]))
56+
57+
if unencrypted_sources:
58+
return f"{kb_name} (unencrypted sources: {', '.join(unencrypted_sources)})"
59+
return None
60+
except ClientError as e:
61+
LOGGER.error(f"Error checking data sources for knowledge base {kb_name}: {str(e)}")
62+
if e.response["Error"]["Code"] == "AccessDeniedException":
63+
return f"{kb_name} (access denied)"
64+
raise
65+
66+
67+
def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]: # noqa: U100
3268
"""Evaluate if Bedrock Knowledge Base data sources are encrypted with KMS.
3369
3470
Args:
@@ -38,36 +74,16 @@ def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]:
3874
tuple[str, str]: Compliance type and annotation message.
3975
"""
4076
try:
41-
# List all knowledge bases
4277
non_compliant_kbs = []
4378
paginator = bedrock_agent_client.get_paginator("list_knowledge_bases")
44-
79+
4580
for page in paginator.paginate():
4681
for kb in page["knowledgeBaseSummaries"]:
4782
kb_id = kb["knowledgeBaseId"]
4883
kb_name = kb.get("name", kb_id)
49-
50-
# Get data sources for each knowledge base
51-
try:
52-
data_sources = bedrock_agent_client.list_data_sources(
53-
knowledgeBaseId=kb_id
54-
)
55-
56-
# Check if any data source is not encrypted
57-
unencrypted_sources = []
58-
for source in data_sources.get("dataSourceSummaries", []):
59-
if not source.get("serverSideEncryptionConfiguration", {}).get("kmsKeyArn"):
60-
unencrypted_sources.append(source.get("name", source["dataSourceId"]))
61-
62-
if unencrypted_sources:
63-
non_compliant_kbs.append(f"{kb_name} (unencrypted sources: {', '.join(unencrypted_sources)})")
64-
65-
except ClientError as e:
66-
LOGGER.error(f"Error checking data sources for knowledge base {kb_name}: {str(e)}")
67-
if e.response["Error"]["Code"] == "AccessDeniedException":
68-
non_compliant_kbs.append(f"{kb_name} (access denied)")
69-
else:
70-
raise
84+
error = check_data_sources(kb_id, kb_name)
85+
if error:
86+
non_compliant_kbs.append(error)
7187

7288
if non_compliant_kbs:
7389
return "NON_COMPLIANT", f"The following knowledge bases have unencrypted data sources: {'; '.join(non_compliant_kbs)}"
@@ -77,6 +93,7 @@ def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]:
7793
LOGGER.error(f"Error evaluating Bedrock Knowledge Base encryption: {str(e)}")
7894
return "ERROR", f"Error evaluating compliance: {str(e)}"
7995

96+
8097
def lambda_handler(event: dict, context: Any) -> None: # noqa: U100
8198
"""Lambda handler.
8299
@@ -105,4 +122,4 @@ def lambda_handler(event: dict, context: Any) -> None: # noqa: U100
105122

106123
config_client.put_evaluations(Evaluations=[evaluation], ResultToken=event["resultToken"]) # type: ignore
107124

108-
LOGGER.info("Compliance evaluation complete.")
125+
LOGGER.info("Compliance evaluation complete.")

aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_kb_logging/app.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@
2929
config_client = boto3.client("config", region_name=AWS_REGION)
3030

3131

32-
def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]:
32+
def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]: # noqa: CFQ004, U100
3333
"""Evaluate if Bedrock Knowledge Base logging is properly configured.
3434
3535
Args:
3636
rule_parameters (dict): Rule parameters from AWS Config rule.
3737
38+
Raises:
39+
ClientError: If there is an error checking the knowledge base
40+
3841
Returns:
3942
tuple[str, str]: Compliance type and annotation message.
4043
"""
@@ -49,20 +52,20 @@ def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]:
4952
return "COMPLIANT", "No knowledge bases found in the account"
5053

5154
non_compliant_kbs = []
52-
55+
5356
# Check each knowledge base for logging configuration
5457
for kb in kb_list:
5558
kb_id = kb['knowledgeBaseId']
5659
try:
5760
kb_details = bedrock_agent_client.get_knowledge_base(
5861
knowledgeBaseId=kb_id
5962
)
60-
63+
6164
# Check if logging is enabled
6265
logging_config = kb_details.get('loggingConfiguration', {})
63-
if not logging_config or not logging_config.get('enabled', False):
66+
if not isinstance(logging_config, dict) or not logging_config.get('enabled', False):
6467
non_compliant_kbs.append(f"{kb_id} ({kb.get('name', 'unnamed')})")
65-
68+
6669
except ClientError as e:
6770
LOGGER.error(f"Error checking knowledge base {kb_id}: {str(e)}")
6871
if e.response['Error']['Code'] == 'AccessDeniedException':
@@ -107,4 +110,4 @@ def lambda_handler(event: dict, context: Any) -> None: # noqa: U100
107110

108111
config_client.put_evaluations(Evaluations=[evaluation], ResultToken=event["resultToken"]) # type: ignore
109112

110-
LOGGER.info("Compliance evaluation complete.")
113+
LOGGER.info("Compliance evaluation complete.")

aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_kb_opensearch_encryption/app.py

Lines changed: 111 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,107 @@
2727
# Initialize AWS clients
2828
bedrock_agent_client = boto3.client("bedrock-agent", region_name=AWS_REGION)
2929
opensearch_client = boto3.client("opensearch", region_name=AWS_REGION)
30+
opensearch_serverless_client = boto3.client("opensearchserverless", region_name=AWS_REGION)
3031
config_client = boto3.client("config", region_name=AWS_REGION)
3132

32-
def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]:
33+
34+
def check_opensearch_serverless(collection_id: str, kb_name: str) -> str | None:
35+
"""Check OpenSearch Serverless collection encryption.
36+
37+
Args:
38+
collection_id (str): Collection ID
39+
kb_name (str): Knowledge base name
40+
41+
Returns:
42+
str | None: Error message if non-compliant, None if compliant
43+
"""
44+
try:
45+
collection = opensearch_serverless_client.get_security_policy(
46+
name=collection_id,
47+
type="encryption"
48+
)
49+
security_policy = collection.get("securityPolicyDetail", {})
50+
if security_policy.get("Type") == "encryption":
51+
security_policies = security_policy.get("SecurityPolicies", [])
52+
if isinstance(security_policies, list) and security_policies:
53+
encryption_policy = security_policies[0]
54+
kms_key_arn = encryption_policy.get("KmsARN", "")
55+
if not kms_key_arn or "aws/opensearchserverless" in kms_key_arn:
56+
return f"{kb_name} (OpenSearch Serverless not using CMK)"
57+
except ClientError as e:
58+
LOGGER.error(f"Error checking OpenSearch Serverless collection: {str(e)}")
59+
return f"{kb_name} (error checking OpenSearch Serverless)"
60+
return None
61+
62+
63+
def check_opensearch_domain(domain_name: str, kb_name: str) -> str | None: # noqa: CFQ004
64+
"""Check standard OpenSearch domain encryption.
65+
66+
Args:
67+
domain_name (str): Domain name
68+
kb_name (str): Knowledge base name
69+
70+
Returns:
71+
str | None: Error message if non-compliant, None if compliant
72+
"""
73+
try:
74+
domain = opensearch_client.describe_domain(DomainName=domain_name)
75+
encryption_config = domain.get("DomainStatus", {}).get("EncryptionAtRestOptions", {})
76+
if not encryption_config.get("Enabled", False):
77+
return f"{kb_name} (OpenSearch domain encryption not enabled)"
78+
if not encryption_config.get("KmsKeyId"):
79+
return f"{kb_name} (OpenSearch domain not using CMK)"
80+
except ClientError as e:
81+
LOGGER.error(f"Error checking OpenSearch domain: {str(e)}")
82+
return f"{kb_name} (error checking OpenSearch domain)"
83+
return None
84+
85+
86+
def check_knowledge_base(kb_id: str, kb_name: str) -> str | None: # noqa: CFQ004
87+
"""Check a knowledge base's OpenSearch configuration.
88+
89+
Args:
90+
kb_id (str): Knowledge base ID
91+
kb_name (str): Knowledge base name
92+
93+
Raises:
94+
ClientError: If there is an error checking the knowledge base
95+
96+
Returns:
97+
str | None: Error message if non-compliant, None if compliant
98+
"""
99+
try:
100+
kb_details = bedrock_agent_client.get_knowledge_base(knowledgeBaseId=kb_id)
101+
vector_store = kb_details.get("vectorStoreConfiguration")
102+
103+
if not vector_store or not isinstance(vector_store, dict):
104+
return None
105+
106+
if vector_store.get("vectorStoreType") != "OPENSEARCH":
107+
return None
108+
109+
opensearch_config = vector_store.get("opensearchServerlessConfiguration") or vector_store.get("opensearchConfiguration")
110+
if not opensearch_config:
111+
return f"{kb_name} (missing OpenSearch configuration)"
112+
113+
if "collectionArn" in opensearch_config:
114+
collection_id = opensearch_config["collectionArn"].split("/")[-1]
115+
return check_opensearch_serverless(collection_id, kb_name)
116+
117+
domain_endpoint = opensearch_config.get("endpoint", "")
118+
if not domain_endpoint:
119+
return f"{kb_name} (missing OpenSearch domain endpoint)"
120+
domain_name = domain_endpoint.split(".")[0]
121+
return check_opensearch_domain(domain_name, kb_name)
122+
123+
except ClientError as e:
124+
LOGGER.error(f"Error checking knowledge base {kb_id}: {str(e)}")
125+
if e.response["Error"]["Code"] == "AccessDeniedException":
126+
return f"{kb_name} (access denied)"
127+
raise
128+
129+
130+
def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]: # noqa: U100
33131
"""Evaluate if Bedrock Knowledge Base OpenSearch vector stores are encrypted with KMS CMK.
34132
35133
Args:
@@ -41,85 +139,28 @@ def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]:
41139
try:
42140
non_compliant_kbs = []
43141
paginator = bedrock_agent_client.get_paginator("list_knowledge_bases")
44-
142+
45143
for page in paginator.paginate():
46144
for kb in page["knowledgeBaseSummaries"]:
47145
kb_id = kb["knowledgeBaseId"]
48146
kb_name = kb.get("name", kb_id)
49-
50-
try:
51-
# Get knowledge base details
52-
kb_details = bedrock_agent_client.get_knowledge_base(knowledgeBaseId=kb_id)
53-
vector_store = kb_details.get("vectorStoreConfiguration")
54-
55-
if vector_store and vector_store.get("vectorStoreType") == "OPENSEARCH":
56-
# Extract OpenSearch domain information
57-
opensearch_config = vector_store.get("opensearchServerlessConfiguration") or vector_store.get("opensearchConfiguration")
58-
59-
if not opensearch_config:
60-
non_compliant_kbs.append(f"{kb_name} (missing OpenSearch configuration)")
61-
continue
62-
63-
# Check if it's OpenSearch Serverless or standard OpenSearch
64-
if "collectionArn" in opensearch_config:
65-
# OpenSearch Serverless - always encrypted with AWS owned key at minimum
66-
collection_id = opensearch_config["collectionArn"].split("/")[-1]
67-
try:
68-
collection = opensearch_client.get_security_policy(
69-
Name=collection_id,
70-
Type="encryption"
71-
)
72-
# Check if using customer managed key
73-
security_policy = collection.get("securityPolicyDetail", {})
74-
if security_policy.get("Type") == "encryption":
75-
encryption_policy = security_policy.get("SecurityPolicies", [])[0]
76-
kms_key_arn = encryption_policy.get("KmsARN", "")
77-
78-
# If not using customer managed key
79-
if not kms_key_arn or "aws/opensearchserverless" in kms_key_arn:
80-
non_compliant_kbs.append(f"{kb_name} (OpenSearch Serverless not using CMK)")
81-
except ClientError as e:
82-
LOGGER.error(f"Error checking OpenSearch Serverless collection: {str(e)}")
83-
non_compliant_kbs.append(f"{kb_name} (error checking OpenSearch Serverless)")
84-
else:
85-
# Standard OpenSearch
86-
domain_endpoint = opensearch_config.get("endpoint", "")
87-
if not domain_endpoint:
88-
non_compliant_kbs.append(f"{kb_name} (missing OpenSearch domain endpoint)")
89-
continue
90-
91-
# Extract domain name from endpoint
92-
domain_name = domain_endpoint.split(".")[0]
93-
94-
try:
95-
domain = opensearch_client.describe_domain(DomainName=domain_name)
96-
encryption_config = domain.get("DomainStatus", {}).get("EncryptionAtRestOptions", {})
97-
98-
# Check if encryption is enabled and using CMK
99-
if not encryption_config.get("Enabled", False):
100-
non_compliant_kbs.append(f"{kb_name} (OpenSearch domain encryption not enabled)")
101-
elif not encryption_config.get("KmsKeyId"):
102-
non_compliant_kbs.append(f"{kb_name} (OpenSearch domain not using CMK)")
103-
except ClientError as e:
104-
LOGGER.error(f"Error checking OpenSearch domain: {str(e)}")
105-
non_compliant_kbs.append(f"{kb_name} (error checking OpenSearch domain)")
106-
107-
except ClientError as e:
108-
LOGGER.error(f"Error checking knowledge base {kb_id}: {str(e)}")
109-
if e.response["Error"]["Code"] == "AccessDeniedException":
110-
non_compliant_kbs.append(f"{kb_name} (access denied)")
111-
else:
112-
raise
147+
error = check_knowledge_base(kb_id, kb_name)
148+
if error:
149+
non_compliant_kbs.append(error)
113150

114151
if non_compliant_kbs:
115-
return "NON_COMPLIANT", f"The following knowledge bases have OpenSearch vector stores not encrypted with CMK: {'; '.join(non_compliant_kbs)}"
152+
return "NON_COMPLIANT", (
153+
"The following knowledge bases have OpenSearch vector stores not encrypted with CMK: "
154+
+ f"{'; '.join(non_compliant_kbs)}"
155+
)
116156
return "COMPLIANT", "All knowledge base OpenSearch vector stores are encrypted with KMS CMK"
117157

118158
except Exception as e:
119159
LOGGER.error(f"Error evaluating Bedrock Knowledge Base OpenSearch encryption: {str(e)}")
120160
return "ERROR", f"Error evaluating compliance: {str(e)}"
121161

122-
def lambda_handler(event: dict, context: Any) -> None:
162+
163+
def lambda_handler(event: dict, context: Any) -> None: # noqa: U100
123164
"""Lambda handler.
124165
125166
Args:
@@ -145,6 +186,6 @@ def lambda_handler(event: dict, context: Any) -> None:
145186
LOGGER.info(f"Compliance evaluation result: {compliance_type}")
146187
LOGGER.info(f"Annotation: {annotation}")
147188

148-
config_client.put_evaluations(Evaluations=[evaluation], ResultToken=event["resultToken"])
189+
config_client.put_evaluations(Evaluations=[evaluation], ResultToken=event["resultToken"]) # type: ignore
149190

150-
LOGGER.info("Compliance evaluation complete.")
191+
LOGGER.info("Compliance evaluation complete.")

0 commit comments

Comments
 (0)