Skip to content

Commit

Permalink
Add cloudprovider to the interface GetNodeClassGVK
Browse files Browse the repository at this point in the history
  • Loading branch information
engedaam committed Mar 28, 2024
1 parent c8eda9b commit 51a4027
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 9 deletions.
5 changes: 5 additions & 0 deletions kwok/cloudprovider/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"

Expand Down Expand Up @@ -118,6 +119,10 @@ func (c CloudProvider) Name() string {
return "kwok"
}

func (c CloudProvider) GetSupportedNodeClass() schema.GroupVersionKind {
return schema.GroupVersionKind{}
}

func (c CloudProvider) getInstanceType(instanceTypeName string) (*cloudprovider.InstanceType, error) {
it, found := lo.Find(c.instanceTypes, func(it *cloudprovider.InstanceType) bool {
return it.Name == instanceTypeName
Expand Down
2 changes: 1 addition & 1 deletion pkg/apis/v1beta1/nodepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ type NodePool struct {
// 1. A field changes its default value for an existing field that is already hashed
// 2. A field is added to the hash calculation with an already-set value
// 3. A field is removed from the hash calculations
const NodePoolHashVersion = "v1"
const NodePoolHashVersion = "v2"

func (in *NodePool) Hash() string {
return fmt.Sprint(lo.Must(hashstructure.Hash(in.Spec.Template, hashstructure.FormatV2, &hashstructure.HashOptions{
Expand Down
15 changes: 13 additions & 2 deletions pkg/cloudprovider/fake/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/sets"

"sigs.k8s.io/karpenter/pkg/apis/v1beta1"
Expand All @@ -51,8 +52,9 @@ type CloudProvider struct {
NextCreateErr error
DeleteCalls []*v1beta1.NodeClaim

CreatedNodeClaims map[string]*v1beta1.NodeClaim
Drifted cloudprovider.DriftReason
CreatedNodeClaims map[string]*v1beta1.NodeClaim
Drifted cloudprovider.DriftReason
NodeClassGroupVersionKind schema.GroupVersionKind
}

func NewCloudProvider() *CloudProvider {
Expand All @@ -77,6 +79,11 @@ func (c *CloudProvider) Reset() {
c.NextCreateErr = nil
c.DeleteCalls = []*v1beta1.NodeClaim{}
c.Drifted = "drifted"
c.NodeClassGroupVersionKind = schema.GroupVersionKind{
Group: "",
Version: "",
Kind: "",
}
}

func (c *CloudProvider) Create(ctx context.Context, nodeClaim *v1beta1.NodeClaim) (*v1beta1.NodeClaim, error) {
Expand Down Expand Up @@ -237,3 +244,7 @@ func (c *CloudProvider) IsDrifted(context.Context, *v1beta1.NodeClaim) (cloudpro
func (c *CloudProvider) Name() string {
return "fake"
}

func (c *CloudProvider) GetSupportedNodeClass() schema.GroupVersionKind {
return c.NodeClassGroupVersionKind
}
3 changes: 3 additions & 0 deletions pkg/cloudprovider/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/samber/lo"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime/schema"

"sigs.k8s.io/karpenter/pkg/apis/v1beta1"
"sigs.k8s.io/karpenter/pkg/scheduling"
Expand Down Expand Up @@ -55,6 +56,8 @@ type CloudProvider interface {
IsDrifted(context.Context, *v1beta1.NodeClaim) (DriftReason, error)
// Name returns the CloudProvider implementation name.
Name() string
// GetSupportedNodeClass returns the group, version, and kind of the CloudProvider NodeClass
GetSupportedNodeClass() schema.GroupVersionKind
}

// InstanceType describes the properties of a potential node (either concrete attributes of an instance of this type
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func NewControllers(
disruption.NewController(clock, kubeClient, p, cloudProvider, recorder, cluster, disruptionQueue),
provisioning.NewPodController(kubeClient, p, recorder),
provisioning.NewNodeController(kubeClient, p, recorder),
nodepoolhash.NewController(kubeClient),
nodepoolhash.NewController(kubeClient, cloudProvider),
informer.NewDaemonSetController(kubeClient, cluster),
informer.NewNodeController(kubeClient, cluster),
informer.NewPodController(kubeClient, cluster),
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/disruption/drift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ var _ = Describe("Drift", func() {
var nodePoolController controller.Controller
BeforeEach(func() {
cp.Drifted = ""
nodePoolController = hash.NewController(env.Client)
nodePoolController = hash.NewController(env.Client, cp)
nodePool = &v1beta1.NodePool{
ObjectMeta: nodePool.ObjectMeta,
Spec: v1beta1.NodePoolSpec{
Expand Down
21 changes: 18 additions & 3 deletions pkg/controllers/nodepool/hash/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/reconcile"

"sigs.k8s.io/karpenter/pkg/apis/v1beta1"
"sigs.k8s.io/karpenter/pkg/cloudprovider"
operatorcontroller "sigs.k8s.io/karpenter/pkg/operator/controller"
)

Expand All @@ -37,19 +38,33 @@ var _ operatorcontroller.TypedController[*v1beta1.NodePool] = (*Controller)(nil)
// Controller is hash controller that constructs a hash based on the fields that are considered for static drift.
// The hash is placed in the metadata for increased observability and should be found on each object.
type Controller struct {
kubeClient client.Client
kubeClient client.Client
cloudProvider cloudprovider.CloudProvider
}

func NewController(kubeClient client.Client) operatorcontroller.Controller {
func NewController(kubeClient client.Client, cloudProvider cloudprovider.CloudProvider) operatorcontroller.Controller {
return operatorcontroller.Typed[*v1beta1.NodePool](kubeClient, &Controller{
kubeClient: kubeClient,
kubeClient: kubeClient,
cloudProvider: cloudProvider,
})
}

// Reconcile the resource
func (c *Controller) Reconcile(ctx context.Context, np *v1beta1.NodePool) (reconcile.Result, error) {
stored := np.DeepCopy()

// To avoid a breaking change on the NodePool API, we will be setting a default APIVersion and Kind
// defined by the cloudprovider to each nodeClassRef on every NodePool. This will be removed once
// the NodePool API requires APIVersion and Kind to be set at NodePool creation.
// TODO: remove at v1 when APIVersion and Kind are required fields on NodePool
supportedNodeClass := c.cloudProvider.GetSupportedNodeClass()
if np.Spec.Template.Spec.NodeClassRef.APIVersion == "" {
np.Spec.Template.Spec.NodeClassRef.APIVersion = supportedNodeClass.GroupVersion().String()
}
if np.Spec.Template.Spec.NodeClassRef.Kind == "" {
np.Spec.Template.Spec.NodeClassRef.Kind = supportedNodeClass.Kind
}

if np.Annotations[v1beta1.NodePoolHashVersionAnnotationKey] != v1beta1.NodePoolHashVersion {
if err := c.updateNodeClaimHash(ctx, np); err != nil {
return reconcile.Result{}, err
Expand Down
44 changes: 43 additions & 1 deletion pkg/controllers/nodepool/hash/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ import (
"github.com/samber/lo"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/runtime/schema"
. "knative.dev/pkg/logging/testing"
"knative.dev/pkg/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"sigs.k8s.io/karpenter/pkg/apis"
"sigs.k8s.io/karpenter/pkg/cloudprovider/fake"
. "sigs.k8s.io/karpenter/pkg/test/expectations"

"sigs.k8s.io/karpenter/pkg/apis/v1beta1"
Expand All @@ -45,6 +47,7 @@ import (
var nodePoolController controller.Controller
var ctx context.Context
var env *test.Environment
var cp *fake.CloudProvider

func TestAPIs(t *testing.T) {
ctx = TestContextWithLogger(t)
Expand All @@ -54,7 +57,8 @@ func TestAPIs(t *testing.T) {

var _ = BeforeSuite(func() {
env = test.NewEnvironment(scheme.Scheme, test.WithCRDs(apis.CRDs...))
nodePoolController = hash.NewController(env.Client)
cp = fake.NewCloudProvider()
nodePoolController = hash.NewController(env.Client, cp)
})

var _ = AfterSuite(func() {
Expand All @@ -64,6 +68,7 @@ var _ = AfterSuite(func() {
var _ = Describe("Static Drift Hash", func() {
var nodePool *v1beta1.NodePool
BeforeEach(func() {
cp.Reset()
nodePool = test.NodePool(v1beta1.NodePool{
Spec: v1beta1.NodePoolSpec{
Template: v1beta1.NodeClaimTemplate{
Expand Down Expand Up @@ -239,4 +244,41 @@ var _ = Describe("Static Drift Hash", func() {
Expect(nodeClaim.Annotations).To(HaveKeyWithValue(v1beta1.NodePoolHashAnnotationKey, "123456"))
Expect(nodeClaim.Annotations).To(HaveKeyWithValue(v1beta1.NodePoolHashVersionAnnotationKey, v1beta1.NodePoolHashVersion))
})
Context("NodeClassRef Defaulting", func() {
BeforeEach(func() {
cp.NodeClassGroupVersionKind = schema.GroupVersionKind{
Group: "testgroup.sh",
Version: "v1test1",
Kind: "TestNodeClass",
}
})
It("should set a cloudprovider default apiversion on a nodeclassref when apiversion is not set", func() {
nodePool.Spec.Template.Spec.NodeClassRef.APIVersion = ""
ExpectApplied(ctx, env.Client, nodePool)
ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool))
nodePool = ExpectExists(ctx, env.Client, nodePool)
Expect(nodePool.Spec.Template.Spec.NodeClassRef.APIVersion).To(Equal(cp.NodeClassGroupVersionKind.GroupVersion().String()))
})
It("should not set a cloudprovider default apiversion on a nodeclassref when apiversion is set", func() {
nodePool.Spec.Template.Spec.NodeClassRef.APIVersion = "ExistingAPIVersion"
ExpectApplied(ctx, env.Client, nodePool)
ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool))
nodePool = ExpectExists(ctx, env.Client, nodePool)
Expect(nodePool.Spec.Template.Spec.NodeClassRef.APIVersion).To(Equal("ExistingAPIVersion"))
})
It("should set a cloudprovider default kind on a nodeclassref when kind is not set", func() {
nodePool.Spec.Template.Spec.NodeClassRef.Kind = ""
ExpectApplied(ctx, env.Client, nodePool)
ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool))
nodePool = ExpectExists(ctx, env.Client, nodePool)
Expect(nodePool.Spec.Template.Spec.NodeClassRef.Kind).To(Equal(cp.NodeClassGroupVersionKind.Kind))
})
It("should not set a cloudprovider default kind on a nodeclassref when kind is set", func() {
nodePool.Spec.Template.Spec.NodeClassRef.Kind = "ExistingKind"
ExpectApplied(ctx, env.Client, nodePool)
ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool))
nodePool = ExpectExists(ctx, env.Client, nodePool)
Expect(nodePool.Spec.Template.Spec.NodeClassRef.Kind).To(Equal("ExistingKind"))
})
})
})

0 comments on commit 51a4027

Please sign in to comment.