diff --git a/controllers/helpers.go b/controllers/helpers.go index ede1634f3a..b6574cbd69 100644 --- a/controllers/helpers.go +++ b/controllers/helpers.go @@ -343,7 +343,7 @@ func GetImageUUID(ctx context.Context, client *nutanixClientV3.Client, imageName // HasTaskInProgress returns true if the given task is in progress func HasTaskInProgress(ctx context.Context, client *nutanixClientV3.Client, taskUUID string) (bool, error) { log := ctrl.LoggerFrom(ctx) - taskStatus, err := nutanixClientHelper.GetTaskState(ctx, client, taskUUID) + taskStatus, err := nutanixClientHelper.GetTaskStatus(ctx, client, taskUUID) if err != nil { return false, err } diff --git a/controllers/nutanixmachine_controller.go b/controllers/nutanixmachine_controller.go index c27404711d..124aed6a5a 100644 --- a/controllers/nutanixmachine_controller.go +++ b/controllers/nutanixmachine_controller.go @@ -740,7 +740,7 @@ func (r *NutanixMachineReconciler) getOrCreateVM(rctx *nctx.MachineContext) (*nu return nil, errorMsg } log.Info(fmt.Sprintf("Waiting for task %s to get completed for VM %s", lastTaskUUID, rctx.NutanixMachine.Name)) - err = nutanixClient.WaitForTaskCompletion(ctx, nc, lastTaskUUID) + err = nutanixClient.WaitForTaskToSucceed(ctx, nc, lastTaskUUID) if err != nil { errorMsg := fmt.Errorf("error occurred while waiting for task %s to start: %v", lastTaskUUID, err) rctx.SetFailureStatus(capierrors.CreateMachineError, errorMsg) diff --git a/pkg/client/state.go b/pkg/client/state.go deleted file mode 100644 index b3da3f9349..0000000000 --- a/pkg/client/state.go +++ /dev/null @@ -1,124 +0,0 @@ -/* -Copyright 2022 Nutanix - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package client - -import ( - "context" - "fmt" - "math" - "time" - - ctrl "sigs.k8s.io/controller-runtime" - - "github.com/nutanix-cloud-native/prism-go-client/utils" - nutanixClientV3 "github.com/nutanix-cloud-native/prism-go-client/v3" -) - -type stateRefreshFunc func() (string, error) - -func WaitForTaskCompletion(ctx context.Context, conn *nutanixClientV3.Client, uuid string) error { - errCh := make(chan error, 1) - go waitForState( - errCh, - "SUCCEEDED", - waitUntilTaskStateFunc(ctx, conn, uuid)) - - err := <-errCh - return err -} - -func waitForState(errCh chan<- error, target string, refresh stateRefreshFunc) { - err := Retry(2, 2, 0, func(_ uint) (bool, error) { - state, err := refresh() - if err != nil { - return false, err - } else if state == target { - return true, nil - } - return false, nil - }) - errCh <- err -} - -func waitUntilTaskStateFunc(ctx context.Context, conn *nutanixClientV3.Client, uuid string) stateRefreshFunc { - return func() (string, error) { - return GetTaskState(ctx, conn, uuid) - } -} - -func GetTaskState(ctx context.Context, client *nutanixClientV3.Client, taskUUID string) (string, error) { - log := ctrl.LoggerFrom(ctx) - log.V(1).Info(fmt.Sprintf("Getting task with UUID %s", taskUUID)) - v, err := client.V3.GetTask(ctx, taskUUID) - if err != nil { - log.Error(err, fmt.Sprintf("error occurred while waiting for task with UUID %s", taskUUID)) - return "", err - } - - if *v.Status == "INVALID_UUID" || *v.Status == "FAILED" { - return *v.Status, - fmt.Errorf("error_detail: %s, progress_message: %s", utils.StringValue(v.ErrorDetail), utils.StringValue(v.ProgressMessage)) - } - taskStatus := *v.Status - log.V(1).Info(fmt.Sprintf("Status for task with UUID %s: %s", taskUUID, taskStatus)) - return taskStatus, nil -} - -// RetryableFunc performs an action and returns a bool indicating whether the -// function is done, or if it should keep retrying, and an error which will -// abort the retry and be returned by the Retry function. The 0-indexed attempt -// is passed with each call. -type RetryableFunc func(uint) (bool, error) - -/* -Retry retries a function up to numTries times with exponential backoff. -If numTries == 0, retry indefinitely. -If interval == 0, Retry will not delay retrying and there will be no -exponential backoff. -If maxInterval == 0, maxInterval is set to +Infinity. -Intervals are in seconds. -Returns an error if initial > max intervals, if retries are exhausted, or if the passed function returns -an error. -*/ -func Retry(initialInterval float64, maxInterval float64, numTries uint, function RetryableFunc) error { - if maxInterval == 0 { - maxInterval = math.Inf(1) - } else if initialInterval < 0 || initialInterval > maxInterval { - return fmt.Errorf("invalid retry intervals (negative or initial < max). Initial: %f, Max: %f", initialInterval, maxInterval) - } - - var err error - done := false - interval := initialInterval - for i := uint(0); !done && (numTries == 0 || i < numTries); i++ { - done, err = function(i) - if err != nil { - return err - } - - if !done { - // Retry after delay. Calculate next delay. - time.Sleep(time.Duration(interval) * time.Second) - interval = math.Min(interval*2, maxInterval) - } - } - - if !done { - return fmt.Errorf("function never succeeded in Retry") - } - return nil -} diff --git a/pkg/client/status.go b/pkg/client/status.go new file mode 100644 index 0000000000..1bdc58e21c --- /dev/null +++ b/pkg/client/status.go @@ -0,0 +1,61 @@ +/* +Copyright 2022 Nutanix + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "fmt" + "time" + + "github.com/nutanix-cloud-native/prism-go-client/utils" + nutanixClientV3 "github.com/nutanix-cloud-native/prism-go-client/v3" + "k8s.io/apimachinery/pkg/util/wait" + ctrl "sigs.k8s.io/controller-runtime" +) + +const ( + pollingInterval = time.Second * 2 + statusSucceeded = "SUCCEEDED" +) + +// WaitForTaskToSucceed will poll indefinitely every 2 seconds for the task with uuid to have status of "SUCCEEDED". +// The polling will not stop if the ctx is cancelled, it's only used for HTTP requests in the client. +// WaitForTaskToSucceed will exit immediately on an error getting the task. +func WaitForTaskToSucceed(ctx context.Context, conn *nutanixClientV3.Client, uuid string) error { + return wait.PollImmediateInfinite(pollingInterval, func() (done bool, err error) { + status, getErr := GetTaskStatus(ctx, conn, uuid) + return status == statusSucceeded, getErr + }) +} + +func GetTaskStatus(ctx context.Context, client *nutanixClientV3.Client, uuid string) (string, error) { + log := ctrl.LoggerFrom(ctx) + log.V(1).Info(fmt.Sprintf("Getting task with UUID %s", uuid)) + v, err := client.V3.GetTask(ctx, uuid) + if err != nil { + log.Error(err, fmt.Sprintf("error occurred while waiting for task with UUID %s", uuid)) + return "", err + } + + if *v.Status == "INVALID_UUID" || *v.Status == "FAILED" { + return *v.Status, + fmt.Errorf("error_detail: %s, progress_message: %s", utils.StringValue(v.ErrorDetail), utils.StringValue(v.ProgressMessage)) + } + taskStatus := *v.Status + log.V(1).Info(fmt.Sprintf("Status for task with UUID %s: %s", uuid, taskStatus)) + return taskStatus, nil +} diff --git a/pkg/client/status_test.go b/pkg/client/status_test.go new file mode 100644 index 0000000000..d97da2d534 --- /dev/null +++ b/pkg/client/status_test.go @@ -0,0 +1,166 @@ +/* +Copyright 2024 Nutanix + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + nutanixtestclient "github.com/nutanix-cloud-native/cluster-api-provider-nutanix/test/helpers/prism-go-client/v3" +) + +func Test_GetTaskStatus(t *testing.T) { + client, err := nutanixtestclient.NewTestClient() + assert.NoError(t, err) + // use cleanup over defer as the connection gets closed before the tests run with t.Parallel() + t.Cleanup(func() { + client.Close() + }) + + t.Parallel() + tests := []struct { + name string + taskUUID string + handler func(w http.ResponseWriter, r *http.Request) + ctx context.Context + expectedStatus string + expectedErr error + }{ + { + name: "succeeded", + taskUUID: "succeeded", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, `{"status": "SUCCEEDED"}`) + }, + ctx: context.Background(), + expectedStatus: "SUCCEEDED", + }, + { + name: "unauthorized", + taskUUID: "unauthorized", + handler: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error_code": "401"}`, http.StatusUnauthorized) + }, + ctx: context.Background(), + expectedErr: fmt.Errorf("invalid Nutanix credentials"), + }, + { + name: "invalid", + taskUUID: "invalid", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, `{"status": "INVALID_UUID", "error_detail": "invalid UUID", "progress_message": "invalid UUID"}`) + }, + ctx: context.Background(), + expectedStatus: "INVALID_UUID", + expectedErr: fmt.Errorf("error_detail: invalid UUID, progress_message: invalid UUID"), + }, + { + name: "failed", + taskUUID: "failed", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, `{"status": "FAILED", "error_detail": "task failed", "progress_message": "will never succeed"}`) + }, + ctx: context.Background(), + expectedStatus: "FAILED", + expectedErr: fmt.Errorf("error_detail: task failed, progress_message: will never succeed"), + }, + } + for _, tt := range tests { + tt := tt // Capture range variable. + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + client.AddMockHandler(nutanixtestclient.GetTaskURLPath(tt.taskUUID), tt.handler) + + status, err := GetTaskStatus(tt.ctx, client.Client, tt.taskUUID) + assert.Equal(t, tt.expectedErr, err) + assert.Equal(t, tt.expectedStatus, status) + }) + } +} + +func Test_WaitForTaskCompletion(t *testing.T) { + client, err := nutanixtestclient.NewTestClient() + assert.NoError(t, err) + // use cleanup over defer as the connection gets closed before the tests run with t.Parallel() + t.Cleanup(func() { + client.Close() + }) + + const ( + timeout = time.Second * 1 + ) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(func() { + cancel() + }) + + t.Parallel() + tests := []struct { + name string + taskUUID string + handler func(w http.ResponseWriter, r *http.Request) + ctx context.Context + expectedErr error + }{ + { + name: "succeeded", + taskUUID: "succeeded", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, `{"status": "SUCCEEDED"}`) + }, + ctx: ctx, + }, + { + name: "invalid", + taskUUID: "invalid", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, `{"status": "INVALID_UUID", "error_detail": "invalid UUID", "progress_message": "invalid UUID"}`) + }, + ctx: ctx, + expectedErr: fmt.Errorf("error_detail: invalid UUID, progress_message: invalid UUID"), + }, + { + name: "timeout", + taskUUID: "timeout", + handler: func(w http.ResponseWriter, r *http.Request) { + // always wait 1 second longer than the timeout to force the context to cancel + time.Sleep(timeout + time.Second) + }, + ctx: ctx, + expectedErr: context.DeadlineExceeded, + }, + } + for _, tt := range tests { + tt := tt // Capture range variable. + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + client.AddMockHandler(nutanixtestclient.GetTaskURLPath(tt.taskUUID), tt.handler) + + err := WaitForTaskToSucceed(tt.ctx, client.Client, tt.taskUUID) + if tt.expectedErr != nil { + assert.ErrorContains(t, err, tt.expectedErr.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/test/helpers/prism-go-client/v3/client.go b/test/helpers/prism-go-client/v3/client.go new file mode 100644 index 0000000000..a904d7bba1 --- /dev/null +++ b/test/helpers/prism-go-client/v3/client.go @@ -0,0 +1,69 @@ +/* +Copyright 2024 Nutanix + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v3 + +import ( + "fmt" + + "net/http" + "net/http/httptest" + "path" + + prismgoclient "github.com/nutanix-cloud-native/prism-go-client" + nutanixClientV3 "github.com/nutanix-cloud-native/prism-go-client/v3" +) + +const ( + baseURLPath = "/api/nutanix/v3/" +) + +type TestClient struct { + *nutanixClientV3.Client + + mux *http.ServeMux + server *httptest.Server +} + +func NewTestClient() (*TestClient, error) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + + cred := prismgoclient.Credentials{ + URL: server.URL, + Username: "username", + Password: "password", + Endpoint: "0.0.0.0", + } + + client, err := nutanixClientV3.NewV3Client(cred) + if err != nil { + return nil, fmt.Errorf("error creating Nutanix test client: %w", err) + } + return &TestClient{client, mux, server}, nil +} + +func (c *TestClient) Close() { + c.server.Close() +} + +func (c *TestClient) AddMockHandler(pattern string, handler func(w http.ResponseWriter, r *http.Request)) { + c.mux.HandleFunc(pattern, handler) +} + +func GetTaskURLPath(uuid string) string { + return path.Join(baseURLPath, "tasks", uuid) +}