Skip to content

Commit

Permalink
Fix bugs relaed to ws/wsmux channel handler
Browse files Browse the repository at this point in the history
  • Loading branch information
Musixal committed Oct 13, 2024
1 parent 97f897c commit 6edfad7
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 63 deletions.
27 changes: 15 additions & 12 deletions internal/client/transport/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,16 @@ func (c *WsTransport) Restart() {
defer c.restartMutex.Unlock()

c.logger.Info("restarting client...")

if c.cancel != nil {
c.cancel()
}

// close control channel connection
if c.controlChannel != nil {
c.controlChannel.Close()
}

time.Sleep(2 * time.Second)

ctx, cancel := context.WithCancel(c.parentctx)
Expand Down Expand Up @@ -133,8 +139,6 @@ func (c *WsTransport) channelDialer() {
}
}



func (c *WsTransport) poolMaintainer() {
for i := 0; i < c.config.ConnPoolSize; i++ { //initial pool filling
go c.tunnelDialer()
Expand Down Expand Up @@ -188,7 +192,6 @@ func (c *WsTransport) poolMaintainer() {

func (c *WsTransport) channelHandler() {
msgChan := make(chan byte, 1000)
errChan := make(chan error, 1000)

// Goroutine to handle the blocking ReceiveBinaryString
go func() {
Expand All @@ -200,7 +203,8 @@ func (c *WsTransport) channelHandler() {
default:
_, msg, err := c.controlChannel.ReadMessage()
if err != nil {
errChan <- err
c.logger.Error("failed to read from channel connection. ", err)
go c.Restart()
return
}

Expand All @@ -215,6 +219,7 @@ func (c *WsTransport) channelHandler() {
case <-c.ctx.Done():
_ = c.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_Closed})
return

case msg := <-msgChan:
switch msg {
case utils.SG_Chan:
Expand All @@ -226,22 +231,20 @@ func (c *WsTransport) channelHandler() {
c.logger.Debug("channel signal received, initiating tunnel dialer")
go c.tunnelDialer()
}

case utils.SG_HB:
c.logger.Debug("heartbeat signal received successfully")

case utils.SG_Closed:
c.logger.Info("control channel has been closed by the server")
go c.Restart()
return
case utils.SG_HB:
c.logger.Debug("heartbeat signal received successfully")

default:
c.logger.Errorf("unexpected response from channel: %v. Restarting client...", msg)
c.logger.Errorf("unexpected response from channel: %v", msg)
go c.Restart()
return
}
case err := <-errChan:
// Handle errors from the control channel
c.logger.Error("failed to read channel signal, restarting client: ", err)
go c.Restart()
return
}
}
}
Expand Down
20 changes: 12 additions & 8 deletions internal/client/transport/wsmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,16 @@ func (c *WsMuxTransport) Restart() {
defer c.restartMutex.Unlock()

c.logger.Info("restarting client...")

if c.cancel != nil {
c.cancel()
}

// close control channel connection
if c.controlChannel != nil {
c.controlChannel.Close()
}

time.Sleep(2 * time.Second)

ctx, cancel := context.WithCancel(c.parentctx)
Expand Down Expand Up @@ -199,7 +205,6 @@ func (c *WsMuxTransport) poolMaintainer() {

func (c *WsMuxTransport) channelHandler() {
msgChan := make(chan byte, 1000)
errChan := make(chan error, 1000)

// Goroutine to handle the blocking ReceiveBinaryString
go func() {
Expand All @@ -211,7 +216,8 @@ func (c *WsMuxTransport) channelHandler() {
default:
_, msg, err := c.controlChannel.ReadMessage()
if err != nil {
errChan <- err
c.logger.Error("failed to read from channel connection. ", err)
go c.Restart()
return
}
msgChan <- msg[0]
Expand All @@ -236,22 +242,20 @@ func (c *WsMuxTransport) channelHandler() {
c.logger.Debug("channel signal received, initiating tunnel dialer")
go c.tunnelDialer()
}

case utils.SG_HB:
c.logger.Debug("heartbeat received successfully")

case utils.SG_Closed:
c.logger.Info("control channel has been closed by the server")
go c.Restart()
return

default:
c.logger.Errorf("unexpected response from control channel: %v. Restarting client...", msg)
c.logger.Errorf("unexpected response from control channel: %v", msg)
go c.Restart()
return

}
case err := <-errChan:
c.logger.Error("failed to read channel signal, restarting client: ", err)
go c.Restart()
return

}
}
Expand Down
58 changes: 39 additions & 19 deletions internal/server/transport/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (s *WsTransport) Restart() {
s.cancel()
}

// Close open connection
// Close control channel connection
if s.controlChannel != nil {
s.controlChannel.Close()
}
Expand Down Expand Up @@ -124,16 +124,27 @@ func (s *WsTransport) channelHandler() {
defer ticker.Stop()

// Channel to receive the message or error
resultChan := make(chan struct {
message []byte
err error
})
messageChan := make(chan []byte, 1)

// Separate goroutine to continuously listen for messages
go func() {
_, message, err := s.controlChannel.ReadMessage()
resultChan <- struct {
message []byte
err error
}{message, err}
for {
select {
case <-s.ctx.Done():
return

default:
_, message, err := s.controlChannel.ReadMessage()
// Exit if there's an error
if err != nil {
s.logger.Error("failed to read from channel connection. ", err)
close(messageChan) // Closing the channel to signal completion
go s.Restart()
return
}
messageChan <- message
}
}
}()

for {
Expand All @@ -144,29 +155,30 @@ func (s *WsTransport) channelHandler() {
case <-s.reqNewConnChan:
err := s.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_Chan})
if err != nil {
s.logger.Error("error sending channel signal, attempting to restart server...")
s.logger.Error("failed to send request new connection signal. ", err)
go s.Restart()
return
}

case <-ticker.C:
err := s.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_HB})
if err != nil {
s.logger.Errorf("failed to send heartbeat signal. Error: %v. Restarting server...", err)
s.logger.Errorf("failed to send heartbeat signal. Error: %v.", err)
go s.Restart()
return
}
s.logger.Debug("heartbeat signal sent successfully")

case result := <-resultChan:
if result.err != nil {
s.logger.Errorf("failed to receive message from channel connection: %v", result.err)
go s.Restart()
case message, ok := <-messageChan:
if !ok {
s.logger.Error("channel closed, likely due to an error in WebSocket read")
return
}
if bytes.Equal(result.message, []byte{utils.SG_Closed}) {

// Handle specific control channel messages
if bytes.Equal(message, []byte{utils.SG_Closed}) {
s.logger.Info("control channel has been closed by the client")
go s.Restart()
s.Restart()
return
}
}
Expand Down Expand Up @@ -204,7 +216,15 @@ func (s *WsTransport) tunnelListener() {
return
}

if r.URL.Path == "/channel" && s.controlChannel == nil {
if r.URL.Path == "/channel" {
if s.controlChannel != nil {
s.logger.Info("new control channel requested. restarting server...")
s.controlChannel.Close()
conn.Close()
go s.Restart()
return
}

s.controlChannel = conn

s.logger.Info("control channel established successfully")
Expand Down
63 changes: 39 additions & 24 deletions internal/server/transport/wsmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (s *WsMuxTransport) Restart() {
s.cancel()
}

// Close tunnel channel connection
// Close control channel connection
if s.controlChannel != nil {
s.controlChannel.Close()
}
Expand Down Expand Up @@ -140,16 +140,27 @@ func (s *WsMuxTransport) channelHandler() {
defer ticker.Stop()

// Channel to receive the message or error
resultChan := make(chan struct {
message []byte
err error
})
messageChan := make(chan []byte, 1)

// Separate goroutine to continuously listen for messages
go func() {
_, message, err := s.controlChannel.ReadMessage()
resultChan <- struct {
message []byte
err error
}{message, err}
for {
select {
case <-s.ctx.Done():
return

default:
_, message, err := s.controlChannel.ReadMessage()
// Exit if there's an error
if err != nil {
s.logger.Error("failed to read from channel connection. ", err)
close(messageChan) // Closing the channel to signal completion
go s.Restart()
return
}
messageChan <- message
}
}
}()

for {
Expand All @@ -160,34 +171,30 @@ func (s *WsMuxTransport) channelHandler() {
case <-s.reqNewConnChan:
err := s.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_Chan})
if err != nil {
s.logger.Error("error sending channel signal, attempting to restart server...")
s.logger.Error("failed to send request new connection signal. ", err)
go s.Restart()
return
}

case <-ticker.C:
if s.controlChannel == nil {
s.logger.Warn("control channel is nil. Restarting server to re-establish connection...")
go s.Restart()
return
}
err := s.controlChannel.WriteMessage(websocket.BinaryMessage, []byte{utils.SG_HB})
if err != nil {
s.logger.Errorf("Failed to send heartbeat signal. Error: %v. Restarting server...", err)
s.logger.Errorf("failed to send heartbeat signal. Error: %v.", err)
go s.Restart()
return
}
s.logger.Debug("heartbeat signal sent successfully")

case result := <-resultChan:
if result.err != nil {
s.logger.Errorf("failed to receive message from channel connection: %v", result.err)
go s.Restart()
case message, ok := <-messageChan:
if !ok {
s.logger.Error("channel closed, likely due to an error in WebSocket read")
return
}
if bytes.Equal(result.message, []byte{utils.SG_Closed}) {

// Handle specific control channel messages
if bytes.Equal(message, []byte{utils.SG_Closed}) {
s.logger.Info("control channel has been closed by the client")
go s.Restart()
s.Restart()
return
}
}
Expand Down Expand Up @@ -225,7 +232,15 @@ func (s *WsMuxTransport) tunnelListener() {
return
}

if r.URL.Path == "/channel" && s.controlChannel == nil {
if r.URL.Path == "/channel" {
if s.controlChannel != nil {
s.logger.Info("new control channel requested. restarting server...")
s.controlChannel.Close()
conn.Close()
go s.Restart()
return
}

s.controlChannel = conn

s.logger.Info("control channel established successfully")
Expand Down

0 comments on commit 6edfad7

Please sign in to comment.