diff --git a/build.gradle b/build.gradle index 91cd604..0033142 100644 --- a/build.gradle +++ b/build.gradle @@ -41,8 +41,6 @@ group "software.amazon.msk" dependencies { compileOnly('org.apache.kafka:kafka-clients:2.8.1') // aws sdk imports. - implementation(platform('com.amazonaws:aws-java-sdk-bom:1.12.638')) - implementation('com.amazonaws:aws-java-sdk-core') implementation(platform('software.amazon.awssdk:bom:2.23.3')) implementation('software.amazon.awssdk:auth') implementation('software.amazon.awssdk:sso') diff --git a/src/main/java/software/amazon/msk/auth/iam/CompatibilityHelper.java b/src/main/java/software/amazon/msk/auth/iam/CompatibilityHelper.java deleted file mode 100644 index 4943e9a..0000000 --- a/src/main/java/software/amazon/msk/auth/iam/CompatibilityHelper.java +++ /dev/null @@ -1,16 +0,0 @@ -package software.amazon.msk.auth.iam; - -import software.amazon.awssdk.regions.Region; - -public class CompatibilityHelper { - - /** - * Convert region from v1 to v2 - * - * @param region v1 region - * @return v2 region - */ - public static Region toV2Region(com.amazonaws.regions.Region region) { - return Region.of(region.getName()); - } -} diff --git a/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java b/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java index 7ff63f3..ac35f4f 100644 --- a/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java +++ b/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java @@ -27,13 +27,11 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; -import org.apache.http.NameValuePair; -import org.apache.http.client.utils.URLEncodedUtils; import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant; import software.amazon.awssdk.utils.StringUtils; +import software.amazon.msk.auth.iam.internals.utils.URIUtils; /** * Implements the contract provided by OAuthBearerToken interface @@ -61,15 +59,12 @@ public IAMOAuthBearerToken(String token) throws URISyntaxException { byte[] decodedBytes = Base64.getUrlDecoder().decode(tokenBytes); final String decodedPresignedUrl = new String(decodedBytes, StandardCharsets.UTF_8); final URI uri = new URI(decodedPresignedUrl); - List params = URLEncodedUtils.parse(uri, StandardCharsets.UTF_8); - Map paramMap = params.stream() - .collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue)); - int lifeTimeSeconds = Integer.parseInt(paramMap.get(SignerConstant.X_AMZ_EXPIRES)); + + Map> params = URIUtils.parseQueryParams(uri); + int lifeTimeSeconds = Integer.parseInt(params.get(SignerConstant.X_AMZ_EXPIRES).get(0)); final DateTimeFormatter dateFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'"); - final LocalDateTime signedDate = LocalDateTime.parse(paramMap.get(SignerConstant.X_AMZ_DATE), dateFormat); - long signedDateEpochMillis = signedDate.toInstant(ZoneOffset.UTC) - .toEpochMilli(); - this.startTimeMs = signedDateEpochMillis; + final LocalDateTime signedDate = LocalDateTime.parse(params.get(SignerConstant.X_AMZ_DATE).get(0), dateFormat); + this.startTimeMs = signedDate.toInstant(ZoneOffset.UTC).toEpochMilli(); this.lifetimeMs = this.startTimeMs + (lifeTimeSeconds * 1000L); } diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/AuthenticationRequestParams.java b/src/main/java/software/amazon/msk/auth/iam/internals/AuthenticationRequestParams.java index 63971bc..816218d 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/AuthenticationRequestParams.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/AuthenticationRequestParams.java @@ -15,19 +15,14 @@ */ package software.amazon.msk.auth.iam.internals; -import static software.amazon.msk.auth.iam.CompatibilityHelper.toV2Region; - -import com.amazonaws.regions.RegionMetadata; -import com.amazonaws.partitions.PartitionsLoader; -import com.amazonaws.regions.Regions; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; -import java.util.Optional; import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.regions.Region; +import software.amazon.msk.auth.iam.internals.utils.RegionUtils; /** * This class represents the parameters that will be used to generate the Sigv4 signature @@ -40,9 +35,6 @@ public class AuthenticationRequestParams { private static final String VERSION_1 = "2020_10_22"; private static final String SERVICE_SCOPE = "kafka-cluster"; - private static RegionMetadata regionMetadata = new RegionMetadata(new PartitionsLoader().build()); - /* we are not using the RegionMetadataFactory.create() method here as one of its path - relies on the LegacyRegionXmlMetadataBuilder which does not implement tryGetRegionByEndpointDnsSuffix */ @NonNull private final String version; @@ -60,13 +52,12 @@ public String getServiceScope() { } public static AuthenticationRequestParams create(@NonNull String host, - AwsCredentials credentials, - @NonNull String userAgent) throws IllegalArgumentException { - com.amazonaws.regions.Region region = Optional.ofNullable(regionMetadata.tryGetRegionByEndpointDnsSuffix(host)) - .orElseGet(() -> Regions.getCurrentRegion()); + AwsCredentials credentials, + @NonNull String userAgent) throws IllegalArgumentException { + Region region = RegionUtils.extractRegionFromHost(host); if (region == null) { throw new IllegalArgumentException("Host " + host + " does not belong to a valid region."); } - return new AuthenticationRequestParams(VERSION_1, host, credentials, toV2Region(region), userAgent); + return new AuthenticationRequestParams(VERSION_1, host, credentials, region, userAgent); } } diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/UserAgentUtils.java b/src/main/java/software/amazon/msk/auth/iam/internals/UserAgentUtils.java index fd05322..e2422bc 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/UserAgentUtils.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/UserAgentUtils.java @@ -25,8 +25,6 @@ import java.util.StringJoiner; import software.amazon.awssdk.core.util.SdkUserAgent; -import static com.amazonaws.util.IOUtils.closeQuietly; - /** * This class is used to generate the user agent for the authentication request. */ @@ -58,19 +56,16 @@ private static final String generateUserAgentString(String[] components) { private static String getLibraryVersion() { String version = "unknown-version"; - InputStream inputStream = getVersionInfoFileAsStream(); - Properties versionProperties = new Properties(); - try { + try (InputStream inputStream = getVersionInfoFileAsStream()) { if (inputStream == null) { log.info("Unable to load version information for msk iam auth plugin"); } else { + Properties versionProperties = new Properties(); versionProperties.load(inputStream); version = versionProperties.getProperty("version"); } } catch (Exception e) { log.info("Unable to load version information for the running SDK: " + e.getMessage()); - } finally { - closeQuietly(inputStream, null); } return version; } diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/utils/RegionUtils.java b/src/main/java/software/amazon/msk/auth/iam/internals/utils/RegionUtils.java new file mode 100644 index 0000000..de41090 --- /dev/null +++ b/src/main/java/software/amazon/msk/auth/iam/internals/utils/RegionUtils.java @@ -0,0 +1,21 @@ +package software.amazon.msk.auth.iam.internals.utils; + +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; + +public class RegionUtils { + + /** + * Try to extract the region from the host. If the region is not found, return the default region + * from the DefaultAwsRegionProviderChain. + * + * @param host The host to extract the region from. + * @return The region extracted from the host. + */ + public static Region extractRegionFromHost(String host) { + return Region.regions().stream() + .filter(region -> host.contains(region.id())) + .findFirst() + .orElseGet(() -> DefaultAwsRegionProviderChain.builder().build().getRegion()); + } +} diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/utils/URIUtils.java b/src/main/java/software/amazon/msk/auth/iam/internals/utils/URIUtils.java new file mode 100644 index 0000000..41dbbe2 --- /dev/null +++ b/src/main/java/software/amazon/msk/auth/iam/internals/utils/URIUtils.java @@ -0,0 +1,50 @@ +package software.amazon.msk.auth.iam.internals.utils; + +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.toList; + +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class URIUtils { + + /** + * Parse the query parameters from the URI. + * + * @param url The URI to parse. + * @return A map of query parameters. + */ + public static Map> parseQueryParams(URI url) { + if (url.getQuery() == null || url.getQuery().isEmpty()) { + return Collections.emptyMap(); + } + return Arrays.stream(url.getQuery().split("&")) + .map(URIUtils::splitQueryParameter) + .collect(Collectors.groupingBy(SimpleImmutableEntry::getKey, LinkedHashMap::new, + mapping(Map.Entry::getValue, toList()))); + } + + private static SimpleImmutableEntry splitQueryParameter(String it) { + final int idx = it.indexOf("="); + final String key = idx > 0 ? it.substring(0, idx) : it; + final String value = idx > 0 && it.length() > idx + 1 ? it.substring(idx + 1) : null; + return new SimpleImmutableEntry<>(decodeSilently(key), decodeSilently(value)); + } + + private static String decodeSilently(String s) { + try { + return URLDecoder.decode(s, StandardCharsets.UTF_8.name()); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandlerTest.java b/src/test/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandlerTest.java index 4e83dff..d38eea6 100644 --- a/src/test/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandlerTest.java +++ b/src/test/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandlerTest.java @@ -29,11 +29,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import javax.security.auth.callback.Callback; import javax.security.auth.callback.UnsupportedCallbackException; -import org.apache.http.NameValuePair; -import org.apache.http.client.utils.URLEncodedUtils; import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; @@ -41,6 +38,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant; +import software.amazon.msk.auth.iam.internals.utils.URIUtils; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -191,17 +189,15 @@ private void assertTokenValidity(OAuthBearerToken token, String region, String a Assertions.assertEquals(String.format("kafka.%s.amazonaws.com", region), uri.getHost()); Assertions.assertEquals("https", uri.getScheme()); - List params = URLEncodedUtils.parse(uri, StandardCharsets.UTF_8); - Map paramMap = params.stream() - .collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue)); - Assertions.assertEquals("kafka-cluster:Connect", paramMap.get("Action")); - Assertions.assertEquals(SignerConstant.AWS4_SIGNING_ALGORITHM, paramMap.get(SignerConstant.X_AMZ_ALGORITHM)); - final Integer expirySeconds = Integer.parseInt(paramMap.get(SignerConstant.X_AMZ_EXPIRES)); + Map> params = URIUtils.parseQueryParams(uri); + Assertions.assertEquals("kafka-cluster:Connect", params.get("Action").get(0)); + Assertions.assertEquals(SignerConstant.AWS4_SIGNING_ALGORITHM, params.get(SignerConstant.X_AMZ_ALGORITHM).get(0)); + final Integer expirySeconds = Integer.parseInt(params.get(SignerConstant.X_AMZ_EXPIRES).get(0)); Assertions.assertTrue(expirySeconds <= 900); - Assertions.assertTrue(token.lifetimeMs() <= System.currentTimeMillis() + Integer.parseInt(paramMap.get(SignerConstant.X_AMZ_EXPIRES)) * 1000); - Assertions.assertEquals(sessionToken, paramMap.get(SignerConstant.X_AMZ_SECURITY_TOKEN)); - Assertions.assertEquals("host", paramMap.get(SignerConstant.X_AMZ_SIGNED_HEADERS)); - String credential = paramMap.get(SignerConstant.X_AMZ_CREDENTIAL); + Assertions.assertTrue(token.lifetimeMs() <= System.currentTimeMillis() + Integer.parseInt(params.get(SignerConstant.X_AMZ_EXPIRES).get(0)) * 1000); + Assertions.assertEquals(sessionToken, params.get(SignerConstant.X_AMZ_SECURITY_TOKEN).get(0)); + Assertions.assertEquals("host", params.get(SignerConstant.X_AMZ_SIGNED_HEADERS).get(0)); + String credential = params.get(SignerConstant.X_AMZ_CREDENTIAL).get(0); Assertions.assertNotNull(credential); String[] credentialArray = credential.split("/"); Assertions.assertEquals(5, credentialArray.length); @@ -209,14 +205,14 @@ private void assertTokenValidity(OAuthBearerToken token, String region, String a Assertions.assertEquals("kafka-cluster", credentialArray[3]); Assertions.assertEquals(SignerConstant.AWS4_TERMINATOR, credentialArray[4]); DateTimeFormatter dateFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'"); - final LocalDateTime signedDate = LocalDateTime.parse(paramMap.get(SignerConstant.X_AMZ_DATE), dateFormat); + final LocalDateTime signedDate = LocalDateTime.parse(params.get(SignerConstant.X_AMZ_DATE).get(0), dateFormat); long signedDateEpochMillis = signedDate.toInstant(ZoneOffset.UTC) .toEpochMilli(); Assertions.assertTrue(signedDateEpochMillis <= Instant.now() .toEpochMilli()); Assertions.assertEquals(signedDateEpochMillis, token.startTimeMs()); Assertions.assertEquals(signedDateEpochMillis + expirySeconds * 1000, token.lifetimeMs()); - String userAgent = paramMap.get("User-Agent"); + String userAgent = params.get("User-Agent").get(0); Assertions.assertNotNull(userAgent); Assertions.assertTrue(userAgent.startsWith("aws-msk-iam-auth")); } diff --git a/src/test/java/software/amazon/msk/auth/iam/internals/AuthenticateRequestParamsTest.java b/src/test/java/software/amazon/msk/auth/iam/internals/AuthenticateRequestParamsTest.java index a2c6960..e7b72f0 100644 --- a/src/test/java/software/amazon/msk/auth/iam/internals/AuthenticateRequestParamsTest.java +++ b/src/test/java/software/amazon/msk/auth/iam/internals/AuthenticateRequestParamsTest.java @@ -15,14 +15,14 @@ */ package software.amazon.msk.auth.iam.internals; +import org.mockito.MockedStatic; +import org.mockito.Mockito; import software.amazon.awssdk.regions.Region; -import com.amazonaws.regions.Regions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.MockedStatic; -import org.mockito.Mockito; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.msk.auth.iam.internals.utils.RegionUtils; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -34,7 +34,6 @@ public class AuthenticateRequestParamsTest { private static final String ACCESS_KEY = "ACCESS_KEY"; private static final String SECRET_KEY = "SECRET_KEY"; private static final String USER_AGENT = "USER_AGENT"; - private static final Region TEST_EC2_REGION = Region.US_WEST_1; @BeforeEach public void setup() { @@ -56,21 +55,25 @@ public void testAllProperties() { @Test public void testInvalidHost() { - try (MockedStatic regionsMockedStatic = Mockito.mockStatic(Regions.class)) { - regionsMockedStatic.when(Regions::getCurrentRegion).thenReturn(null); + try(MockedStatic mockStatic = Mockito.mockStatic(RegionUtils.class)) { + mockStatic + .when(() -> RegionUtils.extractRegionFromHost(HOSTNAME_NO_REGION)) + .thenReturn(null); + assertThrows(IllegalArgumentException.class, - () -> AuthenticationRequestParams.create(HOSTNAME_NO_REGION, credentials, USER_AGENT)); + () -> AuthenticationRequestParams.create(HOSTNAME_NO_REGION, credentials, USER_AGENT)); } } @Test public void testInvalidHostInEC2() { - try (MockedStatic regionsMockedStatic = Mockito.mockStatic(Regions.class)) { - regionsMockedStatic.when(Regions::getCurrentRegion) - .thenReturn(com.amazonaws.regions.Region.getRegion(Regions.US_WEST_1)); + try(MockedStatic mockStatic = Mockito.mockStatic(RegionUtils.class)) { + mockStatic + .when(() -> RegionUtils.extractRegionFromHost(HOSTNAME_NO_REGION)) + .thenReturn(Region.US_WEST_1); + AuthenticationRequestParams params = AuthenticationRequestParams.create(HOSTNAME_NO_REGION, credentials, USER_AGENT); assertEquals(Region.US_WEST_1, params.getRegion()); } } - }