Skip to content

Commit

Permalink
Separate STS call and retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
gl-johnson committed Jun 23, 2023
1 parent 4432fe7 commit 3934bd4
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions app/domain/authentication/authn_iam/authenticator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,34 @@ def extract_relevant_data(response)

# Call to AWS STS endpoint using the provided authentication header
def attempt_signed_request(signed_headers)
sts_host = extract_sts_host(signed_headers)
aws_request = URI("https://#{sts_host}/?Action=GetCallerIdentity&Version=2011-06-15")
begin
response = @client.get_response(aws_request, signed_headers)
return response unless response.code.to_i == 403 && sts_host.include?('us-east-1')

# If the request to `us-east-1` failed with a 403, retry on the global endpoint
retry_signed_request_on_global(signed_headers)
region = extract_sts_region(signed_headers)

# Handle any network failures with a generic verification error
rescue StandardError => e
raise(Errors::Authentication::AuthnIam::VerificationError.new(e))
# Attempt request using the discovered region and return immediately if successful
response = aws_call(region: region, headers: signed_headers)
return response if response.code.to_i == 200

# If the discovered region is `us-east-1`, fallback to the global endpoint
if region == 'us-east-1'
@logger.debug(LogMessages::Authentication::AuthnIam::RetryWithGlobalEndpoint.new)
fallback_response = aws_call(region: 'global', headers: signed_headers)
return fallback_response if fallback_response.code.to_i == 200
end

return response
end

# Retry request on AWS STS global endpoint
def retry_signed_request_on_global(signed_headers)
@logger.debug(
LogMessages::Authentication::AuthnIam::RetryWithGlobalEndpoint.new
)
aws_request = URI('https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15')
def aws_call(region:, headers:)
host = if region == 'global'
'sts.amazonaws.com'
else
"sts.#{region}.amazonaws.com"
end
aws_request = URI("https://#{host}/?Action=GetCallerIdentity&Version=2011-06-15")
begin
@client.get_response(aws_request, signed_headers)

# Handle any network failures with a generic verification error
@client.get_response(aws_request, headers)
rescue StandardError => e
raise(Errors::Authentication::AuthnIam::VerificationError.new(e))
# Handle any network failures with a generic verification error
raise(Errors::Authentication::AuthnIam::VerificationError, e)
end
end

Expand All @@ -97,15 +98,23 @@ def response_from_signed_request(aws_headers)
)
end

# Extract AWS region from the authorization header's credential string, i.e.:
# Extracts the STS region from the host header if it exists.
# If not, we use the authorization header's credential string, i.e.:
# Credential=AKIAIOSFODNN7EXAMPLE/20220830/us-east-1/sts/aws4_request
def extract_sts_host(signed_headers)
return signed_headers['host'] if signed_headers['host'].present?
def extract_sts_region(signed_headers)
host = signed_headers['host']

region = signed_headers['authorization'].match(%r{Credential=[^/]+/[^/]+/([^/]+)/})&.captures&.first
raise(Errors::Authentication::AuthnIam::InvalidAWSHeaders, 'Failed to extract AWS region from authorization header') unless region
if host == 'sts.amazonaws.com'
return 'global'
end

match = host&.match(%r{sts.([\w\-]+).amazonaws.com})
return match&.captures&.first if match

"sts.#{region}.amazonaws.com"
match = signed_headers['authorization']&.match(%r{Credential=[^/]+/[^/]+/([^/]+)/})
return match&.captures&.first if match

raise Errors::Authentication::AuthnIam::InvalidAWSHeaders, 'Failed to extract AWS region from authorization header'
end
end
end
Expand Down

0 comments on commit 3934bd4

Please sign in to comment.