From 2384484ed79f4c97a7dee8b90a2024f5ae3de7be Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 8 Sep 2022 14:32:14 +0300 Subject: [PATCH] Wss support (#3) * added wss support * changed entrypoints in Dockerfiles to include config from .ini files Signed-off-by: Daniel Soifer --- client/control.go | 28 ++++++++++++++++++++++++---- client/service.go | 28 ++++++++++++++++++++++++---- cmd/frpc/sub/root.go | 2 +- conf/frpc_full.ini | 2 +- dockerfiles/Dockerfile-for-frpc | 2 +- dockerfiles/Dockerfile-for-frps | 2 +- pkg/config/client.go | 6 +++++- pkg/util/net/dial.go | 12 +++++++++--- 8 files changed, 66 insertions(+), 16 deletions(-) diff --git a/client/control.go b/client/control.go index ef9a766fcda..ccec2feda7e 100644 --- a/client/control.go +++ b/client/control.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "io" + "math" "net" "runtime/debug" "strconv" @@ -242,9 +243,21 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) { } dialOptions := []libdial.DialOption{} protocol := ctl.clientCfg.Protocol - if protocol == "websocket" { + var websocketAfterHook *libdial.AfterHook + if protocol == "websocket" || protocol == "wss" { + if protocol == "wss" { + websocketAfterHook = &libdial.AfterHook{ + Priority: math.MaxUint64, // in case of wss, we first want to make the TLS handshake and then switch protocols from https to wss + Hook: frpNet.DialHookWebsocket(true), + } + } else { + websocketAfterHook = &libdial.AfterHook{ + Priority: 0, // in case of ws, we first want to switch protocols from http to ws, and only then make the TLS handshake in case TLS is enabled + Hook: frpNet.DialHookWebsocket(false), + } + } + protocol = "tcp" - dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: frpNet.DialHookWebsocket()})) } if ctl.clientCfg.ConnectServerLocalIP != "" { dialOptions = append(dialOptions, libdial.WithLocalAddr(ctl.clientCfg.ConnectServerLocalIP)) @@ -255,11 +268,18 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) { libdial.WithKeepAlive(time.Duration(ctl.clientCfg.DialServerKeepAlive)*time.Second), libdial.WithProxy(proxyType, addr), libdial.WithProxyAuth(auth), - libdial.WithTLSConfig(tlsConfig), + libdial.WithTLSConfig(tlsConfig), // TLS AfterHook has math.MaxUint64 priority libdial.WithAfterHook(libdial.AfterHook{ - Hook: frpNet.DialHookCustomTLSHeadByte(tlsConfig != nil, ctl.clientCfg.DisableCustomTLSFirstByte), + Priority: 1, // should be executed before TLS AfterHook but after the rest of the AfterHooks (except for wss) + Hook: frpNet.DialHookCustomTLSHeadByte(tlsConfig != nil, ctl.clientCfg.DisableCustomTLSFirstByte), }), ) + if websocketAfterHook != nil { + // websocketAfterHook must be appended after TLS AfterHook because they both might have the + // same priority of math.MaxUint64 in case of wss but TLS AfterHook must be executed first + dialOptions = append(dialOptions, libdial.WithAfterHook(*websocketAfterHook)) + } + conn, err = libdial.Dial( net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)), dialOptions..., diff --git a/client/service.go b/client/service.go index 30bd3f8f5bc..8c753ff4d8e 100644 --- a/client/service.go +++ b/client/service.go @@ -19,6 +19,7 @@ import ( "crypto/tls" "fmt" "io" + "math" "math/rand" "net" "runtime" @@ -256,9 +257,21 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) { } dialOptions := []libdial.DialOption{} protocol := svr.cfg.Protocol - if protocol == "websocket" { + var websocketAfterHook *libdial.AfterHook + if protocol == "websocket" || protocol == "wss" { + if protocol == "wss" { + websocketAfterHook = &libdial.AfterHook{ + Priority: math.MaxUint64, // in case of wss, we first want to make the TLS handshake and then switch protocols from https to wss + Hook: frpNet.DialHookWebsocket(true), + } + } else { + websocketAfterHook = &libdial.AfterHook{ + Priority: 0, // in case of ws, we first want to switch protocols from http to ws, and only then make the TLS handshake in case TLS is enabled + Hook: frpNet.DialHookWebsocket(false), + } + } + protocol = "tcp" - dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: frpNet.DialHookWebsocket()})) } if svr.cfg.ConnectServerLocalIP != "" { dialOptions = append(dialOptions, libdial.WithLocalAddr(svr.cfg.ConnectServerLocalIP)) @@ -269,11 +282,18 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) { libdial.WithKeepAlive(time.Duration(svr.cfg.DialServerKeepAlive)*time.Second), libdial.WithProxy(proxyType, addr), libdial.WithProxyAuth(auth), - libdial.WithTLSConfig(tlsConfig), + libdial.WithTLSConfig(tlsConfig), // TLS AfterHook has math.MaxUint64 priority libdial.WithAfterHook(libdial.AfterHook{ - Hook: frpNet.DialHookCustomTLSHeadByte(tlsConfig != nil, svr.cfg.DisableCustomTLSFirstByte), + Priority: 1, // should be executed before TLS AfterHook but after the rest of the AfterHooks (except for wss) + Hook: frpNet.DialHookCustomTLSHeadByte(tlsConfig != nil, svr.cfg.DisableCustomTLSFirstByte), }), ) + if websocketAfterHook != nil { + // websocketAfterHook must be appended after TLS AfterHook because they both might have the + // same priority of math.MaxUint64 in case of wss but TLS AfterHook must be executed first + dialOptions = append(dialOptions, libdial.WithAfterHook(*websocketAfterHook)) + } + conn, err = libdial.Dial( net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)), dialOptions..., diff --git a/cmd/frpc/sub/root.go b/cmd/frpc/sub/root.go index f8d7eb17708..3bde61fd9f4 100644 --- a/cmd/frpc/sub/root.go +++ b/cmd/frpc/sub/root.go @@ -85,7 +85,7 @@ func init() { func RegisterCommonFlags(cmd *cobra.Command) { cmd.PersistentFlags().StringVarP(&serverAddr, "server_addr", "s", "127.0.0.1:7000", "frp server's address") cmd.PersistentFlags().StringVarP(&user, "user", "u", "", "user") - cmd.PersistentFlags().StringVarP(&protocol, "protocol", "p", "tcp", "tcp or kcp or websocket") + cmd.PersistentFlags().StringVarP(&protocol, "protocol", "p", "tcp", "tcp or kcp or websocket or wss") cmd.PersistentFlags().StringVarP(&token, "token", "t", "", "auth token") cmd.PersistentFlags().StringVarP(&logLevel, "log_level", "", "info", "log level") cmd.PersistentFlags().StringVarP(&logFile, "log_file", "", "console", "console or file path") diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini index 9c939678b1f..94c7ab4a2f1 100644 --- a/conf/frpc_full.ini +++ b/conf/frpc_full.ini @@ -87,7 +87,7 @@ user = your_name login_fail_exit = true # communication protocol used to connect to server -# now it supports tcp, kcp and websocket, default is tcp +# now it supports tcp, kcp, websocket and wss, default is tcp protocol = tcp # set client binding ip when connect server, default is empty. diff --git a/dockerfiles/Dockerfile-for-frpc b/dockerfiles/Dockerfile-for-frpc index ece950f2190..e8bfa478706 100644 --- a/dockerfiles/Dockerfile-for-frpc +++ b/dockerfiles/Dockerfile-for-frpc @@ -9,4 +9,4 @@ FROM alpine:3 COPY --from=building /building/bin/frpc /usr/bin/frpc -ENTRYPOINT ["/usr/bin/frpc"] +ENTRYPOINT /usr/bin/frpc -c /etc/frp/frpc.ini diff --git a/dockerfiles/Dockerfile-for-frps b/dockerfiles/Dockerfile-for-frps index d8ab8247a3a..65b8c92e344 100644 --- a/dockerfiles/Dockerfile-for-frps +++ b/dockerfiles/Dockerfile-for-frps @@ -9,4 +9,4 @@ FROM alpine:3 COPY --from=building /building/bin/frps /usr/bin/frps -ENTRYPOINT ["/usr/bin/frps"] +ENTRYPOINT /usr/bin/frps -c /etc/frp/frps.ini diff --git a/pkg/config/client.go b/pkg/config/client.go index f503711b07a..f3c2e41f88c 100644 --- a/pkg/config/client.go +++ b/pkg/config/client.go @@ -215,6 +215,10 @@ func (cfg *ClientCommonConf) Validate() error { } if cfg.TLSEnable == false { + if cfg.Protocol == "wss" { + return fmt.Errorf("tls_enable must be true for wss support") + } + if cfg.TLSCertFile != "" { fmt.Println("WARNING! tls_cert_file is invalid when tls_enable is false") } @@ -228,7 +232,7 @@ func (cfg *ClientCommonConf) Validate() error { } } - if cfg.Protocol != "tcp" && cfg.Protocol != "kcp" && cfg.Protocol != "websocket" { + if cfg.Protocol != "tcp" && cfg.Protocol != "kcp" && cfg.Protocol != "websocket" && cfg.Protocol != "wss" { return fmt.Errorf("invalid protocol") } diff --git a/pkg/util/net/dial.go b/pkg/util/net/dial.go index 251ebbff7ae..4f9bae3242e 100644 --- a/pkg/util/net/dial.go +++ b/pkg/util/net/dial.go @@ -21,15 +21,21 @@ func DialHookCustomTLSHeadByte(enableTLS bool, disableCustomTLSHeadByte bool) li } } -func DialHookWebsocket() libdial.AfterHookFunc { +func DialHookWebsocket(isSecure bool) libdial.AfterHookFunc { return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) { - addr = "ws://" + addr + FrpWebsocketPath + addrScheme := "ws" + originScheme := "http" + if isSecure { + addrScheme = "wss" + originScheme = "https" + } + addr = addrScheme + "://" + addr + FrpWebsocketPath uri, err := url.Parse(addr) if err != nil { return nil, nil, err } - origin := "http://" + uri.Host + origin := originScheme + "://" + uri.Host cfg, err := websocket.NewConfig(addr, origin) if err != nil { return nil, nil, err