diff --git a/controllers/nodenetworkconfigurationpolicy_controller.go b/controllers/nodenetworkconfigurationpolicy_controller.go index 1361069bc7..13ef4b8816 100644 --- a/controllers/nodenetworkconfigurationpolicy_controller.go +++ b/controllers/nodenetworkconfigurationpolicy_controller.go @@ -49,7 +49,6 @@ import ( enactmentconditions "github.com/nmstate/kubernetes-nmstate/pkg/enactmentstatus/conditions" "github.com/nmstate/kubernetes-nmstate/pkg/environment" nmstate "github.com/nmstate/kubernetes-nmstate/pkg/helper" - "github.com/nmstate/kubernetes-nmstate/pkg/node" "github.com/nmstate/kubernetes-nmstate/pkg/policyconditions" "github.com/nmstate/kubernetes-nmstate/pkg/selectors" "k8s.io/apimachinery/pkg/types" @@ -160,16 +159,12 @@ func (r *NodeNetworkConfigurationPolicyReconciler) initializeEnactment(policy nm }) } -func (r *NodeNetworkConfigurationPolicyReconciler) maxUnavailableNodeCount(policy *nmstatev1beta1.NodeNetworkConfigurationPolicy) (int, error) { - nmstateNodes, err := node.NodesRunningNmstate(r.Client) - if err != nil { - return 0, err - } +func (r *NodeNetworkConfigurationPolicyReconciler) maxUnavailableNodeCount(policy *nmstatev1beta1.NodeNetworkConfigurationPolicy, matchingNodes int) (int, error) { intOrPercent := intstr.FromString(DEFAULT_MAXUNAVAILABLE) if policy.Spec.MaxUnavailable != nil { intOrPercent = *policy.Spec.MaxUnavailable } - maxUnavailable, err := intstr.GetScaledValueFromIntOrPercent(&intOrPercent, len(nmstateNodes), true) + maxUnavailable, err := intstr.GetScaledValueFromIntOrPercent(&intOrPercent, matchingNodes, true) if err != nil { return 0, err } @@ -190,13 +185,13 @@ func (r *NodeNetworkConfigurationPolicyReconciler) enactmentsCountByPolicy(polic return enactmentCount, nil } -func (r *NodeNetworkConfigurationPolicyReconciler) incrementUnavailableNodeCount(policy *nmstatev1beta1.NodeNetworkConfigurationPolicy) error { +func (r *NodeNetworkConfigurationPolicyReconciler) incrementUnavailableNodeCount(policy *nmstatev1beta1.NodeNetworkConfigurationPolicy, matchingNodes int) error { policyKey := types.NamespacedName{Name: policy.GetName(), Namespace: policy.GetNamespace()} err := r.Client.Get(context.TODO(), policyKey, policy) if err != nil { return err } - maxUnavailable, err := r.maxUnavailableNodeCount(policy) + maxUnavailable, err := r.maxUnavailableNodeCount(policy, matchingNodes) if err != nil { return err } @@ -294,7 +289,7 @@ func (r *NodeNetworkConfigurationPolicyReconciler) Reconcile(ctx context.Context return ctrl.Result{}, nil } - err = r.incrementUnavailableNodeCount(instance) + err = r.incrementUnavailableNodeCount(instance, enactmentCount.Matching()) if err != nil { if apierrors.IsConflict(err) { return ctrl.Result{RequeueAfter: nodeRunningUpdateRetryTime}, err diff --git a/test/e2e/handler/nncp_parallel_test.go b/test/e2e/handler/nncp_parallel_test.go index 220dd760d3..65eb534111 100644 --- a/test/e2e/handler/nncp_parallel_test.go +++ b/test/e2e/handler/nncp_parallel_test.go @@ -26,6 +26,8 @@ func enactmentsInProgress(policy string) int { var _ = Describe("NNCP with maxUnavailable", func() { Context("when applying a policy to matching nodes", func() { + duration := 10 * time.Second + interval := 1 * time.Second BeforeEach(func() { By("Create a policy") updateDesiredState(linuxBrUp(bridge1)) @@ -41,15 +43,13 @@ var _ = Describe("NNCP with maxUnavailable", func() { It("[parallel] should be progressing on multiple nodes", func() { Eventually(func() int { return enactmentsInProgress(TestPolicy) - }).Should(BeNumerically("==", maxUnavailable)) + }, duration, interval).Should(BeNumerically("==", maxUnavailableNodeCount())) waitForAvailablePolicy(TestPolicy) }) It("[parallel] should never exceed maxUnavailable nodes", func() { - duration := 10 * time.Second - interval := 1 * time.Second Consistently(func() int { return enactmentsInProgress(TestPolicy) - }, duration, interval).Should(BeNumerically("<=", maxUnavailable)) + }, duration, interval).Should(BeNumerically("<=", maxUnavailableNodeCount())) waitForAvailablePolicy(TestPolicy) }) }) diff --git a/test/e2e/handler/upgrade_test.go b/test/e2e/handler/upgrade_test.go index 3978adab68..b54d058aa9 100644 --- a/test/e2e/handler/upgrade_test.go +++ b/test/e2e/handler/upgrade_test.go @@ -18,7 +18,7 @@ import ( var _ = Describe("NodeNetworkConfigurationPolicy upgrade", func() { Context("when v1alpha1 is populated", func() { BeforeEach(func() { - maxUnavailableIntOrString := intstr.FromInt(maxUnavailable) + maxUnavailableIntOrString := intstr.FromString(maxUnavailable) policy := nmstatev1alpha1.NodeNetworkConfigurationPolicy{ ObjectMeta: metav1.ObjectMeta{ Name: TestPolicy, diff --git a/test/e2e/handler/utils.go b/test/e2e/handler/utils.go index c430a21916..f6e3b964c1 100644 --- a/test/e2e/handler/utils.go +++ b/test/e2e/handler/utils.go @@ -37,7 +37,7 @@ const TestPolicy = "test-policy" var ( bridgeCounter = 0 bondConunter = 0 - maxUnavailable = environment.GetIntVarWithDefault("NMSTATE_MAX_UNAVAILABLE", 1) + maxUnavailable = environment.GetVarWithDefault("NMSTATE_MAX_UNAVAILABLE", "50%") ) func interfacesName(interfaces []interface{}) []string { @@ -70,7 +70,7 @@ func setDesiredStateWithPolicyAndNodeSelector(name string, desiredState nmstate. err := testenv.Client.Get(context.TODO(), key, &policy) policy.Spec.DesiredState = desiredState policy.Spec.NodeSelector = nodeSelector - maxUnavailableIntOrString := intstr.FromInt(maxUnavailable) + maxUnavailableIntOrString := intstr.FromString(maxUnavailable) policy.Spec.MaxUnavailable = &maxUnavailableIntOrString if err != nil { if apierrors.IsNotFound(err) { @@ -539,3 +539,15 @@ func skipIfNotKubernetes() { Skip("Tutorials use interface naming that is available only on Kubernetes providers") } } + +func maxUnavailableNodeCount() int { + intOrPercent := intstr.FromString(maxUnavailable) + maxUnavailableScaled, err := intstr.GetScaledValueFromIntOrPercent(&intOrPercent, len(nodes), true) + if err != nil { + return 0 + } + if maxUnavailableScaled < 1 { + maxUnavailableScaled = 1 + } + return maxUnavailableScaled +} diff --git a/test/environment/environment.go b/test/environment/environment.go index c743090f12..3537a54cb2 100644 --- a/test/environment/environment.go +++ b/test/environment/environment.go @@ -2,7 +2,6 @@ package environment import ( "os" - "strconv" ) func GetVarWithDefault(name string, defaultValue string) string { @@ -12,12 +11,3 @@ func GetVarWithDefault(name string, defaultValue string) string { } return value } - -func GetIntVarWithDefault(name string, defaultValue int) int { - value := os.Getenv(name) - intValue, err := strconv.Atoi(value) - if err != nil { - intValue = defaultValue - } - return intValue -}