From 6edfad7d2646daeaceff5757ce45ddd0078e1238 Mon Sep 17 00:00:00 2001 From: Musixal Date: Mon, 14 Oct 2024 01:03:52 +0330 Subject: [PATCH] Fix bugs relaed to ws/wsmux channel handler --- internal/client/transport/ws.go | 27 +++++++------ internal/client/transport/wsmux.go | 20 ++++++---- internal/server/transport/ws.go | 58 ++++++++++++++++++--------- internal/server/transport/wsmux.go | 63 ++++++++++++++++++------------ 4 files changed, 105 insertions(+), 63 deletions(-) diff --git a/internal/client/transport/ws.go b/internal/client/transport/ws.go index b135b57..6b8d8a5 100644 --- a/internal/client/transport/ws.go +++ b/internal/client/transport/ws.go @@ -84,10 +84,16 @@ func (c *WsTransport) Restart() { defer c.restartMutex.Unlock() c.logger.Info("restarting client...") + if c.cancel != nil { c.cancel() } + // close control channel connection + if c.controlChannel != nil { + c.controlChannel.Close() + } + time.Sleep(2 * time.Second) ctx, cancel := context.WithCancel(c.parentctx) @@ -133,8 +139,6 @@ func (c *WsTransport) channelDialer() { } } - - func (c *WsTransport) poolMaintainer() { for i := 0; i < c.config.ConnPoolSize; i++ { //initial pool filling go c.tunnelDialer() @@ -188,7 +192,6 @@ func (c *WsTransport) poolMaintainer() { func (c *WsTransport) channelHandler() { msgChan := make(chan byte, 1000) - errChan := make(chan error, 1000) // Goroutine to handle the blocking ReceiveBinaryString go func() { @@ -200,7 +203,8 @@ func (c *WsTransport) channelHandler() { default: _, msg, err := c.controlChannel.ReadMessage() if err != nil { - errChan <- err + c.logger.Error("failed to read from channel connection. ", err) + go c.Restart() return } @@ -215,6 +219,7 @@ func (c *WsTransport) channelHandler() { case <-c.ctx.Done(): _ = c.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_Closed}) return + case msg := <-msgChan: switch msg { case utils.SG_Chan: @@ -226,22 +231,20 @@ func (c *WsTransport) channelHandler() { c.logger.Debug("channel signal received, initiating tunnel dialer") go c.tunnelDialer() } + + case utils.SG_HB: + c.logger.Debug("heartbeat signal received successfully") + case utils.SG_Closed: c.logger.Info("control channel has been closed by the server") go c.Restart() return - case utils.SG_HB: - c.logger.Debug("heartbeat signal received successfully") + default: - c.logger.Errorf("unexpected response from channel: %v. Restarting client...", msg) + c.logger.Errorf("unexpected response from channel: %v", msg) go c.Restart() return } - case err := <-errChan: - // Handle errors from the control channel - c.logger.Error("failed to read channel signal, restarting client: ", err) - go c.Restart() - return } } } diff --git a/internal/client/transport/wsmux.go b/internal/client/transport/wsmux.go index 37cd7aa..6433ca4 100644 --- a/internal/client/transport/wsmux.go +++ b/internal/client/transport/wsmux.go @@ -96,10 +96,16 @@ func (c *WsMuxTransport) Restart() { defer c.restartMutex.Unlock() c.logger.Info("restarting client...") + if c.cancel != nil { c.cancel() } + // close control channel connection + if c.controlChannel != nil { + c.controlChannel.Close() + } + time.Sleep(2 * time.Second) ctx, cancel := context.WithCancel(c.parentctx) @@ -199,7 +205,6 @@ func (c *WsMuxTransport) poolMaintainer() { func (c *WsMuxTransport) channelHandler() { msgChan := make(chan byte, 1000) - errChan := make(chan error, 1000) // Goroutine to handle the blocking ReceiveBinaryString go func() { @@ -211,7 +216,8 @@ func (c *WsMuxTransport) channelHandler() { default: _, msg, err := c.controlChannel.ReadMessage() if err != nil { - errChan <- err + c.logger.Error("failed to read from channel connection. ", err) + go c.Restart() return } msgChan <- msg[0] @@ -236,22 +242,20 @@ func (c *WsMuxTransport) channelHandler() { c.logger.Debug("channel signal received, initiating tunnel dialer") go c.tunnelDialer() } + case utils.SG_HB: c.logger.Debug("heartbeat received successfully") + case utils.SG_Closed: c.logger.Info("control channel has been closed by the server") go c.Restart() return + default: - c.logger.Errorf("unexpected response from control channel: %v. Restarting client...", msg) + c.logger.Errorf("unexpected response from control channel: %v", msg) go c.Restart() return - } - case err := <-errChan: - c.logger.Error("failed to read channel signal, restarting client: ", err) - go c.Restart() - return } } diff --git a/internal/server/transport/ws.go b/internal/server/transport/ws.go index 59a3c12..d39be81 100644 --- a/internal/server/transport/ws.go +++ b/internal/server/transport/ws.go @@ -96,7 +96,7 @@ func (s *WsTransport) Restart() { s.cancel() } - // Close open connection + // Close control channel connection if s.controlChannel != nil { s.controlChannel.Close() } @@ -124,16 +124,27 @@ func (s *WsTransport) channelHandler() { defer ticker.Stop() // Channel to receive the message or error - resultChan := make(chan struct { - message []byte - err error - }) + messageChan := make(chan []byte, 1) + + // Separate goroutine to continuously listen for messages go func() { - _, message, err := s.controlChannel.ReadMessage() - resultChan <- struct { - message []byte - err error - }{message, err} + for { + select { + case <-s.ctx.Done(): + return + + default: + _, message, err := s.controlChannel.ReadMessage() + // Exit if there's an error + if err != nil { + s.logger.Error("failed to read from channel connection. ", err) + close(messageChan) // Closing the channel to signal completion + go s.Restart() + return + } + messageChan <- message + } + } }() for { @@ -144,7 +155,7 @@ func (s *WsTransport) channelHandler() { case <-s.reqNewConnChan: err := s.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_Chan}) if err != nil { - s.logger.Error("error sending channel signal, attempting to restart server...") + s.logger.Error("failed to send request new connection signal. ", err) go s.Restart() return } @@ -152,21 +163,22 @@ func (s *WsTransport) channelHandler() { case <-ticker.C: err := s.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_HB}) if err != nil { - s.logger.Errorf("failed to send heartbeat signal. Error: %v. Restarting server...", err) + s.logger.Errorf("failed to send heartbeat signal. Error: %v.", err) go s.Restart() return } s.logger.Debug("heartbeat signal sent successfully") - case result := <-resultChan: - if result.err != nil { - s.logger.Errorf("failed to receive message from channel connection: %v", result.err) - go s.Restart() + case message, ok := <-messageChan: + if !ok { + s.logger.Error("channel closed, likely due to an error in WebSocket read") return } - if bytes.Equal(result.message, []byte{utils.SG_Closed}) { + + // Handle specific control channel messages + if bytes.Equal(message, []byte{utils.SG_Closed}) { s.logger.Info("control channel has been closed by the client") - go s.Restart() + s.Restart() return } } @@ -204,7 +216,15 @@ func (s *WsTransport) tunnelListener() { return } - if r.URL.Path == "/channel" && s.controlChannel == nil { + if r.URL.Path == "/channel" { + if s.controlChannel != nil { + s.logger.Info("new control channel requested. restarting server...") + s.controlChannel.Close() + conn.Close() + go s.Restart() + return + } + s.controlChannel = conn s.logger.Info("control channel established successfully") diff --git a/internal/server/transport/wsmux.go b/internal/server/transport/wsmux.go index c7164ce..b21a60e 100644 --- a/internal/server/transport/wsmux.go +++ b/internal/server/transport/wsmux.go @@ -112,7 +112,7 @@ func (s *WsMuxTransport) Restart() { s.cancel() } - // Close tunnel channel connection + // Close control channel connection if s.controlChannel != nil { s.controlChannel.Close() } @@ -140,16 +140,27 @@ func (s *WsMuxTransport) channelHandler() { defer ticker.Stop() // Channel to receive the message or error - resultChan := make(chan struct { - message []byte - err error - }) + messageChan := make(chan []byte, 1) + + // Separate goroutine to continuously listen for messages go func() { - _, message, err := s.controlChannel.ReadMessage() - resultChan <- struct { - message []byte - err error - }{message, err} + for { + select { + case <-s.ctx.Done(): + return + + default: + _, message, err := s.controlChannel.ReadMessage() + // Exit if there's an error + if err != nil { + s.logger.Error("failed to read from channel connection. ", err) + close(messageChan) // Closing the channel to signal completion + go s.Restart() + return + } + messageChan <- message + } + } }() for { @@ -160,34 +171,30 @@ func (s *WsMuxTransport) channelHandler() { case <-s.reqNewConnChan: err := s.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_Chan}) if err != nil { - s.logger.Error("error sending channel signal, attempting to restart server...") + s.logger.Error("failed to send request new connection signal. ", err) go s.Restart() return } case <-ticker.C: - if s.controlChannel == nil { - s.logger.Warn("control channel is nil. Restarting server to re-establish connection...") - go s.Restart() - return - } err := s.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_HB}) if err != nil { - s.logger.Errorf("Failed to send heartbeat signal. Error: %v. Restarting server...", err) + s.logger.Errorf("failed to send heartbeat signal. Error: %v.", err) go s.Restart() return } s.logger.Debug("heartbeat signal sent successfully") - case result := <-resultChan: - if result.err != nil { - s.logger.Errorf("failed to receive message from channel connection: %v", result.err) - go s.Restart() + case message, ok := <-messageChan: + if !ok { + s.logger.Error("channel closed, likely due to an error in WebSocket read") return } - if bytes.Equal(result.message, []byte{utils.SG_Closed}) { + + // Handle specific control channel messages + if bytes.Equal(message, []byte{utils.SG_Closed}) { s.logger.Info("control channel has been closed by the client") - go s.Restart() + s.Restart() return } } @@ -225,7 +232,15 @@ func (s *WsMuxTransport) tunnelListener() { return } - if r.URL.Path == "/channel" && s.controlChannel == nil { + if r.URL.Path == "/channel" { + if s.controlChannel != nil { + s.logger.Info("new control channel requested. restarting server...") + s.controlChannel.Close() + conn.Close() + go s.Restart() + return + } + s.controlChannel = conn s.logger.Info("control channel established successfully")