@@ -33,6 +33,21 @@ class SRASTS:
3333 log_level : str = os .environ .get ("LOG_LEVEL" , "INFO" )
3434 LOGGER .setLevel (log_level )
3535
36+ def _get_partition_for_region (self , region_name : str ) -> str :
37+ """Get AWS partition for a given region.
38+
39+ Args:
40+ region_name (str): AWS region name
41+
42+ Returns:
43+ str: AWS partition name (aws, aws-cn, aws-us-gov)
44+ """
45+ if region_name .startswith ('us-gov-' ):
46+ return 'aws-us-gov'
47+ elif region_name .startswith ('cn-' ):
48+ return 'aws-cn'
49+ return 'aws'
50+
3651 def __init__ (self , profile : str = "default" ) -> None :
3752 """Initialize class object.
3853
@@ -56,14 +71,14 @@ def __init__(self, profile: str = "default") -> None:
5671 self .STS_CLIENT = self .MANAGEMENT_ACCOUNT_SESSION .client ("sts" )
5772 self .HOME_REGION = self .MANAGEMENT_ACCOUNT_SESSION .region_name
5873 self .LOGGER .info (f"STS detected home region: { self .HOME_REGION } " )
59- self .PARTITION = self .MANAGEMENT_ACCOUNT_SESSION . get_partition_for_region (self .HOME_REGION )
74+ self .PARTITION = self ._get_partition_for_region (self .HOME_REGION )
6075 except botocore .exceptions .ClientError as error :
6176 if error .response ["Error" ]["Code" ] == "ExpiredToken" :
6277 self .LOGGER .info ("Token has expired, please re-run with proper credentials set." )
6378 self .MANAGEMENT_ACCOUNT_SESSION = boto3 .Session ()
6479 self .STS_CLIENT = self .MANAGEMENT_ACCOUNT_SESSION .client ("sts" )
6580 self .HOME_REGION = self .MANAGEMENT_ACCOUNT_SESSION .region_name
66- self .PARTITION = self .MANAGEMENT_ACCOUNT_SESSION . get_partition_for_region (self .HOME_REGION )
81+ self .PARTITION = self ._get_partition_for_region (self .HOME_REGION )
6782
6883 else :
6984 self .LOGGER .info (f"Error: { error } " )
0 commit comments