Skip to content

Commit b40a257

Browse files
authored
Refactor aws cloud service and introduce a client provider (#3895)
* add support for aws clients provider * refactor aws cloud service * fix typo
1 parent 1ea514f commit b40a257

File tree

12 files changed

+528
-152
lines changed

12 files changed

+528
-152
lines changed

main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func main() {
8181
ctrl.SetLogger(appLogger)
8282
klog.SetLoggerWithOptions(appLogger, klog.ContextualLogger(true))
8383

84-
cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log)
84+
cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log, nil)
8585
if err != nil {
8686
setupLog.Error(err, "unable to initialize AWS cloud")
8787
os.Exit(1)

pkg/aws/cloud.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/prometheus/client_golang/prometheus"
2525
amerrors "k8s.io/apimachinery/pkg/util/errors"
2626
epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
27+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
2728
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
2829
)
2930

@@ -59,7 +60,7 @@ type Cloud interface {
5960
}
6061

6162
// NewCloud constructs new Cloud implementation.
62-
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger) (Cloud, error) {
63+
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (Cloud, error) {
6364
hasIPv4 := true
6465
addrs, err := net.InterfaceAddrs()
6566
if err == nil {
@@ -129,7 +130,14 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
129130
awsConfig.APIOptions = metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions)
130131
}
131132

132-
ec2Service := services.NewEC2(awsConfig, endpointsResolver)
133+
if awsClientsProvider == nil {
134+
var err error
135+
awsClientsProvider, err = provider.NewDefaultAWSClientsProvider(awsConfig, endpointsResolver)
136+
if err != nil {
137+
return nil, errors.Wrap(err, "failed to create aws clients provider")
138+
}
139+
}
140+
ec2Service := services.NewEC2(awsClientsProvider)
133141

134142
vpcID, err := getVpcID(cfg, ec2Service, ec2Metadata, logger)
135143
if err != nil {
@@ -139,17 +147,16 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
139147
return &defaultCloud{
140148
cfg: cfg,
141149
ec2: ec2Service,
142-
elbv2: services.NewELBV2(awsConfig, endpointsResolver),
143-
acm: services.NewACM(awsConfig, endpointsResolver),
144-
wafv2: services.NewWAFv2(awsConfig, endpointsResolver),
145-
wafRegional: services.NewWAFRegional(awsConfig, endpointsResolver, cfg.Region),
146-
shield: services.NewShield(awsConfig, endpointsResolver), //done
147-
rgt: services.NewRGT(awsConfig, endpointsResolver),
150+
elbv2: services.NewELBV2(awsClientsProvider),
151+
acm: services.NewACM(awsClientsProvider),
152+
wafv2: services.NewWAFv2(awsClientsProvider),
153+
wafRegional: services.NewWAFRegional(awsClientsProvider, cfg.Region),
154+
shield: services.NewShield(awsClientsProvider),
155+
rgt: services.NewRGT(awsClientsProvider),
148156
}, nil
149157
}
150158

151159
func getVpcID(cfg CloudConfig, ec2Service services.EC2, ec2Metadata services.EC2Metadata, logger logr.Logger) (string, error) {
152-
153160
if cfg.VpcID != "" {
154161
logger.V(1).Info("vpcid is specified using flag --aws-vpc-id, controller will use the value", "vpc: ", cfg.VpcID)
155162
return cfg.VpcID, nil
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
"github.com/aws/aws-sdk-go-v2/aws"
6+
"github.com/aws/aws-sdk-go-v2/service/acm"
7+
"github.com/aws/aws-sdk-go-v2/service/ec2"
8+
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
9+
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
10+
"github.com/aws/aws-sdk-go-v2/service/shield"
11+
"github.com/aws/aws-sdk-go-v2/service/wafregional"
12+
"github.com/aws/aws-sdk-go-v2/service/wafv2"
13+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
14+
)
15+
16+
type defaultAWSClientsProvider struct {
17+
ec2Client *ec2.Client
18+
elbv2Client *elasticloadbalancingv2.Client
19+
acmClient *acm.Client
20+
wafv2Client *wafv2.Client
21+
wafRegionClient *wafregional.Client
22+
shieldClient *shield.Client
23+
rgtClient *resourcegroupstaggingapi.Client
24+
}
25+
26+
func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) {
27+
ec2CustomEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
28+
elbv2CustomEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID)
29+
acmCustomEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
30+
wafv2CustomEndpoint := endpointsResolver.EndpointFor(wafv2.ServiceID)
31+
wafregionalCustomEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID)
32+
shieldCustomEndpoint := endpointsResolver.EndpointFor(shield.ServiceID)
33+
rgtCustomEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID)
34+
35+
ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) {
36+
if ec2CustomEndpoint != nil {
37+
o.BaseEndpoint = ec2CustomEndpoint
38+
}
39+
})
40+
elbv2Client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) {
41+
if elbv2CustomEndpoint != nil {
42+
o.BaseEndpoint = elbv2CustomEndpoint
43+
}
44+
})
45+
acmClient := acm.NewFromConfig(cfg, func(o *acm.Options) {
46+
if acmCustomEndpoint != nil {
47+
o.BaseEndpoint = acmCustomEndpoint
48+
}
49+
})
50+
wafv2Client := wafv2.NewFromConfig(cfg, func(o *wafv2.Options) {
51+
if wafv2CustomEndpoint != nil {
52+
o.BaseEndpoint = wafv2CustomEndpoint
53+
}
54+
})
55+
wafregionalClient := wafregional.NewFromConfig(cfg, func(o *wafregional.Options) {
56+
o.Region = cfg.Region
57+
o.BaseEndpoint = wafregionalCustomEndpoint
58+
})
59+
sheildClient := shield.NewFromConfig(cfg, func(o *shield.Options) {
60+
o.Region = "us-east-1"
61+
o.BaseEndpoint = shieldCustomEndpoint
62+
})
63+
rgtClient := resourcegroupstaggingapi.NewFromConfig(cfg, func(o *resourcegroupstaggingapi.Options) {
64+
if rgtCustomEndpoint != nil {
65+
o.BaseEndpoint = rgtCustomEndpoint
66+
}
67+
})
68+
69+
return &defaultAWSClientsProvider{
70+
ec2Client: ec2Client,
71+
elbv2Client: elbv2Client,
72+
acmClient: acmClient,
73+
wafv2Client: wafv2Client,
74+
wafRegionClient: wafregionalClient,
75+
shieldClient: sheildClient,
76+
rgtClient: rgtClient,
77+
}, nil
78+
}
79+
80+
// DO NOT REMOVE operationName as parameter, this is on purpose
81+
// to retain the default behavior for OSS controller to use the default client for each aws service
82+
// for our internal controller, we will choose different client based on operationName
83+
func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) {
84+
return p.ec2Client, nil
85+
}
86+
87+
func (p *defaultAWSClientsProvider) GetELBv2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error) {
88+
return p.elbv2Client, nil
89+
}
90+
91+
func (p *defaultAWSClientsProvider) GetACMClient(ctx context.Context, operationName string) (*acm.Client, error) {
92+
return p.acmClient, nil
93+
}
94+
95+
func (p *defaultAWSClientsProvider) GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error) {
96+
return p.wafv2Client, nil
97+
}
98+
99+
func (p *defaultAWSClientsProvider) GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) {
100+
return p.wafRegionClient, nil
101+
}
102+
103+
func (p *defaultAWSClientsProvider) GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) {
104+
return p.shieldClient, nil
105+
}
106+
107+
func (p *defaultAWSClientsProvider) GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) {
108+
return p.rgtClient, nil
109+
}

pkg/aws/provider/provider.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
"github.com/aws/aws-sdk-go-v2/service/acm"
6+
"github.com/aws/aws-sdk-go-v2/service/ec2"
7+
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
8+
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
9+
"github.com/aws/aws-sdk-go-v2/service/shield"
10+
"github.com/aws/aws-sdk-go-v2/service/wafregional"
11+
"github.com/aws/aws-sdk-go-v2/service/wafv2"
12+
)
13+
14+
type AWSClientsProvider interface {
15+
GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error)
16+
GetELBv2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error)
17+
GetACMClient(ctx context.Context, operationName string) (*acm.Client, error)
18+
GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error)
19+
GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error)
20+
GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error)
21+
GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error)
22+
}

pkg/aws/services/acm.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ package services
22

33
import (
44
"context"
5-
"github.com/aws/aws-sdk-go-v2/aws"
65
"github.com/aws/aws-sdk-go-v2/service/acm"
76
"github.com/aws/aws-sdk-go-v2/service/acm/types"
8-
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
7+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
98
)
109

1110
type ACM interface {
@@ -15,24 +14,23 @@ type ACM interface {
1514
}
1615

1716
// NewACM constructs new ACM implementation.
18-
func NewACM(cfg aws.Config, endpointsResolver *endpoints.Resolver) ACM {
19-
customEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
17+
func NewACM(awsClientsProvider provider.AWSClientsProvider) ACM {
2018
return &acmClient{
21-
acmClient: acm.NewFromConfig(cfg, func(o *acm.Options) {
22-
if customEndpoint != nil {
23-
o.BaseEndpoint = customEndpoint
24-
}
25-
}),
19+
awsClientsProvider: awsClientsProvider,
2620
}
2721
}
2822

2923
type acmClient struct {
30-
acmClient *acm.Client
24+
awsClientsProvider provider.AWSClientsProvider
3125
}
3226

3327
func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListCertificatesInput) ([]types.CertificateSummary, error) {
3428
var result []types.CertificateSummary
35-
paginator := acm.NewListCertificatesPaginator(c.acmClient, input)
29+
client, err := c.awsClientsProvider.GetACMClient(ctx, "ListCertificates")
30+
if err != nil {
31+
return nil, err
32+
}
33+
paginator := acm.NewListCertificatesPaginator(client, input)
3634
for paginator.HasMorePages() {
3735
output, err := paginator.NextPage(ctx)
3836
if err != nil {
@@ -44,5 +42,9 @@ func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListC
4442
}
4543

4644
func (c *acmClient) DescribeCertificateWithContext(ctx context.Context, input *acm.DescribeCertificateInput) (*acm.DescribeCertificateOutput, error) {
47-
return c.acmClient.DescribeCertificate(ctx, input)
45+
client, err := c.awsClientsProvider.GetACMClient(ctx, "DescribeCertificate")
46+
if err != nil {
47+
return nil, err
48+
}
49+
return client.DescribeCertificate(ctx, input)
4850
}

0 commit comments

Comments
 (0)