Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to AWS SDK v2 #1179

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ThisBuild / scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked", "-Xf

resolvers += DefaultMavenRepository

val awsSdkVersion = "1.12.778"
val awsSdkVersion = "2.29.15"
val playJsonVersion = "3.0.4"
val jacksonVersion = "2.18.1"

Expand All @@ -31,32 +31,33 @@ lazy val hq = (project in file("hq"))
libraryDependencies ++= Seq(
ws,
filters,
"com.gu.play-googleauth" %% "play-v30" % "15.1.1",
"com.gu.play-secret-rotation" %% "play-v30" % "11.3.8",
"com.gu.play-secret-rotation" %% "aws-parameterstore-sdk-v1" % "11.3.8",
"com.gu.play-googleauth" %% "play-v30" % "16.0.0",
"com.gu.play-secret-rotation" %% "play-v30" % "12.0.0",
"com.gu.play-secret-rotation" %% "aws-parameterstore-sdk-v2" % "12.0.0",

"joda-time" % "joda-time" % "2.13.0",
"org.typelevel" %% "cats-core" % "2.12.0",
"com.github.tototoshi" %% "scala-csv" % "2.0.0",
"com.amazonaws" % "aws-java-sdk-s3" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-iam" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-sts" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-support" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-ec2" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-cloudformation" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-efs" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-cloudwatch" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-dynamodb" % awsSdkVersion,
"software.amazon.awssdk" % "iam" % awsSdkVersion,
"software.amazon.awssdk" % "cloudformation" % awsSdkVersion,
"software.amazon.awssdk" % "cloudwatch" % awsSdkVersion,
"software.amazon.awssdk" % "dynamodb" % awsSdkVersion,
"software.amazon.awssdk" % "ec2" % awsSdkVersion,
"software.amazon.awssdk" % "efs" % awsSdkVersion,
"software.amazon.awssdk" % "s3" % awsSdkVersion,
"software.amazon.awssdk" % "sns" % awsSdkVersion,
"software.amazon.awssdk" % "ssm" % awsSdkVersion,
"software.amazon.awssdk" % "sts" % awsSdkVersion,
"software.amazon.awssdk" % "support" % awsSdkVersion,
"com.vladsch.flexmark" % "flexmark" % "0.64.8",
"com.amazonaws" % "aws-java-sdk-sns" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-ssm" % awsSdkVersion,
"io.reactivex" %% "rxscala" % "0.27.0",
"com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion,
"com.fasterxml.jackson.module" %% "jackson-module-scala" % jacksonVersion,
"org.scalatest" %% "scalatest" % "3.2.14" % Test,
"org.scalatestplus" %% "scalacheck-1-16" % "3.2.14.0" % Test,
"org.scalacheck" %% "scalacheck" % "1.18.1" % Test,
"com.github.alexarchambault" %% "scalacheck-shapeless_1.15" % "1.3.0" % Test,
"com.gu" %% "anghammarad-client" % "3.0.0",
"com.gu" %% "anghammarad-client" % "4.0.0",
"ch.qos.logback" % "logback-classic" % "1.5.12",


Expand Down Expand Up @@ -127,13 +128,11 @@ lazy val lambdaCommon = (project in file("lambda/common")).
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-lambda-java-events" % "3.14.0",
"com.amazonaws" % "aws-lambda-java-core" % "1.2.3",
"com.amazonaws" % "aws-java-sdk-lambda" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-config" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-elasticloadbalancing" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-config" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-sns" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-sts" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-s3" % awsSdkVersion,
"software.amazon.awssdk" % "s3" % awsSdkVersion,
"software.amazon.awssdk" % "elasticloadbalancingv2" % awsSdkVersion,
"software.amazon.awssdk" % "sts" % awsSdkVersion,
"software.amazon.awssdk" % "sns" % awsSdkVersion,

"org.scalatest" %% "scalatest" % "3.2.19" % Test,
"org.playframework" %% "play-json" % playJsonVersion,
"com.typesafe.scala-logging" %% "scala-logging" % "3.9.5",
Expand All @@ -149,7 +148,7 @@ lazy val lambdaSecurityGroups = (project in file("lambda/security-groups")).
name := """securitygroups-lambda""",
assembly / assemblyJarName := s"${name.value}-${version.value}.jar",
libraryDependencies ++= Seq(
"com.gu" %% "anghammarad-client" % "3.0.0"
"com.gu" %% "anghammarad-client" % "4.0.0"
)
)

Expand Down
107 changes: 56 additions & 51 deletions hq/app/AppComponents.scala
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
import aws.ec2.EC2
import aws.{AWS, AwsClient}
import com.amazonaws.ClientConfiguration
import com.amazonaws.auth.profile.ProfileCredentialsProvider
import com.amazonaws.auth.{
AWSCredentialsProviderChain,
DefaultAWSCredentialsProviderChain
}
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
import com.amazonaws.regions.{Region, RegionUtils, Regions}
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder
import com.amazonaws.services.ec2.AmazonEC2AsyncClientBuilder
import com.amazonaws.services.s3.AmazonS3ClientBuilder
import com.amazonaws.services.simplesystemsmanagement.AWSSimpleSystemsManagementClientBuilder
import com.amazonaws.services.sns.AmazonSNSAsyncClientBuilder
import config.Config
import controllers._
import db.IamRemediationDb
Expand All @@ -31,8 +18,26 @@ import utils.attempt.Attempt

import scala.concurrent.Await
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.language.postfixOps

import software.amazon.awssdk.core.client.config.ClientAsyncConfiguration
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider
import software.amazon.awssdk.http.async.SdkAsyncHttpClient
import software.amazon.awssdk.http.SdkHttpConfigurationOption
import software.amazon.awssdk.services.ec2.Ec2AsyncClient
import software.amazon.awssdk.services.sns.SnsAsyncClient
import software.amazon.awssdk.services.s3.S3Client
import software.amazon.awssdk.services.dynamodb.DynamoDbClient
import software.amazon.awssdk.services.dynamodb.endpoints.{DynamoDbEndpointProvider, DynamoDbEndpointParams}
import software.amazon.awssdk.services.ssm.SsmClient
import software.amazon.awssdk.utils.AttributeMap

import software.amazon.awssdk.core.internal.http.loader.DefaultSdkAsyncHttpClientBuilder

class AppComponents(context: Context)
extends BuiltInComponentsFromContext(context)
with CSRFComponents
Expand All @@ -57,9 +62,8 @@ class AppComponents(context: Context)
// the aim of this is to get all the regions that are available to this account
private val availableRegions: List[Region] = {
val ec2Client = AwsClient(
AmazonEC2AsyncClientBuilder
.standard()
.withRegion(Config.region.getName)
Ec2AsyncClient.builder
.region(Config.region)
.build(),
AwsAccount(stack, stack, stack, stack),
Config.region
Expand All @@ -68,24 +72,24 @@ class AppComponents(context: Context)
val availableRegionsAttempt: Attempt[List[Region]] = for {
ec2RegionList <- EC2.getAvailableRegions(ec2Client)
regionList = ec2RegionList.map(ec2Region =>
RegionUtils.getRegion(ec2Region.getRegionName)
Region.of(ec2Region.regionName)
)
} yield regionList
Await
.result(availableRegionsAttempt.asFuture, 30 seconds)
.getOrElse(List(Config.region, RegionUtils.getRegion("us-east-1")))
.getOrElse(List(Config.region, Region.of("us-east-1")))
} finally {
ec2Client.client.shutdown()
ec2Client.client.close()
}
}

logger.info(
s"Polling in the following regions: ${availableRegions.map(_.getName).mkString(", ")}"
s"Polling in the following regions: ${availableRegions.map(_.id).mkString(", ")}"
)

val regionsNotInSdk: Set[String] = availableRegions
.map(_.getName)
.toSet -- Regions.values().map(_.getName).toSet
.map(_.id)
.toSet -- Region.regions.asScala.map(_.id).toSet
if (regionsNotInSdk.nonEmpty) {
logger.warn(
s"Regions exist that are not in the current SDK (${regionsNotInSdk.mkString(", ")}), update your SDK!"
Expand All @@ -97,47 +101,48 @@ class AppComponents(context: Context)
private val s3Clients = AWS.s3Clients(configuration, availableRegions)
private val iamClients = AWS.iamClients(configuration, availableRegions)

private val securityCredentialsProvider = new AWSCredentialsProviderChain(
new ProfileCredentialsProvider("security"),
DefaultAWSCredentialsProviderChain.getInstance()
private val securityCredentialsProvider: AwsCredentialsProviderChain = AwsCredentialsProviderChain.of(
ProfileCredentialsProvider.create("security"),
DefaultCredentialsProvider.create()
)
private val securitySnsClient = AmazonSNSAsyncClientBuilder
.standard()
.withCredentials(securityCredentialsProvider)
.withRegion(Config.region.getName)
.withClientConfiguration(new ClientConfiguration().withMaxConnections(10))

/*
The casting from SdkHttpConfigurationOption[Integer] to AttributeMap.Key[Any] is required because Scala compiler comlains

Integer <: Any, but Java-defined class Key is invariant in type T.
You may wish to investigate a wildcard type such as `_ <: Any`. (SLS 3.2.10)
*/
private val MAX_10_CONNECTIONS: AttributeMap = AttributeMap.builder().put(SdkHttpConfigurationOption.MAX_CONNECTIONS.asInstanceOf[AttributeMap.Key[Any]], 10).build()

private val securitySnsClient = SnsAsyncClient.builder
.credentialsProvider(securityCredentialsProvider)
.region(Config.region)
.httpClient(new DefaultSdkAsyncHttpClientBuilder().buildWithDefaults(MAX_10_CONNECTIONS))
.build()
private val securitySsmClient = AWSSimpleSystemsManagementClientBuilder
.standard()
.withCredentials(securityCredentialsProvider)
.withRegion(Config.region.getName)
private val securitySsmClient = SsmClient.builder
.credentialsProvider(securityCredentialsProvider)
.region(Config.region)
.build()
private val googleAuthConfig =
Config.googleSettings(stage, stack, configuration, securitySsmClient)

private val securityDynamoDbClient = stage match {
case PROD =>
AmazonDynamoDBClientBuilder
.standard()
.withCredentials(securityCredentialsProvider)
.withRegion(Config.region.getName)
DynamoDbClient.builder()
.credentialsProvider(securityCredentialsProvider)
.region(Config.region)
.build()
case DEV =>
AmazonDynamoDBClientBuilder
.standard()
.withCredentials(securityCredentialsProvider)
.withEndpointConfiguration(
new EndpointConfiguration(
"http://localhost:8000",
Config.region.getName
)
)
DynamoDbClient.builder()
.credentialsProvider(securityCredentialsProvider)
.region(Config.region)
.endpointOverride(new java.net.URI("http://localhost:8000")) //An alternative could be to configure a specific builder DynamoDbEndpointParams.builder().endpoint("http://localhost:8000").region(Config.region)
.build()
}
private val securityS3Client = AmazonS3ClientBuilder
.standard()
.withCredentials(securityCredentialsProvider)
.withRegion(Config.region.getName)
private val securityS3Client = S3Client
.builder
.credentialsProvider(securityCredentialsProvider)
.region(Config.region)
.build()

private val cacheService = new CacheService(
Expand Down
80 changes: 47 additions & 33 deletions hq/app/aws/AWS.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
package aws

import com.amazonaws.ClientConfiguration
import com.amazonaws.auth.profile.ProfileCredentialsProvider
import com.amazonaws.auth.{AWSCredentialsProvider, AWSCredentialsProviderChain, STSAssumeRoleSessionCredentialsProvider}
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
import com.amazonaws.client.builder.{AwsAsyncClientBuilder, AwsClientBuilder}
import com.amazonaws.regions.{Region, RegionUtils, Regions}
import com.amazonaws.services.cloudformation.{AmazonCloudFormationAsync, AmazonCloudFormationAsyncClientBuilder}
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
import com.amazonaws.services.ec2.{AmazonEC2Async, AmazonEC2AsyncClientBuilder}
import com.amazonaws.services.elasticfilesystem.{AmazonElasticFileSystemAsync, AmazonElasticFileSystemAsyncClientBuilder}
import com.amazonaws.services.identitymanagement.{AmazonIdentityManagementAsync, AmazonIdentityManagementAsyncClientBuilder}
import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder}
import com.amazonaws.services.simplesystemsmanagement.{AWSSimpleSystemsManagement, AWSSimpleSystemsManagementClientBuilder}
import com.amazonaws.services.support.{AWSSupportAsync, AWSSupportAsyncClientBuilder}
import config.Config
import model.{AwsAccount, DEV, PROD, Stage}
import play.api.Configuration
import utils.attempt.{Attempt, Failure}

import software.amazon.awssdk.core.client.builder.SdkClientBuilder
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder
import software.amazon.awssdk.awscore.client.builder.AwsAsyncClientBuilder
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain
import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.iam.IamAsyncClient
import software.amazon.awssdk.services.sts.StsClient
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest
import software.amazon.awssdk.services.cloudformation.CloudFormationAsyncClient
import software.amazon.awssdk.services.dynamodb.DynamoDbClient
import software.amazon.awssdk.services.s3.S3Client
import software.amazon.awssdk.services.ec2.Ec2AsyncClient
import software.amazon.awssdk.services.efs.EfsAsyncClient
import software.amazon.awssdk.services.support.SupportAsyncClient



import java.util.concurrent.Executors.newCachedThreadPool

object AWS {
Expand All @@ -30,48 +38,54 @@ object AWS {
)
}

private def credentialsProvider(account: AwsAccount): AWSCredentialsProviderChain = {
new AWSCredentialsProviderChain(
new STSAssumeRoleSessionCredentialsProvider.Builder(account.roleArn, "security-hq").build(),
new ProfileCredentialsProvider(account.id)
private def stsClientForRoleAssumption(account: AwsAccount): StsClient = {
StsClient.builder.region(Config.region).credentialsProvider(ProfileCredentialsProvider.create(account.id)).build()
}

private def credentialsProvider(account: AwsAccount): AwsCredentialsProviderChain = {
AwsCredentialsProviderChain.of(
StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClientForRoleAssumption(account))
.refreshRequest(AssumeRoleRequest.builder.roleArn(account.roleArn).roleSessionName("security-hq").build()).build(),
ProfileCredentialsProvider.create(account.id)
)
}

private[aws] def clients[A, B <: AwsClientBuilder[B, A]](
builder: AwsClientBuilder[B, A],
builder: AwsClientBuilder[B, A],
configuration: Configuration,
regionList: Region*
): AwsClients[A] = {
for {
account <- Config.getAwsAccounts(configuration)
region <- regionList
client = builder
.withCredentials(credentialsProvider(account))
.withRegion(region.getName)
.withClientConfiguration(new ClientConfiguration().withMaxConnections(10))
.credentialsProvider(credentialsProvider(account))
.region(region)
.build()
} yield AwsClient(client, account, region)
}

private def withCustomThreadPool[A, B <: AwsAsyncClientBuilder[B, A]] = (asyncClientBuilder: AwsAsyncClientBuilder[B, A]) =>
asyncClientBuilder.withExecutorFactory(() => newCachedThreadPool())
asyncClientBuilder.asyncConfiguration(c => c.advancedOption(SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR, newCachedThreadPool())
)

def ec2Clients(configuration: Configuration, regions: List[Region]): AwsClients[AmazonEC2Async] =
clients(withCustomThreadPool(AmazonEC2AsyncClientBuilder.standard()), configuration, regions:_*)
def ec2Clients(configuration: Configuration, regions: List[Region]): AwsClients[Ec2AsyncClient] =
clients(withCustomThreadPool(Ec2AsyncClient.builder), configuration, regions:_*)

def cfnClients(configuration: Configuration, regions: List[Region]): AwsClients[AmazonCloudFormationAsync] =
clients(withCustomThreadPool(AmazonCloudFormationAsyncClientBuilder.standard()), configuration, regions:_*)
def cfnClients(configuration: Configuration, regions: List[Region]): AwsClients[CloudFormationAsyncClient] =
clients(withCustomThreadPool(CloudFormationAsyncClient.builder), configuration, regions:_*)

// Only needs Regions.US_EAST_1
def taClients(configuration: Configuration, region: Region = RegionUtils.getRegion("us-east-1")): AwsClients[AWSSupportAsync] =
clients(withCustomThreadPool(AWSSupportAsyncClientBuilder.standard()), configuration, region)
def taClients(configuration: Configuration, region: Region = Region.of("us-east-1")): AwsClients[SupportAsyncClient] =
clients(withCustomThreadPool(SupportAsyncClient.builder), configuration, region)

def s3Clients(configuration: Configuration, regions: List[Region]): AwsClients[AmazonS3] =
clients(AmazonS3ClientBuilder.standard(), configuration, regions:_*)
def s3Clients(configuration: Configuration, regions: List[Region]): AwsClients[S3Client] =
clients(S3Client.builder, configuration, regions:_*)

def iamClients(configuration: Configuration, regions: List[Region]): AwsClients[AmazonIdentityManagementAsync] =
clients(withCustomThreadPool(AmazonIdentityManagementAsyncClientBuilder.standard()), configuration, regions:_*)
def iamClients(configuration: Configuration, regions: List[Region]): AwsClients[IamAsyncClient] =
clients(withCustomThreadPool(IamAsyncClient.builder), configuration, regions:_*)

def efsClients(configuration: Configuration, regions: List[Region]): AwsClients[AmazonElasticFileSystemAsync] =
clients(withCustomThreadPool(AmazonElasticFileSystemAsyncClientBuilder.standard()), configuration, regions:_*)
def efsClients(configuration: Configuration, regions: List[Region]): AwsClients[EfsAsyncClient] =
clients(withCustomThreadPool(EfsAsyncClient.builder), configuration, regions:_*)
}
Loading