diff --git a/connection/connection.go b/connection/connection.go index 5e66ae78..9fa76b19 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -24,6 +24,9 @@ import ( "strings" "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/container-storage-interface/spec/lib/go/csi" "github.com/kubernetes-csi/csi-lib-utils/protosanitizer" "google.golang.org/grpc" @@ -33,6 +36,9 @@ import ( const ( // Interval of logging connection errors connectionLoggingInterval = 10 * time.Second + + // Interval of trying to call Probe() until it succeeds + probeInterval = 1 * time.Second ) // Connect opens insecure gRPC connection to a CSI driver. Address must be either absolute path to UNIX domain socket @@ -163,6 +169,7 @@ func LogGRPC(ctx context.Context, method string, req, reply interface{}, cc *grp return err } +// GetDriverName returns name of CSI driver. func GetDriverName(ctx context.Context, conn *grpc.ClientConn) (string, error) { client := csi.NewIdentityClient(conn) @@ -177,3 +184,111 @@ func GetDriverName(ctx context.Context, conn *grpc.ClientConn) (string, error) { } return name, nil } + +// PluginCapabilitySet is set of CSI plugin capabilities. Only supported capabilities are in the map. +type PluginCapabilitySet map[csi.PluginCapability_Service_Type]bool + +// GetPluginCapabilities returns set of supported capabilities of CSI driver. +func GetPluginCapabilities(ctx context.Context, conn *grpc.ClientConn) (PluginCapabilitySet, error) { + client := csi.NewIdentityClient(conn) + req := csi.GetPluginCapabilitiesRequest{} + rsp, err := client.GetPluginCapabilities(ctx, &req) + if err != nil { + return nil, err + } + caps := PluginCapabilitySet{} + for _, cap := range rsp.GetCapabilities() { + if cap == nil { + continue + } + srv := cap.GetService() + if srv == nil { + continue + } + t := srv.GetType() + caps[t] = true + } + return caps, nil +} + +// ControllerCapabilitySet is set of CSI controller capabilities. Only supported capabilities are in the map. +type ControllerCapabilitySet map[csi.ControllerServiceCapability_RPC_Type]bool + +// GetControllerCapabilities returns set of supported controller capabilities of CSI driver. +func GetControllerCapabilities(ctx context.Context, conn *grpc.ClientConn) (ControllerCapabilitySet, error) { + client := csi.NewControllerClient(conn) + req := csi.ControllerGetCapabilitiesRequest{} + rsp, err := client.ControllerGetCapabilities(ctx, &req) + if err != nil { + return nil, err + } + + caps := ControllerCapabilitySet{} + for _, cap := range rsp.GetCapabilities() { + if cap == nil { + continue + } + rpc := cap.GetRpc() + if rpc == nil { + continue + } + t := rpc.GetType() + caps[t] = true + } + return caps, nil +} + +// ProbeForever calls Probe() of a CSI driver and waits until the driver becomes ready. +// Any error other than timeout is returned. +func ProbeForever(conn *grpc.ClientConn, singleProbeTimeout time.Duration) error { + for { + klog.Info("Probing CSI driver for readiness") + ready, err := probeOnce(conn, singleProbeTimeout) + if err != nil { + st, ok := status.FromError(err) + if !ok { + // This is not gRPC error. The probe must have failed before gRPC + // method was called, otherwise we would get gRPC error. + return fmt.Errorf("CSI driver probe failed: %s", err) + } + if st.Code() != codes.DeadlineExceeded { + return fmt.Errorf("CSI driver probe failed: %s", err) + } + // Timeout -> driver is not ready. Fall through to sleep() below. + klog.Warning("CSI driver probe timed out") + } else { + if ready { + return nil + } + klog.Warning("CSI driver is not ready") + } + // Timeout was returned or driver is not ready. + time.Sleep(probeInterval) + } +} + +// probeOnce is a helper to simplify defer cancel() +func probeOnce(conn *grpc.ClientConn, timeout time.Duration) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return Probe(ctx, conn) +} + +// Probe calls driver Probe() just once and returns its result without any processing. +func Probe(ctx context.Context, conn *grpc.ClientConn) (ready bool, err error) { + client := csi.NewIdentityClient(conn) + + req := csi.ProbeRequest{} + rsp, err := client.Probe(ctx, &req) + + if err != nil { + return false, err + } + + r := rsp.GetReady() + if r == nil { + // "If not present, the caller SHALL assume that the plugin is in a ready state" + return true, nil + } + return r.GetValue(), nil +} diff --git a/connection/connection_test.go b/connection/connection_test.go index c87fb0aa..dbaf70ee 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -23,6 +23,7 @@ import ( "net" "os" "path" + "reflect" "sync" "testing" "time" @@ -32,6 +33,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/status" + "github.com/golang/protobuf/ptypes/wrappers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -51,7 +53,7 @@ const ( // startServer creates a gRPC server without any registered services. // The returned address can be used to connect to it. The cleanup // function stops it. It can be called multiple times. -func startServer(t *testing.T, tmp string, identity csi.IdentityServer) (string, func()) { +func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controller csi.ControllerServer) (string, func()) { addr := path.Join(tmp, serverSock) listener, err := net.Listen("unix", addr) require.NoError(t, err, "listening on %s", addr) @@ -59,6 +61,9 @@ func startServer(t *testing.T, tmp string, identity csi.IdentityServer) (string, if identity != nil { csi.RegisterIdentityServer(server, identity) } + if controller != nil { + csi.RegisterControllerServer(server, controller) + } var wg sync.WaitGroup wg.Add(1) go func() { @@ -79,7 +84,7 @@ func startServer(t *testing.T, tmp string, identity csi.IdentityServer) (string, func TestConnect(t *testing.T) { tmp := tmpDir(t) defer os.RemoveAll(tmp) - addr, stopServer := startServer(t, tmp, nil) + addr, stopServer := startServer(t, tmp, nil, nil) defer stopServer() conn, err := Connect(addr) @@ -94,7 +99,7 @@ func TestConnect(t *testing.T) { func TestConnectUnix(t *testing.T) { tmp := tmpDir(t) defer os.RemoveAll(tmp) - addr, stopServer := startServer(t, tmp, nil) + addr, stopServer := startServer(t, tmp, nil, nil) defer stopServer() conn, err := Connect("unix:///" + addr) @@ -135,7 +140,7 @@ func TestWaitForServer(t *testing.T) { t.Logf("sleeping %s before starting server", delay) time.Sleep(delay) startTimeServer = time.Now() - _, stopServer = startServer(t, tmp, nil) + _, stopServer = startServer(t, tmp, nil, nil) }() conn, err := Connect(path.Join(tmp, serverSock)) if assert.NoError(t, err, "connect via absolute path") { @@ -169,7 +174,7 @@ func TestTimout(t *testing.T) { func TestReconnect(t *testing.T) { tmp := tmpDir(t) defer os.RemoveAll(tmp) - addr, stopServer := startServer(t, tmp, nil) + addr, stopServer := startServer(t, tmp, nil, nil) defer func() { stopServer() }() @@ -196,7 +201,7 @@ func TestReconnect(t *testing.T) { } // No reconnection either when the server comes back. - _, stopServer = startServer(t, tmp, nil) + _, stopServer = startServer(t, tmp, nil, nil) // We need to give gRPC some time. It does not attempt to reconnect // immediately. If we send the method call too soon, the test passes // even though a later method call will go through again. @@ -214,7 +219,7 @@ func TestReconnect(t *testing.T) { func TestDisconnect(t *testing.T) { tmp := tmpDir(t) defer os.RemoveAll(tmp) - addr, stopServer := startServer(t, tmp, nil) + addr, stopServer := startServer(t, tmp, nil, nil) defer func() { stopServer() }() @@ -245,7 +250,7 @@ func TestDisconnect(t *testing.T) { } // No reconnection either when the server comes back. - _, stopServer = startServer(t, tmp, nil) + _, stopServer = startServer(t, tmp, nil, nil) // We need to give gRPC some time. It does not attempt to reconnect // immediately. If we send the method call too soon, the test passes // even though a later method call will go through again. @@ -265,7 +270,7 @@ func TestDisconnect(t *testing.T) { func TestExplicitReconnect(t *testing.T) { tmp := tmpDir(t) defer os.RemoveAll(tmp) - addr, stopServer := startServer(t, tmp, nil) + addr, stopServer := startServer(t, tmp, nil, nil) defer func() { stopServer() }() @@ -296,7 +301,7 @@ func TestExplicitReconnect(t *testing.T) { } // No reconnection either when the server comes back. - _, stopServer = startServer(t, tmp, nil) + _, stopServer = startServer(t, tmp, nil, nil) // We need to give gRPC some time. It does not attempt to reconnect // immediately. If we send the method call too soon, the test passes // even though a later method call will go through again. @@ -356,43 +361,490 @@ func TestGetDriverName(t *testing.T) { tmp := tmpDir(t) defer os.RemoveAll(tmp) - identity := &identityServer{out, injectedErr} - addr, stopServer := startServer(t, tmp, identity) + identity := &identityServer{ + pluginInfoResponse: out, + err: injectedErr, + } + addr, stopServer := startServer(t, tmp, identity, nil) defer func() { stopServer() }() conn, err := Connect(addr) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } name, err := GetDriverName(context.Background(), conn) if test.expectError && err == nil { - t.Errorf("test %q: Expected error, got none", test.name) + t.Errorf("Expected error, got none") } if !test.expectError && err != nil { - t.Errorf("test %q: got error: %v", test.name, err) + t.Errorf("Got error: %v", err) } if err == nil && name != "csi/example" { - t.Errorf("got unexpected name: %q", name) + t.Errorf("Got unexpected name: %q", name) + } + }) + } +} + +func TestGetPluginCapabilities(t *testing.T) { + tests := []struct { + name string + output *csi.GetPluginCapabilitiesResponse + injectError bool + expectCapabilities PluginCapabilitySet + expectError bool + }{ + { + name: "success", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_CONTROLLER_SERVICE, + }, + }, + }, + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_UNKNOWN, + }, + }, + }, + }, + }, + expectCapabilities: PluginCapabilitySet{ + csi.PluginCapability_Service_CONTROLLER_SERVICE: true, + csi.PluginCapability_Service_UNKNOWN: true, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "no controller service", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_UNKNOWN, + }, + }, + }, + }, + }, + expectCapabilities: PluginCapabilitySet{ + csi.PluginCapability_Service_UNKNOWN: true, + }, + expectError: false, + }, + { + name: "empty capability", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: nil, + }, + }, + }, + expectCapabilities: PluginCapabilitySet{}, + expectError: false, + }, + { + name: "no capabilities", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{}, + }, + expectCapabilities: PluginCapabilitySet{}, + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var injectedErr error + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + tmp := tmpDir(t) + defer os.RemoveAll(tmp) + identity := &identityServer{ + getPluginCapabilitiesResponse: test.output, + err: injectedErr, + } + addr, stopServer := startServer(t, tmp, identity, nil) + defer func() { + stopServer() + }() + + conn, err := Connect(addr) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } + + caps, err := GetPluginCapabilities(context.Background(), conn) + if test.expectError && err == nil { + t.Errorf("Expected error, got none") + } + if !test.expectError && err != nil { + t.Errorf("Got error: %v", err) + } + if !reflect.DeepEqual(test.expectCapabilities, caps) { + t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps) + } + }) + } +} + +func TestGetControllerCapabilities(t *testing.T) { + tests := []struct { + name string + output *csi.ControllerGetCapabilitiesResponse + injectError bool + expectCapabilities ControllerCapabilitySet + expectError bool + }{ + { + name: "success", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, + }, + }, + }, + }, + }, + expectCapabilities: ControllerCapabilitySet{ + csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME: true, + csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME: true, + }, + expectError: false, + }, + { + name: "supports read only", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_PUBLISH_READONLY, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, + }, + }, + }, + }, + }, + expectCapabilities: ControllerCapabilitySet{ + csi.ControllerServiceCapability_RPC_PUBLISH_READONLY: true, + csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME: true, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "empty capability", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: nil, + }, + }, + }, + expectCapabilities: ControllerCapabilitySet{}, + expectError: false, + }, + { + name: "no capabilities", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{}, + }, + expectCapabilities: ControllerCapabilitySet{}, + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var injectedErr error + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + tmp := tmpDir(t) + defer os.RemoveAll(tmp) + controller := &controllerServer{ + controllerGetCapabilitiesResponse: test.output, + err: injectedErr, + } + addr, stopServer := startServer(t, tmp, nil, controller) + defer func() { + stopServer() + }() + + conn, err := Connect(addr) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } + + caps, err := GetControllerCapabilities(context.Background(), conn) + if test.expectError && err == nil { + t.Errorf("Expected error, got none") + } + if !test.expectError && err != nil { + t.Errorf("Got error: %v", err) + } + if !reflect.DeepEqual(test.expectCapabilities, caps) { + t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps) + } + }) + } +} + +func TestProbeForever(t *testing.T) { + tests := []struct { + name string + probeCalls []probeCall + expectError bool + }{ + { + name: "success", + probeCalls: []probeCall{ + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: true}, + }, + }, + }, + expectError: false, + }, + { + name: "success with empty Ready field (true is assumed)", + probeCalls: []probeCall{ + { + response: &csi.ProbeResponse{ + Ready: nil, + }, + }, + }, + expectError: false, + }, + { + name: "error", + probeCalls: []probeCall{ + { + err: fmt.Errorf("mock error"), + }, + }, + expectError: true, + }, + { + name: "timeout + failure", + probeCalls: []probeCall{ + { + err: status.Error(codes.DeadlineExceeded, "timeout"), + }, + { + err: fmt.Errorf("mock error"), + }, + }, + expectError: true, + }, + { + name: "timeout + success", + probeCalls: []probeCall{ + { + err: status.Error(codes.DeadlineExceeded, "timeout"), + }, + { + err: status.Error(codes.DeadlineExceeded, "timeout"), + }, + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: true}, + }, + }, + }, + expectError: false, + }, + { + name: "unready + failure", + probeCalls: []probeCall{ + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: false}, + }, + }, + { + err: fmt.Errorf("mock error"), + }, + }, + expectError: true, + }, + { + name: "unready + success", + probeCalls: []probeCall{ + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: false}, + }, + }, + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: false}, + }, + }, + { + response: &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: true}, + }, + }, + }, + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tmp := tmpDir(t) + defer os.RemoveAll(tmp) + identity := &identityServer{ + probeCalls: test.probeCalls, + } + addr, stopServer := startServer(t, tmp, identity, nil) + defer func() { + stopServer() + }() + + conn, err := Connect(addr) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } + + err = ProbeForever(conn, time.Second) + if test.expectError && err == nil { + t.Errorf("Expected error, got none") + } + if !test.expectError && err != nil { + t.Errorf("Got error: %v", err) + } + if len(identity.probeCalls) != identity.probeCallCount { + t.Errorf("Expected %d probe calls, got %d", len(identity.probeCalls), identity.probeCallCount) } }) } } type identityServer struct { - response *csi.GetPluginInfoResponse + pluginInfoResponse *csi.GetPluginInfoResponse + getPluginCapabilitiesResponse *csi.GetPluginCapabilitiesResponse + err error + + probeCalls []probeCall + probeCallCount int +} + +type probeCall struct { + response *csi.ProbeResponse err error } var _ csi.IdentityServer = &identityServer{} func (i *identityServer) GetPluginCapabilities(context.Context, *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) { - return nil, fmt.Errorf("Not implemented") + return i.getPluginCapabilitiesResponse, i.err } func (i *identityServer) GetPluginInfo(context.Context, *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) { - return i.response, i.err + return i.pluginInfoResponse, i.err } func (i *identityServer) Probe(context.Context, *csi.ProbeRequest) (*csi.ProbeResponse, error) { - return nil, fmt.Errorf("Not implemented") + if i.probeCallCount >= len(i.probeCalls) { + return nil, fmt.Errorf("Unexpected Probe() call") + } + call := i.probeCalls[i.probeCallCount] + i.probeCallCount++ + return call.response, call.err +} + +type controllerServer struct { + controllerGetCapabilitiesResponse *csi.ControllerGetCapabilitiesResponse + err error +} + +var _ csi.ControllerServer = &controllerServer{} + +func (c *controllerServer) CreateVolume(context.Context, *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) DeleteVolume(context.Context, *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) ControllerPublishVolume(context.Context, *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) ControllerUnpublishVolume(context.Context, *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) ValidateVolumeCapabilities(context.Context, *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) ListVolumes(context.Context, *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) GetCapacity(context.Context, *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) ControllerGetCapabilities(context.Context, *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { + return c.controllerGetCapabilitiesResponse, c.err +} + +func (c *controllerServer) CreateSnapshot(context.Context, *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) DeleteSnapshot(context.Context, *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *controllerServer) ListSnapshots(context.Context, *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { + return nil, fmt.Errorf("unimplemented") }