Skip to content

Commit

Permalink
Complete Scanamo update - support both v1 and v2 credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
rtyley committed Feb 1, 2024
1 parent d328c60 commit 7e5bb61
Show file tree
Hide file tree
Showing 15 changed files with 113 additions and 49 deletions.
2 changes: 1 addition & 1 deletion app/di.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MediaAtomMaker(context: Context)
private val capi = new Capi(config)

private val stores = new DataStores(aws, capi)
private val permissions = new MediaAtomMakerPermissionsProvider(aws.stage, aws.region.getName, aws.credentials.instance)
private val permissions = new MediaAtomMakerPermissionsProvider(aws.stage, aws.region.getName, aws.credentials.instance.v1)

private val reindexer = buildReindexer()

Expand Down
4 changes: 2 additions & 2 deletions app/util/AWS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class AWSConfig(override val config: Config, override val credentials: AwsCreden
lazy val ec2Client = AmazonEC2ClientBuilder
.standard()
.withRegion(region.getName)
.withCredentials(credentials.instance)
.withCredentials(credentials.instance.v1)
.build()

lazy val pinboardLoaderUrl = getString("panda.domain").map(domain => s"https://pinboard.$domain/pinboard.loader.js")
Expand All @@ -43,7 +43,7 @@ class AWSConfig(override val config: Config, override val credentials: AwsCreden
lazy val expiryPollerName = "Expiry"
lazy val expiryPollerLastName = "Poller"

final override def regionName = getString("aws.region")
final override def region = AwsAccess.regionFrom(this)

final override def readTag(tagName: String) = {
val tagsResult = ec2Client.describeTags(
Expand Down
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import scala.sys.process._

val scroogeVersion = "4.12.0"
val awsVersion = "1.11.678"
val awsV2Version = "2.21.17"
val pandaVersion = "3.0.1"
val atomMakerVersion = "2.0.0-PREVIEW.fix-tests-under-java-17.2024-01-18T1039.b4d55b3d"
val typesafeConfigVersion = "1.4.0" // to match what we get from Play transitively
Expand Down Expand Up @@ -96,11 +97,13 @@ lazy val common = (project in file("common"))
"com.typesafe" % "config" % typesafeConfigVersion,
"com.amazonaws" % "aws-lambda-java-core" % awsLambdaCoreVersion,
"com.amazonaws" % "aws-java-sdk-s3" % awsVersion,
"software.amazon.awssdk" % "dynamodb" % "2.21.17",
"com.amazonaws" % "aws-java-sdk-dynamodb" % awsVersion,
"software.amazon.awssdk" % "dynamodb" % awsV2Version,
"com.amazonaws" % "aws-java-sdk-kinesis" % awsVersion,
"ai.x" %% "play-json-extensions" % playJsonExtensionsVersion,
"ch.qos.logback" % "logback-classic" % logbackClassicVersion,
"com.amazonaws" % "aws-java-sdk-sts" % awsVersion,
"software.amazon.awssdk" % "sts" % awsV2Version,
"com.amazonaws" % "aws-java-sdk-elastictranscoder" % awsVersion,
"org.scanamo" %% "scanamo" % "1.0.0-M28",
"com.squareup.okhttp" % "okhttp" % okHttpVersion,
Expand All @@ -127,6 +130,7 @@ lazy val app = (project in file("."))
ehcache,
"com.fasterxml.jackson.core" % "jackson-databind" % jacksonDatabindVersion,
"com.amazonaws" % "aws-java-sdk-sts" % awsVersion,
"software.amazon.awssdk" % "sts" % awsV2Version,
"com.amazonaws" % "aws-java-sdk-ec2" % awsVersion,
"org.scalatestplus.play" %% "scalatestplus-play" % scalaTestPlusPlayVersion % "test",
"org.mockito" % "mockito-core" % mockitoVersion % "test",
Expand Down
17 changes: 10 additions & 7 deletions common/src/main/scala/com/gu/media/aws/AwsAccess.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@ import com.amazonaws.regions.{Region, Regions}
import com.gu.media.Settings

trait AwsAccess { this: Settings =>
def regionName: Option[String]
def readTag(tag: String): Option[String]

val credentials: AwsCredentials

val awsV2Credentials: software.amazon.awssdk.auth.credentials.AwsCredentialsProvider
// To avoid renaming references everywhere
def credsProvider: AWSCredentialsProvider = credentials.instance
def credsProvider: AWSCredentialsProvider = credentials.instance.v1

final def defaultRegion: Region = Region.getRegion(Regions.EU_WEST_1)
final def region: Region = regionName
.map { name => Region.getRegion(Regions.fromName(name)) }
.getOrElse(defaultRegion)
def region: Region

final def awsV2Region: software.amazon.awssdk.regions.Region =
software.amazon.awssdk.regions.Region.of(region.getName)
Expand All @@ -29,3 +24,11 @@ trait AwsAccess { this: Settings =>
final val stack: Option[String] = if (isDev) Some("media-atom-maker") else readTag("Stack")
final val app: String = if (isDev) "media-atom-maker" else readTag("App").getOrElse("media-atom-maker")
}

object AwsAccess {
def regionFrom(maybeName: Option[String]): Region = maybeName
.map { name => Region.getRegion(Regions.fromName(name)) }
.getOrElse(Region.getRegion(Regions.EU_WEST_1))

def regionFrom(settings: Settings): Region = regionFrom(settings.getString("aws.region"))
}
33 changes: 13 additions & 20 deletions common/src/main/scala/com/gu/media/aws/AwsCredentials.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
package com.gu.media.aws

import com.amazonaws.auth._
import com.amazonaws.auth.profile.ProfileCredentialsProvider
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder
import com.gu.media.Settings

case class AwsCredentials(instance: AWSCredentialsProvider, crossAccount: AWSCredentialsProvider,
upload: AWSCredentialsProvider)
case class AwsCredentials(
instance: AwsCredentialsProvidersForBothSdkVersions,
crossAccount: AwsCredentialsProvidersForBothSdkVersions,
upload: AwsCredentialsProvidersForBothSdkVersions
)

object AwsCredentials {
def dev(settings: Settings): AwsCredentials = {
val profile = settings.getMandatoryString("aws.profile")
val instance = new ProfileCredentialsProvider(profile)
val instance = AwsCredentialsProvidersForBothSdkVersions.profile(profile)

// To enable publishing to CAPI code from DEV, update the kinesis streams in config and uncomment below:
// val crossAccount = new ProfileCredentialsProvider("composer")
// val crossAccount = AwsCredentialsProvidersForBothSdkVersions.profile("composer")
val crossAccount = instance

val upload = devUpload(settings)
Expand All @@ -23,38 +23,31 @@ object AwsCredentials {
}

def app(settings: Settings): AwsCredentials = {
val instance = InstanceProfileCredentialsProvider.getInstance()
val instance = AwsCredentialsProvidersForBothSdkVersions.instance()

val crossAccount = assumeCrossAccountRole(instance, settings)

AwsCredentials(instance, crossAccount, upload = instance)
}

def lambda(): AwsCredentials = {
val instance = new EnvironmentVariableCredentialsProvider()
val instance = AwsCredentialsProvidersForBothSdkVersions.environmentVariables()
AwsCredentials(instance, crossAccount = instance, upload = instance)
}

private def devUpload(settings: Settings): AWSCredentialsProvider = {
private def devUpload(settings: Settings): AwsCredentialsProvidersForBothSdkVersions = {
// Only required in dev (because federated credentials such as those from Janus cannot do STS requests).
// Instance profile credentials are sufficient when deployed.
val accessKey = settings.getMandatoryString("aws.upload.accessKey", "This is the AwsId output of the dev cloudformation")
val secretKey = settings.getMandatoryString("aws.upload.secretKey", "This is the AwsSecret output of the dev cloudformation")

new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey))
AwsCredentialsProvidersForBothSdkVersions.static(accessKey, secretKey)
}

private def assumeCrossAccountRole(instance: AWSCredentialsProvider, settings: Settings) = {
private def assumeCrossAccountRole(instance: AwsCredentialsProvidersForBothSdkVersions, settings: Settings): AwsCredentialsProvidersForBothSdkVersions = {
val crossAccountRoleArn = settings.getMandatoryString("aws.kinesis.stsCapiRoleToAssume",
"Role to assume to access CAPI streams (in format arn:aws:iam::<account>:role/<role_name>)")

assumeAccountRole(instance, crossAccountRoleArn, "capi")
}

private def assumeAccountRole(instance: AWSCredentialsProvider, roleArn: String, sessionNameSuffix: String): AWSCredentialsProvider = {
val securityTokens = AWSSecurityTokenServiceClientBuilder.standard().withCredentials(instance).build()

new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, s"media-atom-maker-${sessionNameSuffix}")
.withStsClient(securityTokens).build()
instance.assumeAccountRole(crossAccountRoleArn, "capi", AwsAccess.regionFrom(settings).getName)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.gu.media.aws

import com.amazonaws.auth._
import com.amazonaws.auth.profile.ProfileCredentialsProvider
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder
import com.gu.media.Settings
import com.gu.media.aws.AwsV2Util
import software.amazon.awssdk.auth.{credentials => awsv2}
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.sts.{StsClient, StsClientBuilder}
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest

case class AwsCredentialsProvidersForBothSdkVersions(
v1: com.amazonaws.auth.AWSCredentialsProvider,
v2: software.amazon.awssdk.auth.credentials.AwsCredentialsProvider
) {
def assumeAccountRole(roleArn: String, sessionNameSuffix: String, regionName: String): AwsCredentialsProvidersForBothSdkVersions = {
val roleSessionName = s"media-atom-maker-$sessionNameSuffix"
AwsCredentialsProvidersForBothSdkVersions(
new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, roleSessionName)
.withStsClient(AWSSecurityTokenServiceClientBuilder.standard().withCredentials(v1).build()).build(),
software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider.builder()
.stsClient(AwsV2Util.buildSync[StsClient, StsClientBuilder](StsClient.builder(), v2, Region.of(regionName)))
.refreshRequest(AssumeRoleRequest.builder.roleSessionName(roleSessionName).roleArn(roleArn).build)
.build()
)
}
}

object AwsCredentialsProvidersForBothSdkVersions {
def profile(name: String): AwsCredentialsProvidersForBothSdkVersions = AwsCredentialsProvidersForBothSdkVersions(
new ProfileCredentialsProvider(name),
awsv2.ProfileCredentialsProvider.create(name)
)

def instance(): AwsCredentialsProvidersForBothSdkVersions = AwsCredentialsProvidersForBothSdkVersions(
InstanceProfileCredentialsProvider.getInstance(),
awsv2.InstanceProfileCredentialsProvider.create()
)

def environmentVariables(): AwsCredentialsProvidersForBothSdkVersions = AwsCredentialsProvidersForBothSdkVersions(
new EnvironmentVariableCredentialsProvider(),
awsv2.EnvironmentVariableCredentialsProvider.create()
)

def static(accessKey: String, secretKey: String): AwsCredentialsProvidersForBothSdkVersions = AwsCredentialsProvidersForBothSdkVersions(
new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)),
awsv2.StaticCredentialsProvider.create(awsv2.AwsBasicCredentials.create(accessKey, secretKey))
)
}
15 changes: 15 additions & 0 deletions common/src/main/scala/com/gu/media/aws/AwsV2Util.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.gu.media.aws

import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider
import software.amazon.awssdk.awscore.client.builder.{AwsClientBuilder, AwsSyncClientBuilder}
import software.amazon.awssdk.http.apache.ApacheHttpClient

object AwsV2Util {
def buildSync[T, B <: AwsClientBuilder[B, T] with AwsSyncClientBuilder[B, T]](
builder: B, creds: AwsCredentialsProvider, region: software.amazon.awssdk.regions.Region
): T = builder
.httpClientBuilder(ApacheHttpClient.builder())
.credentialsProvider(creds)
.region(region)
.build()
}
16 changes: 8 additions & 8 deletions common/src/main/scala/com/gu/media/aws/DynamoAccess.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.gu.media.aws

import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder
import com.gu.media.Settings
import com.gu.media.aws.AwsV2Util.buildSync
import org.scanamo.Scanamo
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider
import software.amazon.awssdk.awscore.client.builder.{AwsClientBuilder, AwsSyncClientBuilder}
Expand All @@ -24,15 +25,14 @@ trait DynamoAccess { this: Settings with AwsAccess =>
lazy val plutoCommissionTableName: String = getTableName("pluto-commissions")
lazy val plutoProjectTableName: String = getTableName("pluto-projects")

def buildSync[T, B <: AwsClientBuilder[B, T] with AwsSyncClientBuilder[B, T]](
builder: B, creds: AwsCredentialsProvider = awsV2Credentials
): T = builder
.httpClientBuilder(ApacheHttpClient.builder())
.credentialsProvider(creds)
.region(awsV2Region)
lazy val dynamoDB = AmazonDynamoDBClientBuilder
.standard()
.withCredentials(credsProvider)
.withRegion(region.getName)
.build()

lazy val dynamoDB: DynamoDbClient = buildSync[DynamoDbClient, DynamoDbClientBuilder](DynamoDbClient.builder())
lazy val dynamoDbSdkV2: DynamoDbClient =
buildSync[DynamoDbClient, DynamoDbClientBuilder](DynamoDbClient.builder(), credentials.instance.v2, awsV2Region)

lazy val scanamo: Scanamo = Scanamo(dynamoDB)
lazy val scanamo: Scanamo = Scanamo(dynamoDbSdkV2)
}
4 changes: 2 additions & 2 deletions common/src/main/scala/com/gu/media/aws/KinesisAccess.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ trait KinesisAccess { this: Settings with AwsAccess with Logging =>
val syncWithPluto: Boolean = getBoolean("pluto.sync").getOrElse(false)

lazy val crossAccountKinesisClient = AmazonKinesisClientBuilder.standard()
.withCredentials(credentials.crossAccount)
.withCredentials(credentials.crossAccount.v1)
.withRegion(region.getName)
.build()

lazy val kinesisClient = AmazonKinesisClientBuilder.standard()
.withCredentials(credentials.instance)
.withCredentials(credentials.instance.v1)
.withRegion(region.getName)
.build()

Expand Down
2 changes: 1 addition & 1 deletion common/src/main/scala/com/gu/media/aws/SNSAccess.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ trait SNSAccess { this: Settings with AwsAccess =>
lazy val snsClient =
AmazonSNSClientBuilder.standard()
.withRegion(Regions.fromName(region.getName))
.withCredentials(credentials.instance)
.withCredentials(credentials.instance.v1)
.build()
}
2 changes: 1 addition & 1 deletion common/src/main/scala/com/gu/media/aws/UploadAccess.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ trait UploadAccess { this: Settings with AwsAccess =>
throw new IllegalArgumentException("aws.upload.role must be in ARN format: arn:aws:iam::<account>:role/<role_name>")
}

AWSSecurityTokenServiceClientBuilder.standard().withCredentials(credentials.upload).withRegion(region.getName).build()
AWSSecurityTokenServiceClientBuilder.standard().withCredentials(credentials.upload.v1).withRegion(region.getName).build()
}

private def getPipelineArn() = {
Expand Down
6 changes: 3 additions & 3 deletions common/src/main/scala/com/gu/media/lambda/LambdaBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import com.gu.media.aws.{AwsAccess, AwsCredentials, HMACSettings}
import com.typesafe.config.{Config, ConfigFactory}

trait LambdaBase extends Settings with AwsAccess with HMACSettings {
final override def regionName = sys.env.get("REGION")
final override def region = AwsAccess.regionFrom(sys.env.get("REGION"))
final override def readTag(tag: String) = sys.env.get(tag.toUpperCase(Locale.ENGLISH))

final override val credentials = AwsCredentials.lambda()
final override val credentials: AwsCredentials = AwsCredentials.lambda()

private val remoteConfig = downloadConfig()
private val mergedConfig = remoteConfig.withFallback(ConfigFactory.load())
Expand All @@ -24,7 +24,7 @@ trait LambdaBase extends Settings with AwsAccess with HMACSettings {
case (Some(bucket), Some(key)) =>
val defaultRegionS3 = AmazonS3ClientBuilder
.standard()
.withCredentials(credentials.instance)
.withCredentials(credentials.instance.v1)
.withRegion(region.getName)
.build()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ trait LambdaYoutubeCredentials { self: AwsAccess =>
case (Some(bucket), Some(key)) =>
val defaultRegionS3 = AmazonS3ClientBuilder
.standard()
.withCredentials(credentials.instance)
.withCredentials(credentials.instance.v1)
.withRegion(region.getName)
.build()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MediaAtomHelpersTest extends FunSuite with MustMatchers {
}

private def assets(atom: Atom): Seq[Asset] = {
atom.data.asInstanceOf[AtomData.Media].media.assets
atom.data.asInstanceOf[AtomData.Media].media.assets.toSeq
}

private def asset(): Asset = Asset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import org.scanamo.generic.auto._

class AddUploadDataToCache extends LambdaWithParams[Upload, Upload] with DynamoAccess with UploadAccess {
private val table = Table[Upload](this.cacheTableName)
private val scanamo = Scanamo(this.dynamoDB)

override def handle(input: Upload): Upload = {
scanamo.exec(table.put(input))
Expand Down

0 comments on commit 7e5bb61

Please sign in to comment.