From b9982ae00fd23a55382d56d1c23c1ca5455a61de Mon Sep 17 00:00:00 2001 From: Musixal Date: Fri, 18 Oct 2024 16:58:31 +0330 Subject: [PATCH] EXPERIMENTAL: Bidirectional heartbeat over control channel --- internal/server/transport/ws.go | 50 ++++++++++++++++++------------ internal/server/transport/wsmux.go | 44 +++++++++++++++++--------- 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/internal/server/transport/ws.go b/internal/server/transport/ws.go index a8c96f0..cbc7c97 100644 --- a/internal/server/transport/ws.go +++ b/internal/server/transport/ws.go @@ -1,7 +1,6 @@ package transport import ( - "bytes" "context" "fmt" "net" @@ -92,6 +91,10 @@ func (s *WsTransport) Restart() { defer s.restartMutex.Unlock() s.logger.Info("restarting server...") + + level := s.logger.Level + s.logger.SetLevel(logrus.FatalLevel) + if s.cancel != nil { s.cancel() } @@ -115,8 +118,10 @@ func (s *WsTransport) Restart() { s.usageMonitor = web.NewDataStore(fmt.Sprintf(":%v", s.config.WebPort), ctx, s.config.SnifferLog, s.config.Sniffer, &s.config.TunnelStatus, s.logger) s.config.TunnelStatus = "" - go s.Start() + // set the log level again + s.logger.SetLevel(level) + go s.Start() } func (s *WsTransport) channelHandler() { @@ -124,7 +129,7 @@ func (s *WsTransport) channelHandler() { defer ticker.Stop() // Channel to receive the message or error - messageChan := make(chan []byte, 1) + messageChan := make(chan byte, 10) // Separate goroutine to continuously listen for messages go func() { @@ -134,15 +139,16 @@ func (s *WsTransport) channelHandler() { return default: - _, message, err := s.controlChannel.ReadMessage() + _, msg, err := s.controlChannel.ReadMessage() // Exit if there's an error if err != nil { - s.logger.Error("failed to read from channel connection. ", err) - close(messageChan) // Closing the channel to signal completion - go s.Restart() + if s.cancel != nil { + s.logger.Error("failed to read from channel connection. ", err) + go s.Restart() + } return } - messageChan <- message + messageChan <- msg[0] } } }() @@ -169,18 +175,26 @@ func (s *WsTransport) channelHandler() { } s.logger.Debug("heartbeat signal sent successfully") - case message, ok := <-messageChan: + case msg, ok := <-messageChan: if !ok { s.logger.Error("channel closed, likely due to an error in WebSocket read") return } + switch msg { + case utils.SG_HB: + s.logger.Trace("heartbeat signal received successfully") - // Handle specific control channel messages - if bytes.Equal(message, []byte{utils.SG_Closed}) { + case utils.SG_Closed: s.logger.Warn("control channel has been closed by the client") s.Restart() return + + default: + s.logger.Errorf("unexpected response from channel: %v", msg) + go s.Restart() + return } + } } } @@ -188,8 +202,9 @@ func (s *WsTransport) channelHandler() { func (s *WsTransport) tunnelListener() { addr := s.config.BindAddr upgrader := websocket.Upgrader{ - ReadBufferSize: 16 * 1024, - WriteBufferSize: 16 * 1024, + ReadBufferSize: 16 * 1024, + WriteBufferSize: 16 * 1024, + HandshakeTimeout: 45 * time.Second, CheckOrigin: func(r *http.Request) bool { return true }, @@ -198,7 +213,7 @@ func (s *WsTransport) tunnelListener() { // Create an HTTP server server := &http.Server{ Addr: addr, - IdleTimeout: 600 * time.Second, // HERE + IdleTimeout: -1, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.logger.Tracef("received http request from %s", r.RemoteAddr) @@ -245,7 +260,6 @@ func (s *WsTransport) tunnelListener() { s.config.TunnelStatus = fmt.Sprintf("Connected (%s)", s.config.Mode) - return } else if r.URL.Path == "/tunnel" { wsConn := TunnelChannel{ conn: conn, @@ -300,8 +314,6 @@ func (s *WsTransport) tunnelListener() { } - - func (s *WsTransport) parsePortMappings() { for _, portMapping := range s.config.Ports { parts := strings.Split(portMapping, "=") @@ -335,7 +347,7 @@ func (s *WsTransport) parsePortMappings() { for port := startPort; port <= endPort; port++ { localAddr = fmt.Sprintf(":%d", port) go s.localListener(localAddr, strconv.Itoa(port)) // Use port as the remoteAddr - time.Sleep(1 * time.Millisecond) // for wide port ranges + time.Sleep(1 * time.Millisecond) // for wide port ranges } continue } else { @@ -393,8 +405,6 @@ func (s *WsTransport) parsePortMappings() { } } - - func (s *WsTransport) localListener(localAddr string, remoteAddr string) { portListener, err := net.Listen("tcp", localAddr) if err != nil { diff --git a/internal/server/transport/wsmux.go b/internal/server/transport/wsmux.go index 31becf4..a011477 100644 --- a/internal/server/transport/wsmux.go +++ b/internal/server/transport/wsmux.go @@ -1,7 +1,6 @@ package transport import ( - "bytes" "context" "fmt" "net" @@ -108,6 +107,11 @@ func (s *WsMuxTransport) Restart() { defer s.restartMutex.Unlock() s.logger.Info("restarting server...") + + // for removing timeout logs + level := s.logger.Level + s.logger.SetLevel(logrus.FatalLevel) + if s.cancel != nil { s.cancel() } @@ -131,8 +135,10 @@ func (s *WsMuxTransport) Restart() { s.usageMonitor = web.NewDataStore(fmt.Sprintf(":%v", s.config.WebPort), ctx, s.config.SnifferLog, s.config.Sniffer, &s.config.TunnelStatus, s.logger) s.config.TunnelStatus = "" - go s.Start() + // set the log level again + s.logger.SetLevel(level) + go s.Start() } func (s *WsMuxTransport) channelHandler() { @@ -140,7 +146,7 @@ func (s *WsMuxTransport) channelHandler() { defer ticker.Stop() // Channel to receive the message or error - messageChan := make(chan []byte, 1) + messageChan := make(chan byte, 10) // Separate goroutine to continuously listen for messages go func() { @@ -150,15 +156,16 @@ func (s *WsMuxTransport) channelHandler() { return default: - _, message, err := s.controlChannel.ReadMessage() + _, msg, err := s.controlChannel.ReadMessage() // Exit if there's an error if err != nil { - s.logger.Error("failed to read from channel connection. ", err) - close(messageChan) // Closing the channel to signal completion - go s.Restart() + if s.cancel != nil { + s.logger.Error("failed to read from channel connection. ", err) + go s.Restart() + } return } - messageChan <- message + messageChan <- msg[0] } } }() @@ -185,18 +192,26 @@ func (s *WsMuxTransport) channelHandler() { } s.logger.Debug("heartbeat signal sent successfully") - case message, ok := <-messageChan: + case msg, ok := <-messageChan: if !ok { s.logger.Error("channel closed, likely due to an error in WebSocket read") return } + switch msg { + case utils.SG_HB: + s.logger.Trace("heartbeat signal received successfully") - // Handle specific control channel messages - if bytes.Equal(message, []byte{utils.SG_Closed}) { + case utils.SG_Closed: s.logger.Warn("control channel has been closed by the client") s.Restart() return + + default: + s.logger.Errorf("unexpected response from channel: %v", msg) + go s.Restart() + return } + } } } @@ -204,8 +219,9 @@ func (s *WsMuxTransport) channelHandler() { func (s *WsMuxTransport) tunnelListener() { addr := s.config.BindAddr upgrader := websocket.Upgrader{ - ReadBufferSize: 16 * 1024, - WriteBufferSize: 16 * 1024, + ReadBufferSize: 16 * 1024, + WriteBufferSize: 16 * 1024, + HandshakeTimeout: 45 * time.Second, CheckOrigin: func(r *http.Request) bool { return true }, @@ -214,7 +230,7 @@ func (s *WsMuxTransport) tunnelListener() { // Create an HTTP server server := &http.Server{ Addr: addr, - IdleTimeout: 600 * time.Second, + IdleTimeout: -1, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.logger.Tracef("received http request from %s", r.RemoteAddr)