From 472c27c1651408cf7fd0c296e2f3820471865b9d Mon Sep 17 00:00:00 2001
From: kpharasi <kartik.230@gmail.com>
Date: Tue, 10 Dec 2024 15:33:10 -0800
Subject: [PATCH] Adding tests

---
 .../pkg/controller/admiral/deployment_test.go | 116 ++++++++++++++++++
 .../pkg/controller/admiral/rollouts_test.go   | 116 ++++++++++++++++++
 2 files changed, 232 insertions(+)

diff --git a/admiral/pkg/controller/admiral/deployment_test.go b/admiral/pkg/controller/admiral/deployment_test.go
index 115518725..20c0cf40d 100644
--- a/admiral/pkg/controller/admiral/deployment_test.go
+++ b/admiral/pkg/controller/admiral/deployment_test.go
@@ -323,6 +323,122 @@ func TestDeploymentControlle_DoesGenerationMatch(t *testing.T) {
 
 }
 
+func TestDeploymentController_IsOnlyReplicaCountChanged(t *testing.T) {
+	dc := DeploymentController{}
+	replicaNewCount := int32(1)
+	replicaOldCount := int32(2)
+
+	admiralParams := common.AdmiralParams{}
+
+	testCases := []struct {
+		name                                 string
+		deploymentNew                        interface{}
+		deploymentOld                        interface{}
+		enableIsOnlyReplicaCountChangedCheck bool
+		expectedValue                        bool
+		expectedError                        error
+	}{
+		{
+			name: "Given context, new deploy and old deploy object " +
+				"When new deploy is not of type *v1.Deployment " +
+				"Then func should return an error",
+			deploymentNew:                        struct{}{},
+			deploymentOld:                        struct{}{},
+			enableIsOnlyReplicaCountChangedCheck: true,
+			expectedError:                        fmt.Errorf("type assertion failed, {} is not of type *v1.Deployment"),
+		},
+		{
+			name: "Given context, new deploy and old deploy object " +
+				"When old deploy is not of type *v1.Deployment  " +
+				"Then func should return an error",
+			deploymentNew:                        struct{}{},
+			deploymentOld:                        struct{}{},
+			enableIsOnlyReplicaCountChangedCheck: true,
+			expectedError:                        fmt.Errorf("type assertion failed, {} is not of type *v1.Deployment"),
+		},
+		{
+			name: "Given context, new deploy and old deploy object " +
+				"When is replica count changed check is enabled " +
+				"And everything in the spec expect the count is the same " +
+				"Then func should return true ",
+			deploymentNew: &k8sAppsV1.Deployment{
+				Spec: k8sAppsV1.DeploymentSpec{
+					Replicas: &replicaNewCount,
+				},
+			},
+			deploymentOld: &k8sAppsV1.Deployment{
+				Spec: k8sAppsV1.DeploymentSpec{
+					Replicas: &replicaOldCount,
+				},
+			},
+			expectedValue:                        true,
+			expectedError:                        nil,
+			enableIsOnlyReplicaCountChangedCheck: true,
+		},
+		{
+			name: "Given context, new deploy and old deploy object " +
+				"When deploy is replica count changed check is disabled " +
+				"Then func should return false",
+			deploymentNew: &k8sAppsV1.Deployment{
+				Spec: k8sAppsV1.DeploymentSpec{
+					Replicas: &replicaNewCount,
+				},
+			},
+			deploymentOld: &k8sAppsV1.Deployment{
+				Spec: k8sAppsV1.DeploymentSpec{
+					Replicas: &replicaOldCount,
+				},
+			},
+			expectedValue:                        false,
+			expectedError:                        nil,
+			enableIsOnlyReplicaCountChangedCheck: false,
+		},
+		{
+			name: "Given context, new deploy and old deploy object " +
+				"When is replica count changed check is enabled " +
+				"And something in the spec expect the count is different " +
+				"Then func should return false ",
+			deploymentNew: &k8sAppsV1.Deployment{
+				Spec: k8sAppsV1.DeploymentSpec{
+					Replicas: &replicaNewCount,
+					Paused:   false,
+				},
+			},
+			deploymentOld: &k8sAppsV1.Deployment{
+				Spec: k8sAppsV1.DeploymentSpec{
+					Replicas: &replicaOldCount,
+					Paused:   true,
+				},
+			},
+			expectedValue:                        false,
+			enableIsOnlyReplicaCountChangedCheck: true,
+			expectedError:                        nil,
+		},
+	}
+
+	ctxLogger := log.WithFields(log.Fields{
+		"txId": "abc",
+	})
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			admiralParams.EnableIsOnlyReplicaCountChangedCheck = tc.enableIsOnlyReplicaCountChangedCheck
+			common.ResetSync()
+			common.InitializeConfig(admiralParams)
+			actual, err := dc.IsOnlyReplicaCountChanged(ctxLogger, tc.deploymentNew, tc.deploymentOld)
+			if !ErrorEqualOrSimilar(err, tc.expectedError) {
+				t.Errorf("expected: %v, got: %v", tc.expectedError, err)
+			}
+			if err == nil {
+				if tc.expectedValue != actual {
+					t.Errorf("expected: %v, got: %v", tc.expectedValue, actual)
+				}
+			}
+		})
+	}
+
+}
+
 func TestNewDeploymentController(t *testing.T) {
 	config, err := clientcmd.BuildConfigFromFlags("", "../../test/resources/admins@fake-cluster.k8s.local")
 	if err != nil {
diff --git a/admiral/pkg/controller/admiral/rollouts_test.go b/admiral/pkg/controller/admiral/rollouts_test.go
index 5ab8fa1e8..263908db5 100644
--- a/admiral/pkg/controller/admiral/rollouts_test.go
+++ b/admiral/pkg/controller/admiral/rollouts_test.go
@@ -148,6 +148,122 @@ func TestRolloutController_DoesGenerationMatch(t *testing.T) {
 
 }
 
+func TestRolloutController_IsOnlyReplicaCountChanged(t *testing.T) {
+	rc := RolloutController{}
+	replicaNewCount := int32(1)
+	replicaOldCount := int32(2)
+
+	admiralParams := common.AdmiralParams{}
+
+	testCases := []struct {
+		name                                 string
+		rolloutNew                           interface{}
+		rolloutOld                           interface{}
+		enableIsOnlyReplicaCountChangedCheck bool
+		expectedValue                        bool
+		expectedError                        error
+	}{
+		{
+			name: "Given context, new rollout and old rollout object " +
+				"When new rollout is not of type *argo.Rollout " +
+				"Then func should return an error",
+			rolloutNew:                           struct{}{},
+			rolloutOld:                           struct{}{},
+			enableIsOnlyReplicaCountChangedCheck: true,
+			expectedError:                        fmt.Errorf("type assertion failed, {} is not of type *argo.Rollout"),
+		},
+		{
+			name: "Given context, new rollout and old rollout object " +
+				"When old rollout is not of type *argo.Rollout " +
+				"Then func should return an error",
+			rolloutNew:                           &argo.Rollout{},
+			rolloutOld:                           struct{}{},
+			enableIsOnlyReplicaCountChangedCheck: true,
+			expectedError:                        fmt.Errorf("type assertion failed, {} is not of type *argo.Rollout"),
+		},
+		{
+			name: "Given context, new rollout and old rollout object " +
+				"When is replica count changed check is enabled " +
+				"And everything in the spec expect the count is the same " +
+				"Then func should return true ",
+			rolloutNew: &argo.Rollout{
+				Spec: argo.RolloutSpec{
+					Replicas: &replicaNewCount,
+				},
+			},
+			rolloutOld: &argo.Rollout{
+				Spec: argo.RolloutSpec{
+					Replicas: &replicaOldCount,
+				},
+			},
+			expectedValue:                        true,
+			expectedError:                        nil,
+			enableIsOnlyReplicaCountChangedCheck: true,
+		},
+		{
+			name: "Given context, new rollout and old rollout object " +
+				"When rollout is replica count changed check is disabled " +
+				"Then func should return false",
+			rolloutNew: &argo.Rollout{
+				Spec: argo.RolloutSpec{
+					Replicas: &replicaNewCount,
+				},
+			},
+			rolloutOld: &argo.Rollout{
+				Spec: argo.RolloutSpec{
+					Replicas: &replicaOldCount,
+				},
+			},
+			expectedValue:                        false,
+			expectedError:                        nil,
+			enableIsOnlyReplicaCountChangedCheck: false,
+		},
+		{
+			name: "Given context, new rollout and old rollout object " +
+				"When is replica count changed check is enabled " +
+				"And something in the spec expect the count is different " +
+				"Then func should return false ",
+			rolloutNew: &argo.Rollout{
+				Spec: argo.RolloutSpec{
+					Replicas: &replicaNewCount,
+					Paused:   false,
+				},
+			},
+			rolloutOld: &argo.Rollout{
+				Spec: argo.RolloutSpec{
+					Replicas: &replicaOldCount,
+					Paused:   true,
+				},
+			},
+			expectedValue:                        false,
+			enableIsOnlyReplicaCountChangedCheck: true,
+			expectedError:                        nil,
+		},
+	}
+
+	ctxLogger := log.WithFields(log.Fields{
+		"txId": "abc",
+	})
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			admiralParams.EnableIsOnlyReplicaCountChangedCheck = tc.enableIsOnlyReplicaCountChangedCheck
+			common.ResetSync()
+			common.InitializeConfig(admiralParams)
+			actual, err := rc.IsOnlyReplicaCountChanged(ctxLogger, tc.rolloutNew, tc.rolloutOld)
+			if !ErrorEqualOrSimilar(err, tc.expectedError) {
+				t.Errorf("expected: %v, got: %v", tc.expectedError, err)
+			}
+			if err == nil {
+				if tc.expectedValue != actual {
+					t.Errorf("expected: %v, got: %v", tc.expectedValue, actual)
+				}
+			}
+		})
+	}
+
+}
+
 func TestRolloutController_Added(t *testing.T) {
 	common.ResetSync()
 	admiralParams := common.AdmiralParams{