diff --git a/cmd/marker/main.go b/cmd/marker/main.go index 581037a2..c145a8cf 100644 --- a/cmd/marker/main.go +++ b/cmd/marker/main.go @@ -17,6 +17,7 @@ package main import ( "flag" "fmt" + "net" "os" "reflect" "strings" @@ -29,6 +30,12 @@ import ( "github.com/k8snetworkplumbingwg/ovs-cni/pkg/marker" ) +const ( + UnixSocketType = "unix" + TcpSocketType = "tcp" + SocketConnectionTimeout = time.Minute +) + func main() { nodeName := flag.String("node-name", "", "name of kubernetes node") ovsSocket := flag.String("ovs-socket", "", "address of openvswitch database connection") @@ -51,41 +58,12 @@ func main() { glog.Fatal("node-name must be set") } - if *ovsSocket == "" { - glog.Fatal("ovs-socket must be set") - } - - var socketType, path string - ovsSocketTokens := strings.Split(*ovsSocket, ":") - if len(ovsSocketTokens) < 2 { - /* - * ovsSocket should consist of comma separated socket type and socket - * detail. If no socket type is specified, it is assumed to be a unix - * domain socket, for backwards compatibility. - */ - socketType = "unix" - path = *ovsSocket - } else { - socketType = ovsSocketTokens[0] - path = ovsSocketTokens[1] - } - - if socketType == "unix" { - for { - _, err := os.Stat(path) - if err == nil { - glog.Info("Found the OVS socket") - break - } else if os.IsNotExist(err) { - glog.Infof("Given ovs-socket %q was not found, waiting for the socket to appear", path) - time.Sleep(time.Minute) - } else { - glog.Fatalf("Failed opening the OVS socket with: %v", err) - } - } + endpoint, err := parseOvsSocket(ovsSocket) + if err != nil { + glog.Fatalf("Failed to parse ovs socket: %v", err) } - markerApp, err := marker.NewMarker(*nodeName, socketType+":"+path) + markerApp, err := marker.NewMarker(*nodeName, endpoint) if err != nil { glog.Fatalf("Failed to create a new marker object: %v", err) } @@ -137,3 +115,68 @@ func keepAlive(healthCheckFile string, healthCheckInterval int) { }, time.Duration(healthCheckInterval)*time.Second) } + +func parseOvsSocket(ovsSocket *string) (string, error) { + if *ovsSocket == "" { + return "", fmt.Errorf("ovs-socket must be set") + } + + var socketType, address string + ovsSocketTokens := strings.Split(*ovsSocket, ":") + if len(ovsSocketTokens) < 2 { + /* + * ovsSocket should consist of comma separated socket type and socket + * detail. If no socket type is specified, it is assumed to be a unix + * domain socket, for backwards compatibility. + */ + socketType = UnixSocketType + address = *ovsSocket + } else { + socketType = ovsSocketTokens[0] + if socketType == TcpSocketType { + if len(ovsSocketTokens) != 3 { + return "", fmt.Errorf("failed to parse OVS %s socket, must be in this format %s::", socketType, socketType) + } + address = fmt.Sprintf("%s:%s", ovsSocketTokens[1], ovsSocketTokens[2]) + } else { + // unix socket + address = ovsSocketTokens[1] + } + } + endpoint := fmt.Sprintf("%s:%s", socketType, address) + + if socketType == UnixSocketType { + for { + _, err := os.Stat(address) + if err == nil { + glog.Info("Found the OVS socket") + break + } else if os.IsNotExist(err) { + glog.Infof("Given ovs-socket %q was not found, waiting for the socket to appear", address) + time.Sleep(SocketConnectionTimeout) + } else { + return "", fmt.Errorf("failed opening the OVS socket with: %v", err) + } + } + } else if socketType == TcpSocketType { + conn, err := net.DialTimeout(socketType, address, SocketConnectionTimeout) + if err == nil { + glog.Info("Successfully connected to TCP socket") + conn.Close() + return endpoint, nil + } + + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return "", fmt.Errorf("connection to %s timed out", address) + } else if opErr, ok := err.(*net.OpError); ok { + if opErr.Op == "dial" { + return "", fmt.Errorf("connection to %s failed: %v", address, err) + } else { + return "", fmt.Errorf("unexpected error when connecting to %s: %v", address, err) + } + } else { + return "", fmt.Errorf("unexpected error when connecting to %s: %v", address, err) + } + } + return endpoint, nil +}