From da7dff4c2525d2d343b09a6b6328f42223a260c9 Mon Sep 17 00:00:00 2001 From: Glen Johnson Date: Thu, 6 Jul 2023 08:34:33 -0600 Subject: [PATCH] Separate call and retry logic --- .../authentication/authn_iam/authenticator.rb | 63 +++++++++++-------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/app/domain/authentication/authn_iam/authenticator.rb b/app/domain/authentication/authn_iam/authenticator.rb index cfef517adb..cb83428a18 100755 --- a/app/domain/authentication/authn_iam/authenticator.rb +++ b/app/domain/authentication/authn_iam/authenticator.rb @@ -54,33 +54,34 @@ def extract_relevant_data(response) # Call to AWS STS endpoint using the provided authentication header def attempt_signed_request(signed_headers) - sts_host = extract_sts_host(signed_headers) - aws_request = URI("https://#{sts_host}/?Action=GetCallerIdentity&Version=2011-06-15") - begin - response = @client.get_response(aws_request, signed_headers) - return response unless response.code.to_i == 403 && sts_host.include?('us-east-1') - - # If the request to `us-east-1` failed with a 403, retry on the global endpoint - retry_signed_request_on_global(signed_headers) + region = extract_sts_region(signed_headers) - # Handle any network failures with a generic verification error - rescue StandardError => e - raise(Errors::Authentication::AuthnIam::VerificationError.new(e)) + # Attempt request using the discovered region and return immediately if successful + response = aws_call(region: region, headers: signed_headers) + return response if response.code.to_i == 200 + + # If the discovered region is `us-east-1`, fallback to the global endpoint + if region == 'us-east-1' + @logger.debug(LogMessages::Authentication::AuthnIam::RetryWithGlobalEndpoint.new) + fallback_response = aws_call(region: 'global', headers: signed_headers) + return fallback_response if fallback_response.code.to_i == 200 end + + return response end - # Retry request on AWS STS global endpoint - def retry_signed_request_on_global(signed_headers) - @logger.debug( - LogMessages::Authentication::AuthnIam::RetryWithGlobalEndpoint.new - ) - aws_request = URI('https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15') + def aws_call(region:, headers:) + host = if region == 'global' + 'sts.amazonaws.com' + else + "sts.#{region}.amazonaws.com" + end + aws_request = URI("https://#{host}/?Action=GetCallerIdentity&Version=2011-06-15") begin - @client.get_response(aws_request, signed_headers) - - # Handle any network failures with a generic verification error + @client.get_response(aws_request, headers) rescue StandardError => e - raise(Errors::Authentication::AuthnIam::VerificationError.new(e)) + # Handle any network failures with a generic verification error + raise(Errors::Authentication::AuthnIam::VerificationError, e) end end @@ -97,15 +98,23 @@ def response_from_signed_request(aws_headers) ) end - # Extract AWS region from the authorization header's credential string, i.e.: + # Extracts the STS region from the host header if it exists. + # If not, we use the authorization header's credential string, i.e.: # Credential=AKIAIOSFODNN7EXAMPLE/20220830/us-east-1/sts/aws4_request - def extract_sts_host(signed_headers) - return signed_headers['host'] if signed_headers['host'].present? + def extract_sts_region(signed_headers) + host = signed_headers['host'] - region = signed_headers['authorization'].match(%r{Credential=[^/]+/[^/]+/([^/]+)/})&.captures&.first - raise(Errors::Authentication::AuthnIam::InvalidAWSHeaders, 'Failed to extract AWS region from authorization header') unless region + if host == 'sts.amazonaws.com' + return 'global' + end + + match = host&.match(%r{sts.([\w\-]+).amazonaws.com}) + return match.captures.first if match - "sts.#{region}.amazonaws.com" + match = signed_headers['authorization']&.match(%r{Credential=[^/]+/[^/]+/([^/]+)/}) + return match.captures.first if match + + raise Errors::Authentication::AuthnIam::InvalidAWSHeaders, 'Failed to extract AWS region from authorization header' end end end