Skip to content

Commit

Permalink
refactor: task status file (#355)
Browse files Browse the repository at this point in the history
* test: add unit tests for pkg/client/state

* refactor: use wait.Poll function waiting for task state

* refactor: use consistent task status names

* fixup! test: add unit tests for pkg/client/state

* fix: revert to previous behaviod polling forever

The ctx passed into WaitForTaskToSucceed is only used to cancel HTTP reqests and not to cancel the wait.

* chore: add license headers

* fix: better function name
  • Loading branch information
dkoshkin authored and thunderboltsid committed Apr 29, 2024
1 parent f31c5f6 commit 2fc3e67
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 126 deletions.
2 changes: 1 addition & 1 deletion controllers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func GetImageUUID(ctx context.Context, client *nutanixClientV3.Client, imageName
// HasTaskInProgress returns true if the given task is in progress
func HasTaskInProgress(ctx context.Context, client *nutanixClientV3.Client, taskUUID string) (bool, error) {
log := ctrl.LoggerFrom(ctx)
taskStatus, err := nutanixClientHelper.GetTaskState(ctx, client, taskUUID)
taskStatus, err := nutanixClientHelper.GetTaskStatus(ctx, client, taskUUID)
if err != nil {
return false, err
}
Expand Down
2 changes: 1 addition & 1 deletion controllers/nutanixmachine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ func (r *NutanixMachineReconciler) getOrCreateVM(rctx *nctx.MachineContext) (*nu
return nil, errorMsg
}
log.Info(fmt.Sprintf("Waiting for task %s to get completed for VM %s", lastTaskUUID, rctx.NutanixMachine.Name))
err = nutanixClient.WaitForTaskCompletion(ctx, nc, lastTaskUUID)
err = nutanixClient.WaitForTaskToSucceed(ctx, nc, lastTaskUUID)
if err != nil {
errorMsg := fmt.Errorf("error occurred while waiting for task %s to start: %v", lastTaskUUID, err)
rctx.SetFailureStatus(capierrors.CreateMachineError, errorMsg)
Expand Down
124 changes: 0 additions & 124 deletions pkg/client/state.go

This file was deleted.

61 changes: 61 additions & 0 deletions pkg/client/status.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
Copyright 2022 Nutanix
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package client

import (
"context"
"fmt"
"time"

"github.com/nutanix-cloud-native/prism-go-client/utils"
nutanixClientV3 "github.com/nutanix-cloud-native/prism-go-client/v3"
"k8s.io/apimachinery/pkg/util/wait"
ctrl "sigs.k8s.io/controller-runtime"
)

const (
pollingInterval = time.Second * 2
statusSucceeded = "SUCCEEDED"
)

// WaitForTaskToSucceed will poll indefinitely every 2 seconds for the task with uuid to have status of "SUCCEEDED".
// The polling will not stop if the ctx is cancelled, it's only used for HTTP requests in the client.
// WaitForTaskToSucceed will exit immediately on an error getting the task.
func WaitForTaskToSucceed(ctx context.Context, conn *nutanixClientV3.Client, uuid string) error {
return wait.PollImmediateInfinite(pollingInterval, func() (done bool, err error) {
status, getErr := GetTaskStatus(ctx, conn, uuid)
return status == statusSucceeded, getErr
})
}

func GetTaskStatus(ctx context.Context, client *nutanixClientV3.Client, uuid string) (string, error) {
log := ctrl.LoggerFrom(ctx)
log.V(1).Info(fmt.Sprintf("Getting task with UUID %s", uuid))
v, err := client.V3.GetTask(ctx, uuid)
if err != nil {
log.Error(err, fmt.Sprintf("error occurred while waiting for task with UUID %s", uuid))
return "", err
}

if *v.Status == "INVALID_UUID" || *v.Status == "FAILED" {
return *v.Status,
fmt.Errorf("error_detail: %s, progress_message: %s", utils.StringValue(v.ErrorDetail), utils.StringValue(v.ProgressMessage))
}
taskStatus := *v.Status
log.V(1).Info(fmt.Sprintf("Status for task with UUID %s: %s", uuid, taskStatus))
return taskStatus, nil
}
166 changes: 166 additions & 0 deletions pkg/client/status_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
Copyright 2024 Nutanix
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package client

import (
"context"
"fmt"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"

nutanixtestclient "github.com/nutanix-cloud-native/cluster-api-provider-nutanix/test/helpers/prism-go-client/v3"
)

func Test_GetTaskStatus(t *testing.T) {
client, err := nutanixtestclient.NewTestClient()
assert.NoError(t, err)
// use cleanup over defer as the connection gets closed before the tests run with t.Parallel()
t.Cleanup(func() {
client.Close()
})

t.Parallel()
tests := []struct {
name string
taskUUID string
handler func(w http.ResponseWriter, r *http.Request)
ctx context.Context
expectedStatus string
expectedErr error
}{
{
name: "succeeded",
taskUUID: "succeeded",
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"status": "SUCCEEDED"}`)
},
ctx: context.Background(),
expectedStatus: "SUCCEEDED",
},
{
name: "unauthorized",
taskUUID: "unauthorized",
handler: func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error_code": "401"}`, http.StatusUnauthorized)
},
ctx: context.Background(),
expectedErr: fmt.Errorf("invalid Nutanix credentials"),
},
{
name: "invalid",
taskUUID: "invalid",
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"status": "INVALID_UUID", "error_detail": "invalid UUID", "progress_message": "invalid UUID"}`)
},
ctx: context.Background(),
expectedStatus: "INVALID_UUID",
expectedErr: fmt.Errorf("error_detail: invalid UUID, progress_message: invalid UUID"),
},
{
name: "failed",
taskUUID: "failed",
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"status": "FAILED", "error_detail": "task failed", "progress_message": "will never succeed"}`)
},
ctx: context.Background(),
expectedStatus: "FAILED",
expectedErr: fmt.Errorf("error_detail: task failed, progress_message: will never succeed"),
},
}
for _, tt := range tests {
tt := tt // Capture range variable.
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client.AddMockHandler(nutanixtestclient.GetTaskURLPath(tt.taskUUID), tt.handler)

status, err := GetTaskStatus(tt.ctx, client.Client, tt.taskUUID)
assert.Equal(t, tt.expectedErr, err)
assert.Equal(t, tt.expectedStatus, status)
})
}
}

func Test_WaitForTaskCompletion(t *testing.T) {
client, err := nutanixtestclient.NewTestClient()
assert.NoError(t, err)
// use cleanup over defer as the connection gets closed before the tests run with t.Parallel()
t.Cleanup(func() {
client.Close()
})

const (
timeout = time.Second * 1
)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
t.Cleanup(func() {
cancel()
})

t.Parallel()
tests := []struct {
name string
taskUUID string
handler func(w http.ResponseWriter, r *http.Request)
ctx context.Context
expectedErr error
}{
{
name: "succeeded",
taskUUID: "succeeded",
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"status": "SUCCEEDED"}`)
},
ctx: ctx,
},
{
name: "invalid",
taskUUID: "invalid",
handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"status": "INVALID_UUID", "error_detail": "invalid UUID", "progress_message": "invalid UUID"}`)
},
ctx: ctx,
expectedErr: fmt.Errorf("error_detail: invalid UUID, progress_message: invalid UUID"),
},
{
name: "timeout",
taskUUID: "timeout",
handler: func(w http.ResponseWriter, r *http.Request) {
// always wait 1 second longer than the timeout to force the context to cancel
time.Sleep(timeout + time.Second)
},
ctx: ctx,
expectedErr: context.DeadlineExceeded,
},
}
for _, tt := range tests {
tt := tt // Capture range variable.
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client.AddMockHandler(nutanixtestclient.GetTaskURLPath(tt.taskUUID), tt.handler)

err := WaitForTaskToSucceed(tt.ctx, client.Client, tt.taskUUID)
if tt.expectedErr != nil {
assert.ErrorContains(t, err, tt.expectedErr.Error())
} else {
assert.NoError(t, err)
}
})
}
}
Loading

0 comments on commit 2fc3e67

Please sign in to comment.