Skip to content

Commit

Permalink
support for neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
jacob committed Sep 17, 2021
1 parent 3839bee commit b3ef776
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pkg/cloudprovider/aws/ami.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (p *AMIProvider) Get(ctx context.Context, constraints *Constraints, instanc
if *constraints.Architecture == v1alpha3.ArchitectureArm64 {
amiNameSuffix = "-arm64"
}
if AnyNvidiaGPUs(instanceTypes) {
if NeedsGPUAmi(instanceTypes) {
if amiNameSuffix != "" {
return "", fmt.Errorf("no amazon-linux-2 ami available for both nvidia gpus and arm64 cpus")
}
Expand Down
18 changes: 16 additions & 2 deletions pkg/cloudprovider/aws/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,29 @@ func (i *InstanceType) NvidiaGPUs() *resource.Quantity {
return resources.Quantity(fmt.Sprint(count))
}

func AnyNvidiaGPUs(is []cloudprovider.InstanceType) bool {
func NeedsGPUAmi(is []cloudprovider.InstanceType) bool {
for _, i := range is {
if !i.NvidiaGPUs().IsZero() {
if !i.NvidiaGPUs().IsZero() || !i.AWSNeurons().IsZero() {
return true
}
}
return false
}

// NeedsDocker returns true if the instance type is unable to use
// conatinerd directly
func NeedsDocker(is []cloudprovider.InstanceType) bool {
for _, i := range is {
// This function can be removed once containerd support for
// Neurons is in the EKS Optimized AMI
if !i.AWSNeurons().IsZero() {
return true
}
}
return false

}

func (i *InstanceType) AMDGPUs() *resource.Quantity {
count := int64(0)
if i.GpuInfo != nil {
Expand Down
13 changes: 9 additions & 4 deletions pkg/cloudprovider/aws/launchtemplate.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (p *LaunchTemplateProvider) Get(ctx context.Context, provisioner *v1alpha3.
}

// 3. Get userData for Node
userData, err := p.getUserData(ctx, provisioner, constraints)
userData, err := p.getUserData(ctx, provisioner, constraints, instanceTypes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -192,13 +192,18 @@ func (p *LaunchTemplateProvider) getSecurityGroupIds(ctx context.Context, provis
return securityGroupIds, nil
}

func (p *LaunchTemplateProvider) getUserData(ctx context.Context, provisioner *v1alpha3.Provisioner, constraints *Constraints) (string, error) {
func (p *LaunchTemplateProvider) getUserData(ctx context.Context, provisioner *v1alpha3.Provisioner, constraints *Constraints, instanceTypes []cloudprovider.InstanceType) (string, error) {
var containerRuntimeArg string
if !NeedsDocker(instanceTypes) {
containerRuntimeArg = "--container-runtime containerd"
}

var userData bytes.Buffer
userData.WriteString(fmt.Sprintf(`#!/bin/bash
/etc/eks/bootstrap.sh '%s' \
--container-runtime containerd \
/etc/eks/bootstrap.sh '%s' %s \
--apiserver-endpoint '%s'`,
*provisioner.Spec.Cluster.Name,
containerRuntimeArg,
provisioner.Spec.Cluster.Endpoint))

caBundle, err := provisioner.Spec.Cluster.GetCABundle(ctx)
Expand Down

0 comments on commit b3ef776

Please sign in to comment.