Skip to content

Commit

Permalink
WX-1318 gcp batch: Add GPU driver install (#7235)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Nichols <[email protected]>
  • Loading branch information
dspeck1 and aednichols authored Oct 13, 2023
1 parent b6aae14 commit aeacb3a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -829,12 +829,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
_ <- evaluateRuntimeAttributes
_ <- uploadScriptFile()
customLabels <- Future.fromTry(GcpLabel.fromWorkflowOptions(workflowDescriptor.workflowOptions))
_ = customLabels.foreach(x => println(s"ZZZ Custom Labels - $x"))
batchParameters <- generateInputOutputParameters
_ = batchParameters.fileInputParameters.foreach(x => println(s"ZZZ File InputParameters - $x"))
_ = batchParameters.jobInputParameters.foreach(x => println(s"ZZZ InputParameters - $x"))
_ = batchParameters.fileOutputParameters.foreach(x => println(s"ZZZ File OutputParameters - $x"))
_ = batchParameters.jobOutputParameters.foreach(x => println(s"ZZZ OutputParameters - $x"))
createParameters = createBatchParameters(batchParameters, customLabels)
drsLocalizationManifestCloudPath = jobPaths.callExecutionRoot / GcpBatchJobPaths.DrsLocalizationManifestName
_ <- uploadDrsLocalizationManifest(createParameters, drsLocalizationManifestCloudPath)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package cromwell.backend.google.batch.api

import com.google.cloud.batch.v1.AllocationPolicy.Accelerator
import com.google.cloud.batch.v1.{DeleteJobRequest, GetJobRequest, JobName}
import cromwell.backend.google.batch.models.GcpBatchConfigurationAttributes.GcsTransferConfiguration
import cromwell.backend.google.batch.models.GcpBatchRequest
import cromwell.backend.google.batch.runnable._
import cromwell.backend.google.batch.util.BatchUtilityConversions
import com.google.cloud.batch.v1.AllocationPolicy.{AttachedDisk, InstancePolicy, InstancePolicyOrTemplate, LocationPolicy, NetworkInterface, NetworkPolicy, ProvisioningModel}
import com.google.cloud.batch.v1.AllocationPolicy._
import com.google.cloud.batch.v1.LogsPolicy.Destination
import com.google.cloud.batch.v1.{AllocationPolicy, ComputeResource, CreateJobRequest, Job, LogsPolicy, Runnable, ServiceAccount, TaskGroup, TaskSpec, Volume}
import com.google.cloud.batch.v1.{AllocationPolicy, ComputeResource, CreateJobRequest, DeleteJobRequest, GetJobRequest, Job, JobName, LogsPolicy, Runnable, ServiceAccount, TaskGroup, TaskSpec, Volume}
import com.google.protobuf.Duration
import cromwell.backend.google.batch.io.GcpBatchAttachedDisk
import cromwell.backend.google.batch.models.VpcAndSubnetworkProjectLabelValues
import cromwell.backend.google.batch.models.GcpBatchConfigurationAttributes.GcsTransferConfiguration
import cromwell.backend.google.batch.models.{GcpBatchRequest, VpcAndSubnetworkProjectLabelValues}
import cromwell.backend.google.batch.runnable._
import cromwell.backend.google.batch.util.BatchUtilityConversions

import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -61,10 +58,11 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
.build
}

private def createInstancePolicy(cpuPlatform: String, spotModel: ProvisioningModel, accelerators: Option[Accelerator.Builder], attachedDisks: List[AttachedDisk]) = {
private def createInstancePolicy(cpuPlatform: String, spotModel: ProvisioningModel, accelerators: Option[Accelerator.Builder], attachedDisks: List[AttachedDisk]): InstancePolicy.Builder = {

//set GPU count to 0 if not included in workflow
val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType(""))
val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType("")) // TODO: Driver version

val instancePolicy = InstancePolicy
.newBuilder
.setProvisioningModel(spotModel)
Expand All @@ -83,7 +81,6 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe

}


private def createNetworkPolicy(networkInterface: NetworkInterface): NetworkPolicy = {
NetworkPolicy
.newBuilder
Expand Down Expand Up @@ -113,22 +110,29 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe

}

private def createAllocationPolicy(data: GcpBatchRequest, locationPolicy: LocationPolicy, instancePolicy: InstancePolicy, networkPolicy: NetworkPolicy, serviceAccount: ServiceAccount) = {
AllocationPolicy
private def createAllocationPolicy(data: GcpBatchRequest, locationPolicy: LocationPolicy, instancePolicy: InstancePolicy, networkPolicy: NetworkPolicy, serviceAccount: ServiceAccount, accelerators: Option[Accelerator.Builder]) = {

val allocationPolicy = AllocationPolicy
.newBuilder
.setLocation(locationPolicy)
.setNetwork(networkPolicy)
.putLabels("cromwell-workflow-id", toLabel(data.workflowId.toString)) //label for workflow from WDL
.putLabels("goog-batch-worker", "true")
.putAllLabels((data.createParameters.googleLabels.map(label => label.key -> label.value).toMap.asJava))
.setServiceAccount(serviceAccount)
.addInstances(InstancePolicyOrTemplate
.newBuilder
.setPolicy(instancePolicy)
.build)
.build
.buildPartial()

val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType(""))

//add GPUs if GPU count is greater than or equal to 1
if (gpuAccelerators.getCount >= 1) {
allocationPolicy.toBuilder.addInstances(InstancePolicyOrTemplate.newBuilder.setPolicy(instancePolicy).setInstallGpuDrivers(true).build)
} else {
allocationPolicy.toBuilder.addInstances(InstancePolicyOrTemplate.newBuilder.setPolicy(instancePolicy).build)
}
}


override def submitRequest(data: GcpBatchRequest): CreateJobRequest = {

val batchAttributes = data.gcpBatchParameters.batchAttributes
Expand Down Expand Up @@ -160,10 +164,6 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
// Batch defaults to 1 task
val taskCount: Long = 1

println(f"command script container path ${data.createParameters.commandScriptContainerPath}")
println(f"cloud workflow root ${data.createParameters.cloudWorkflowRoot}")
println(f"all parameters:\n ${data.createParameters.allParameters.mkString("\n")}")

// parse preemption value and set value for Spot. Spot is replacement for preemptible
val spotModel = toProvisioningModel(runtimeAttributes.preemptible)

Expand Down Expand Up @@ -205,11 +205,11 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
val taskGroup: TaskGroup = createTaskGroup(taskCount, taskSpec)
val instancePolicy = createInstancePolicy(cpuPlatform, spotModel, accelerators, allDisks)
val locationPolicy = LocationPolicy.newBuilder.addAllowedLocations(zones).build
val allocationPolicy = createAllocationPolicy(data, locationPolicy, instancePolicy.build, networkPolicy, gcpSa)
val allocationPolicy = createAllocationPolicy(data, locationPolicy, instancePolicy.build, networkPolicy, gcpSa, accelerators)
val job = Job
.newBuilder
.addTaskGroups(taskGroup)
.setAllocationPolicy(allocationPolicy)
.setAllocationPolicy(allocationPolicy.build())
.putLabels("submitter", "cromwell") // label to signify job submitted by cromwell for larger tracking purposes within GCP batch
.putLabels("goog-batch-worker", "true")
.putAllLabels((data.createParameters.googleLabels.map(label => label.key -> label.value).toMap.asJava))
Expand All @@ -218,9 +218,6 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
.setDestination(Destination.CLOUD_LOGGING)
.build)

println(f"job shell ${data.createParameters.jobShell}")
println(f"script container path ${data.createParameters.commandScriptContainerPath}")
println(f"labels ${data.createParameters.googleLabels}")

CreateJobRequest
.newBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import wom.values.{WomArray, WomBoolean, WomInteger, WomString, WomValue}

object GpuResource {

val DefaultNvidiaDriverVersion = "418.87.00"

final case class GpuType(name: String) {
override def toString: String = name
}
Expand Down Expand Up @@ -99,6 +101,9 @@ object GcpBatchRuntimeAttributes {
private def cpuPlatformValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[String] = cpuPlatformValidationInstance
private def gpuTypeValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[GpuType] = GpuTypeValidation.optional

val GpuDriverVersionKey = "nvidiaDriverVersion"
private def gpuDriverValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[String] = new StringRuntimeAttributesValidation(GpuDriverVersionKey).optional

private def gpuCountValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[Int Refined Positive] = GpuValidation.optional
private def gpuMinValidation(runtimeConfig: Option[Config]):OptionalRuntimeAttributesValidation[Int Refined Positive] = GpuValidation.optionalMin

Expand Down Expand Up @@ -159,6 +164,7 @@ object GcpBatchRuntimeAttributes {
StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
gpuCountValidation(runtimeConfig),
gpuTypeValidation(runtimeConfig),
gpuDriverValidation(runtimeConfig),
cpuValidation(runtimeConfig),
cpuPlatformValidation(runtimeConfig),
cpuMinValidation(runtimeConfig),
Expand Down Expand Up @@ -189,8 +195,9 @@ object GcpBatchRuntimeAttributes {
.extractOption(gpuTypeValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)
lazy val gpuCount: Option[Int Refined Positive] = RuntimeAttributesValidation
.extractOption(gpuCountValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)
lazy val gpuDriver: Option[String] = RuntimeAttributesValidation.extractOption(gpuDriverValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)

val gpuResource: Option[GpuResource] = if (gpuType.isDefined || gpuCount.isDefined) {
val gpuResource: Option[GpuResource] = if (gpuType.isDefined || gpuCount.isDefined || gpuDriver.isDefined) {
Option(GpuResource(gpuType.getOrElse(GpuType.DefaultGpuType), gpuCount
.getOrElse(GpuType.DefaultGpuCount)))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ trait UserRunnable {

def userRunnables(createParameters: CreateBatchJobParameters, volumes: List[Volume]): List[Runnable] = {

println(f"job shell ${createParameters.jobShell}")
println(f"script container path ${createParameters.commandScriptContainerPath}")

val userRunnable = RunnableBuilder.userRunnable(
docker = createParameters.dockerImage,
scriptContainerPath = createParameters.commandScriptContainerPath.pathAsString,
Expand Down

0 comments on commit aeacb3a

Please sign in to comment.