Skip to content

Commit

Permalink
add support for GPUs on GCP
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelStuchly committed Aug 31, 2021
1 parent a02074e commit 965906b
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 16 deletions.
36 changes: 23 additions & 13 deletions pkg/apis/gcpprovider/v1beta1/gcpmachineproviderconfig_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,26 @@ type GCPMachineProviderSpec struct {
// CredentialsSecret is a reference to the secret with GCP credentials.
CredentialsSecret *corev1.LocalObjectReference `json:"credentialsSecret,omitempty"`

CanIPForward bool `json:"canIPForward"`
DeletionProtection bool `json:"deletionProtection"`
Disks []*GCPDisk `json:"disks,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
Metadata []*GCPMetadata `json:"gcpMetadata,omitempty"`
NetworkInterfaces []*GCPNetworkInterface `json:"networkInterfaces,omitempty"`
ServiceAccounts []GCPServiceAccount `json:"serviceAccounts"`
Tags []string `json:"tags,omitempty"`
TargetPools []string `json:"targetPools,omitempty"`
MachineType string `json:"machineType"`
Region string `json:"region"`
Zone string `json:"zone"`
ProjectID string `json:"projectID,omitempty"`
CanIPForward bool `json:"canIPForward"`
DeletionProtection bool `json:"deletionProtection"`
Disks []*GCPDisk `json:"disks,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
Metadata []*GCPMetadata `json:"gcpMetadata,omitempty"`
NetworkInterfaces []*GCPNetworkInterface `json:"networkInterfaces,omitempty"`
ServiceAccounts []GCPServiceAccount `json:"serviceAccounts"`
Tags []string `json:"tags,omitempty"`
TargetPools []string `json:"targetPools,omitempty"`
MachineType string `json:"machineType"`
Region string `json:"region"`
Zone string `json:"zone"`
ProjectID string `json:"projectID,omitempty"`
GuestAccelerators []*GCPAcceleratorConfig `json:"guestAccelerators,omitempty"`

// Preemptible indicates if created instance is preemptible
Preemptible bool `json:"preemptible,omitempty"`

OnHostMaintenance string `json:"onHostMaintenance,omitempty"`
AutomaticRestart *bool `json:"automaticRestart,omitempty"`
}

// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
Expand Down Expand Up @@ -104,3 +108,9 @@ type GCPKMSKeyReference struct {
// Location is the GCP location in which the Key Ring exists.
Location string `json:"location"`
}

// GCPAcceleratorConfig describes type and count of accelerator cards attached to the instance on GCP.
type GCPAcceleratorConfig struct {
AcceleratorCount int64 `json:"acceleratorCount,omitempty"`
AcceleratorType string `json:"acceleratorType,omitempty"`
}
31 changes: 31 additions & 0 deletions pkg/apis/gcpprovider/v1beta1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 102 additions & 3 deletions pkg/cloud/gcp/actuators/machine/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ package machine
import (
"context"
"fmt"
"time"

"strconv"
"strings"
"time"

"github.com/openshift/cluster-api-provider-gcp/pkg/apis/gcpprovider/v1beta1"
machinev1 "github.com/openshift/machine-api-operator/pkg/apis/machine/v1beta1"
Expand Down Expand Up @@ -39,6 +38,91 @@ func newReconciler(scope *machineScope) *Reconciler {
}
}

var (
supportedGpuTypes = map[string]string{
"nvidia-tesla-k80": "NVIDIA_K80_GPUS",
"nvidia-tesla-p100": "NVIDIA_P100_GPUS",
"nvidia-tesla-v100": "NVIDIA_V100_GPUS",
"nvidia-tesla-a100": "NVIDIA_A100_GPUS",
"nvidia-tesla-p4": "NVIDIA_P4_GPUS",
"nvidia-tesla-t4": "NVIDIA_T4_GPUS",
}
)

func containsString(sli []string, str string) bool {
for _, elem := range sli {
if elem == str {
return true
}
}
return false
}

func (r *Reconciler) checkQuota(machineFamily int64) error {
region, err := r.computeService.RegionGet(r.projectID, r.providerSpec.Region)
if err != nil {
return fmt.Errorf("failed to get region via compute service: %v", err)
}
quotas := region.Quotas
var guestAccelerators = []*v1beta1.GCPAcceleratorConfig{}
if machineFamily != 0 {
guestAccelerators = append(guestAccelerators, &v1beta1.GCPAcceleratorConfig{AcceleratorType: "nvidia-tesla-a100", AcceleratorCount: machineFamily})
} else {
guestAccelerators = r.providerSpec.GuestAccelerators
}
// validate zone and then quota
for _, elem := range guestAccelerators {
_, err := r.computeService.AcceleratorTypesList(r.projectID, r.providerSpec.Zone, elem.AcceleratorType)
if err != nil {
return fmt.Errorf("AcceleratorType not available in the zone: %v", err)
}
metric := supportedGpuTypes[elem.AcceleratorType]
if metric == "" {
return machinecontroller.InvalidMachineConfiguration("Unsupported accelerator type")
}
// preemptible instances have separate quota
if r.providerSpec.Preemptible {
metric = "PREEMPTIBLE_" + metric
}
// check quota for GA
for i, q := range quotas {
if q.Metric == metric {
if int64(q.Usage)+elem.AcceleratorCount > int64(q.Limit) {
return machinecontroller.InvalidMachineConfiguration(fmt.Sprintf("Quota exceeded. Metric: %s. Usage: %v. Limit: %v.", metric, q.Usage, q.Limit))
}
break
}
if i == len(quotas)-1 {
return machinecontroller.InvalidMachineConfiguration(fmt.Sprintf("No quota found. Metric: %s.", metric))
}
}
}
return nil
}

func (r *Reconciler) validateGuestAccelerators() error {

a2MachineFamily, n1MachineFamily := r.computeService.MachineTypesList(r.providerSpec.ProjectID, r.providerSpec.Zone, r.Context)
machineType := r.providerSpec.MachineType
if a2MachineFamily[machineType] != 0 {
// a2 family machine - has fixed type and count of GPUs
if err := r.checkQuota(a2MachineFamily[machineType]); err != nil {
return err
} else {
return nil
}
} else if containsString(n1MachineFamily, machineType) {
// n1 family machine
if err := r.checkQuota(0); err != nil {
return err
} else {
return nil
}
} else {
return nil
}
}

// Create creates machine if and only if machine exists, handled by cluster-api
func (r *Reconciler) create() error {
if err := validateMachine(*r.machine, *r.providerSpec); err != nil {
Expand All @@ -56,10 +140,25 @@ func (r *Reconciler) create() error {
Items: r.providerSpec.Tags,
},
Scheduling: &compute.Scheduling{
Preemptible: r.providerSpec.Preemptible,
Preemptible: r.providerSpec.Preemptible,
AutomaticRestart: r.providerSpec.AutomaticRestart,
OnHostMaintenance: r.providerSpec.OnHostMaintenance,
},
}

var guestAccelerators = []*compute.AcceleratorConfig{}
for index, ga := range r.providerSpec.GuestAccelerators {
guestAccelerators = append(guestAccelerators, &compute.AcceleratorConfig{
AcceleratorType: fmt.Sprintf("zones/%s/acceleratorTypes/%s", zone, r.providerSpec.GuestAccelerators[index].AcceleratorType),
AcceleratorCount: ga.AcceleratorCount,
})
}
instance.GuestAccelerators = guestAccelerators

if err := r.validateGuestAccelerators(); err != nil {
return err
}

if instance.Labels == nil {
instance.Labels = map[string]string{}
}
Expand Down
36 changes: 36 additions & 0 deletions pkg/cloud/gcp/actuators/services/compute/computeservice.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package computeservice

import (
"context"
"log"
"strings"

"github.com/openshift/cluster-api-provider-gcp/pkg/cloud/gcp/actuators/util"
"github.com/openshift/cluster-api-provider-gcp/pkg/version"
"google.golang.org/api/compute/v1"
Expand All @@ -19,6 +23,9 @@ type GCPComputeService interface {
TargetPoolsAddInstance(project string, region string, name string, instance string) (*compute.Operation, error)
TargetPoolsRemoveInstance(project string, region string, name string, instance string) (*compute.Operation, error)
MachineTypesGet(project string, machineType string, zone string) (*compute.MachineType, error)
RegionGet(project string, region string) (*compute.Region, error)
MachineTypesList(project string, zone string, ctx context.Context) (map[string]int64, []string)
AcceleratorTypesList(project string, zone string, acceleratorType string) (*compute.AcceleratorType, error)
}

type computeService struct {
Expand Down Expand Up @@ -101,3 +108,32 @@ func (c *computeService) TargetPoolsRemoveInstance(project string, region string
func (c *computeService) MachineTypesGet(project string, zone string, machineType string) (*compute.MachineType, error) {
return c.service.MachineTypes.Get(project, zone, machineType).Do()
}

func (c *computeService) MachineTypesList(project string, zone string, ctx context.Context) (map[string]int64, []string) {
req := c.service.MachineTypes.List(project, zone)
var (
a2MachineFamily = map[string]int64{}
n1MachineFamily []string
)
if err := req.Pages(ctx, func(page *compute.MachineTypeList) error {
for _, machineType := range page.Items {
if strings.HasPrefix(machineType.Name, "a2") {
a2MachineFamily[machineType.Name] = machineType.Accelerators[0].GuestAcceleratorCount
} else if strings.HasPrefix(machineType.Name, "n1") {
n1MachineFamily = append(n1MachineFamily, machineType.Name)
}
}
return nil
}); err != nil {
log.Fatal(err)
}
return a2MachineFamily, n1MachineFamily
}

func (c *computeService) AcceleratorTypesList(project string, zone string, acceleratorType string) (*compute.AcceleratorType, error) {
return c.service.AcceleratorTypes.Get(project, zone, acceleratorType).Do()
}

func (c *computeService) RegionGet(project string, region string) (*compute.Region, error) {
return c.service.Regions.Get(project, region).Do()
}
16 changes: 16 additions & 0 deletions pkg/cloud/gcp/actuators/services/compute/computeservice_mock.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package computeservice

import (
"context"

compute "google.golang.org/api/compute/v1"
"google.golang.org/api/googleapi"
)
Expand Down Expand Up @@ -129,3 +131,17 @@ func MockBuilderFuncTypeNotFound(serviceAccountJSON string) (GCPComputeService,
}
return computeSvc, nil
}

// new code added sstuchly
func (c *GCPComputeServiceMock) RegionGet(project string, region string) (*compute.Region, error) {
return nil, nil
}

func (c *GCPComputeServiceMock) MachineTypesList(project string, zone string, ctx context.Context) (map[string]int64, []string) {
return nil, nil
}
func (c *GCPComputeServiceMock) AcceleratorTypesList(project string, zone string, acceleratorType string) (*compute.AcceleratorType, error) {
return nil, nil
}

/// end

0 comments on commit 965906b

Please sign in to comment.