From f4abcc54778fc1849f45272fb767b6e7e9d013f6 Mon Sep 17 00:00:00 2001 From: Jean-Vincent D'Adda Date: Fri, 19 Apr 2024 09:01:56 +0200 Subject: [PATCH] fix: add missing InstanceProfileCredentialsProvider --- .../iam/internals/MSKCredentialProvider.java | 4 +- .../internals/MSKCredentialProviderTest.java | 131 ++++++++++++++++-- 2 files changed, 124 insertions(+), 11 deletions(-) diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java index 80850b0..3945400 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java @@ -36,6 +36,7 @@ import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.auth.credentials.SystemPropertyCredentialsProvider; @@ -157,7 +158,8 @@ protected AwsCredentialsProvider getDefaultProvider() { SystemPropertyCredentialsProvider.create(), WebIdentityTokenFileCredentialsProvider.create(), ProfileCredentialsProvider.create(), - ContainerCredentialsProvider.builder().build() + ContainerCredentialsProvider.builder().build(), + InstanceProfileCredentialsProvider.create() ); } diff --git a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java index fb90425..728f6ac 100644 --- a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java +++ b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java @@ -36,6 +36,7 @@ import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.auth.credentials.SystemPropertyCredentialsProvider; import software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider; @@ -132,7 +133,8 @@ protected AwsCredentialsProvider getDefaultProvider() { SystemPropertyCredentialsProvider.create(), WebIdentityTokenFileCredentialsProvider.create(), ProfileCredentialsProvider.builder().profileFile(profileFile).build(), - ContainerCredentialsProvider.builder().build() + ContainerCredentialsProvider.builder().build(), + InstanceProfileCredentialsProvider.create() ); } }; @@ -245,19 +247,54 @@ StsClient getStsClientForDebuggingCreds(AwsCredentials credentials) { } @Test - public void testEc2CredsWithDebugCredsNoAccessToSts_Succeed() { + public void testEcsCredsWithDebugCredsNoAccessToSts_Succeed() { Map optionsMap = new HashMap<>(); optionsMap.put(AWS_DEBUG_CREDS_NAME, "true"); - ContainerCredentialsProvider mockEc2CredsProvider = Mockito.mock(ContainerCredentialsProvider.class); - Mockito.when(mockEc2CredsProvider.resolveIdentity()) + ContainerCredentialsProvider mockEcsCredsProvider = Mockito.mock(ContainerCredentialsProvider.class); + Mockito.when(mockEcsCredsProvider.resolveIdentity()) .thenAnswer(i -> CompletableFuture.completedFuture(AwsBasicCredentials.create(ACCESS_KEY_VALUE_TWO, SECRET_KEY_VALUE_TWO))); StsClient mockSts = Mockito.mock(StsClient.class); Mockito.when(mockSts.getCallerIdentity()) .thenThrow(SdkClientException.create("TEST TEST")); + MSKCredentialProvider provider = new MSKCredentialProvider(optionsMap) { + protected AwsCredentialsProvider getDefaultProvider() { + return mockEcsCredsProvider; + } + + StsClient getStsClientForDebuggingCreds(AwsCredentials credentials) { + return mockSts; + } + }; + assertTrue(provider.getShouldDebugCreds()); + + AwsCredentials credentials = provider.resolveCredentials(); + + validateBasicCredentialsTwo(credentials); + + provider.close(); + Mockito.verify(mockSts, times(1)).getCallerIdentity(); + Mockito.verify(mockEcsCredsProvider, times(1)).resolveIdentity(); + Mockito.verifyNoMoreInteractions(mockEcsCredsProvider); + } + + @Test + public void testEc2CredsWithDebugCredsNoAccessToSts_Succeed() { + Map optionsMap = new HashMap<>(); + optionsMap.put(AWS_DEBUG_CREDS_NAME, "true"); + + + InstanceProfileCredentialsProvider mockEc2CredsProvider = Mockito.mock(InstanceProfileCredentialsProvider.class); + Mockito.when(mockEc2CredsProvider.resolveIdentity()) + .thenAnswer(i -> CompletableFuture.completedFuture(AwsBasicCredentials.create(ACCESS_KEY_VALUE_TWO, SECRET_KEY_VALUE_TWO))); + + StsClient mockSts = Mockito.mock(StsClient.class); + Mockito.when(mockSts.getCallerIdentity()) + .thenThrow(SdkClientException.create("TEST TEST")); + MSKCredentialProvider provider = new MSKCredentialProvider(optionsMap) { protected AwsCredentialsProvider getDefaultProvider() { return mockEc2CredsProvider; @@ -460,7 +497,7 @@ public void testEc2CredsWithSixRetriableErrorsCustomRetry_ThrowsException() { Map optionsMap = new HashMap<>(); optionsMap.put("awsMaxRetries", "5"); - AwsCredentialsProvider mockEc2CredsProvider = setupMockDefaultProviderWithRetriableExceptions(numExceptions); + AwsCredentialsProvider mockEc2CredsProvider = setupMockEc2DefaultProviderWithRetriableExceptions(numExceptions); MSKCredentialProvider provider = new MSKCredentialProvider(optionsMap) { protected AwsCredentialsProvider getDefaultProvider() { @@ -481,7 +518,7 @@ public void testEc2CredsWithOnrRetriableErrorsCustomZeroRetry_ThrowsException() Map optionsMap = new HashMap<>(); optionsMap.put("awsMaxRetries", "0"); - AwsCredentialsProvider mockEc2CredsProvider = setupMockDefaultProviderWithRetriableExceptions(numExceptions); + AwsCredentialsProvider mockEc2CredsProvider = setupMockEc2DefaultProviderWithRetriableExceptions(numExceptions); MSKCredentialProvider provider = new MSKCredentialProvider(optionsMap) { protected AwsCredentialsProvider getDefaultProvider() { @@ -500,7 +537,7 @@ private void testEc2CredsWithRetriableErrorsCustomRetry(int numExceptions) { Map optionsMap = new HashMap<>(); optionsMap.put("awsMaxRetries", "5"); - AwsCredentialsProvider mockEc2CredsProvider = setupMockDefaultProviderWithRetriableExceptions(numExceptions); + AwsCredentialsProvider mockEc2CredsProvider = setupMockEc2DefaultProviderWithRetriableExceptions(numExceptions); MSKCredentialProvider provider = new MSKCredentialProvider(optionsMap) { protected AwsCredentialsProvider getDefaultProvider() { @@ -518,6 +555,70 @@ protected AwsCredentialsProvider getDefaultProvider() { Mockito.verifyNoMoreInteractions(mockEc2CredsProvider); } + @Test + public void testEcsCredsWithSixRetriableErrorsCustomRetry_ThrowsException() { + int numExceptions = 6; + Map optionsMap = new HashMap<>(); + optionsMap.put("awsMaxRetries", "5"); + + AwsCredentialsProvider mockEcsCredsProvider = setupMockEcsDefaultProviderWithRetriableExceptions(numExceptions); + + MSKCredentialProvider provider = new MSKCredentialProvider(optionsMap) { + protected AwsCredentialsProvider getDefaultProvider() { + return mockEcsCredsProvider; + } + }; + assertFalse(provider.getShouldDebugCreds()); + + assertThrows(SdkClientException.class, () -> provider.resolveCredentials()); + + Mockito.verify(mockEcsCredsProvider, times(numExceptions)).resolveIdentity(); + Mockito.verifyNoMoreInteractions(mockEcsCredsProvider); + } + + @Test + public void testEcsCredsWithOnrRetriableErrorsCustomZeroRetry_ThrowsException() { + int numExceptions = 1; + Map optionsMap = new HashMap<>(); + optionsMap.put("awsMaxRetries", "0"); + + AwsCredentialsProvider mockEcsCredsProvider = setupMockEcsDefaultProviderWithRetriableExceptions(numExceptions); + + MSKCredentialProvider provider = new MSKCredentialProvider(optionsMap) { + protected AwsCredentialsProvider getDefaultProvider() { + return mockEcsCredsProvider; + } + }; + assertFalse(provider.getShouldDebugCreds()); + + assertThrows(SdkClientException.class, () -> provider.resolveCredentials()); + + Mockito.verify(mockEcsCredsProvider, times(numExceptions)).resolveIdentity(); + Mockito.verifyNoMoreInteractions(mockEcsCredsProvider); + } + + private void testEcsCredsWithRetriableErrorsCustomRetry(int numExceptions) { + Map optionsMap = new HashMap<>(); + optionsMap.put("awsMaxRetries", "5"); + + AwsCredentialsProvider mockEcsCredsProvider = setupMockEcsDefaultProviderWithRetriableExceptions(numExceptions); + + MSKCredentialProvider provider = new MSKCredentialProvider(optionsMap) { + protected AwsCredentialsProvider getDefaultProvider() { + return mockEcsCredsProvider; + } + }; + assertFalse(provider.getShouldDebugCreds()); + + AwsCredentials credentials = provider.resolveCredentials(); + + validateBasicCredentialsTwo(credentials); + + provider.close(); + Mockito.verify(mockEcsCredsProvider, times(numExceptions + 1)).resolveIdentity(); + Mockito.verifyNoMoreInteractions(mockEcsCredsProvider); + } + private void testRoleCredsWithRetriableErrors(int numExceptions) { StsAssumeRoleCredentialsProvider mockStsRoleProvider = setupMockStsRoleCredentialsProviderWithRetriableExceptions( numExceptions); @@ -608,13 +709,23 @@ private SdkException[] getSdkBaseExceptions(int numErrors) { .collect(Collectors.toList()).toArray(new SdkException[numErrors]); } - private AwsCredentialsProvider setupMockDefaultProviderWithRetriableExceptions(int numErrors) { + private AwsCredentialsProvider setupMockEcsDefaultProviderWithRetriableExceptions(int numErrors) { SdkException[] exceptionsToThrow = getSdkBaseExceptions(numErrors); - ContainerCredentialsProvider mockEc2Provider = Mockito.mock(ContainerCredentialsProvider.class); + ContainerCredentialsProvider mockEcsProvider = Mockito.mock(ContainerCredentialsProvider.class); - Mockito.when(mockEc2Provider.resolveIdentity()) + Mockito.when(mockEcsProvider.resolveIdentity()) .thenThrow(exceptionsToThrow) .thenAnswer(i -> CompletableFuture.completedFuture(AwsBasicCredentials.create(ACCESS_KEY_VALUE_TWO, SECRET_KEY_VALUE_TWO))); + return mockEcsProvider; + } + + private AwsCredentialsProvider setupMockEc2DefaultProviderWithRetriableExceptions(int numErrors) { + SdkException[] exceptionsToThrow = getSdkBaseExceptions(numErrors); + InstanceProfileCredentialsProvider mockEc2Provider = Mockito.mock(InstanceProfileCredentialsProvider.class); + + Mockito.when(mockEc2Provider.resolveIdentity()) + .thenThrow(exceptionsToThrow) + .thenAnswer(i -> CompletableFuture.completedFuture(AwsBasicCredentials.create(ACCESS_KEY_VALUE_TWO, SECRET_KEY_VALUE_TWO))); return mockEc2Provider; }