diff --git a/kwok/cloudprovider/cloudprovider.go b/kwok/cloudprovider/cloudprovider.go index 47d70bf947..5e100bb916 100644 --- a/kwok/cloudprovider/cloudprovider.go +++ b/kwok/cloudprovider/cloudprovider.go @@ -29,6 +29,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" @@ -118,6 +119,10 @@ func (c CloudProvider) Name() string { return "kwok" } +func (c CloudProvider) GetSupportedNodeClasses() []schema.GroupVersionKind { + return []schema.GroupVersionKind{} +} + func (c CloudProvider) getInstanceType(instanceTypeName string) (*cloudprovider.InstanceType, error) { it, found := lo.Find(c.instanceTypes, func(it *cloudprovider.InstanceType) bool { return it.Name == instanceTypeName diff --git a/pkg/apis/v1beta1/nodepool.go b/pkg/apis/v1beta1/nodepool.go index 79109d586d..68fb3ceb79 100644 --- a/pkg/apis/v1beta1/nodepool.go +++ b/pkg/apis/v1beta1/nodepool.go @@ -194,7 +194,7 @@ type NodePool struct { // 1. A field changes its default value for an existing field that is already hashed // 2. A field is added to the hash calculation with an already-set value // 3. A field is removed from the hash calculations -const NodePoolHashVersion = "v1" +const NodePoolHashVersion = "v2" func (in *NodePool) Hash() string { return fmt.Sprint(lo.Must(hashstructure.Hash(in.Spec.Template, hashstructure.FormatV2, &hashstructure.HashOptions{ diff --git a/pkg/cloudprovider/fake/cloudprovider.go b/pkg/cloudprovider/fake/cloudprovider.go index 9f81588317..2e0ab1155c 100644 --- a/pkg/cloudprovider/fake/cloudprovider.go +++ b/pkg/cloudprovider/fake/cloudprovider.go @@ -27,6 +27,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/sets" "sigs.k8s.io/karpenter/pkg/apis/v1beta1" @@ -51,8 +52,9 @@ type CloudProvider struct { NextCreateErr error DeleteCalls []*v1beta1.NodeClaim - CreatedNodeClaims map[string]*v1beta1.NodeClaim - Drifted cloudprovider.DriftReason + CreatedNodeClaims map[string]*v1beta1.NodeClaim + Drifted cloudprovider.DriftReason + NodeClassGroupVersionKind []schema.GroupVersionKind } func NewCloudProvider() *CloudProvider { @@ -77,6 +79,13 @@ func (c *CloudProvider) Reset() { c.NextCreateErr = nil c.DeleteCalls = []*v1beta1.NodeClaim{} c.Drifted = "drifted" + c.NodeClassGroupVersionKind = []schema.GroupVersionKind{ + { + Group: "", + Version: "", + Kind: "", + }, + } } func (c *CloudProvider) Create(ctx context.Context, nodeClaim *v1beta1.NodeClaim) (*v1beta1.NodeClaim, error) { @@ -237,3 +246,7 @@ func (c *CloudProvider) IsDrifted(context.Context, *v1beta1.NodeClaim) (cloudpro func (c *CloudProvider) Name() string { return "fake" } + +func (c *CloudProvider) GetSupportedNodeClasses() []schema.GroupVersionKind { + return c.NodeClassGroupVersionKind +} diff --git a/pkg/cloudprovider/types.go b/pkg/cloudprovider/types.go index dde363ca32..ee1fa1f496 100644 --- a/pkg/cloudprovider/types.go +++ b/pkg/cloudprovider/types.go @@ -26,6 +26,7 @@ import ( "github.com/samber/lo" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime/schema" "sigs.k8s.io/karpenter/pkg/apis/v1beta1" "sigs.k8s.io/karpenter/pkg/scheduling" @@ -55,6 +56,8 @@ type CloudProvider interface { IsDrifted(context.Context, *v1beta1.NodeClaim) (DriftReason, error) // Name returns the CloudProvider implementation name. Name() string + // GetSupportedNodeClass returns the group, version, and kind of the CloudProvider NodeClass + GetSupportedNodeClasses() []schema.GroupVersionKind } // InstanceType describes the properties of a potential node (either concrete attributes of an instance of this type diff --git a/pkg/controllers/controllers.go b/pkg/controllers/controllers.go index 4f46a32c8c..de6cccb14b 100644 --- a/pkg/controllers/controllers.go +++ b/pkg/controllers/controllers.go @@ -60,7 +60,7 @@ func NewControllers( disruption.NewController(clock, kubeClient, p, cloudProvider, recorder, cluster, disruptionQueue), provisioning.NewPodController(kubeClient, p, recorder), provisioning.NewNodeController(kubeClient, p, recorder), - nodepoolhash.NewController(kubeClient), + nodepoolhash.NewController(kubeClient, cloudProvider), informer.NewDaemonSetController(kubeClient, cluster), informer.NewNodeController(kubeClient, cluster), informer.NewPodController(kubeClient, cluster), diff --git a/pkg/controllers/nodeclaim/disruption/drift_test.go b/pkg/controllers/nodeclaim/disruption/drift_test.go index cea357668b..3420050fce 100644 --- a/pkg/controllers/nodeclaim/disruption/drift_test.go +++ b/pkg/controllers/nodeclaim/disruption/drift_test.go @@ -431,7 +431,7 @@ var _ = Describe("Drift", func() { var nodePoolController controller.Controller BeforeEach(func() { cp.Drifted = "" - nodePoolController = hash.NewController(env.Client) + nodePoolController = hash.NewController(env.Client, cp) nodePool = &v1beta1.NodePool{ ObjectMeta: nodePool.ObjectMeta, Spec: v1beta1.NodePoolSpec{ diff --git a/pkg/controllers/nodepool/hash/controller.go b/pkg/controllers/nodepool/hash/controller.go index cc8ecf60d7..4f1fb568e2 100644 --- a/pkg/controllers/nodepool/hash/controller.go +++ b/pkg/controllers/nodepool/hash/controller.go @@ -29,6 +29,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/karpenter/pkg/apis/v1beta1" + "sigs.k8s.io/karpenter/pkg/cloudprovider" operatorcontroller "sigs.k8s.io/karpenter/pkg/operator/controller" ) @@ -37,12 +38,14 @@ var _ operatorcontroller.TypedController[*v1beta1.NodePool] = (*Controller)(nil) // Controller is hash controller that constructs a hash based on the fields that are considered for static drift. // The hash is placed in the metadata for increased observability and should be found on each object. type Controller struct { - kubeClient client.Client + kubeClient client.Client + cloudProvider cloudprovider.CloudProvider } -func NewController(kubeClient client.Client) operatorcontroller.Controller { +func NewController(kubeClient client.Client, cloudProvider cloudprovider.CloudProvider) operatorcontroller.Controller { return operatorcontroller.Typed[*v1beta1.NodePool](kubeClient, &Controller{ - kubeClient: kubeClient, + kubeClient: kubeClient, + cloudProvider: cloudProvider, }) } @@ -50,6 +53,20 @@ func NewController(kubeClient client.Client) operatorcontroller.Controller { func (c *Controller) Reconcile(ctx context.Context, np *v1beta1.NodePool) (reconcile.Result, error) { stored := np.DeepCopy() + // To avoid a breaking change on the NodePool API, we will be setting a default APIVersion and Kind + // defined by the cloudprovider to each nodeClassRef on every NodePool. This will be removed once + // the NodePool API requires APIVersion and Kind to be set at NodePool creation. + // TODO: remove at v1 when APIVersion and Kind are required fields on NodePool + supportedNodeClass := c.cloudProvider.GetSupportedNodeClasses() + if len(supportedNodeClass) == 1 { + if np.Spec.Template.Spec.NodeClassRef.APIVersion == "" { + np.Spec.Template.Spec.NodeClassRef.APIVersion = supportedNodeClass[0].GroupVersion().String() + } + if np.Spec.Template.Spec.NodeClassRef.Kind == "" { + np.Spec.Template.Spec.NodeClassRef.Kind = supportedNodeClass[0].Kind + } + } + if np.Annotations[v1beta1.NodePoolHashVersionAnnotationKey] != v1beta1.NodePoolHashVersion { if err := c.updateNodeClaimHash(ctx, np); err != nil { return reconcile.Result{}, err diff --git a/pkg/controllers/nodepool/hash/suite_test.go b/pkg/controllers/nodepool/hash/suite_test.go index dcbad90cca..8cade4b7d3 100644 --- a/pkg/controllers/nodepool/hash/suite_test.go +++ b/pkg/controllers/nodepool/hash/suite_test.go @@ -26,6 +26,7 @@ import ( "github.com/samber/lo" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/runtime/schema" . "knative.dev/pkg/logging/testing" "knative.dev/pkg/ptr" "sigs.k8s.io/controller-runtime/pkg/client" @@ -33,6 +34,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/karpenter/pkg/apis" + "sigs.k8s.io/karpenter/pkg/cloudprovider/fake" . "sigs.k8s.io/karpenter/pkg/test/expectations" "sigs.k8s.io/karpenter/pkg/apis/v1beta1" @@ -45,6 +47,7 @@ import ( var nodePoolController controller.Controller var ctx context.Context var env *test.Environment +var cp *fake.CloudProvider func TestAPIs(t *testing.T) { ctx = TestContextWithLogger(t) @@ -54,7 +57,8 @@ func TestAPIs(t *testing.T) { var _ = BeforeSuite(func() { env = test.NewEnvironment(scheme.Scheme, test.WithCRDs(apis.CRDs...)) - nodePoolController = hash.NewController(env.Client) + cp = fake.NewCloudProvider() + nodePoolController = hash.NewController(env.Client, cp) }) var _ = AfterSuite(func() { @@ -64,6 +68,7 @@ var _ = AfterSuite(func() { var _ = Describe("Static Drift Hash", func() { var nodePool *v1beta1.NodePool BeforeEach(func() { + cp.Reset() nodePool = test.NodePool(v1beta1.NodePool{ Spec: v1beta1.NodePoolSpec{ Template: v1beta1.NodeClaimTemplate{ @@ -239,4 +244,62 @@ var _ = Describe("Static Drift Hash", func() { Expect(nodeClaim.Annotations).To(HaveKeyWithValue(v1beta1.NodePoolHashAnnotationKey, "123456")) Expect(nodeClaim.Annotations).To(HaveKeyWithValue(v1beta1.NodePoolHashVersionAnnotationKey, v1beta1.NodePoolHashVersion)) }) + Context("NodeClassRef Defaulting", func() { + BeforeEach(func() { + cp.NodeClassGroupVersionKind = []schema.GroupVersionKind{ + { + Group: "testgroup.sh", + Version: "v1test1", + Kind: "TestNodeClass", + }, + } + }) + It("should set a cloudprovider default apiversion on a nodeclassref when apiversion is not set", func() { + nodePool.Spec.Template.Spec.NodeClassRef.APIVersion = "" + ExpectApplied(ctx, env.Client, nodePool) + ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool)) + nodePool = ExpectExists(ctx, env.Client, nodePool) + Expect(nodePool.Spec.Template.Spec.NodeClassRef.APIVersion).To(Equal(cp.NodeClassGroupVersionKind[0].GroupVersion().String())) + }) + It("should not set a cloudprovider default apiversion on a nodeclassref when apiversion is set", func() { + nodePool.Spec.Template.Spec.NodeClassRef.APIVersion = "ExistingAPIVersion" + ExpectApplied(ctx, env.Client, nodePool) + ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool)) + nodePool = ExpectExists(ctx, env.Client, nodePool) + Expect(nodePool.Spec.Template.Spec.NodeClassRef.APIVersion).To(Equal("ExistingAPIVersion")) + }) + It("should set a cloudprovider default kind on a nodeclassref when kind is not set", func() { + nodePool.Spec.Template.Spec.NodeClassRef.Kind = "" + ExpectApplied(ctx, env.Client, nodePool) + ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool)) + nodePool = ExpectExists(ctx, env.Client, nodePool) + Expect(nodePool.Spec.Template.Spec.NodeClassRef.Kind).To(Equal(cp.NodeClassGroupVersionKind[0].Kind)) + }) + It("should not set a cloudprovider default kind on a nodeclassref when kind is set", func() { + nodePool.Spec.Template.Spec.NodeClassRef.Kind = "ExistingKind" + ExpectApplied(ctx, env.Client, nodePool) + ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool)) + nodePool = ExpectExists(ctx, env.Client, nodePool) + Expect(nodePool.Spec.Template.Spec.NodeClassRef.Kind).To(Equal("ExistingKind")) + }) + It("should not set cloudprovider default for kind and apiversion if there is more than one supported node class", func() { + cp.NodeClassGroupVersionKind = []schema.GroupVersionKind{ + { + Group: "testgroup.sh", + Version: "v1test1", + Kind: "TestNodeClass", + }, + { + Group: "testgroup2.sh", + Version: "v1test2", + Kind: "TestNodeClass2", + }, + } + ExpectApplied(ctx, env.Client, nodePool) + ExpectReconcileSucceeded(ctx, nodePoolController, client.ObjectKeyFromObject(nodePool)) + nodePool = ExpectExists(ctx, env.Client, nodePool) + Expect(nodePool.Spec.Template.Spec.NodeClassRef.APIVersion).To(Equal("")) + Expect(nodePool.Spec.Template.Spec.NodeClassRef.Kind).To(Equal("")) + }) + }) })