Skip to content

Commit e3335c2

Browse files
committed
working to fix bug
1 parent 089feeb commit e3335c2

File tree

1 file changed

+86
-53
lines changed
  • aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_kb_s3_bucket

1 file changed

+86
-53
lines changed

aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_kb_s3_bucket/app.py

Lines changed: 86 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -45,60 +45,92 @@ def evaluate_compliance(rule_parameters: dict) -> tuple[str, str]:
4545

4646
for page in paginator.paginate():
4747
for kb in page["knowledgeBaseSummaries"]:
48-
kb_details = bedrock_agent_client.get_knowledge_base(knowledgeBaseId=kb["knowledgeBaseId"])
49-
data_source = bedrock_agent_client.get_data_source(
50-
knowledgeBaseId=kb["knowledgeBaseId"],
51-
dataSourceId=kb_details["dataSource"]["dataSourceId"]
52-
)
48+
kb_id = kb["knowledgeBaseId"]
5349

54-
# Extract bucket name from S3 path
55-
s3_path = data_source["configuration"]["s3Configuration"]["bucketName"]
56-
bucket_name = s3_path.split("/")[0]
50+
# List data sources for this knowledge base
51+
data_sources_paginator = bedrock_agent_client.get_paginator("list_data_sources")
5752

58-
issues = []
59-
60-
# Check retention
61-
if rule_parameters.get("check_retention", "true").lower() == "true":
62-
try:
63-
lifecycle = s3_client.get_bucket_lifecycle_configuration(Bucket=bucket_name)
64-
if not any(rule.get("Expiration") for rule in lifecycle.get("Rules", [])):
65-
issues.append("retention")
66-
except ClientError as e:
67-
if e.response["Error"]["Code"] == "NoSuchLifecycleConfiguration":
68-
issues.append("retention")
69-
70-
# Check encryption
71-
if rule_parameters.get("check_encryption", "true").lower() == "true":
72-
try:
73-
encryption = s3_client.get_bucket_encryption(Bucket=bucket_name)
74-
if not encryption.get("ServerSideEncryptionConfiguration"):
75-
issues.append("encryption")
76-
except ClientError:
77-
issues.append("encryption")
78-
79-
# Check server access logging
80-
if rule_parameters.get("check_access_logging", "true").lower() == "true":
81-
logging_config = s3_client.get_bucket_logging(Bucket=bucket_name)
82-
if not logging_config.get("LoggingEnabled"):
83-
issues.append("access logging")
84-
85-
# Check object lock
86-
if rule_parameters.get("check_object_locking", "true").lower() == "true":
87-
try:
88-
lock_config = s3_client.get_bucket_object_lock_configuration(Bucket=bucket_name)
89-
if not lock_config.get("ObjectLockConfiguration"):
90-
issues.append("object locking")
91-
except ClientError:
92-
issues.append("object locking")
93-
94-
# Check versioning
95-
if rule_parameters.get("check_versioning", "true").lower() == "true":
96-
versioning = s3_client.get_bucket_versioning(Bucket=bucket_name)
97-
if versioning.get("Status") != "Enabled":
98-
issues.append("versioning")
99-
100-
if issues:
101-
non_compliant_buckets.append(f"{bucket_name} (missing: {', '.join(issues)})")
53+
for ds_page in data_sources_paginator.paginate(knowledgeBaseId=kb_id):
54+
for ds in ds_page.get("dataSourceSummaries", []):
55+
data_source = bedrock_agent_client.get_data_source(
56+
knowledgeBaseId=kb_id,
57+
dataSourceId=ds["dataSourceId"]
58+
)
59+
60+
# Check if this is an S3 data source and extract bucket name
61+
LOGGER.info(f"Data source structure: {json.dumps(data_source)}")
62+
if "s3Configuration" in data_source.get("dataSource", {}).get("dataSourceConfiguration", {}):
63+
s3_config = data_source["dataSource"]["dataSourceConfiguration"]["s3Configuration"]
64+
bucket_name = s3_config.get("bucketName", "")
65+
else:
66+
continue
67+
68+
if not bucket_name:
69+
LOGGER.info(f"No bucket name found for data source {ds['dataSourceId']}")
70+
continue
71+
72+
# If bucket name contains a path, extract just the bucket name
73+
if "/" in bucket_name:
74+
bucket_name = bucket_name.split("/")[0]
75+
76+
LOGGER.info(f"Checking S3 bucket: {bucket_name}")
77+
78+
issues = []
79+
80+
# Check retention
81+
if rule_parameters.get("check_retention", "true").lower() == "true":
82+
try:
83+
lifecycle = s3_client.get_bucket_lifecycle_configuration(Bucket=bucket_name)
84+
if not any(rule.get("Expiration") for rule in lifecycle.get("Rules", [])):
85+
issues.append("retention")
86+
except ClientError as e:
87+
if e.response["Error"]["Code"] == "NoSuchLifecycleConfiguration":
88+
issues.append("retention")
89+
elif e.response["Error"]["Code"] != "NoSuchBucket":
90+
LOGGER.error(f"Error checking retention for bucket {bucket_name}: {str(e)}")
91+
92+
# Check encryption
93+
if rule_parameters.get("check_encryption", "true").lower() == "true":
94+
try:
95+
encryption = s3_client.get_bucket_encryption(Bucket=bucket_name)
96+
if not encryption.get("ServerSideEncryptionConfiguration"):
97+
issues.append("encryption")
98+
except ClientError as e:
99+
if e.response["Error"]["Code"] != "NoSuchBucket":
100+
issues.append("encryption")
101+
102+
# Check server access logging
103+
if rule_parameters.get("check_access_logging", "true").lower() == "true":
104+
try:
105+
logging_config = s3_client.get_bucket_logging(Bucket=bucket_name)
106+
if not logging_config.get("LoggingEnabled"):
107+
issues.append("access logging")
108+
except ClientError as e:
109+
if e.response["Error"]["Code"] != "NoSuchBucket":
110+
issues.append("access logging")
111+
112+
# Check object lock
113+
if rule_parameters.get("check_object_locking", "true").lower() == "true":
114+
try:
115+
lock_config = s3_client.get_object_lock_configuration(Bucket=bucket_name)
116+
if not lock_config.get("ObjectLockConfiguration"):
117+
issues.append("object locking")
118+
except ClientError as e:
119+
if e.response["Error"]["Code"] != "NoSuchBucket":
120+
issues.append("object locking")
121+
122+
# Check versioning
123+
if rule_parameters.get("check_versioning", "true").lower() == "true":
124+
try:
125+
versioning = s3_client.get_bucket_versioning(Bucket=bucket_name)
126+
if versioning.get("Status") != "Enabled":
127+
issues.append("versioning")
128+
except ClientError as e:
129+
if e.response["Error"]["Code"] != "NoSuchBucket":
130+
issues.append("versioning")
131+
132+
if issues:
133+
non_compliant_buckets.append(f"{bucket_name} (missing: {', '.join(issues)})")
102134

103135
if non_compliant_buckets:
104136
return "NON_COMPLIANT", f"The following KB S3 buckets are non-compliant: {'; '.join(non_compliant_buckets)}"
@@ -136,4 +168,5 @@ def lambda_handler(event: dict, context: Any) -> None:
136168

137169
config_client.put_evaluations(Evaluations=[evaluation], ResultToken=event["resultToken"])
138170

139-
LOGGER.info("Compliance evaluation complete.")
171+
LOGGER.info("Compliance evaluation complete.")
172+

0 commit comments

Comments
 (0)