Skip to content

Commit

Permalink
Merge branch 'develop_aws' into develop_aws
Browse files Browse the repository at this point in the history
  • Loading branch information
henriqueribeiro authored Dec 22, 2023
2 parents bb9984b + 465d8f5 commit 7647b77
Show file tree
Hide file tree
Showing 16 changed files with 617 additions and 171 deletions.
1 change: 1 addition & 0 deletions backend/src/main/scala/cromwell/backend/backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ object CommonBackendConfigurationAttributes {
"default-runtime-attributes.queueArn",
"default-runtime-attributes.awsBatchRetryAttempts",
"default-runtime-attributes.maxRetries",
"default-runtime-attributes.awsBatchEvaluateOnExit",
"default-runtime-attributes.ulimits",
"default-runtime-attributes.efsDelocalize",
"default-runtime-attributes.efsMakeMD5",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
package cromwell.docker.registryv2.flows.aws

import cats.effect.{IO, Resource}
import cromwell.core.TestKitSuite
import cromwell.docker.registryv2.DockerRegistryV2Abstract
Expand All @@ -12,20 +11,19 @@ import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import software.amazon.awssdk.services.ecrpublic.model.{AuthorizationData, GetAuthorizationTokenRequest, GetAuthorizationTokenResponse}
import software.amazon.awssdk.services.ecrpublic.EcrPublicClient

class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with MockitoSugar with BeforeAndAfter with PrivateMethodTester {
class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with BeforeAndAfter with PrivateMethodTester {
behavior of "AmazonEcrPublic"

val goodUri = "public.ecr.aws/amazonlinux/amazonlinux:latest"
val otherUri = "ubuntu:latest"


val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.ManifestV2MediaType).right.get
val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.DockerManifestV2MediaType).getOrElse(fail("Cant parse mediatype"))
val contentType: Header = `Content-Type`(mediaType)
val mockEcrClient: EcrPublicClient = mock[EcrPublicClient]
val mockEcrClient: EcrPublicClient = mock(classOf[EcrPublicClient])
implicit val mockIOClient: Client[IO] = Client({ _: Request[IO] =>
// This response will have an empty body, so we need to be explicit about the typing:
Resource.pure[IO, Response[IO]](Response(headers = Headers.of(contentType))) : Resource[IO, Response[IO]]
Expand All @@ -44,7 +42,7 @@ class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matcher
}

it should "have public.ecr.aws as registryHostName" in {
val registryHostNameMethod = PrivateMethod[String]('registryHostName)
val registryHostNameMethod = PrivateMethod[String](Symbol("registryHostName"))
registry invokePrivate registryHostNameMethod(DockerImageIdentifier.fromString(goodUri).get) shouldEqual "public.ecr.aws"
}

Expand All @@ -63,7 +61,7 @@ class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matcher
.build())
.build)

val getTokenMethod = PrivateMethod[IO[Option[String]]]('getToken)
val getTokenMethod = PrivateMethod[IO[Option[String]]](Symbol("getToken"))
registry invokePrivate getTokenMethod(context, mockIOClient) ensuring(io => io.unsafeRunSync().get == token)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@ import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import software.amazon.awssdk.services.ecr.EcrClient
import software.amazon.awssdk.services.ecr.model.{AuthorizationData, GetAuthorizationTokenResponse}

class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with MockitoSugar with BeforeAndAfter with PrivateMethodTester{
class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with BeforeAndAfter with PrivateMethodTester{
behavior of "AmazonEcr"

val goodUri = "123456789012.dkr.ecr.us-east-1.amazonaws.com/amazonlinux/amazonlinux:latest"
val otherUri = "ubuntu:latest"

val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.ManifestV2MediaType).right.get
val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.DockerManifestV2MediaType).getOrElse(fail("Can't parse media type"))
val contentType: Header = `Content-Type`(mediaType)
val mockEcrClient: EcrClient = mock[EcrClient]
val mockEcrClient: EcrClient = mock(classOf[EcrClient])
implicit val mockIOClient: Client[IO] = Client({ _: Request[IO] =>
// This response will have an empty body, so we need to be explicit about the typing:
Resource.pure[IO, Response[IO]](Response(headers = Headers.of(contentType))) : Resource[IO, Response[IO]]
Expand All @@ -42,12 +41,12 @@ class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with
}

it should "use Basic Auth Scheme" in {
val authSchemeMethod = PrivateMethod[AuthScheme]('authorizationScheme)
val authSchemeMethod = PrivateMethod[AuthScheme](Symbol("authorizationScheme"))
registry invokePrivate authSchemeMethod() shouldEqual AuthScheme.Basic
}

it should "return 123456789012.dkr.ecr.us-east-1.amazonaws.com as registryHostName" in {
val registryHostNameMethod = PrivateMethod[String]('registryHostName)
val registryHostNameMethod = PrivateMethod[String](Symbol("registryHostName"))
registry invokePrivate registryHostNameMethod(DockerImageIdentifier.fromString(goodUri).get) shouldEqual "123456789012.dkr.ecr.us-east-1.amazonaws.com"
}

Expand All @@ -66,7 +65,7 @@ class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with
.build())
.build)

val getTokenMethod = PrivateMethod[IO[Option[String]]]('getToken)
val getTokenMethod = PrivateMethod[IO[Option[String]]](Symbol("getToken"))
registry invokePrivate getTokenMethod(context, mockIOClient) ensuring(io => io.unsafeRunSync().get == token)
}
}
5 changes: 1 addition & 4 deletions filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ public Bucket getBucket() {
}

private Bucket getBucket(String bucketName) {
for (Bucket buck : getClient().listBuckets().buckets())
if (buck.name().equals(bucketName))
return buck;
return null;
return Bucket.builder().name(bucketName).build();
}

private boolean hasBucket(String bucketName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import software.amazon.awssdk.services.sns.model.PublishRequest
import spray.json.enrichAny

import scala.concurrent.{ExecutionContextExecutor, Future}
import scala.util.{Failure, Success}
import scala.util.{Failure, Success, Try}


/**
Expand All @@ -61,6 +61,10 @@ class AwsSnsMetadataServiceActor(serviceConfig: Config, globalConfig: Config, se

//setup sns client
val topicArn: String = serviceConfig.getString("aws.topicArn")
val publishStatusOnly: Boolean = Try(serviceConfig.getBoolean("aws.publishStatusOnly")) match {
case Failure(_) => false
case Success(value) => value
}

val awsConfig: AwsConfiguration = AwsConfiguration(globalConfig)
val credentialsProviderChain: AwsCredentialsProviderChain =
Expand All @@ -74,7 +78,11 @@ class AwsSnsMetadataServiceActor(serviceConfig: Config, globalConfig: Config, se
def publishMessages(events: Iterable[MetadataEvent]): Future[Unit] = {
import AwsSnsMetadataServiceActor.EnhancedMetadataEvents

val eventsJson = events.toJson
val eventsJson = if (publishStatusOnly) {
events.filter(_.key.key == "status").toJson
} else {
events.toJson
}
//if there are no events then don't publish anything
if( eventsJson.length < 1) { return Future(())}
log.debug(f"Publishing to $topicArn : $eventsJson")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ object AwsBatchAttributes {
"numSubmitAttempts",
"default-runtime-attributes.scriptBucketName",
"awsBatchRetryAttempts",
"awsBatchEvaluateOnExit",
"ulimits",
"efsDelocalize",
"efsMakeMD5",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
*/
lazy val reconfiguredScript: String = {
//this is the location of the aws cli mounted into the container by the ec2 launch template
val awsCmd = "/usr/local/aws-cli/v2/current/bin/aws "
val awsCmd = "/usr/local/aws-cli/v2/current/bin/aws"
//internal to the container, therefore not mounted
val workDir = "/tmp/scratch"
//working in a mount will cause collisions in long running workers
val replaced = commandScript.replace(AwsBatchWorkingDisk.MountPoint.pathAsString, workDir)
val insertionPoint = replaced.indexOf("\n", replaced.indexOf("#!")) +1 //just after the new line after the shebang!
// load the config
val conf : Config = ConfigFactory.load();

/* generate a series of s3 copy statements to copy any s3 files into the container. */
val inputCopyCommand = inputs.map {
case input: AwsBatchFileInput if input.s3key.startsWith("s3://") && input.s3key.endsWith(".tmp") =>
Expand All @@ -136,14 +136,11 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
|sed -i 's#${AwsBatchWorkingDisk.MountPoint.pathAsString}#$workDir#g' "$workDir/${input.local}"
|""".stripMargin


case input: AwsBatchFileInput if input.s3key.startsWith("s3://") =>
// regular s3 objects : download to working dir.
s"""_s3_localize_with_retry "${input.s3key}" "${input.mount.mountPoint.pathAsString}/${input.local}" """.stripMargin
.replace(AwsBatchWorkingDisk.MountPoint.pathAsString, workDir)



case input: AwsBatchFileInput if efsMntPoint.isDefined && input.s3key.startsWith(efsMntPoint.get) =>
// EFS located file : test for presence on provided path.
Log.debug("EFS input file detected: "+ input.s3key + " / "+ input.local.pathAsString)
Expand Down Expand Up @@ -201,7 +198,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
| LOCALIZATION_FAILED=1
| break
| fi
| # copy
| # copy
| $awsCmd s3 cp --no-progress "$$s3_path" "$$destination" ||
| { echo "attempt $$i to copy $$s3_path failed" && sleep $$((7 * "$$i")) && continue; }
| # check data integrity
Expand All @@ -228,7 +225,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
| # get the multipart chunk size
| chunk_size=$$(_get_multipart_chunk_size "$$local_path")
| local MP_THRESHOLD=${mp_threshold}
| # then set them
| # then set them
| $awsCmd configure set default.s3.multipart_threshold $$MP_THRESHOLD
| $awsCmd configure set default.s3.multipart_chunksize $$chunk_size
|
Expand All @@ -249,7 +246,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
| fi
| # copy ok or try again.
| if [[ -d "$$local_path" ]]; then
| # make sure to strip the trailing / in destination
| # make sure to strip the trailing / in destination
| destination=$${destination%/}
| # glob directory. do recursive copy
| $awsCmd s3 cp --no-progress "$$local_path" "$$destination" --recursive --exclude "cromwell_glob_control_file" ||
Expand Down Expand Up @@ -279,7 +276,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
| # file size
| file_size=$$(stat --printf="%s" "$$file_path")
| # chunk_size : you can have at most 10K parts with at least one 5MB part
| # this reflects the formula in s3-copy commands of cromwell (S3FileSystemProvider.java)
| # this reflects the formula in s3-copy commands of cromwell (S3FileSystemProvider.java)
| # => long partSize = Math.max((objectSize / 10000L) + 1, 5 * 1024 * 1024);
| a=$$(( ( file_size / 10000) + 1 ))
| b=$$(( 5 * 1024 * 1024 ))
Expand All @@ -300,16 +297,16 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
| echo "$$s3_path is not an S3 path with a bucket and key."
| exit 1
| fi
| s3_content_length=$$($awsCmd s3api head-object --bucket "$$bucket" --key "$$key" --query 'ContentLength') ||
| { echo "Attempt to get head of object failed for $$s3_path." && return 1 ; }
| s3_content_length=$$($awsCmd s3api head-object --bucket "$$bucket" --key "$$key" --query 'ContentLength') ||
| { echo "Attempt to get head of object failed for $$s3_path." && return 1; }
| # local
| local_content_length=$$(LC_ALL=C ls -dnL -- "$$local_path" | awk '{print $$5; exit}' ) ||
| { echo "Attempt to get local content length failed for $$_local_path." && return 1; }
| local_content_length=$$(LC_ALL=C ls -dnL -- "$$local_path" | awk '{print $$5; exit}' ) ||
| { echo "Attempt to get local content length failed for $$_local_path." && return 1; }
| # compare
| if [[ "$$s3_content_length" -eq "$$local_content_length" ]]; then
| true
| else
| false
| false
| fi
|}
|
Expand Down Expand Up @@ -457,8 +454,8 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
s"""_s3_delocalize_with_retry "$workDir/${output.local.pathAsString}" "${output.s3key}" """.stripMargin

// file(name (full path), s3key (delocalized path), local (file basename), mount (disk details))
// files on EFS mounts are optionally delocalized.
case output: AwsBatchFileOutput if efsMntPoint.isDefined && output.mount.mountPoint.pathAsString == efsMntPoint.get =>
// files on EFS mounts are optionally delocalized.
case output: AwsBatchFileOutput if efsMntPoint.isDefined && output.mount.mountPoint.pathAsString == efsMntPoint.get =>
Log.debug("EFS output file detected: "+ output.s3key + s" / ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}")
// EFS located file : test existence or delocalize.
var test_cmd = ""
Expand All @@ -470,24 +467,24 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
// check file for existence
test_cmd = s"""test -e "${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}" || (echo 'output file: ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString} does not exist' && DELOCALIZATION_FAILED=1)""".stripMargin
}
// need to make md5sum?
// need to make md5sum?
var md5_cmd = ""
if (efsMakeMD5.isDefined && efsMakeMD5.getOrElse(false)) {
Log.debug("Add cmd to create MD5 sibling.")
md5_cmd = s"""
|if [[ ! -f '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' ]] ; then
| md5sum '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}' > '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' || (echo 'Could not generate ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' && DELOCALIZATION_FAILED=1 );
|fi
|""".stripMargin
|""".stripMargin
} else {
md5_cmd = ""
}
}
// return combined result
s"""
|${test_cmd}
|${md5_cmd}
| """.stripMargin

case output: AwsBatchFileOutput =>
//output on a different mount
Log.debug("output data on other mount")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ package cromwell.backend.impl.aws
import scala.collection.mutable.ListBuffer
import cromwell.backend.BackendJobDescriptor
import cromwell.backend.io.JobPaths
import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, RetryStrategy, Ulimit, Volume}
import software.amazon.awssdk.services.batch.model.{ContainerProperties, EvaluateOnExit, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, RetryAction, RetryStrategy, Ulimit, Volume}
import cromwell.backend.impl.aws.io.AwsBatchVolume

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -198,9 +198,30 @@ trait AwsBatchJobDefinitionBuilder {

def retryStrategyBuilder(context: AwsBatchJobDefinitionContext): (RetryStrategy.Builder, String) = {
// We can add here the 'evaluateOnExit' statement
(RetryStrategy.builder()
.attempts(context.runtimeAttributes.awsBatchRetryAttempts),
context.runtimeAttributes.awsBatchRetryAttempts.toString)
var builder = RetryStrategy.builder()
.attempts(context.runtimeAttributes.awsBatchRetryAttempts)

var evaluations: Seq[EvaluateOnExit] = Seq()
context.runtimeAttributes.awsBatchEvaluateOnExit.foreach(
(evaluate) => {
val evaluateBuilder = evaluate.foldLeft(EvaluateOnExit.builder()) {
case (acc, (k, v)) => (k.toLowerCase, v.toLowerCase) match {
case ("action", "retry") => acc.action(RetryAction.RETRY)
case ("action", "exit") => acc.action(RetryAction.EXIT)
case ("onexitcode", _) => acc.onExitCode(v)
case ("onreason", _) => acc.onReason(v)
case ("onstatusreason", _) => acc.onStatusReason(v)
case _ => acc
}
}
evaluations = evaluations :+ evaluateBuilder.build()
}
)

builder = builder.evaluateOnExit(evaluations.asJava)

(builder,
s"${context.runtimeAttributes.awsBatchRetryAttempts.toString}${context.runtimeAttributes.awsBatchEvaluateOnExit.toString}")
}


Expand Down
Loading

0 comments on commit 7647b77

Please sign in to comment.