diff --git a/cmd/mock-driver/main.go b/cmd/mock-driver/main.go index e95783ce..8313630b 100644 --- a/cmd/mock-driver/main.go +++ b/cmd/mock-driver/main.go @@ -16,16 +16,16 @@ limitations under the License. package main import ( + "context" "flag" - "fmt" "io/ioutil" - "net" "os" "os/signal" - "strings" "syscall" "github.com/kubernetes-csi/csi-test/v4/driver" + "github.com/kubernetes-csi/csi-test/v4/internal/endpoint" + "github.com/kubernetes-csi/csi-test/v4/internal/proxy" "github.com/kubernetes-csi/csi-test/v4/mock/service" "gopkg.in/yaml.v2" "k8s.io/klog/v2" @@ -50,13 +50,37 @@ func main() { flag.BoolVar(&config.DisableOnlineExpansion, "disable-online-expansion", false, "Disables online volume expansion capability.") flag.BoolVar(&config.PermissiveTargetPath, "permissive-target-path", false, "Allows the CO to create PublishVolumeRequest.TargetPath, which violates the CSI spec.") flag.StringVar(&hooksFile, "hooks-file", "", "YAML file with hook scripts.") + proxyEndpoint := flag.String("proxy-endpoint", "", "Instead of running the CSI driver code, just proxy connections from $CSI_ENDPOINT to the given listening socket.") flag.Parse() - endpoint := os.Getenv("CSI_ENDPOINT") + csiEndpoint := os.Getenv("CSI_ENDPOINT") controllerEndpoint := os.Getenv("CSI_CONTROLLER_ENDPOINT") if len(controllerEndpoint) == 0 { // If empty, set to the common endpoint. - controllerEndpoint = endpoint + controllerEndpoint = csiEndpoint + } + + if *proxyEndpoint != "" { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + closer, err := proxy.Run(ctx, csiEndpoint, *proxyEndpoint) + if err != nil { + klog.Fatalf("failed to run proxy: %v", err) + } + defer closer.Close() + + // Wait for signal + sigc := make(chan os.Signal, 1) + sigs := []os.Signal{ + syscall.SIGTERM, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGQUIT, + } + signal.Notify(sigc, sigs...) + + <-sigc + return } if hooksFile != "" { @@ -71,7 +95,7 @@ func main() { // Create mock driver s := service.New(config) - if endpoint == controllerEndpoint { + if csiEndpoint == controllerEndpoint { servers := &driver.CSIDriverServers{ Controller: s, Identity: s, @@ -86,10 +110,10 @@ func main() { } // Listen - l, cleanup, err := listen(endpoint) + l, cleanup, err := endpoint.Listen(csiEndpoint) if err != nil { klog.Exitf("Error: Unable to listen on %s socket: %v\n", - endpoint, + csiEndpoint, err) } defer cleanup() @@ -134,7 +158,7 @@ func main() { } // Listen controller. - l, cleanupController, err := listen(controllerEndpoint) + l, cleanupController, err := endpoint.Listen(controllerEndpoint) if err != nil { klog.Exitf("Error: Unable to listen on %s socket: %v\n", controllerEndpoint, @@ -150,10 +174,10 @@ func main() { klog.Infof("mock controller driver started") // Listen node. - l, cleanupNode, err := listen(endpoint) + l, cleanupNode, err := endpoint.Listen(csiEndpoint) if err != nil { klog.Exitf("Error: Unable to listen on %s socket: %v\n", - endpoint, + csiEndpoint, err) } defer cleanupNode() @@ -182,39 +206,6 @@ func main() { } } -func parseEndpoint(ep string) (string, string, error) { - if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") { - s := strings.SplitN(ep, "://", 2) - if s[1] != "" { - return s[0], s[1], nil - } - return "", "", fmt.Errorf("Invalid endpoint: %v", ep) - } - // Assume everything else is a file path for a Unix Domain Socket. - return "unix", ep, nil -} - -func listen(endpoint string) (net.Listener, func(), error) { - proto, addr, err := parseEndpoint(endpoint) - if err != nil { - return nil, nil, err - } - - cleanup := func() {} - if proto == "unix" { - addr = "/" + addr - if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { //nolint: vetshadow - return nil, nil, fmt.Errorf("%s: %q", addr, err) - } - cleanup = func() { - os.Remove(addr) - } - } - - l, err := net.Listen(proto, addr) - return l, cleanup, err -} - func parseHooksFile(file string) (*service.Hooks, error) { var hooks service.Hooks diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go new file mode 100644 index 00000000..45321079 --- /dev/null +++ b/internal/endpoint/endpoint.go @@ -0,0 +1,57 @@ +/* +Copyright 2020 Kubernetes Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package endpoint + +import ( + "fmt" + "net" + "os" + "strings" +) + +func Parse(ep string) (string, string, error) { + if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") { + s := strings.SplitN(ep, "://", 2) + if s[1] != "" { + return s[0], s[1], nil + } + return "", "", fmt.Errorf("Invalid endpoint: %v", ep) + } + // Assume everything else is a file path for a Unix Domain Socket. + return "unix", ep, nil +} + +func Listen(endpoint string) (net.Listener, func(), error) { + proto, addr, err := Parse(endpoint) + if err != nil { + return nil, nil, err + } + + cleanup := func() {} + if proto == "unix" { + addr = "/" + addr + if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { //nolint: vetshadow + return nil, nil, fmt.Errorf("%s: %q", addr, err) + } + cleanup = func() { + os.Remove(addr) + } + } + + l, err := net.Listen(proto, addr) + return l, cleanup, err +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go new file mode 100644 index 00000000..8d0d1408 --- /dev/null +++ b/internal/proxy/proxy.go @@ -0,0 +1,146 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package proxy makes it possible to forward a listening socket in +// situations where the proxy cannot connect to some other address. +// Instead, it creates two listening sockets, pairs two incoming +// connections and then moves data back and forth. This matches +// the behavior of the following socat command: +// socat -d -d -d UNIX-LISTEN:/tmp/socat,fork TCP-LISTEN:9000,reuseport +// +// The advantage over that command is that both listening +// sockets are always open, in contrast to the socat solution +// where the TCP port is only open when there actually is a connection +// available. +// +// To establish a connection, someone has to poll the proxy with a dialer. +package proxy + +import ( + "context" + "fmt" + "io" + "net" + + "k8s.io/klog/v2" + + "github.com/kubernetes-csi/csi-test/v4/internal/endpoint" +) + +// New listens on both endpoints and starts accepting connections +// until closed or the context is done. +func Run(ctx context.Context, endpoint1, endpoint2 string) (io.Closer, error) { + proxy := &proxy{} + failedProxy := proxy + defer func() { + if failedProxy != nil { + failedProxy.Close() + } + }() + + proxy.ctx, proxy.cancel = context.WithCancel(ctx) + + var err error + proxy.s1, proxy.cleanup1, err = endpoint.Listen(endpoint1) + if err != nil { + return nil, fmt.Errorf("listen %s: %v", endpoint1, err) + } + proxy.s2, proxy.cleanup2, err = endpoint.Listen(endpoint2) + if err != nil { + return nil, fmt.Errorf("listen %s: %v", endpoint2, err) + } + + klog.V(3).Infof("proxy listening on %s and %s", endpoint1, endpoint2) + + go func() { + for { + // We block on the first listening socket. + // The Linux kernel proactively accepts connections + // on the second one which we will take over below. + conn1 := accept(proxy.ctx, proxy.s1, endpoint1) + if conn1 == nil { + // Done, shut down. + klog.V(5).Infof("proxy endpoint %s closed, shutting down", endpoint1) + return + } + conn2 := accept(proxy.ctx, proxy.s2, endpoint2) + if conn2 == nil { + // Done, shut down. The already accepted + // connection gets closed. + klog.V(5).Infof("proxy endpoint %s closed, shutting down and close established connection", endpoint2) + conn1.Close() + return + } + + klog.V(3).Infof("proxy established a new connection between %s and %s", endpoint1, endpoint2) + go copy(conn1, conn2, endpoint1, endpoint2) + go copy(conn2, conn1, endpoint2, endpoint1) + } + }() + + failedProxy = nil + return proxy, nil +} + +type proxy struct { + ctx context.Context + cancel func() + s1, s2 net.Listener + cleanup1, cleanup2 func() +} + +func (p *proxy) Close() error { + if p.cancel != nil { + p.cancel() + } + if p.s1 != nil { + p.s1.Close() + } + if p.s2 != nil { + p.s2.Close() + } + if p.cleanup1 != nil { + p.cleanup1() + } + if p.cleanup2 != nil { + p.cleanup2() + } + return nil +} + +func copy(from, to net.Conn, fromEndpoint, toEndpoint string) { + klog.V(5).Infof("starting to copy %s -> %s", fromEndpoint, toEndpoint) + // Signal recipient that no more data is going to come. + // This also stops reading from it. + defer to.Close() + // Copy data until EOF. + cnt, err := io.Copy(to, from) + klog.V(5).Infof("done copying %s -> %s: %d bytes, %v", fromEndpoint, toEndpoint, cnt, err) +} + +func accept(ctx context.Context, s net.Listener, endpoint string) net.Conn { + for { + c, err := s.Accept() + if err == nil { + return c + } + // Ignore error if we are shutting down. + if ctx.Err() != nil { + return nil + } + klog.V(3).Infof("accept on %s failed: %v", endpoint, err) + } +} diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go new file mode 100644 index 00000000..61b4849d --- /dev/null +++ b/internal/proxy/proxy_test.go @@ -0,0 +1,109 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "bytes" + "context" + "io" + "net" + "testing" + + "k8s.io/klog/v2" +) + +func init() { + klog.InitFlags(nil) +} + +func TestProxy(t *testing.T) { + tmpdir := t.TempDir() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + endpoint1 := tmpdir + "/a.sock" + endpoint2 := tmpdir + "/b.sock" + + closer, err := Run(ctx, endpoint1, endpoint2) + if err != nil { + t.Fatalf("proxy error: %v", err) + } + defer closer.Close() + + t.Run("a-to-b", func(t *testing.T) { + sendReceive(t, endpoint1, endpoint2) + }) + t.Run("b-to-a", func(t *testing.T) { + sendReceive(t, endpoint2, endpoint1) + }) +} + +func sendReceive(t *testing.T, endpoint1, endpoint2 string) { + conn1, err := net.Dial("unix", endpoint1) + if err != nil { + t.Fatalf("error connecting to first endpoint %s: %v", endpoint1, err) + } + defer conn1.Close() + conn2, err := net.Dial("unix", endpoint2) + if err != nil { + t.Fatalf("error connecting to second endpoint %s: %v", endpoint2, err) + } + defer conn2.Close() + + req1 := "ping" + if _, err := conn1.Write([]byte(req1)); err != nil { + t.Fatalf("error writing %q: %v", req1, err) + } + buffer := make([]byte, 100) + len, err := conn2.Read(buffer) + if err != nil { + t.Fatalf("error reading %q: %v", req1, err) + } + if string(buffer[:len]) != req1 { + t.Fatalf("expected %q, got %q", req1, string(buffer[:len])) + } + + resp1 := "pong-pong" + if _, err := conn2.Write([]byte(resp1)); err != nil { + t.Fatalf("error writing %q: %v", resp1, err) + } + buffer = make([]byte, 100) + len, err = conn1.Read(buffer) + if err != nil { + t.Fatalf("error reading %q: %v", resp1, err) + } + if string(buffer[:len]) != resp1 { + t.Fatalf("expected %q, got %q", resp1, string(buffer[:len])) + } + + // Closing one side should be noticed at the other end. + err = conn1.Close() + if err != nil { + t.Fatalf("error closing connection to %s: %v", endpoint1, err) + } + len2, err := io.Copy(&bytes.Buffer{}, conn2) + if err != nil { + t.Fatalf("error reading from %s: %v", endpoint2, err) + } + if len2 != 0 { + t.Fatalf("unexpected data via %s: %d", endpoint2, len2) + } + err = conn2.Close() + if err != nil { + t.Fatalf("error closing connection to %s: %v", endpoint2, err) + } +}