From 369e69db65b0a317d46d023c117abebb8fea1592 Mon Sep 17 00:00:00 2001 From: umagnus Date: Thu, 14 Nov 2024 08:40:29 +0000 Subject: [PATCH] use util wait func --- pkg/smb/nodeserver.go | 15 +++------ pkg/util/util.go | 48 +++++++++++++++++++++++++++ pkg/util/util_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 10 deletions(-) create mode 100644 pkg/util/util.go create mode 100644 pkg/util/util_test.go diff --git a/pkg/smb/nodeserver.go b/pkg/smb/nodeserver.go index d57c1861554..422f33256ab 100644 --- a/pkg/smb/nodeserver.go +++ b/pkg/smb/nodeserver.go @@ -28,7 +28,6 @@ import ( "github.com/container-storage-interface/spec/lib/go/csi" - "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" "k8s.io/kubernetes/pkg/volume" @@ -37,6 +36,7 @@ import ( "golang.org/x/net/context" + volumehelper "github.com/kubernetes-csi/csi-driver-smb/pkg/util" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" ) @@ -232,16 +232,11 @@ func (d *Driver) NodeStageVolume(_ context.Context, req *csi.NodeStageVolumeRequ source = strings.TrimRight(source, "/") source = fmt.Sprintf("%s/%s", source, subDir) } - mountComplete := false - err = wait.PollImmediate(1*time.Second, 2*time.Minute, func() (bool, error) { - err := Mount(d.mounter, source, targetPath, "cifs", mountOptions, sensitiveMountOptions, volumeID) - mountComplete = true - return true, err - }) - if !mountComplete { - return nil, status.Error(codes.Internal, fmt.Sprintf("volume(%s) mount %q on %q failed with timeout(10m)", volumeID, source, targetPath)) + execFunc := func() error { + return Mount(d.mounter, source, targetPath, "cifs", mountOptions, sensitiveMountOptions, volumeID) } - if err != nil { + timeoutFunc := func() error { return fmt.Errorf("time out") } + if err := volumehelper.WaitUntilTimeout(90*time.Second, execFunc, timeoutFunc); err != nil { return nil, status.Error(codes.Internal, fmt.Sprintf("volume(%s) mount %q on %q failed with %v", volumeID, source, targetPath, err)) } klog.V(2).Infof("volume(%s) mount %q on %q succeeded", volumeID, source, targetPath) diff --git a/pkg/util/util.go b/pkg/util/util.go new file mode 100644 index 00000000000..38e2f4cd850 --- /dev/null +++ b/pkg/util/util.go @@ -0,0 +1,48 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package util + +import ( + "time" +) + +// ExecFunc returns a exec function's output and error +type ExecFunc func() (err error) + +// TimeoutFunc returns output and error if an ExecFunc timeout +type TimeoutFunc func() (err error) + +// WaitUntilTimeout waits for the exec function to complete or return timeout error +func WaitUntilTimeout(timeout time.Duration, execFunc ExecFunc, timeoutFunc TimeoutFunc) error { + // Create a channel to receive the result of the azcopy exec function + done := make(chan bool) + var err error + + // Start the azcopy exec function in a goroutine + go func() { + err = execFunc() + done <- true + }() + + // Wait for the function to complete or time out + select { + case <-done: + return err + case <-time.After(timeout): + return timeoutFunc() + } +} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go new file mode 100644 index 00000000000..718b59a563f --- /dev/null +++ b/pkg/util/util_test.go @@ -0,0 +1,75 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package util + +import ( + "fmt" + "testing" + "time" +) + +func TestWaitUntilTimeout(t *testing.T) { + tests := []struct { + desc string + timeout time.Duration + execFunc ExecFunc + timeoutFunc TimeoutFunc + expectedErr error + }{ + { + desc: "execFunc returns error", + timeout: 1 * time.Second, + execFunc: func() error { + return fmt.Errorf("execFunc error") + }, + timeoutFunc: func() error { + return fmt.Errorf("timeout error") + }, + expectedErr: fmt.Errorf("execFunc error"), + }, + { + desc: "execFunc timeout", + timeout: 1 * time.Second, + execFunc: func() error { + time.Sleep(2 * time.Second) + return nil + }, + timeoutFunc: func() error { + return fmt.Errorf("timeout error") + }, + expectedErr: fmt.Errorf("timeout error"), + }, + { + desc: "execFunc completed successfully", + timeout: 1 * time.Second, + execFunc: func() error { + return nil + }, + timeoutFunc: func() error { + return fmt.Errorf("timeout error") + }, + expectedErr: nil, + }, + } + + for _, test := range tests { + err := WaitUntilTimeout(test.timeout, test.execFunc, test.timeoutFunc) + if err != nil && (err.Error() != test.expectedErr.Error()) { + t.Errorf("unexpected error: %v, expected error: %v", err, test.expectedErr) + } + } +}