Migrate AWS Verifier to aws-sdk-go-v2

This commit is contained in:
Peter Rifel 2024-04-21 07:06:16 -04:00
parent 3d43f9eba5
commit 62df0dba04
No known key found for this signature in database
6 changed files with 60 additions and 66 deletions

View File

@ -130,7 +130,7 @@ func main() {
var verifiers []bootstrap.Verifier
var err error
if opt.Server.Provider.AWS != nil {
verifier, err := awsup.NewAWSVerifier(opt.Server.Provider.AWS)
verifier, err := awsup.NewAWSVerifier(ctx, opt.Server.Provider.AWS)
if err != nil {
setupLog.Error(err, "unable to create verifier")
os.Exit(1)

View File

@ -52,7 +52,7 @@ func (b BootstrapClientBuilder) Build(c *fi.NodeupModelBuilderContext) error {
switch b.CloudProvider() {
case kops.CloudProviderAWS:
a, err := awsup.NewAWSAuthenticator(b.Cloud.Region())
a, err := awsup.NewAWSAuthenticator(c.Context(), b.Cloud.Region())
if err != nil {
return err
}

View File

@ -25,17 +25,15 @@ import (
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts"
smithyhttp "github.com/aws/smithy-go/transport/http"
"k8s.io/kops/pkg/bootstrap"
)
const AWSAuthenticationTokenPrefix = "x-aws-sts "
type awsAuthenticator struct {
sts *sts.STS
sts *sts.Client
}
var _ bootstrap.Authenticator = &awsAuthenticator{}
@ -55,32 +53,28 @@ func RegionFromMetadata(ctx context.Context) (string, error) {
return resp.Region, nil
}
func NewAWSAuthenticator(region string) (bootstrap.Authenticator, error) {
config := aws.NewConfig().
WithCredentialsChainVerboseErrors(true).
WithRegion(region).
WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
sess, err := session.NewSession(config)
func NewAWSAuthenticator(ctx context.Context, region string) (bootstrap.Authenticator, error) {
config, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to load aws config: %w", err)
}
return &awsAuthenticator{
sts: sts.New(sess, config),
sts: sts.NewFromConfig(config),
}, nil
}
func (a *awsAuthenticator) CreateToken(body []byte) (string, error) {
sha := sha256.Sum256(body)
stsRequest, _ := a.sts.GetCallerIdentityRequest(nil)
presignClient := sts.NewPresignClient(a.sts)
// Ensure the signature is only valid for this particular body content.
stsRequest.HTTPRequest.Header.Add("X-Kops-Request-SHA", base64.RawStdEncoding.EncodeToString(sha[:]))
stsRequest, _ := presignClient.PresignGetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) {
po.ClientOptions = append(po.ClientOptions, func(o *sts.Options) {
o.APIOptions = append(o.APIOptions, smithyhttp.AddHeaderValue("X-Kops-Request-SHA", base64.RawStdEncoding.EncodeToString(sha[:])))
})
})
if err := stsRequest.Sign(); err != nil {
return "", err
}
headers, _ := json.Marshal(stsRequest.HTTPRequest.Header)
headers, _ := json.Marshal(stsRequest.SignedHeader)
return AWSAuthenticationTokenPrefix + base64.StdEncoding.EncodeToString(headers), nil
}

View File

@ -46,7 +46,6 @@ import (
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/route53"
"github.com/aws/aws-sdk-go-v2/service/sts"
ec2v1 "github.com/aws/aws-sdk-go/service/ec2"
"k8s.io/klog/v2"
v1 "k8s.io/api/core/v1"
@ -2358,7 +2357,7 @@ func GetRolesInInstanceProfile(c AWSCloud, profileName string) ([]string, error)
// GetInstanceCertificateNames returns the instance hostname and addresses that should go into certificates.
// The first value is the node name and any additional values are the DNS name and IP addresses.
func GetInstanceCertificateNames(instances *ec2v1.DescribeInstancesOutput) (addrs []string, err error) {
func GetInstanceCertificateNames(instances *ec2.DescribeInstancesOutput) (addrs []string, err error) {
if len(instances.Reservations) != 1 {
return nil, fmt.Errorf("too many reservations returned for the single instance-id")
}

View File

@ -27,15 +27,15 @@ import (
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/sts"
"k8s.io/kops/pkg/bootstrap"
nodeidentityaws "k8s.io/kops/pkg/nodeidentity/aws"
"k8s.io/kops/pkg/wellknownports"
@ -53,39 +53,38 @@ type awsVerifier struct {
partition string
opt AWSVerifierOptions
ec2 *ec2.EC2
sts *sts.STS
ec2 *ec2.Client
sts *sts.PresignClient
client http.Client
}
var _ bootstrap.Verifier = &awsVerifier{}
func NewAWSVerifier(opt *AWSVerifierOptions) (bootstrap.Verifier, error) {
config := aws.NewConfig().
WithCredentialsChainVerboseErrors(true).
WithRegion(opt.Region).
WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
sess, err := session.NewSession(config)
func NewAWSVerifier(ctx context.Context, opt *AWSVerifierOptions) (bootstrap.Verifier, error) {
config, err := awsconfig.LoadDefaultConfig(
ctx,
awsconfig.WithRegion(opt.Region),
)
if err != nil {
return nil, fmt.Errorf("failed to load aws config: %w", err)
}
stsClient := sts.NewFromConfig(config)
identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, err
}
stsClient := sts.New(sess, config)
identity, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return nil, err
}
partition := strings.Split(aws.ToString(identity.Arn), ":")[1]
partition := strings.Split(aws.StringValue(identity.Arn), ":")[1]
ec2Client := ec2.New(sess, config)
ec2Client := ec2.NewFromConfig(config)
return &awsVerifier{
accountId: aws.StringValue(identity.Account),
accountId: aws.ToString(identity.Account),
partition: partition,
opt: *opt,
ec2: ec2Client,
sts: stsClient,
sts: sts.NewPresignClient(stsClient),
client: http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
@ -128,35 +127,37 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request,
token = strings.TrimPrefix(token, AWSAuthenticationTokenPrefix)
// We rely on the client and server using the same version of the same STS library.
stsRequest, _ := a.sts.GetCallerIdentityRequest(nil)
err := stsRequest.Sign()
stsRequest, err := a.sts.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, fmt.Errorf("creating identity request: %v", err)
}
stsRequest.HTTPRequest.Header = nil
stsRequest.SignedHeader = nil
tokenBytes, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return nil, fmt.Errorf("decoding authorization token: %v", err)
}
err = json.Unmarshal(tokenBytes, &stsRequest.HTTPRequest.Header)
err = json.Unmarshal(tokenBytes, &stsRequest.SignedHeader)
if err != nil {
return nil, fmt.Errorf("unmarshalling authorization token: %v", err)
}
// Verify the token has signed the body content.
sha := sha256.Sum256(body)
if stsRequest.HTTPRequest.Header.Get("X-Kops-Request-SHA") != base64.RawStdEncoding.EncodeToString(sha[:]) {
if stsRequest.SignedHeader.Get("X-Kops-Request-SHA") != base64.RawStdEncoding.EncodeToString(sha[:]) {
return nil, fmt.Errorf("incorrect SHA")
}
requestBytes, _ := io.ReadAll(stsRequest.Body)
_, _ = stsRequest.Body.Seek(0, io.SeekStart)
if stsRequest.HTTPRequest.Header.Get("Content-Length") != strconv.Itoa(len(requestBytes)) {
return nil, fmt.Errorf("incorrect content-length")
reqURL, err := url.Parse(stsRequest.URL)
if err != nil {
return nil, fmt.Errorf("parsing STS request URL: %v", err)
}
response, err := a.client.Do(stsRequest.HTTPRequest)
req := &http.Request{
URL: reqURL,
Method: stsRequest.Method,
Header: stsRequest.SignedHeader,
}
response, err := a.client.Do(req)
if err != nil {
return nil, fmt.Errorf("sending STS request: %v", err)
}
@ -217,8 +218,8 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request,
}
instanceID := resource[2]
instances, err := a.ec2.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: aws.StringSlice([]string{instanceID}),
instances, err := a.ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{
InstanceIds: []string{instanceID},
})
if err != nil {
return nil, fmt.Errorf("describing instance for arn %q", arn)
@ -240,17 +241,17 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request,
var challengeEndpoints []string
for _, nic := range instance.NetworkInterfaces {
if ip := aws.StringValue(nic.PrivateIpAddress); ip != "" {
if ip := aws.ToString(nic.PrivateIpAddress); ip != "" {
challengeEndpoints = append(challengeEndpoints, net.JoinHostPort(ip, strconv.Itoa(wellknownports.NodeupChallenge)))
}
for _, a := range nic.PrivateIpAddresses {
if ip := aws.StringValue(a.PrivateIpAddress); ip != "" {
if ip := aws.ToString(a.PrivateIpAddress); ip != "" {
challengeEndpoints = append(challengeEndpoints, net.JoinHostPort(ip, strconv.Itoa(wellknownports.NodeupChallenge)))
}
}
for _, a := range nic.Ipv6Addresses {
if ip := aws.StringValue(a.Ipv6Address); ip != "" {
if ip := aws.ToString(a.Ipv6Address); ip != "" {
challengeEndpoints = append(challengeEndpoints, net.JoinHostPort(ip, strconv.Itoa(wellknownports.NodeupChallenge)))
}
}
@ -267,9 +268,9 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request,
}
for _, tag := range instance.Tags {
tagKey := aws.StringValue(tag.Key)
tagKey := aws.ToString(tag.Key)
if tagKey == nodeidentityaws.CloudTagInstanceGroupName {
result.InstanceGroupName = aws.StringValue(tag.Value)
result.InstanceGroupName = aws.ToString(tag.Value)
}
}

View File

@ -625,7 +625,7 @@ func getNodeConfigFromServers(ctx context.Context, bootConfig *nodeup.BootConfig
switch bootConfig.CloudProvider {
case api.CloudProviderAWS:
a, err := awsup.NewAWSAuthenticator(region)
a, err := awsup.NewAWSAuthenticator(ctx, region)
if err != nil {
return nil, err
}