From 19a708d2cf6eb1b61ac8b07baec21b6d1688c019 Mon Sep 17 00:00:00 2001 From: Harshavardhana Date: Tue, 11 Jan 2022 12:51:10 -0800 Subject: [PATCH] add fallback code for STS ErrorResponse v/s S3 Error (#1607) --- pkg/credentials/assume_role.go | 13 +++++++++++- pkg/credentials/error_response.go | 30 ++++++++++++++++++++++++++++ pkg/credentials/sts_client_grants.go | 20 ++++++++++++++----- pkg/credentials/sts_ldap_identity.go | 19 ++++++++++++++---- pkg/credentials/sts_tls_identity.go | 19 ++++++++++++++---- pkg/credentials/sts_web_identity.go | 19 ++++++++++++++---- 6 files changed, 102 insertions(+), 18 deletions(-) diff --git a/pkg/credentials/assume_role.go b/pkg/credentials/assume_role.go index d14f5239f..107a11b14 100644 --- a/pkg/credentials/assume_role.go +++ b/pkg/credentials/assume_role.go @@ -18,6 +18,7 @@ package credentials import ( + "bytes" "encoding/hex" "encoding/xml" "errors" @@ -185,10 +186,20 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume defer closeResponse(resp) if resp.StatusCode != http.StatusOK { var errResp ErrorResponse - _, err = xmlDecodeAndBody(resp.Body, &errResp) + buf, err := ioutil.ReadAll(resp.Body) if err != nil { return AssumeRoleResponse{}, err } + _, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp) + if err != nil { + var s3Err Error + if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil { + return AssumeRoleResponse{}, err + } + errResp.RequestID = s3Err.RequestID + errResp.STSError.Code = s3Err.Code + errResp.STSError.Message = s3Err.Message + } return AssumeRoleResponse{}, errResp } diff --git a/pkg/credentials/error_response.go b/pkg/credentials/error_response.go index 73e53f616..f4b027a41 100644 --- a/pkg/credentials/error_response.go +++ b/pkg/credentials/error_response.go @@ -38,6 +38,36 @@ type ErrorResponse struct { RequestID string `xml:"RequestId"` } +// Error - Is the typed error returned by all API operations. +type Error struct { + XMLName xml.Name `xml:"Error" json:"-"` + Code string + Message string + BucketName string + Key string + Resource string + RequestID string `xml:"RequestId"` + HostID string `xml:"HostId"` + + // Region where the bucket is located. This header is returned + // only in HEAD bucket and ListObjects response. + Region string + + // Captures the server string returned in response header. + Server string + + // Underlying HTTP status code for the returned error + StatusCode int `xml:"-" json:"-"` +} + +// Error - Returns S3 error string. +func (e Error) Error() string { + if e.Message == "" { + return fmt.Sprintf("Error response code %s.", e.Code) + } + return e.Message +} + // Error - Returns STS error string. func (e ErrorResponse) Error() string { if e.STSError.Message == "" { diff --git a/pkg/credentials/sts_client_grants.go b/pkg/credentials/sts_client_grants.go index 85cd4599d..b6712b19d 100644 --- a/pkg/credentials/sts_client_grants.go +++ b/pkg/credentials/sts_client_grants.go @@ -18,9 +18,11 @@ package credentials import ( + "bytes" "encoding/xml" "errors" "fmt" + "io/ioutil" "net/http" "net/url" "time" @@ -133,12 +135,20 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string, defer resp.Body.Close() if resp.StatusCode != http.StatusOK { var errResp ErrorResponse - _, err = xmlDecodeAndBody(resp.Body, &errResp) + buf, err := ioutil.ReadAll(resp.Body) if err != nil { - errResp := ErrorResponse{} - errResp.STSError.Code = "InvalidArgument" - errResp.STSError.Message = err.Error() - return AssumeRoleWithClientGrantsResponse{}, errResp + return AssumeRoleWithClientGrantsResponse{}, err + + } + _, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp) + if err != nil { + var s3Err Error + if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil { + return AssumeRoleWithClientGrantsResponse{}, err + } + errResp.RequestID = s3Err.RequestID + errResp.STSError.Code = s3Err.Code + errResp.STSError.Message = s3Err.Message } return AssumeRoleWithClientGrantsResponse{}, errResp } diff --git a/pkg/credentials/sts_ldap_identity.go b/pkg/credentials/sts_ldap_identity.go index ec2f1a31b..39c7892b6 100644 --- a/pkg/credentials/sts_ldap_identity.go +++ b/pkg/credentials/sts_ldap_identity.go @@ -18,8 +18,10 @@ package credentials import ( + "bytes" "encoding/xml" "fmt" + "io/ioutil" "net/http" "net/url" "time" @@ -169,11 +171,20 @@ func (k *LDAPIdentity) Retrieve() (value Value, err error) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { var errResp ErrorResponse - _, err = xmlDecodeAndBody(resp.Body, &errResp) + buf, err := ioutil.ReadAll(resp.Body) if err != nil { - errResp.STSError.Code = "InvalidArgument" - errResp.STSError.Message = err.Error() - return value, errResp + return value, err + + } + _, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp) + if err != nil { + var s3Err Error + if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil { + return value, err + } + errResp.RequestID = s3Err.RequestID + errResp.STSError.Code = s3Err.Code + errResp.STSError.Message = s3Err.Message } return value, errResp } diff --git a/pkg/credentials/sts_tls_identity.go b/pkg/credentials/sts_tls_identity.go index 105fc209a..7f485d639 100644 --- a/pkg/credentials/sts_tls_identity.go +++ b/pkg/credentials/sts_tls_identity.go @@ -16,10 +16,12 @@ package credentials import ( + "bytes" "crypto/tls" "encoding/xml" "errors" "io" + "io/ioutil" "net" "net/http" "net/url" @@ -150,11 +152,20 @@ func (i *STSCertificateIdentity) Retrieve() (Value, error) { } if resp.StatusCode != http.StatusOK { var errResp ErrorResponse - _, err = xmlDecodeAndBody(resp.Body, &errResp) + buf, err := ioutil.ReadAll(resp.Body) if err != nil { - errResp.STSError.Code = "InvalidArgument" - errResp.STSError.Message = err.Error() - return Value{}, errResp + return Value{}, err + + } + _, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp) + if err != nil { + var s3Err Error + if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil { + return Value{}, err + } + errResp.RequestID = s3Err.RequestID + errResp.STSError.Code = s3Err.Code + errResp.STSError.Message = s3Err.Message } return Value{}, errResp } diff --git a/pkg/credentials/sts_web_identity.go b/pkg/credentials/sts_web_identity.go index 70e8f9674..98f6ea653 100644 --- a/pkg/credentials/sts_web_identity.go +++ b/pkg/credentials/sts_web_identity.go @@ -18,9 +18,11 @@ package credentials import ( + "bytes" "encoding/xml" "errors" "fmt" + "io/ioutil" "net/http" "net/url" "strconv" @@ -151,11 +153,20 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession defer resp.Body.Close() if resp.StatusCode != http.StatusOK { var errResp ErrorResponse - _, err = xmlDecodeAndBody(resp.Body, &errResp) + buf, err := ioutil.ReadAll(resp.Body) if err != nil { - errResp.STSError.Code = "InvalidArgument" - errResp.STSError.Message = err.Error() - return AssumeRoleWithWebIdentityResponse{}, errResp + return AssumeRoleWithWebIdentityResponse{}, err + + } + _, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp) + if err != nil { + var s3Err Error + if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil { + return AssumeRoleWithWebIdentityResponse{}, err + } + errResp.RequestID = s3Err.RequestID + errResp.STSError.Code = s3Err.Code + errResp.STSError.Message = s3Err.Message } return AssumeRoleWithWebIdentityResponse{}, errResp }