Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WX-1318 gcp batch: Add GPU driver install #7235

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -829,12 +829,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
_ <- evaluateRuntimeAttributes
_ <- uploadScriptFile()
customLabels <- Future.fromTry(GcpLabel.fromWorkflowOptions(workflowDescriptor.workflowOptions))
_ = customLabels.foreach(x => println(s"ZZZ Custom Labels - $x"))
batchParameters <- generateInputOutputParameters
_ = batchParameters.fileInputParameters.foreach(x => println(s"ZZZ File InputParameters - $x"))
_ = batchParameters.jobInputParameters.foreach(x => println(s"ZZZ InputParameters - $x"))
_ = batchParameters.fileOutputParameters.foreach(x => println(s"ZZZ File OutputParameters - $x"))
_ = batchParameters.jobOutputParameters.foreach(x => println(s"ZZZ OutputParameters - $x"))
createParameters = createBatchParameters(batchParameters, customLabels)
drsLocalizationManifestCloudPath = jobPaths.callExecutionRoot / GcpBatchJobPaths.DrsLocalizationManifestName
_ <- uploadDrsLocalizationManifest(createParameters, drsLocalizationManifestCloudPath)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package cromwell.backend.google.batch.api

import com.google.cloud.batch.v1.AllocationPolicy.Accelerator
import com.google.cloud.batch.v1.{DeleteJobRequest, GetJobRequest, JobName}
import cromwell.backend.google.batch.models.GcpBatchConfigurationAttributes.GcsTransferConfiguration
import cromwell.backend.google.batch.models.GcpBatchRequest
import cromwell.backend.google.batch.runnable._
import cromwell.backend.google.batch.util.BatchUtilityConversions
import com.google.cloud.batch.v1.AllocationPolicy.{AttachedDisk, InstancePolicy, InstancePolicyOrTemplate, LocationPolicy, NetworkInterface, NetworkPolicy, ProvisioningModel}
import com.google.cloud.batch.v1.AllocationPolicy._
import com.google.cloud.batch.v1.LogsPolicy.Destination
import com.google.cloud.batch.v1.{AllocationPolicy, ComputeResource, CreateJobRequest, Job, LogsPolicy, Runnable, ServiceAccount, TaskGroup, TaskSpec, Volume}
import com.google.cloud.batch.v1.{AllocationPolicy, ComputeResource, CreateJobRequest, DeleteJobRequest, GetJobRequest, Job, JobName, LogsPolicy, Runnable, ServiceAccount, TaskGroup, TaskSpec, Volume}
import com.google.protobuf.Duration
import cromwell.backend.google.batch.io.GcpBatchAttachedDisk
import cromwell.backend.google.batch.models.VpcAndSubnetworkProjectLabelValues
import cromwell.backend.google.batch.models.GcpBatchConfigurationAttributes.GcsTransferConfiguration
import cromwell.backend.google.batch.models.{GcpBatchRequest, VpcAndSubnetworkProjectLabelValues}
import cromwell.backend.google.batch.runnable._
import cromwell.backend.google.batch.util.BatchUtilityConversions

import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -61,10 +58,11 @@
.build
}

private def createInstancePolicy(cpuPlatform: String, spotModel: ProvisioningModel, accelerators: Option[Accelerator.Builder], attachedDisks: List[AttachedDisk]) = {
private def createInstancePolicy(cpuPlatform: String, spotModel: ProvisioningModel, accelerators: Option[Accelerator.Builder], attachedDisks: List[AttachedDisk]): InstancePolicy.Builder = {

//set GPU count to 0 if not included in workflow
val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType(""))
val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType("")) // TODO: Driver version

val instancePolicy = InstancePolicy
.newBuilder
.setProvisioningModel(spotModel)
Expand All @@ -83,7 +81,6 @@

}


private def createNetworkPolicy(networkInterface: NetworkInterface): NetworkPolicy = {
NetworkPolicy
.newBuilder
Expand Down Expand Up @@ -113,22 +110,29 @@

}

private def createAllocationPolicy(data: GcpBatchRequest, locationPolicy: LocationPolicy, instancePolicy: InstancePolicy, networkPolicy: NetworkPolicy, serviceAccount: ServiceAccount) = {
AllocationPolicy
private def createAllocationPolicy(data: GcpBatchRequest, locationPolicy: LocationPolicy, instancePolicy: InstancePolicy, networkPolicy: NetworkPolicy, serviceAccount: ServiceAccount, accelerators: Option[Accelerator.Builder]) = {

val allocationPolicy = AllocationPolicy
.newBuilder
.setLocation(locationPolicy)
.setNetwork(networkPolicy)
.putLabels("cromwell-workflow-id", toLabel(data.workflowId.toString)) //label for workflow from WDL
.putLabels("goog-batch-worker", "true")
.putAllLabels((data.createParameters.googleLabels.map(label => label.key -> label.value).toMap.asJava))
.setServiceAccount(serviceAccount)
.addInstances(InstancePolicyOrTemplate
.newBuilder
.setPolicy(instancePolicy)
.build)
.build
.buildPartial()

val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType(""))

//add GPUs if GPU count is greater than or equal to 1
if (gpuAccelerators.getCount >= 1) {
allocationPolicy.toBuilder.addInstances(InstancePolicyOrTemplate.newBuilder.setPolicy(instancePolicy).setInstallGpuDrivers(true).build)

Check warning on line 129 in supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/api/GcpBatchRequestFactoryImpl.scala

View check run for this annotation

Codecov / codecov/patch

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/api/GcpBatchRequestFactoryImpl.scala#L129

Added line #L129 was not covered by tests
} else {
allocationPolicy.toBuilder.addInstances(InstancePolicyOrTemplate.newBuilder.setPolicy(instancePolicy).build)
}
}


override def submitRequest(data: GcpBatchRequest): CreateJobRequest = {

val batchAttributes = data.gcpBatchParameters.batchAttributes
Expand Down Expand Up @@ -160,10 +164,6 @@
// Batch defaults to 1 task
val taskCount: Long = 1

println(f"command script container path ${data.createParameters.commandScriptContainerPath}")
println(f"cloud workflow root ${data.createParameters.cloudWorkflowRoot}")
println(f"all parameters:\n ${data.createParameters.allParameters.mkString("\n")}")

// parse preemption value and set value for Spot. Spot is replacement for preemptible
val spotModel = toProvisioningModel(runtimeAttributes.preemptible)

Expand Down Expand Up @@ -205,11 +205,11 @@
val taskGroup: TaskGroup = createTaskGroup(taskCount, taskSpec)
val instancePolicy = createInstancePolicy(cpuPlatform, spotModel, accelerators, allDisks)
val locationPolicy = LocationPolicy.newBuilder.addAllowedLocations(zones).build
val allocationPolicy = createAllocationPolicy(data, locationPolicy, instancePolicy.build, networkPolicy, gcpSa)
val allocationPolicy = createAllocationPolicy(data, locationPolicy, instancePolicy.build, networkPolicy, gcpSa, accelerators)
val job = Job
.newBuilder
.addTaskGroups(taskGroup)
.setAllocationPolicy(allocationPolicy)
.setAllocationPolicy(allocationPolicy.build())
.putLabels("submitter", "cromwell") // label to signify job submitted by cromwell for larger tracking purposes within GCP batch
.putLabels("goog-batch-worker", "true")
.putAllLabels((data.createParameters.googleLabels.map(label => label.key -> label.value).toMap.asJava))
Expand All @@ -218,9 +218,6 @@
.setDestination(Destination.CLOUD_LOGGING)
.build)

println(f"job shell ${data.createParameters.jobShell}")
println(f"script container path ${data.createParameters.commandScriptContainerPath}")
println(f"labels ${data.createParameters.googleLabels}")

CreateJobRequest
.newBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

object GpuResource {

val DefaultNvidiaDriverVersion = "418.87.00"

Check warning on line 22 in supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/models/GcpBatchRuntimeAttributes.scala

View check run for this annotation

Codecov / codecov/patch

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/models/GcpBatchRuntimeAttributes.scala#L22

Added line #L22 was not covered by tests

final case class GpuType(name: String) {
override def toString: String = name
}
Expand Down Expand Up @@ -99,6 +101,9 @@
private def cpuPlatformValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[String] = cpuPlatformValidationInstance
private def gpuTypeValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[GpuType] = GpuTypeValidation.optional

val GpuDriverVersionKey = "nvidiaDriverVersion"
private def gpuDriverValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[String] = new StringRuntimeAttributesValidation(GpuDriverVersionKey).optional

private def gpuCountValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[Int Refined Positive] = GpuValidation.optional
private def gpuMinValidation(runtimeConfig: Option[Config]):OptionalRuntimeAttributesValidation[Int Refined Positive] = GpuValidation.optionalMin

Expand Down Expand Up @@ -159,6 +164,7 @@
StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
gpuCountValidation(runtimeConfig),
gpuTypeValidation(runtimeConfig),
gpuDriverValidation(runtimeConfig),
cpuValidation(runtimeConfig),
cpuPlatformValidation(runtimeConfig),
cpuMinValidation(runtimeConfig),
Expand Down Expand Up @@ -189,8 +195,9 @@
.extractOption(gpuTypeValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)
lazy val gpuCount: Option[Int Refined Positive] = RuntimeAttributesValidation
.extractOption(gpuCountValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)
lazy val gpuDriver: Option[String] = RuntimeAttributesValidation.extractOption(gpuDriverValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)

val gpuResource: Option[GpuResource] = if (gpuType.isDefined || gpuCount.isDefined) {
val gpuResource: Option[GpuResource] = if (gpuType.isDefined || gpuCount.isDefined || gpuDriver.isDefined) {
Option(GpuResource(gpuType.getOrElse(GpuType.DefaultGpuType), gpuCount
.getOrElse(GpuType.DefaultGpuCount)))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ trait UserRunnable {

def userRunnables(createParameters: CreateBatchJobParameters, volumes: List[Volume]): List[Runnable] = {

println(f"job shell ${createParameters.jobShell}")
println(f"script container path ${createParameters.commandScriptContainerPath}")

val userRunnable = RunnableBuilder.userRunnable(
docker = createParameters.dockerImage,
scriptContainerPath = createParameters.commandScriptContainerPath.pathAsString,
Expand Down
Loading