Skip to content

Commit

Permalink
EXPERIMENTAL: Bidirectional heartbeat over control channel
Browse files Browse the repository at this point in the history
  • Loading branch information
Musixal committed Oct 18, 2024
1 parent 65c7260 commit b9982ae
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 34 deletions.
50 changes: 30 additions & 20 deletions internal/server/transport/ws.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package transport

import (
"bytes"
"context"
"fmt"
"net"
Expand Down Expand Up @@ -92,6 +91,10 @@ func (s *WsTransport) Restart() {
defer s.restartMutex.Unlock()

s.logger.Info("restarting server...")

level := s.logger.Level
s.logger.SetLevel(logrus.FatalLevel)

if s.cancel != nil {
s.cancel()
}
Expand All @@ -115,16 +118,18 @@ func (s *WsTransport) Restart() {
s.usageMonitor = web.NewDataStore(fmt.Sprintf(":%v", s.config.WebPort), ctx, s.config.SnifferLog, s.config.Sniffer, &s.config.TunnelStatus, s.logger)
s.config.TunnelStatus = ""

go s.Start()
// set the log level again
s.logger.SetLevel(level)

go s.Start()
}

func (s *WsTransport) channelHandler() {
ticker := time.NewTicker(s.config.Heartbeat)
defer ticker.Stop()

// Channel to receive the message or error
messageChan := make(chan []byte, 1)
messageChan := make(chan byte, 10)

// Separate goroutine to continuously listen for messages
go func() {
Expand All @@ -134,15 +139,16 @@ func (s *WsTransport) channelHandler() {
return

default:
_, message, err := s.controlChannel.ReadMessage()
_, msg, err := s.controlChannel.ReadMessage()
// Exit if there's an error
if err != nil {
s.logger.Error("failed to read from channel connection. ", err)
close(messageChan) // Closing the channel to signal completion
go s.Restart()
if s.cancel != nil {
s.logger.Error("failed to read from channel connection. ", err)
go s.Restart()
}
return
}
messageChan <- message
messageChan <- msg[0]
}
}
}()
Expand All @@ -169,27 +175,36 @@ func (s *WsTransport) channelHandler() {
}
s.logger.Debug("heartbeat signal sent successfully")

case message, ok := <-messageChan:
case msg, ok := <-messageChan:
if !ok {
s.logger.Error("channel closed, likely due to an error in WebSocket read")
return
}
switch msg {
case utils.SG_HB:
s.logger.Trace("heartbeat signal received successfully")

// Handle specific control channel messages
if bytes.Equal(message, []byte{utils.SG_Closed}) {
case utils.SG_Closed:
s.logger.Warn("control channel has been closed by the client")
s.Restart()
return

default:
s.logger.Errorf("unexpected response from channel: %v", msg)
go s.Restart()
return
}

}
}
}

func (s *WsTransport) tunnelListener() {
addr := s.config.BindAddr
upgrader := websocket.Upgrader{
ReadBufferSize: 16 * 1024,
WriteBufferSize: 16 * 1024,
ReadBufferSize: 16 * 1024,
WriteBufferSize: 16 * 1024,
HandshakeTimeout: 45 * time.Second,
CheckOrigin: func(r *http.Request) bool {
return true
},
Expand All @@ -198,7 +213,7 @@ func (s *WsTransport) tunnelListener() {
// Create an HTTP server
server := &http.Server{
Addr: addr,
IdleTimeout: 600 * time.Second, // HERE
IdleTimeout: -1,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.logger.Tracef("received http request from %s", r.RemoteAddr)

Expand Down Expand Up @@ -245,7 +260,6 @@ func (s *WsTransport) tunnelListener() {

s.config.TunnelStatus = fmt.Sprintf("Connected (%s)", s.config.Mode)

return
} else if r.URL.Path == "/tunnel" {
wsConn := TunnelChannel{
conn: conn,
Expand Down Expand Up @@ -300,8 +314,6 @@ func (s *WsTransport) tunnelListener() {

}



func (s *WsTransport) parsePortMappings() {
for _, portMapping := range s.config.Ports {
parts := strings.Split(portMapping, "=")
Expand Down Expand Up @@ -335,7 +347,7 @@ func (s *WsTransport) parsePortMappings() {
for port := startPort; port <= endPort; port++ {
localAddr = fmt.Sprintf(":%d", port)
go s.localListener(localAddr, strconv.Itoa(port)) // Use port as the remoteAddr
time.Sleep(1 * time.Millisecond) // for wide port ranges
time.Sleep(1 * time.Millisecond) // for wide port ranges
}
continue
} else {
Expand Down Expand Up @@ -393,8 +405,6 @@ func (s *WsTransport) parsePortMappings() {
}
}



func (s *WsTransport) localListener(localAddr string, remoteAddr string) {
portListener, err := net.Listen("tcp", localAddr)
if err != nil {
Expand Down
44 changes: 30 additions & 14 deletions internal/server/transport/wsmux.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package transport

import (
"bytes"
"context"
"fmt"
"net"
Expand Down Expand Up @@ -108,6 +107,11 @@ func (s *WsMuxTransport) Restart() {
defer s.restartMutex.Unlock()

s.logger.Info("restarting server...")

// for removing timeout logs
level := s.logger.Level
s.logger.SetLevel(logrus.FatalLevel)

if s.cancel != nil {
s.cancel()
}
Expand All @@ -131,16 +135,18 @@ func (s *WsMuxTransport) Restart() {
s.usageMonitor = web.NewDataStore(fmt.Sprintf(":%v", s.config.WebPort), ctx, s.config.SnifferLog, s.config.Sniffer, &s.config.TunnelStatus, s.logger)
s.config.TunnelStatus = ""

go s.Start()
// set the log level again
s.logger.SetLevel(level)

go s.Start()
}

func (s *WsMuxTransport) channelHandler() {
ticker := time.NewTicker(s.config.Heartbeat)
defer ticker.Stop()

// Channel to receive the message or error
messageChan := make(chan []byte, 1)
messageChan := make(chan byte, 10)

// Separate goroutine to continuously listen for messages
go func() {
Expand All @@ -150,15 +156,16 @@ func (s *WsMuxTransport) channelHandler() {
return

default:
_, message, err := s.controlChannel.ReadMessage()
_, msg, err := s.controlChannel.ReadMessage()
// Exit if there's an error
if err != nil {
s.logger.Error("failed to read from channel connection. ", err)
close(messageChan) // Closing the channel to signal completion
go s.Restart()
if s.cancel != nil {
s.logger.Error("failed to read from channel connection. ", err)
go s.Restart()
}
return
}
messageChan <- message
messageChan <- msg[0]
}
}
}()
Expand All @@ -185,27 +192,36 @@ func (s *WsMuxTransport) channelHandler() {
}
s.logger.Debug("heartbeat signal sent successfully")

case message, ok := <-messageChan:
case msg, ok := <-messageChan:
if !ok {
s.logger.Error("channel closed, likely due to an error in WebSocket read")
return
}
switch msg {
case utils.SG_HB:
s.logger.Trace("heartbeat signal received successfully")

// Handle specific control channel messages
if bytes.Equal(message, []byte{utils.SG_Closed}) {
case utils.SG_Closed:
s.logger.Warn("control channel has been closed by the client")
s.Restart()
return

default:
s.logger.Errorf("unexpected response from channel: %v", msg)
go s.Restart()
return
}

}
}
}

func (s *WsMuxTransport) tunnelListener() {
addr := s.config.BindAddr
upgrader := websocket.Upgrader{
ReadBufferSize: 16 * 1024,
WriteBufferSize: 16 * 1024,
ReadBufferSize: 16 * 1024,
WriteBufferSize: 16 * 1024,
HandshakeTimeout: 45 * time.Second,
CheckOrigin: func(r *http.Request) bool {
return true
},
Expand All @@ -214,7 +230,7 @@ func (s *WsMuxTransport) tunnelListener() {
// Create an HTTP server
server := &http.Server{
Addr: addr,
IdleTimeout: 600 * time.Second,
IdleTimeout: -1,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.logger.Tracef("received http request from %s", r.RemoteAddr)

Expand Down

0 comments on commit b9982ae

Please sign in to comment.