From b3ef776eaa0fbec3349d60893562f40910809159 Mon Sep 17 00:00:00 2001 From: jacob Date: Fri, 17 Sep 2021 13:53:07 -0700 Subject: [PATCH] support for neuron --- pkg/cloudprovider/aws/ami.go | 2 +- pkg/cloudprovider/aws/instancetype.go | 18 ++++++++++++++++-- pkg/cloudprovider/aws/launchtemplate.go | 13 +++++++++---- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/pkg/cloudprovider/aws/ami.go b/pkg/cloudprovider/aws/ami.go index 13497c57e7ee..285976eb252d 100644 --- a/pkg/cloudprovider/aws/ami.go +++ b/pkg/cloudprovider/aws/ami.go @@ -54,7 +54,7 @@ func (p *AMIProvider) Get(ctx context.Context, constraints *Constraints, instanc if *constraints.Architecture == v1alpha3.ArchitectureArm64 { amiNameSuffix = "-arm64" } - if AnyNvidiaGPUs(instanceTypes) { + if NeedsGPUAmi(instanceTypes) { if amiNameSuffix != "" { return "", fmt.Errorf("no amazon-linux-2 ami available for both nvidia gpus and arm64 cpus") } diff --git a/pkg/cloudprovider/aws/instancetype.go b/pkg/cloudprovider/aws/instancetype.go index cf0f6ee94fe9..f08d01778f56 100644 --- a/pkg/cloudprovider/aws/instancetype.go +++ b/pkg/cloudprovider/aws/instancetype.go @@ -84,15 +84,29 @@ func (i *InstanceType) NvidiaGPUs() *resource.Quantity { return resources.Quantity(fmt.Sprint(count)) } -func AnyNvidiaGPUs(is []cloudprovider.InstanceType) bool { +func NeedsGPUAmi(is []cloudprovider.InstanceType) bool { for _, i := range is { - if !i.NvidiaGPUs().IsZero() { + if !i.NvidiaGPUs().IsZero() || !i.AWSNeurons().IsZero() { return true } } return false } +// NeedsDocker returns true if the instance type is unable to use +// conatinerd directly +func NeedsDocker(is []cloudprovider.InstanceType) bool { + for _, i := range is { + // This function can be removed once containerd support for + // Neurons is in the EKS Optimized AMI + if !i.AWSNeurons().IsZero() { + return true + } + } + return false + +} + func (i *InstanceType) AMDGPUs() *resource.Quantity { count := int64(0) if i.GpuInfo != nil { diff --git a/pkg/cloudprovider/aws/launchtemplate.go b/pkg/cloudprovider/aws/launchtemplate.go index c74c4303afad..f0280d71048d 100644 --- a/pkg/cloudprovider/aws/launchtemplate.go +++ b/pkg/cloudprovider/aws/launchtemplate.go @@ -93,7 +93,7 @@ func (p *LaunchTemplateProvider) Get(ctx context.Context, provisioner *v1alpha3. } // 3. Get userData for Node - userData, err := p.getUserData(ctx, provisioner, constraints) + userData, err := p.getUserData(ctx, provisioner, constraints, instanceTypes) if err != nil { return nil, err } @@ -192,13 +192,18 @@ func (p *LaunchTemplateProvider) getSecurityGroupIds(ctx context.Context, provis return securityGroupIds, nil } -func (p *LaunchTemplateProvider) getUserData(ctx context.Context, provisioner *v1alpha3.Provisioner, constraints *Constraints) (string, error) { +func (p *LaunchTemplateProvider) getUserData(ctx context.Context, provisioner *v1alpha3.Provisioner, constraints *Constraints, instanceTypes []cloudprovider.InstanceType) (string, error) { + var containerRuntimeArg string + if !NeedsDocker(instanceTypes) { + containerRuntimeArg = "--container-runtime containerd" + } + var userData bytes.Buffer userData.WriteString(fmt.Sprintf(`#!/bin/bash -/etc/eks/bootstrap.sh '%s' \ - --container-runtime containerd \ +/etc/eks/bootstrap.sh '%s' %s \ --apiserver-endpoint '%s'`, *provisioner.Spec.Cluster.Name, + containerRuntimeArg, provisioner.Spec.Cluster.Endpoint)) caBundle, err := provisioner.Spec.Cluster.GetCABundle(ctx)