From 3baa86912acc0d5c7d31168485b6e4643b95f8fb Mon Sep 17 00:00:00 2001 From: xquek <46759870+xquek@users.noreply.github.com> Date: Mon, 16 Oct 2023 08:41:31 -0700 Subject: [PATCH 1/4] Options to publish status only (#36) * add options to publish status only * updated readme.md --------- Co-authored-by: quekx --- .../metadata/impl/aws/SnsMetadataServiceActor.scala | 12 ++++++++++-- .../main/scala/cromwell/backend/impl/aws/README.md | 4 +++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/services/src/main/scala/cromwell/services/metadata/impl/aws/SnsMetadataServiceActor.scala b/services/src/main/scala/cromwell/services/metadata/impl/aws/SnsMetadataServiceActor.scala index da66ae4f7f4..4fc9f43328e 100644 --- a/services/src/main/scala/cromwell/services/metadata/impl/aws/SnsMetadataServiceActor.scala +++ b/services/src/main/scala/cromwell/services/metadata/impl/aws/SnsMetadataServiceActor.scala @@ -46,7 +46,7 @@ import software.amazon.awssdk.services.sns.model.PublishRequest import spray.json.enrichAny import scala.concurrent.{ExecutionContextExecutor, Future} -import scala.util.{Failure, Success} +import scala.util.{Failure, Success, Try} /** @@ -61,6 +61,10 @@ class AwsSnsMetadataServiceActor(serviceConfig: Config, globalConfig: Config, se //setup sns client val topicArn: String = serviceConfig.getString("aws.topicArn") + val publishStatusOnly: Boolean = Try(serviceConfig.getBoolean("aws.publishStatusOnly")) match { + case Failure(_) => false + case Success(value) => value + } val awsConfig: AwsConfiguration = AwsConfiguration(globalConfig) val credentialsProviderChain: AwsCredentialsProviderChain = @@ -74,7 +78,11 @@ class AwsSnsMetadataServiceActor(serviceConfig: Config, globalConfig: Config, se def publishMessages(events: Iterable[MetadataEvent]): Future[Unit] = { import AwsSnsMetadataServiceActor.EnhancedMetadataEvents - val eventsJson = events.toJson + val eventsJson = if (publishStatusOnly) { + events.filter(_.key.key == "status").toJson + } else { + events.toJson + } //if there are no events then don't publish anything if( eventsJson.length < 1) { return Future(())} log.debug(f"Publishing to $topicArn : $eventsJson") diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md index 613607da42a..48bc3adc442 100644 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md @@ -227,6 +227,7 @@ In the AWS backend those notifications can be send to **SNS Topic** or **EventBr #### AWS SNS 1. Create an SNS topic, add the following to your `cromwell.conf` file and replace `topicArn` with the topic's ARN you just created: +2. By default, all cromwell events will be publish to sns. Set `publishStatusOnly = true` if you only publish events that are `status` updates. ``` services { @@ -241,12 +242,13 @@ services { }] region = "us-east-1" topicArn = "" + publishStatusOnly = true } } } } ``` -2. Add `sns:Publish` IAM policy to your Cromwell server IAM role. +3. Add `sns:Publish` IAM policy to your Cromwell server IAM role. #### AWS EventBridge From 7bc78bde3b940e00a8cd161f8582bdc8274872e9 Mon Sep 17 00:00:00 2001 From: xquek <46759870+xquek@users.noreply.github.com> Date: Mon, 30 Oct 2023 05:11:17 -0700 Subject: [PATCH 2/4] Fix aws unit tests (#39) * checkpoint * fix ecr and batch tests * fix AwsBatchJobSpec.scala --------- Co-authored-by: quekx --- .../flows/aws/AmazonEcrPublicSpec.scala | 12 +- .../registryv2/flows/aws/AmazonEcrSpec.scala | 13 +- .../backend/impl/aws/AwsBatchJob.scala | 65 +++-- ...tchAsyncBackendJobExecutionActorSpec.scala | 3 +- .../impl/aws/AwsBatchCallPathsSpec.scala | 3 +- .../backend/impl/aws/AwsBatchJobSpec.scala | 261 ++++++++++-------- .../aws/AwsBatchRuntimeAttributesSpec.scala | 7 +- .../impl/aws/AwsBatchWorkflowPathsSpec.scala | 1 + 8 files changed, 203 insertions(+), 162 deletions(-) diff --git a/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublicSpec.scala b/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublicSpec.scala index 18a0a23232e..f6f1520512f 100644 --- a/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublicSpec.scala +++ b/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublicSpec.scala @@ -1,5 +1,4 @@ package cromwell.docker.registryv2.flows.aws - import cats.effect.{IO, Resource} import cromwell.core.TestKitSuite import cromwell.docker.registryv2.DockerRegistryV2Abstract @@ -12,20 +11,19 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers -import org.scalatestplus.mockito.MockitoSugar import software.amazon.awssdk.services.ecrpublic.model.{AuthorizationData, GetAuthorizationTokenRequest, GetAuthorizationTokenResponse} import software.amazon.awssdk.services.ecrpublic.EcrPublicClient -class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with MockitoSugar with BeforeAndAfter with PrivateMethodTester { +class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with BeforeAndAfter with PrivateMethodTester { behavior of "AmazonEcrPublic" val goodUri = "public.ecr.aws/amazonlinux/amazonlinux:latest" val otherUri = "ubuntu:latest" - val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.ManifestV2MediaType).right.get + val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.DockerManifestV2MediaType).getOrElse(fail("Cant parse mediatype")) val contentType: Header = `Content-Type`(mediaType) - val mockEcrClient: EcrPublicClient = mock[EcrPublicClient] + val mockEcrClient: EcrPublicClient = mock(classOf[EcrPublicClient]) implicit val mockIOClient: Client[IO] = Client({ _: Request[IO] => // This response will have an empty body, so we need to be explicit about the typing: Resource.pure[IO, Response[IO]](Response(headers = Headers.of(contentType))) : Resource[IO, Response[IO]] @@ -44,7 +42,7 @@ class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matcher } it should "have public.ecr.aws as registryHostName" in { - val registryHostNameMethod = PrivateMethod[String]('registryHostName) + val registryHostNameMethod = PrivateMethod[String](Symbol("registryHostName")) registry invokePrivate registryHostNameMethod(DockerImageIdentifier.fromString(goodUri).get) shouldEqual "public.ecr.aws" } @@ -63,7 +61,7 @@ class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matcher .build()) .build) - val getTokenMethod = PrivateMethod[IO[Option[String]]]('getToken) + val getTokenMethod = PrivateMethod[IO[Option[String]]](Symbol("getToken")) registry invokePrivate getTokenMethod(context, mockIOClient) ensuring(io => io.unsafeRunSync().get == token) } } diff --git a/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrSpec.scala b/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrSpec.scala index 5ddf98c7ffa..c4f121188db 100644 --- a/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrSpec.scala +++ b/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrSpec.scala @@ -11,19 +11,18 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers -import org.scalatestplus.mockito.MockitoSugar import software.amazon.awssdk.services.ecr.EcrClient import software.amazon.awssdk.services.ecr.model.{AuthorizationData, GetAuthorizationTokenResponse} -class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with MockitoSugar with BeforeAndAfter with PrivateMethodTester{ +class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with BeforeAndAfter with PrivateMethodTester{ behavior of "AmazonEcr" val goodUri = "123456789012.dkr.ecr.us-east-1.amazonaws.com/amazonlinux/amazonlinux:latest" val otherUri = "ubuntu:latest" - val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.ManifestV2MediaType).right.get + val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.DockerManifestV2MediaType).getOrElse(fail("Can't parse media type")) val contentType: Header = `Content-Type`(mediaType) - val mockEcrClient: EcrClient = mock[EcrClient] + val mockEcrClient: EcrClient = mock(classOf[EcrClient]) implicit val mockIOClient: Client[IO] = Client({ _: Request[IO] => // This response will have an empty body, so we need to be explicit about the typing: Resource.pure[IO, Response[IO]](Response(headers = Headers.of(contentType))) : Resource[IO, Response[IO]] @@ -42,12 +41,12 @@ class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with } it should "use Basic Auth Scheme" in { - val authSchemeMethod = PrivateMethod[AuthScheme]('authorizationScheme) + val authSchemeMethod = PrivateMethod[AuthScheme](Symbol("authorizationScheme")) registry invokePrivate authSchemeMethod() shouldEqual AuthScheme.Basic } it should "return 123456789012.dkr.ecr.us-east-1.amazonaws.com as registryHostName" in { - val registryHostNameMethod = PrivateMethod[String]('registryHostName) + val registryHostNameMethod = PrivateMethod[String](Symbol("registryHostName")) registry invokePrivate registryHostNameMethod(DockerImageIdentifier.fromString(goodUri).get) shouldEqual "123456789012.dkr.ecr.us-east-1.amazonaws.com" } @@ -66,7 +65,7 @@ class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with .build()) .build) - val getTokenMethod = PrivateMethod[IO[Option[String]]]('getToken) + val getTokenMethod = PrivateMethod[IO[Option[String]]](Symbol("getToken")) registry invokePrivate getTokenMethod(context, mockIOClient) ensuring(io => io.unsafeRunSync().get == token) } } diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala index 97302357175..d967e26659a 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala @@ -117,7 +117,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL */ lazy val reconfiguredScript: String = { //this is the location of the aws cli mounted into the container by the ec2 launch template - val awsCmd = "/usr/local/aws-cli/v2/current/bin/aws " + val awsCmd = "/usr/local/aws-cli/v2/current/bin/aws" //internal to the container, therefore not mounted val workDir = "/tmp/scratch" //working in a mount will cause collisions in long running workers @@ -125,7 +125,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL val insertionPoint = replaced.indexOf("\n", replaced.indexOf("#!")) +1 //just after the new line after the shebang! // load the config val conf : Config = ConfigFactory.load(); - + /* generate a series of s3 copy statements to copy any s3 files into the container. */ val inputCopyCommand = inputs.map { case input: AwsBatchFileInput if input.s3key.startsWith("s3://") && input.s3key.endsWith(".tmp") => @@ -135,14 +135,11 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL |sed -i 's#${AwsBatchWorkingDisk.MountPoint.pathAsString}#$workDir#g' $workDir/${input.local} |""".stripMargin - case input: AwsBatchFileInput if input.s3key.startsWith("s3://") => // regular s3 objects : download to working dir. s"_s3_localize_with_retry ${input.s3key} ${input.mount.mountPoint.pathAsString}/${input.local}" .replace(AwsBatchWorkingDisk.MountPoint.pathAsString, workDir) - - case input: AwsBatchFileInput if efsMntPoint.isDefined && input.s3key.startsWith(efsMntPoint.get) => // EFS located file : test for presence on provided path. Log.debug("EFS input file detected: "+ input.s3key + " / "+ input.local.pathAsString) @@ -187,11 +184,11 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL | echo "$$s3_path is not an S3 path with a bucket and key. aborting" | exit 1 | fi - | # copy + | # copy | $awsCmd s3 cp --no-progress "$$s3_path" "$$destination" || | { echo "attempt $$i to copy $$s3_path failed" && sleep $$((7 * "$$i")) && continue; } | # check data integrity - | _check_data_integrity $$destination $$s3_path || + | _check_data_integrity $$destination $$s3_path || | { echo "data content length difference detected in attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue; } | # copy succeeded | break @@ -207,7 +204,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL | # get the multipart chunk size | chunk_size=$$(_get_multipart_chunk_size $$local_path) | local MP_THRESHOLD=${mp_threshold} - | # then set them + | # then set them | $awsCmd configure set default.s3.multipart_threshold $$MP_THRESHOLD | $awsCmd configure set default.s3.multipart_chunksize $$chunk_size | @@ -220,27 +217,27 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL | exit 2 | fi | # if destination is not a bucket : abort - | if ! [[ $$destination =~ s3://([^/]+)/(.+) ]]; then + | if ! [[ $$destination =~ s3://([^/]+)/(.+) ]]; then | echo "$$destination is not an S3 path with a bucket and key. aborting" | exit 1 | fi | # copy ok or try again. | if [[ -d "$$local_path" ]]; then - | # make sure to strip the trailing / in destination + | # make sure to strip the trailing / in destination | destination=$${destination%/} | # glob directory. do recursive copy - | $awsCmd s3 cp --no-progress $$local_path $$destination --recursive --exclude "cromwell_glob_control_file" || - | { echo "attempt $$i to copy globDir $$local_path failed" && sleep $$((7 * "$$i")) && continue; } + | $awsCmd s3 cp --no-progress $$local_path $$destination --recursive --exclude "cromwell_glob_control_file" || + | { echo "attempt $$i to copy globDir $$local_path failed" && sleep $$((7 * "$$i")) && continue; } | # check integrity for each of the files | for FILE in $$(cd $$local_path ; ls | grep -v cromwell_glob_control_file); do - | _check_data_integrity $$local_path/$$FILE $$destination/$$FILE || + | _check_data_integrity $$local_path/$$FILE $$destination/$$FILE || | { echo "data content length difference detected in attempt $$i to copy $$local_path/$$FILE failed" && sleep $$((7 * "$$i")) && continue 2; } | done - | else - | $awsCmd s3 cp --no-progress "$$local_path" "$$destination" || - | { echo "attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue; } + | else + | $awsCmd s3 cp --no-progress "$$local_path" "$$destination" || + | { echo "attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue; } | # check content length for data integrity - | _check_data_integrity $$local_path $$destination || + | _check_data_integrity $$local_path $$destination || | { echo "data content length difference detected in attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue; } | fi | # copy succeeded @@ -251,9 +248,9 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL |function _get_multipart_chunk_size() { | local file_path=$$1 | # file size - | file_size=$$(stat --printf="%s" $$file_path) + | file_size=$$(stat --printf="%s" $$file_path) | # chunk_size : you can have at most 10K parts with at least one 5MB part - | # this reflects the formula in s3-copy commands of cromwell (S3FileSystemProvider.java) + | # this reflects the formula in s3-copy commands of cromwell (S3FileSystemProvider.java) | # => long partSize = Math.max((objectSize / 10000L) + 1, 5 * 1024 * 1024); | a=$$(( ( file_size / 10000) + 1 )) | b=$$(( 5 * 1024 * 1024 )) @@ -264,9 +261,9 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL |function _check_data_integrity() { | local local_path=$$1 | local s3_path=$$2 - | + | | # remote : use content_length - | if [[ $$s3_path =~ s3://([^/]+)/(.+) ]]; then + | if [[ $$s3_path =~ s3://([^/]+)/(.+) ]]; then | bucket="$${BASH_REMATCH[1]}" | key="$${BASH_REMATCH[2]}" | else @@ -274,16 +271,16 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL | echo "$$s3_path is not an S3 path with a bucket and key. aborting" | exit 1 | fi - | s3_content_length=$$($awsCmd s3api head-object --bucket "$$bucket" --key "$$key" --query 'ContentLength') || - | { echo "Attempt to get head of object failed for $$s3_path." && return 1 ; } + | s3_content_length=$$($awsCmd s3api head-object --bucket "$$bucket" --key "$$key" --query 'ContentLength') || + | { echo "Attempt to get head of object failed for $$s3_path." && return 1; } | # local - | local_content_length=$$(LC_ALL=C ls -dnL -- "$$local_path" | awk '{print $$5; exit}' ) || - | { echo "Attempt to get local content length failed for $$_local_path." && return 1; } + | local_content_length=$$(LC_ALL=C ls -dnL -- "$$local_path" | awk '{print $$5; exit}' ) || + | { echo "Attempt to get local content length failed for $$_local_path." && return 1; } | # compare | if [[ "$$s3_content_length" -eq "$$local_content_length" ]]; then | true | else - | false + | false | fi |} | @@ -375,8 +372,8 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL s"""_s3_delocalize_with_retry $workDir/${output.local.pathAsString} ${output.s3key}""".stripMargin // file(name (full path), s3key (delocalized path), local (file basename), mount (disk details)) - // files on EFS mounts are optionally delocalized. - case output: AwsBatchFileOutput if efsMntPoint.isDefined && output.mount.mountPoint.pathAsString == efsMntPoint.get => + // files on EFS mounts are optionally delocalized. + case output: AwsBatchFileOutput if efsMntPoint.isDefined && output.mount.mountPoint.pathAsString == efsMntPoint.get => Log.debug("EFS output file detected: "+ output.s3key + s" / ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}") // EFS located file : test existence or delocalize. var test_cmd = "" @@ -388,25 +385,25 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL // check file for existence test_cmd = s"test -e ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString} || (echo 'output file: ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString} does not exist' && exit 1)" } - // need to make md5sum? + // need to make md5sum? var md5_cmd = "" if (efsMakeMD5.isDefined && efsMakeMD5.getOrElse(false)) { Log.debug("Add cmd to create MD5 sibling.") md5_cmd = s""" - |if [[ ! -f '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' ]] ; then - | md5sum '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}' > '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' || (echo 'Could not generate ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' && exit 1 ); + |if [[ ! -f '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' ]] ; then + | md5sum '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}' > '${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' || (echo 'Could not generate ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString}.md5' && exit 1 ); |fi - |""".stripMargin + |""".stripMargin } else { Log.debug("MD5 not enabled: "+efsMakeMD5.get.toString()) md5_cmd = "" - } + } // return combined result s""" |${test_cmd} |${md5_cmd} | """.stripMargin - + case output: AwsBatchFileOutput => //output on a different mount Log.debug("output data on other mount") diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchAsyncBackendJobExecutionActorSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchAsyncBackendJobExecutionActorSpec.scala index ba2e5580b3f..98af4f00282 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchAsyncBackendJobExecutionActorSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchAsyncBackendJobExecutionActorSpec.scala @@ -32,7 +32,6 @@ package cromwell.backend.impl.aws import java.util.UUID - import akka.actor.{ActorRef, Props} import akka.testkit.{ImplicitSender, TestActorRef, TestDuration} import common.collections.EnhancedCollections._ @@ -41,7 +40,7 @@ import cromwell.backend._ import cromwell.backend.async.{ExecutionHandle, FailedNonRetryableExecutionHandle, PendingExecutionHandle} import cromwell.backend.impl.aws.AwsBatchAsyncBackendJobExecutionActor.AwsBatchPendingExecutionHandle import cromwell.backend.impl.aws.RunStatus.UnsuccessfulRunStatus -import cromwell.backend.impl.aws.io.AwsBatchWorkingDisk +import cromwell.backend.impl.aws.io.{AwsBatchWorkflowPaths, AwsBatchWorkingDisk} import cromwell.backend.io.JobPathsSpecHelper._ import cromwell.backend.standard.{DefaultStandardAsyncExecutionActorParams, StandardAsyncExecutionActorParams, StandardAsyncJob, StandardExpressionFunctions, StandardExpressionFunctionsParams} import cromwell.cloudsupport.aws.s3.S3Storage diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchCallPathsSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchCallPathsSpec.scala index b250ab12ed4..131114c8921 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchCallPathsSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchCallPathsSpec.scala @@ -33,7 +33,8 @@ package cromwell.backend.impl.aws import common.collections.EnhancedCollections._ import cromwell.backend.BackendSpec -import cromwell.backend.io.JobPathsSpecHelper._ +import cromwell.backend.impl.aws.io.{AwsBatchJobPaths, AwsBatchWorkflowPaths} +import cromwell.backend.io.JobPathsSpecHelper.{EnhancedJobPaths, EnhancedCallContext} import org.scalatest.flatspec.AnyFlatSpecLike import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider import cromwell.core.Tags.AwsTest diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala index 0e7172343cd..3c8e0382974 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala @@ -34,7 +34,7 @@ package cromwell.backend.impl.aws import common.collections.EnhancedCollections._ import cromwell.backend.{BackendJobDescriptorKey, BackendWorkflowDescriptor} import cromwell.backend.BackendSpec._ -import cromwell.backend.impl.aws.io.AwsBatchWorkingDisk +import cromwell.backend.impl.aws.io.{AwsBatchJobPaths, AwsBatchWorkflowPaths, AwsBatchWorkingDisk} import cromwell.backend.validation.ContinueOnReturnCodeFlag import cromwell.core.path.DefaultPathBuilder import cromwell.core.TestKitSuite @@ -118,6 +118,8 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi scriptS3BucketName = "script-bucket", awsBatchRetryAttempts = 1, ulimits = Vector(Map.empty[String, String]), + efsDelocalize = false, + efsMakeMD5 = false, fileSystem = "s3") val containerDetail: ContainerDetail = ContainerDetail.builder().exitCode(0).build() @@ -127,21 +129,21 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi val job = AwsBatchJob(null, runtimeAttributes, "commandLine", script, "/cromwell_root/hello-rc.txt", "/cromwell_root/hello-stdout.log", "/cromwell_root/hello-stderr.log", Seq.empty[AwsBatchInput].toSet, Seq.empty[AwsBatchFileOutput].toSet, - jobPaths, Seq.empty[AwsBatchParameter], None, None) + jobPaths, Seq.empty[AwsBatchParameter], None, None, None, None, None, None) job } private def generateBasicJobForLocalFS: AwsBatchJob = { val job = AwsBatchJob(null, runtimeAttributes.copy(fileSystem="local"), "commandLine", script, "/cromwell_root/hello-rc.txt", "/cromwell_root/hello-stdout.log", "/cromwell_root/hello-stderr.log", Seq.empty[AwsBatchInput].toSet, Seq.empty[AwsBatchFileOutput].toSet, - jobPaths, Seq.empty[AwsBatchParameter], None, None) + jobPaths, Seq.empty[AwsBatchParameter], None, None, None, None, None, None) job } private def generateJobWithS3InOut: AwsBatchJob = { val job = AwsBatchJob(null, runtimeAttributes, "commandLine", script, "/cromwell_root/hello-rc.txt", "/cromwell_root/hello-stdout.log", "/cromwell_root/hello-stderr.log", s3Inputs, s3Outputs, - jobPaths, Seq.empty[AwsBatchParameter], None, None) + jobPaths, Seq.empty[AwsBatchParameter], None, None, None, None, None, None) job } @@ -180,109 +182,152 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi it should "add s3 localize with retry function to reconfigured script" in { val job = generateBasicJob - val retryFunctionText = s""" - |export AWS_METADATA_SERVICE_TIMEOUT=10 - |export AWS_METADATA_SERVICE_NUM_ATTEMPTS=10 - | - |function _s3_localize_with_retry() { - | local s3_path=$$1 - | # destination must be the path to a file and not just the directory you want the file in - | local destination=$$2 - | - | for i in {1..6}; - | do - | # abort if tries are exhausted - | if [ "$$i" -eq 6 ]; then - | echo "failed to copy $$s3_path after $$(( $$i - 1 )) attempts. aborting" - | exit 2 - | fi - | # check validity of source path - | if ! [[ $$s3_path =~ s3://([^/]+)/(.+) ]]; then - | echo "$$s3_path is not an S3 path with a bucket and key. aborting" - | exit 1 - | fi - | # copy - | /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress "$$s3_path" "$$destination" || - | ( echo "attempt $$i to copy $$s3_path failed" sleep $$((7 * "$$i")) && continue) - | # check data integrity - | _check_data_integrity $$destination $$s3_path || - | (echo "data content length difference detected in attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue) - | # copy succeeded - | break - | done - |} - | - |function _s3_delocalize_with_retry() { - | local local_path=$$1 - | # destination must be the path to a file and not just the directory you want the file in - | local destination=$$2 - | - | for i in {1..6}; - | do - | # if tries exceeded : abort - | if [ "$$i" -eq 6 ]; then - | echo "failed to delocalize $$local_path after $$(( $$i - 1 )) attempts. aborting" - | exit 2 - | fi - | # if destination is not a bucket : abort - | if ! [[ $$destination =~ s3://([^/]+)/(.+) ]]; then - | echo "$$destination is not an S3 path with a bucket and key. aborting" - | exit 1 - | fi - | # copy ok or try again. - | if [[ -d "$$local_path" ]]; then - | # make sure to strip the trailing / in destination - | destination=$${destination%/} - | # glob directory. do recursive copy - | /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress $$local_path $$destination --recursive --exclude "cromwell_glob_control_file" || - | ( echo "attempt $$i to copy globDir $$local_path failed" && sleep $$((7 * "$$i")) && continue) - | # check integrity for each of the files - | for FILE in $$(cd $$local_path ; ls | grep -v cromwell_glob_control_file); do - | _check_data_integrity $$local_path/$$FILE $$destination/$$FILE || - | ( echo "data content length difference detected in attempt $$i to copy $$local_path/$$FILE failed" && sleep $$((7 * "$$i")) && continue 2) - | done - | else - | /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress "$$local_path" "$$destination" || - | ( echo "attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue) - | # check content length for data integrity - | _check_data_integrity $$local_path $$destination || - | ( echo "data content length difference detected in attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue) - | fi - | # copy succeeded - | break - | done - |} - | - |function _check_data_integrity() { - | local local_path=$$1 - | local s3_path=$$2 - | - | # remote : use content_length - | if [[ $$s3_path =~ s3://([^/]+)/(.+) ]]; then - | bucket="$${BASH_REMATCH[1]}" - | key="$${BASH_REMATCH[2]}" - | else - | # this is already checked in the caller function - | echo "$$s3_path is not an S3 path with a bucket and key. aborting" - | exit 1 - | fi - | s3_content_length=$$(/usr/local/aws-cli/v2/current/bin/aws s3api head-object --bucket "$$bucket" --key "$$key" --query 'ContentLength') || - | ( echo "Attempt to get head of object failed for $$s3_path." && return 1 ) - | # local - | local_content_length=$$(LC_ALL=C ls -dn -- "$$local_path" | awk '{print $$5; exit}' ) || - | ( echo "Attempt to get local content length failed for $$_local_path." && return 1 ) - | # compare - | if [[ "$$s3_content_length" -eq "$$local_content_length" ]]; then - | true - | else - | false - | fi - |} - |""".stripMargin + val retryFunctionText = + s""" + |export AWS_METADATA_SERVICE_TIMEOUT=10 + |export AWS_METADATA_SERVICE_NUM_ATTEMPTS=10 + | + |function _s3_localize_with_retry() { + | local s3_path=$$1 + | # destination must be the path to a file and not just the directory you want the file in + | local destination=$$2 + | + | for i in {1..6}; + | do + | # abort if tries are exhausted + | if [ "$$i" -eq 6 ]; then + | echo "failed to copy $$s3_path after $$(( $$i - 1 )) attempts. aborting" + | exit 2 + | fi + | # check validity of source path + | if ! [[ $$s3_path =~ s3://([^/]+)/(.+) ]]; then + | echo "$$s3_path is not an S3 path with a bucket and key. aborting" + | exit 1 + | fi + | # copy + | /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress "$$s3_path" "$$destination" || + | { echo "attempt $$i to copy $$s3_path failed" && sleep $$((7 * "$$i")) && continue; } + | # check data integrity + | _check_data_integrity $$destination $$s3_path || + | { echo "data content length difference detected in attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue; } + | # copy succeeded + | break + | done + |}""".stripMargin job.reconfiguredScript should include (retryFunctionText) } + it should "s3 delocalization with retry function in reconfigured script" in { + val job = generateBasicJob + val delocalizeText = s""" + | + |function _s3_delocalize_with_retry() { + | # input variables + | local local_path=$$1 + | # destination must be the path to a file and not just the directory you want the file in + | local destination=$$2 + | + | # get the multipart chunk size + | chunk_size=$$(_get_multipart_chunk_size $$local_path) + | local MP_THRESHOLD=5368709120 + | # then set them + | /usr/local/aws-cli/v2/current/bin/aws configure set default.s3.multipart_threshold $$MP_THRESHOLD + | /usr/local/aws-cli/v2/current/bin/aws configure set default.s3.multipart_chunksize $$chunk_size + | + | # try & validate upload 5 times + | for i in {1..6}; + | do + | # if tries exceeded : abort + | if [ "$$i" -eq 6 ]; then + | echo "failed to delocalize $$local_path after $$(( $$i - 1 )) attempts. aborting" + | exit 2 + | fi + | # if destination is not a bucket : abort + | if ! [[ $$destination =~ s3://([^/]+)/(.+) ]]; then + | echo "$$destination is not an S3 path with a bucket and key. aborting" + | exit 1 + | fi + | # copy ok or try again. + | if [[ -d "$$local_path" ]]; then + | # make sure to strip the trailing / in destination + | destination=$${destination%/} + | # glob directory. do recursive copy + | /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress $$local_path $$destination --recursive --exclude "cromwell_glob_control_file" || + | { echo "attempt $$i to copy globDir $$local_path failed" && sleep $$((7 * "$$i")) && continue; } + | # check integrity for each of the files + | for FILE in $$(cd $$local_path ; ls | grep -v cromwell_glob_control_file); do + | _check_data_integrity $$local_path/$$FILE $$destination/$$FILE || + | { echo "data content length difference detected in attempt $$i to copy $$local_path/$$FILE failed" && sleep $$((7 * "$$i")) && continue 2; } + | done + | else + | /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress "$$local_path" "$$destination" || + | { echo "attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue; } + | # check content length for data integrity + | _check_data_integrity $$local_path $$destination || + | { echo "data content length difference detected in attempt $$i to copy $$local_path failed" && sleep $$((7 * "$$i")) && continue; } + | fi + | # copy succeeded + | break + | done + |}""".stripMargin + job.reconfiguredScript should include (delocalizeText) + } + + it should "generate check data integrity in reconfigured script" in { + val job = generateBasicJob + val checkDataIntegrityBlock = + s""" + |function _check_data_integrity() { + | local local_path=$$1 + | local s3_path=$$2 + | + | # remote : use content_length + | if [[ $$s3_path =~ s3://([^/]+)/(.+) ]]; then + | bucket="$${BASH_REMATCH[1]}" + | key="$${BASH_REMATCH[2]}" + | else + | # this is already checked in the caller function + | echo "$$s3_path is not an S3 path with a bucket and key. aborting" + | exit 1 + | fi + | s3_content_length=$$(/usr/local/aws-cli/v2/current/bin/aws s3api head-object --bucket "$$bucket" --key "$$key" --query 'ContentLength') || + | { echo "Attempt to get head of object failed for $$s3_path." && return 1; } + | # local + | local_content_length=$$(LC_ALL=C ls -dnL -- "$$local_path" | awk '{print $$5; exit}' ) || + | { echo "Attempt to get local content length failed for $$_local_path." && return 1; } + | # compare + | if [[ "$$s3_content_length" -eq "$$local_content_length" ]]; then + | true + | else + | false + | fi + |} + |""".stripMargin + job.reconfiguredScript should include (checkDataIntegrityBlock) + } + + it should "generate get multipart chunk size in script" in { + val job = generateBasicJob + val getMultiplePartChunkSize = + s""" + |function _get_multipart_chunk_size() { + | local file_path=$$1 + | # file size + | file_size=$$(stat --printf="%s" $$file_path) + | # chunk_size : you can have at most 10K parts with at least one 5MB part + | # this reflects the formula in s3-copy commands of cromwell (S3FileSystemProvider.java) + | # => long partSize = Math.max((objectSize / 10000L) + 1, 5 * 1024 * 1024); + | a=$$(( ( file_size / 10000) + 1 )) + | b=$$(( 5 * 1024 * 1024 )) + | chunk_size=$$(( a > b ? a : b )) + | echo $$chunk_size + |} + |""".stripMargin + + job.reconfiguredScript should include (getMultiplePartChunkSize) + } + it should "generate postscript with output copy command in reconfigured script" in { val job = generateJobWithS3InOut val postscript = @@ -290,13 +335,10 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi |{ |set -e |echo '*** DELOCALIZING OUTPUTS ***' - | |_s3_delocalize_with_retry /tmp/scratch/baa s3://bucket/somewhere/baa | - | |if [ -f /tmp/scratch/hello-rc.txt ]; then _s3_delocalize_with_retry /tmp/scratch/hello-rc.txt ${job.jobPaths.returnCode} ; fi - | - |if [ -f /tmp/scratch/hello-stderr.log ]; then _s3_delocalize_with_retrys /tmp/scratch/hello-stderr.log ${job.jobPaths.standardPaths.error}; fi + |if [ -f /tmp/scratch/hello-stderr.log ]; then _s3_delocalize_with_retry /tmp/scratch/hello-stderr.log ${job.jobPaths.standardPaths.error}; fi |if [ -f /tmp/scratch/hello-stdout.log ]; then _s3_delocalize_with_retry /tmp/scratch/hello-stdout.log ${job.jobPaths.standardPaths.output}; fi | |echo '*** COMPLETED DELOCALIZATION ***' @@ -309,6 +351,8 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi job.reconfiguredScript should include (postscript) } + + it should "generate preamble with input copy command in reconfigured script" in { val job = generateJobWithS3InOut val preamble = @@ -345,7 +389,4 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi val job = generateBasicJob job.rc(jobDetail) should be (0) } - - - } diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala index 3de9cac621b..7664c9dc748 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala @@ -66,7 +66,10 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout false, "my-stuff", 1, - Vector(Map.empty[String, String])) + Vector(Map.empty[String, String]), + false, + false + ) val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"), @@ -79,6 +82,8 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout "", 1, Vector(Map.empty[String, String]), + false, + false, "local") "AwsBatchRuntimeAttributes" should { diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchWorkflowPathsSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchWorkflowPathsSpec.scala index 2b2d0ea14d8..09a2223c0ac 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchWorkflowPathsSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchWorkflowPathsSpec.scala @@ -33,6 +33,7 @@ package cromwell.backend.impl.aws import common.collections.EnhancedCollections._ import cromwell.backend.BackendSpec +import cromwell.backend.impl.aws.io.AwsBatchWorkflowPaths import org.scalatest.flatspec.AnyFlatSpecLike import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider import cromwell.core.Tags.AwsTest From 0fd2cc87def59a0db5151e74f575e07f642e568a Mon Sep 17 00:00:00 2001 From: xquek <46759870+xquek@users.noreply.github.com> Date: Wed, 1 Nov 2023 02:57:21 -0700 Subject: [PATCH 3/4] return bucket directly instead of listing and checking it (#38) Co-authored-by: quekx --- filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java index 8e15999180a..5de7e89813b 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java @@ -99,10 +99,7 @@ public Bucket getBucket() { } private Bucket getBucket(String bucketName) { - for (Bucket buck : getClient().listBuckets().buckets()) - if (buck.name().equals(bucketName)) - return buck; - return null; + return Bucket.builder().name(bucketName).build(); } private boolean hasBucket(String bucketName) { From 465d8f59e5209dd59ac7610d7217d3eda7a35bc3 Mon Sep 17 00:00:00 2001 From: xquek <46759870+xquek@users.noreply.github.com> Date: Mon, 13 Nov 2023 01:59:13 -0800 Subject: [PATCH 4/4] Add evaluteOnExit for aws batch retry (#40) Co-authored-by: quekx --- .../main/scala/cromwell/backend/backend.scala | 1 + .../backend/impl/aws/AwsBatchAttributes.scala | 1 + .../impl/aws/AwsBatchJobDefinition.scala | 29 +++- .../impl/aws/AwsBatchRuntimeAttributes.scala | 134 +++++++++++++++++- .../scala/cromwell/backend/impl/aws/README.md | 31 ++++ .../backend/impl/aws/AwsBatchJobSpec.scala | 48 ++++++- .../aws/AwsBatchRuntimeAttributesSpec.scala | 106 ++++++++++++-- .../backend/impl/aws/AwsBatchTestConfig.scala | 90 ++++++++++++ 8 files changed, 419 insertions(+), 21 deletions(-) diff --git a/backend/src/main/scala/cromwell/backend/backend.scala b/backend/src/main/scala/cromwell/backend/backend.scala index bf9f67e27a1..2f18f499f30 100644 --- a/backend/src/main/scala/cromwell/backend/backend.scala +++ b/backend/src/main/scala/cromwell/backend/backend.scala @@ -140,6 +140,7 @@ object CommonBackendConfigurationAttributes { "default-runtime-attributes.docker", "default-runtime-attributes.queueArn", "default-runtime-attributes.awsBatchRetryAttempts", + "default-runtime-attributes.awsBatchEvaluateOnExit", "default-runtime-attributes.ulimits", "default-runtime-attributes.efsDelocalize", "default-runtime-attributes.efsMakeMD5", diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala index ecb6c87a457..1119e192262 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala @@ -87,6 +87,7 @@ object AwsBatchAttributes { "numSubmitAttempts", "default-runtime-attributes.scriptBucketName", "awsBatchRetryAttempts", + "awsBatchEvaluateOnExit", "ulimits", "efsDelocalize", "efsMakeMD5", diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala index e1e5af42539..792fd6825d0 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala @@ -34,7 +34,7 @@ package cromwell.backend.impl.aws import scala.collection.mutable.ListBuffer import cromwell.backend.BackendJobDescriptor import cromwell.backend.io.JobPaths -import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, RetryStrategy, Ulimit, Volume} +import software.amazon.awssdk.services.batch.model.{ContainerProperties, EvaluateOnExit, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, RetryAction, RetryStrategy, Ulimit, Volume} import cromwell.backend.impl.aws.io.AwsBatchVolume import scala.jdk.CollectionConverters._ @@ -183,9 +183,30 @@ trait AwsBatchJobDefinitionBuilder { def retryStrategyBuilder(context: AwsBatchJobDefinitionContext): (RetryStrategy.Builder, String) = { // We can add here the 'evaluateOnExit' statement - (RetryStrategy.builder() - .attempts(context.runtimeAttributes.awsBatchRetryAttempts), - context.runtimeAttributes.awsBatchRetryAttempts.toString) + var builder = RetryStrategy.builder() + .attempts(context.runtimeAttributes.awsBatchRetryAttempts) + + var evaluations: Seq[EvaluateOnExit] = Seq() + context.runtimeAttributes.awsBatchEvaluateOnExit.foreach( + (evaluate) => { + val evaluateBuilder = evaluate.foldLeft(EvaluateOnExit.builder()) { + case (acc, (k, v)) => (k.toLowerCase, v.toLowerCase) match { + case ("action", "retry") => acc.action(RetryAction.RETRY) + case ("action", "exit") => acc.action(RetryAction.EXIT) + case ("onexitcode", _) => acc.onExitCode(v) + case ("onreason", _) => acc.onReason(v) + case ("onstatusreason", _) => acc.onStatusReason(v) + case _ => acc + } + } + evaluations = evaluations :+ evaluateBuilder.build() + } + ) + + builder = builder.evaluateOnExit(evaluations.asJava) + + (builder, + s"${context.runtimeAttributes.awsBatchRetryAttempts.toString}${context.runtimeAttributes.awsBatchEvaluateOnExit.toString}") } diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala index 235927ab07c..13137718354 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala @@ -44,11 +44,15 @@ import wom.RuntimeAttributesKeys import wom.format.MemorySize import wom.types._ import wom.values._ -import com.typesafe.config.{ConfigException,ConfigValueFactory} +import com.typesafe.config.{ConfigException, ConfigValueFactory} import scala.util.matching.Regex import org.slf4j.{Logger, LoggerFactory} +import scala.util.{Failure, Success, Try} +import scala.jdk.CollectionConverters._ + + /** * Attributes that are provided to the job at runtime * @param cpu number of vCPU @@ -63,6 +67,7 @@ import org.slf4j.{Logger, LoggerFactory} * @param scriptS3BucketName the s3 bucket where the execution command or script will be written and, from there, fetched into the container and executed * @param fileSystem the filesystem type, default is "s3" * @param awsBatchRetryAttempts number of attempts that AWS Batch will retry the task if it fails + * @param awsBatchEvaluateOnExit Evaluate on exit strategy setting for AWS batch retry * @param ulimits ulimit values to be passed to the container * @param efsDelocalize should we delocalize efs files to s3 * @param efsMakeMD5 should we make a sibling md5 file as part of the job @@ -78,6 +83,7 @@ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive, noAddress: Boolean, scriptS3BucketName: String, awsBatchRetryAttempts: Int, + awsBatchEvaluateOnExit: Vector[Map[String, String]], ulimits: Vector[Map[String, String]], efsDelocalize: Boolean, efsMakeMD5 : Boolean, @@ -91,6 +97,10 @@ object AwsBatchRuntimeAttributes { val awsBatchRetryAttemptsKey = "awsBatchRetryAttempts" + val awsBatchEvaluateOnExitKey = "awsBatchEvaluateOnExit" + private val awsBatchEvaluateOnExitDefault = WomArray(WomArrayType(WomMapType(WomStringType,WomStringType)), Vector(WomMap(Map.empty[WomValue, WomValue]))) + + val awsBatchefsDelocalizeKey = "efsDelocalize" val awsBatchefsMakeMD5Key = "efsMakeMD5" @@ -157,6 +167,11 @@ object AwsBatchRuntimeAttributes { .configDefaultWomValue(runtimeConfig).getOrElse(WomInteger(0))) } + def awsBatchEvaluateOnExitValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Vector[Map[String, String]]] = { + AwsBatchEvaluateOnExitValidation + .withDefault(AwsBatchEvaluateOnExitValidation.fromConfig(runtimeConfig).getOrElse(awsBatchEvaluateOnExitDefault)) + } + private def awsBatchefsDelocalizeValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Boolean] = { AwsBatchefsDelocalizeValidation(awsBatchefsDelocalizeKey).withDefault(AwsBatchefsDelocalizeValidation(awsBatchefsDelocalizeKey) .configDefaultWomValue(runtimeConfig).getOrElse(WomBoolean(false))) @@ -199,7 +214,6 @@ object AwsBatchRuntimeAttributes { def runtimeAttributesBuilder(configuration: AwsBatchConfiguration): StandardValidatedRuntimeAttributesBuilder = { val runtimeConfig = aggregateDisksInRuntimeConfig(configuration) - def validationsS3backend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation( cpuValidation(runtimeConfig), cpuMinValidation(runtimeConfig), @@ -212,6 +226,7 @@ object AwsBatchRuntimeAttributes { queueArnValidation(runtimeConfig), scriptS3BucketNameValidation(runtimeConfig), awsBatchRetryAttemptsValidation(runtimeConfig), + awsBatchEvaluateOnExitValidation(runtimeConfig), ulimitsValidation(runtimeConfig), awsBatchefsDelocalizeValidation(runtimeConfig), awsBatchefsMakeMD5Validation(runtimeConfig) @@ -227,6 +242,7 @@ object AwsBatchRuntimeAttributes { dockerValidation, queueArnValidation(runtimeConfig), awsBatchRetryAttemptsValidation(runtimeConfig), + awsBatchEvaluateOnExitValidation(runtimeConfig), ulimitsValidation(runtimeConfig), awsBatchefsDelocalizeValidation(runtimeConfig), awsBatchefsMakeMD5Validation(runtimeConfig) @@ -254,6 +270,8 @@ object AwsBatchRuntimeAttributes { case _ => "" } val awsBatchRetryAttempts: Int = RuntimeAttributesValidation.extract(awsBatchRetryAttemptsValidation(runtimeAttrsConfig), validatedRuntimeAttributes) + val awsBatchEvaluateOnExit: Vector[Map[String, String]] = RuntimeAttributesValidation.extract(awsBatchEvaluateOnExitValidation(runtimeAttrsConfig), validatedRuntimeAttributes) + val ulimits: Vector[Map[String, String]] = RuntimeAttributesValidation.extract(ulimitsValidation(runtimeAttrsConfig), validatedRuntimeAttributes) val efsDelocalize: Boolean = RuntimeAttributesValidation.extract(awsBatchefsDelocalizeValidation(runtimeAttrsConfig),validatedRuntimeAttributes) val efsMakeMD5: Boolean = RuntimeAttributesValidation.extract(awsBatchefsMakeMD5Validation(runtimeAttrsConfig),validatedRuntimeAttributes) @@ -270,6 +288,7 @@ object AwsBatchRuntimeAttributes { noAddress, scriptS3BucketName, awsBatchRetryAttempts, + awsBatchEvaluateOnExit, ulimits, efsDelocalize, efsMakeMD5, @@ -473,6 +492,115 @@ class AwsBatchRetryAttemptsValidation(key: String) extends IntRuntimeAttributesV override protected def missingValueMessage: String = s"Expecting $key runtime attribute to be an Integer" } +object AwsBatchEvaluateOnExitValidation extends RuntimeAttributesValidation[Vector[Map[String, String]]] { + + val requiredKey = "action" + private val acceptedKeys = Set(requiredKey, "onExitCode", "onReason", "onStatusReason") + + + def fromConfig(runtimeConfig: Option[Config]): Option[WomValue]= { + val config = runtimeConfig match { + case Some(value) => Try(value.getObjectList(key)) match { + case Failure(_) => None + case Success(value) => Some(value.asScala.map { + _.unwrapped().asScala.toMap + }.toList) + } + case _ => None + } + + config match { + case Some(value) => Some(AwsBatchEvaluateOnExitValidation + .coercion collectFirst { + case womType if womType.coerceRawValue(value).isSuccess => womType.coerceRawValue(value).get + } getOrElse { + BadDefaultAttribute(WomString(value.toString)) + }) + case None => None + } + } + + override def coercion: Iterable[WomType] = { + Set(WomStringType, WomArrayType(WomMapType(WomStringType, WomStringType))) + } + + override protected def validateValue: PartialFunction[WomValue, ErrorOr[Vector[Map[String, String]]]] = { + case WomArray(womType, value) + if womType.memberType == WomMapType(WomStringType, WomStringType) => + check_maps(value.toVector) + case WomMap(_, _) => "!!! ERROR1".invalidNel + } + + private def check_maps( + maps: Vector[WomValue] + ): ErrorOr[Vector[Map[String, String]]] = { + val entryNels: Vector[ErrorOr[Map[String, String]]] = maps.map { + case WomMap(_, value) => check_keys(value) + case _ => "!!! ERROR2".invalidNel + } + val sequenced: ErrorOr[Vector[Map[String, String]]] = sequenceNels( + entryNels + ) + sequenced + } + + private def validateActionKey(dict: Map[WomValue, WomValue]): ErrorOr[Map[String, String]] = { + val validCondition = Set("retry", "exit") + val convertedMap = dict + .map { case (WomString(k), WomString(v)) => + (k, v) + // case _ => "!!! ERROR3".invalidNel + } + if (convertedMap.exists { + case (key, value) => key.toLowerCase == requiredKey && validCondition.contains(value.toLowerCase) + }) { + convertedMap.validNel + } + else { + s"Missing or invalid $requiredKey key/value for runtime attribute: $key. Refer to https://docs.aws.amazon.com/batch/latest/APIReference/API_RetryStrategy.html".invalidNel + } + } + + private def check_keys( + dict: Map[WomValue, WomValue] + ): ErrorOr[Map[String, String]] = { + val map_keys = dict.keySet.map(_.valueString.toLowerCase) + val unrecognizedKeys = + map_keys.diff(acceptedKeys.map(x => x.toLowerCase)) + if (!dict.nonEmpty) { + Map.empty[String, String].validNel + } + else if (unrecognizedKeys.nonEmpty) { + s"Invalid keys in $key runtime attribute: $unrecognizedKeys. Only $acceptedKeys are accepted. Refer to https://docs.aws.amazon.com/batch/latest/APIReference/API_RetryStrategy.html".invalidNel + } + else { + validateActionKey(dict) + } + } + + private def sequenceNels( + nels: Vector[ErrorOr[Map[String, String]]] + ): ErrorOr[Vector[Map[String, String]]] = { + val emptyNel: ErrorOr[Vector[Map[String, String]]] = + Vector.empty[Map[String, String]].validNel + val seqNel: ErrorOr[Vector[Map[String, String]]] = + nels.foldLeft(emptyNel) { (acc, v) => + (acc, v) mapN { (a, v) => a :+ v } + } + seqNel + } + + + override protected def missingValueMessage: String = s"Expecting $key runtime attribute to be defined" + + /** + * Returns the key of the runtime attribute. + * + * @return The key of the runtime attribute. + */ + override def key: String = AwsBatchRuntimeAttributes.awsBatchEvaluateOnExitKey +} + object AwsBatchefsDelocalizeValidation { def apply(key: String): AwsBatchefsDelocalizeValidation = new AwsBatchefsDelocalizeValidation(key) } @@ -531,7 +659,7 @@ object UlimitsValidation accepted_keys.diff(map_keys) union map_keys.diff(accepted_keys) if (!dict.nonEmpty){ - Map.empty[String, String].validNel + Map.empty[String, String].validNel }else if (unrecognizedKeys.nonEmpty) { s"Invalid keys in $key runtime attribute. Refer to 'ulimits' section on https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html#containerProperties".invalidNel } else { diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md index 48bc3adc442..f9c3670f67b 100644 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md @@ -75,6 +75,37 @@ runtime { } ``` +### `awsBatchEvaluteOnExit` + +*Default: _[]_* - will always retry + +This runtime attribute sets the `evaluateOnExit` for [*AWS Batch Automated Job Retries*](https://docs.aws.amazon.com/batch/latest/userguide/job_retries.html) and specify the retry condition for a failed job. + +This configuration works with `awsBatchRetryAttempts` and is useful if you only want to retry on certain failures. + +For instance, if you will only like to retry during spot termination. + +``` +runtime { + awsBatchEvaluateOnExit: [ + { + Action: "RETRY", + onStatusReason: "Host EC2*" + }, + { + onReason : "*" + Action: "EXIT" + } + ] +} +``` + +For more information on the batch retry strategy, please refer to: + +* General Doc: [userguide/job_retries.html](https://docs.aws.amazon.com/batch/latest/userguide/job_retries.html) +* Blog: [Introducing retry strategies](https://aws.amazon.com/blogs/compute/introducing-retry-strategies-for-aws-batch/) + + ### `ulimits` *Default: _empty_* diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala index 3c8e0382974..0c202aa8946 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala @@ -32,7 +32,7 @@ package cromwell.backend.impl.aws import common.collections.EnhancedCollections._ -import cromwell.backend.{BackendJobDescriptorKey, BackendWorkflowDescriptor} +import cromwell.backend.{BackendJobDescriptor, BackendJobDescriptorKey, BackendWorkflowDescriptor} import cromwell.backend.BackendSpec._ import cromwell.backend.impl.aws.io.{AwsBatchJobPaths, AwsBatchWorkflowPaths, AwsBatchWorkingDisk} import cromwell.backend.validation.ContinueOnReturnCodeFlag @@ -46,7 +46,7 @@ import org.scalatest.PrivateMethodTester import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider -import software.amazon.awssdk.services.batch.model.{ContainerDetail, JobDetail, KeyValuePair} +import software.amazon.awssdk.services.batch.model.{ContainerDetail, EvaluateOnExit, JobDetail, KeyValuePair, RetryAction, RetryStrategy} import spray.json.{JsObject, JsString} import wdl4s.parser.MemoryUnit import wom.format.MemorySize @@ -100,6 +100,8 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi val call: CommandCallNode = workFlowDescriptor.callable.taskCallNodes.head val jobKey: BackendJobDescriptorKey = BackendJobDescriptorKey(call, None, 1) + val jobDescriptor: BackendJobDescriptor = BackendJobDescriptor(null, null, null, Map.empty, null, null, null) + val jobPaths: AwsBatchJobPaths = AwsBatchJobPaths(workflowPaths, jobKey) val s3Inputs: Set[AwsBatchInput] = Set(AwsBatchFileInput("foo", "s3://bucket/foo", DefaultPathBuilder.get("foo"), AwsBatchWorkingDisk())) val s3Outputs: Set[AwsBatchFileOutput] = Set(AwsBatchFileOutput("baa", "s3://bucket/somewhere/baa", DefaultPathBuilder.get("baa"), AwsBatchWorkingDisk())) @@ -117,11 +119,19 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi noAddress = false, scriptS3BucketName = "script-bucket", awsBatchRetryAttempts = 1, + awsBatchEvaluateOnExit = Vector(Map.empty[String, String]), ulimits = Vector(Map.empty[String, String]), efsDelocalize = false, efsMakeMD5 = false, fileSystem = "s3") + val batchJobDefintion = AwsBatchJobDefinitionContext( + runtimeAttributes = runtimeAttributes, + commandText = "", dockerRcPath = "", dockerStdoutPath = "", dockerStderrPath = "", jobDescriptor = jobDescriptor + , jobPaths = jobPaths, inputs = Set(), outputs = Set(), fsxMntPoint = None, None, None, None + + ) + val containerDetail: ContainerDetail = ContainerDetail.builder().exitCode(0).build() val jobDetail: JobDetail = JobDetail.builder().container(containerDetail).build @@ -351,8 +361,6 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi job.reconfiguredScript should include (postscript) } - - it should "generate preamble with input copy command in reconfigured script" in { val job = generateJobWithS3InOut val preamble = @@ -389,4 +397,36 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi val job = generateBasicJob job.rc(jobDetail) should be (0) } + + it should "use RetryStrategy" in { + val runtime = runtimeAttributes.copy( + awsBatchEvaluateOnExit = Vector(Map("action" -> "EXIT", "onStatusReason" -> "Failed")), + ) + + val builder = RetryStrategy.builder().attempts(1).evaluateOnExit( + EvaluateOnExit.builder().onStatusReason("Failed").action(RetryAction.EXIT).build() + ).build() + + val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime)) + val jobDefinitionName = jobDefinition.name + val expected = jobDefinition.retryStrategy + expected should equal (builder) + jobDefinitionName should equal ("cromwell_ubuntu_latest_656d5a7e7cd016d2360b27bc5ee75018d91a777a") + } + + it should "use RetryStrategy evaluateOnExit should be case insensitive" in { + val runtime = runtimeAttributes.copy( + awsBatchEvaluateOnExit = Vector(Map("aCtIoN" -> "EXIT", "onStatusReason" -> "Failed")), + ) + + val builder = RetryStrategy.builder().attempts(1).evaluateOnExit( + EvaluateOnExit.builder().onStatusReason("Failed").action(RetryAction.EXIT).build() + ).build() + + val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime)) + val jobDefinitionName = jobDefinition.name + val expected = jobDefinition.retryStrategy + expected should equal(builder) + jobDefinitionName should equal("cromwell_ubuntu_latest_66a335d761780e64e6b154339c5f1db2f0783f96") + } } diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala index 7664c9dc748..944577d229d 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala @@ -48,6 +48,8 @@ import wom.format.MemorySize import wom.types._ import wom.values._ +import scala.util.{Failure, Success, Try} + class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeoutSpec with Matchers { def workflowOptionsWithDefaultRA(defaults: Map[String, JsValue]): WorkflowOptions = { @@ -67,6 +69,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout "my-stuff", 1, Vector(Map.empty[String, String]), + Vector(Map.empty[String, String]), false, false ) @@ -82,6 +85,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout "", 1, Vector(Map.empty[String, String]), + Vector(Map.empty[String, String]), false, false, "local") @@ -191,9 +195,9 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedRuntimeAttributes) } "validate a valid Filesystem string entry local Filesystem" in { - val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"),"scriptBucketName" -> WomString(""), "filesystem" -> WomString("local")) + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "scriptBucketName" -> WomString(""), "filesystem" -> WomString("local")) val expectedRuntimeAttributes = expectedDefaultsLocalFS - assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedRuntimeAttributes,WorkflowOptions.fromMap(Map.empty).get, + assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedRuntimeAttributes, WorkflowOptions.fromMap(Map.empty).get, NonEmptyList.of("us-east-1a", "us-east-1b"), new AwsBatchConfiguration(AwsBatchTestConfigForLocalFS.AwsBatchBackendConfigurationDescriptor)) } @@ -307,7 +311,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout } "override config default attributes with default attributes declared in workflow options" in { - val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "scriptBucketName" -> WomString("my-stuff") ) + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "scriptBucketName" -> WomString("my-stuff")) val workflowOptionsJson = """{ @@ -321,7 +325,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout } "override config default runtime attributes with task runtime attributes" in { - val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "scriptBucketName" -> WomString("my-stuff"), "cpu" -> WomInteger(4)) + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "scriptBucketName" -> WomString("my-stuff"), "cpu" -> WomInteger(4)) val workflowOptionsJson = """{ @@ -335,7 +339,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout } "override invalid config default attributes with task runtime attributes" in { - val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"),"scriptBucketName" -> WomString("my-stuff"), "cpu" -> WomInteger(4)) + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "scriptBucketName" -> WomString("my-stuff"), "cpu" -> WomInteger(4)) val workflowOptionsJson = """{ @@ -374,6 +378,89 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout val expectedRuntimeAttributes = expectedDefaults.copy(awsBatchRetryAttempts = 0) assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedRuntimeAttributes) } + + "validate a valid awsBatchEvaluateOnExit " in { + val runtimeAttributes = Map( + "docker" -> WomString("ubuntu:latest"), + "awsBatchRetryAttempts" -> WomInteger(0), + "scriptBucketName" -> WomString("my-stuff"), + "awsBatchEvaluateOnExit" -> WomArray( + Seq(WomMap(Map(WomString("action") -> WomString("RETRY"), WomString("onStatusReason") -> WomString("Host EC2*"))) + ) + ) + ) + + assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedDefaults.copy( + awsBatchRetryAttempts = 0, + awsBatchEvaluateOnExit = Vector(Map("action" -> "RETRY", "onStatusReason" -> "Host EC2*")) + )) + } + + "if awsBatchEvaluteOnExit is empty, do not fail" in { + val runtimeAttributes = Map( + "docker" -> WomString("ubuntu:latest"), + "awsBatchRetryAttempts" -> WomInteger(0), + "scriptBucketName" -> WomString("my-stuff"), + "awsBatchEvaluateOnExit" -> WomArray(WomArrayType(WomMapType(WomStringType,WomStringType)), Vector(WomMap(Map.empty[WomValue, WomValue]))) + ) + assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedDefaults.copy( + awsBatchRetryAttempts = 0, + )) + } + + "missing or invalid action key result in an invalid awsBatchEvaluateOnExit" in { + val invalidEvaluateOnExit = List( + // missing action key + WomArray( + Seq(WomMap(Map(WomString("onStatusReason") -> WomString("Host EC2*"))) + ) + ), + // invalid value + WomArray( + Seq(WomMap(Map(WomString("action") -> WomString("TRYAGAIN"), WomString("onStatusReason") -> WomString("Host EC2*"))) + ) + ) + ) + + invalidEvaluateOnExit foreach { invalidVal => + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "awsBatchEvaluateOnExit" -> invalidVal) + assertAwsBatchRuntimeAttributesFailedCreation(runtimeAttributes, + "Missing or invalid action key/value for runtime attribute: awsBatchEvaluateOnExit") + } + } + } + + "Unrecognized keys for retry strategy should result in an invalid awsBatchEvaluateOnExit" in { + // invalid key + val invalidValue = WomArray( + Seq(WomMap(Map(WomString("action") -> WomString("RETRY"), WomString("onRandomStatus") -> WomString("Host EC2*"))) + ) + ) + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "awsBatchEvaluateOnExit" -> invalidValue) + assertAwsBatchRuntimeAttributesFailedCreation(runtimeAttributes, + s"""Invalid keys in awsBatchEvaluateOnExit runtime attribute: Set(onrandomstatus). + | Only Set(action, onExitCode, onReason, onStatusReason) are accepted.""".stripMargin.replace( + "\n", "")) + } + + "Config with defined awsBatchEvaluateOnExit works" in { + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "scriptBucketName" -> WomString("my-stuff"), + ) + val batchConfig = new AwsBatchConfiguration(AwsBatchTestWithRetryConfig.AwsBatchBackendConfigurationDescriptor) + val workflowOptions = WorkflowOptions.fromMap(Map.empty).get + val expectedRuntimeAttributes = expectedDefaults.copy( + awsBatchEvaluateOnExit = Vector( + Map("Action" -> "RETRY", "onStatusReason" -> "Host EC2*"), Map("Action" -> "EXIT", "onReason" -> "*")) + ) + + val runtimeAttributesBuilder = AwsBatchRuntimeAttributes.runtimeAttributesBuilder(batchConfig) + val defaultedAttributes = RuntimeAttributeDefinition.addDefaultsToAttributes( + AwsBatchRuntimeAttributes.runtimeAttributesBuilder(batchConfig).definitions.toSet, workflowOptions)(runtimeAttributes) + + val validatedRuntimeAttributes = runtimeAttributesBuilder.build(defaultedAttributes, NOPLogger.NOP_LOGGER) + val actualRuntimeAttributes = AwsBatchRuntimeAttributes( + validatedRuntimeAttributes, batchConfig.runtimeConfig, batchConfig.fileSystem) + assert(actualRuntimeAttributes == expectedRuntimeAttributes) } private def assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes: Map[String, WomValue], @@ -393,11 +480,10 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout private def assertAwsBatchRuntimeAttributesFailedCreation(runtimeAttributes: Map[String, WomValue], exMsg: String, workflowOptions: WorkflowOptions = emptyWorkflowOptions): Unit = { - try { - toAwsBatchRuntimeAttributes(runtimeAttributes, workflowOptions, configuration) - fail(s"A RuntimeException was expected with message: $exMsg") - } catch { - case ex: RuntimeException => assert(ex.getMessage.contains(exMsg)) + + Try(toAwsBatchRuntimeAttributes(runtimeAttributes, workflowOptions, configuration)) match { + case Failure(exception) => assert(exception.getMessage.contains(exMsg)) + case Success(_) => fail(s"A RuntimeException was expected with message: $exMsg") } () } diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala index 682714b225c..275f126765b 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala @@ -187,6 +187,96 @@ object AwsBatchTestConfigForLocalFS { | |""".stripMargin + val AwsBatchBackendConfig = ConfigFactory.parseString(AwsBatchBackendConfigString) + val AwsBatchGlobalConfig = ConfigFactory.parseString(AwsBatchGlobalConfigString) + val AwsBatchBackendNoDefaultConfig = ConfigFactory.parseString(NoDefaultsConfigString) + val AwsBatchBackendConfigurationDescriptor = BackendConfigurationDescriptor(AwsBatchBackendConfig, AwsBatchGlobalConfig) + val NoDefaultsConfigurationDescriptor = BackendConfigurationDescriptor(AwsBatchBackendNoDefaultConfig, AwsBatchGlobalConfig) +} + +object AwsBatchTestWithRetryConfig { + + private val AwsBatchBackendConfigString = + """ + |root = "s3://my-cromwell-workflows-bucket" + | + |filesystems { + | s3 { + | auth = "default" + | } + |} + | + |auth = "default" + | numSubmitAttempts = 6 + | numCreateDefinitionAttempts = 6 + | + |default-runtime-attributes { + | cpu: 1 + | failOnStderr: false + | continueOnReturnCode: 0 + | docker: "ubuntu:latest" + | memory: "2 GB" + | disks: "local-disk" + | noAddress: false + | zones:["us-east-1a", "us-east-1b"] + | queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue" + | scriptBucketName: "my-bucket" + | awsBatchRetryAttempts: 1 + | awsBatchEvaluateOnExit: [ + | { + | Action: "RETRY", + | onStatusReason: "Host EC2*" + | }, + | { + | onReason : "*" + | Action: "EXIT" + | } + | ] + |} + | + |""".stripMargin + + private val NoDefaultsConfigString = + """ + |root = "s3://my-cromwell-workflows-bucket" + | + |auth = "default" + | numSubmitAttempts = 6 + | numCreateDefinitionAttempts = 6 + | + |filesystems { + | s3 { + | auth = "default" + | } + |} + |""".stripMargin + + private val AwsBatchGlobalConfigString = + s""" + |aws { + | application-name = "cromwell" + | auths = [ + | { + | name = "default" + | scheme = "default" + | } + | ] + |} + | + |backend { + | default = "AWS" + | providers { + | AWS { + | actor-factory = "cromwell.backend.impl.aws.AwsBatchBackendLifecycleFactory" + | config { + | $AwsBatchBackendConfigString + | } + | } + | } + |} + | + |""".stripMargin + val AwsBatchBackendConfig = ConfigFactory.parseString(AwsBatchBackendConfigString) val AwsBatchGlobalConfig = ConfigFactory.parseString(AwsBatchGlobalConfigString) val AwsBatchBackendNoDefaultConfig = ConfigFactory.parseString(NoDefaultsConfigString)