diff --git a/internal/core/client.go b/internal/core/client.go index d0e7215..982a9d4 100644 --- a/internal/core/client.go +++ b/internal/core/client.go @@ -190,7 +190,15 @@ func (client *Client) dialWSS() (net.Conn, error) { } func (client *Client) dialTLS() (net.Conn, error) { - return tls.DialWithDialer(client.netDialer, "tcp", client.conf.RemoteAddr, client.tlsConf) + conn, err := tls.DialWithDialer(client.netDialer, "tcp", client.conf.RemoteAddr, client.tlsConf) + if err != nil { + return nil, err + } + if err := conn.Handshake(); err != nil { + conn.Close() + return nil, err + } + return conn, nil } func (client *Client) newServerConn() (net.Conn, error) { diff --git a/internal/core/server.go b/internal/core/server.go index 5a7c9ab..8ef491c 100644 --- a/internal/core/server.go +++ b/internal/core/server.go @@ -126,12 +126,11 @@ func NewServer(c *ServerConfig) (*Server, error) { func (server *Server) Start() error { listenConfig := net.ListenConfig{Control: getControlFunc(server.tcpConfig)} - innerListener, err := listenConfig.Listen(context.Background(), "tcp", server.conf.BindAddr) + listener, err := listenConfig.Listen(context.Background(), "tcp", server.conf.BindAddr) if err != nil { return fmt.Errorf("tls inner Listener: %v", err) } - listener := tls.NewListener(innerListener, server.tlsConf) server.listenerLocker.Lock() server.listener = listener server.listenerLocker.Unlock() @@ -146,17 +145,28 @@ func (server *Server) Start() error { } } else { for { - leftConn, err := listener.Accept() + leftRawConn, err := listener.Accept() if err != nil { return fmt.Errorf("listener.Accept: %v", err) } - server.log.Debugf("leftConn from %s accepted", leftConn.RemoteAddr()) - if server.conf.EnableMux { - server.handleClientMuxConn(leftConn) - } else { - server.handleClientConn(leftConn) - } + go func() { + server.log.Debugf("leftConn from %s accepted", leftRawConn.RemoteAddr()) + + leftConn := tls.Server(leftRawConn, server.tlsConf) + defer leftConn.Close() + if err := leftConn.Handshake(); err != nil { + server.log.Errorf("leftConn tls handshake: %v", err) + return + } + + if server.conf.EnableMux { + server.handleClientMuxConn(leftConn) + } else { + server.handleClientConn(leftConn) + } + }() + } } return nil