From d057cd71ffe1aee1426fe60e97451776eaca1945 Mon Sep 17 00:00:00 2001 From: Musixal Date: Wed, 30 Oct 2024 23:37:43 +0330 Subject: [PATCH] discard timeouted (after 3000ms) local connections to improve tunnel stabillity --- internal/server/transport/shared.go | 5 +++-- internal/server/transport/tcp.go | 8 +++++++- internal/server/transport/tcpmux.go | 14 ++++++++++---- internal/server/transport/udp.go | 7 ++++++- internal/server/transport/ws.go | 8 +++++++- internal/server/transport/wsmux.go | 14 ++++++++++---- 6 files changed, 43 insertions(+), 13 deletions(-) diff --git a/internal/server/transport/shared.go b/internal/server/transport/shared.go index d91e8a0..4c38980 100644 --- a/internal/server/transport/shared.go +++ b/internal/server/transport/shared.go @@ -14,8 +14,9 @@ type TunnelChannel struct { // for websocket } type LocalTCPConn struct { - conn net.Conn - remoteAddr string + conn net.Conn + remoteAddr string + timeCreated int64 } type LocalAcceptUDPConn struct { diff --git a/internal/server/transport/tcp.go b/internal/server/transport/tcp.go index a3612c4..5600988 100644 --- a/internal/server/transport/tcp.go +++ b/internal/server/transport/tcp.go @@ -493,7 +493,7 @@ func (s *TcpTransport) acceptLocalConn(listener net.Listener, remoteAddr string) } select { - case s.localChannel <- LocalTCPConn{conn: conn, remoteAddr: remoteAddr}: + case s.localChannel <- LocalTCPConn{conn: conn, remoteAddr: remoteAddr, timeCreated: time.Now().UnixMilli()}: select { case s.reqNewConnChan <- struct{}{}: @@ -521,6 +521,12 @@ func (s *TcpTransport) handleLoop() { case localConn := <-s.localChannel: loop: for { + if time.Now().UnixMilli()-localConn.timeCreated > 3000 { // 3000ms + s.logger.Debugf("timeouted local connection: %d ms", time.Now().UnixMilli()-localConn.timeCreated) + localConn.conn.Close() + break loop + } + select { case <-s.ctx.Done(): return diff --git a/internal/server/transport/tcpmux.go b/internal/server/transport/tcpmux.go index d603817..8059add 100644 --- a/internal/server/transport/tcpmux.go +++ b/internal/server/transport/tcpmux.go @@ -502,7 +502,7 @@ func (s *TcpMuxTransport) acceptLocalConn(listener net.Listener, remoteAddr stri } select { - case s.localChannel <- LocalTCPConn{conn: conn, remoteAddr: remoteAddr}: + case s.localChannel <- LocalTCPConn{conn: conn, remoteAddr: remoteAddr, timeCreated: time.Now().UnixMilli()}: s.logger.Debugf("accepted incoming TCP connection from %s", tcpConn.RemoteAddr().String()) default: // channel is full, discard the connection @@ -549,15 +549,21 @@ func (s *TcpMuxTransport) handleSession(session *smux.Session, next chan struct{ } s.logger.Tracef("stream counter: %v, session counter: %v", atomic.LoadInt32(&s.streamCounter), atomic.LoadInt32(&s.sessionCounter)) - // +1 for Muxed connections counter - done <- struct{}{} - select { case <-s.ctx.Done(): session.Close() return case incomingConn := <-s.localChannel: + if time.Now().UnixMilli()-incomingConn.timeCreated > 3000 { // 3000ms + s.logger.Debugf("timeouted local connection: %d ms", time.Now().UnixMilli()-incomingConn.timeCreated) + incomingConn.conn.Close() + continue + } + + // +1 for mux connection counter + done <- struct{}{} + // +1 for stream counter atomic.AddInt32(&s.streamCounter, 1) diff --git a/internal/server/transport/udp.go b/internal/server/transport/udp.go index e87f573..f938696 100644 --- a/internal/server/transport/udp.go +++ b/internal/server/transport/udp.go @@ -522,7 +522,7 @@ func (s *UdpTransport) localListener(localAddr, remoteAddr string) { // Build the UDP connection object newUDPConn := LocalUDPConn{ - timeCreated: time.Now().UnixNano(), // Just for debugging + timeCreated: time.Now().UnixMilli(), // Just for debugging payload: payloadChan, remoteAddr: remoteAddr, listener: listener, @@ -568,6 +568,11 @@ func (s *UdpTransport) handleLoop(udpChan chan *LocalUDPConn, activeConnections case <-s.ctx.Done(): return case localConn := <-udpChan: + if time.Now().UnixMilli()-localConn.timeCreated > 3000 { // 3000ms + s.logger.Debugf("timeouted local connection: %d ms", time.Now().UnixMilli()-localConn.timeCreated) + continue + } + loop: for { select { diff --git a/internal/server/transport/ws.go b/internal/server/transport/ws.go index cbc7c97..5e8febc 100644 --- a/internal/server/transport/ws.go +++ b/internal/server/transport/ws.go @@ -454,7 +454,7 @@ func (s *WsTransport) acceptLocalConn(listener net.Listener, remoteAddr string) } select { - case s.localChannel <- LocalTCPConn{conn: conn, remoteAddr: remoteAddr}: + case s.localChannel <- LocalTCPConn{conn: conn, remoteAddr: remoteAddr, timeCreated: time.Now().UnixMilli()}: select { case s.reqNewConnChan <- struct{}{}: @@ -482,6 +482,12 @@ func (s *WsTransport) handleLoop() { case localConn := <-s.localChannel: loop: for { + if time.Now().UnixMilli()-localConn.timeCreated > 3000 { // 3000ms + s.logger.Debugf("timeouted local connection: %d ms", time.Now().UnixMilli()-localConn.timeCreated) + localConn.conn.Close() + break loop + } + select { case <-s.ctx.Done(): return diff --git a/internal/server/transport/wsmux.go b/internal/server/transport/wsmux.go index 04ae897..d100ed0 100644 --- a/internal/server/transport/wsmux.go +++ b/internal/server/transport/wsmux.go @@ -476,7 +476,7 @@ func (s *WsMuxTransport) acceptLocalConn(listener net.Listener, remoteAddr strin } select { - case s.localChannel <- LocalTCPConn{conn: conn, remoteAddr: remoteAddr}: + case s.localChannel <- LocalTCPConn{conn: conn, remoteAddr: remoteAddr, timeCreated: time.Now().UnixMilli()}: s.logger.Debugf("accepted incoming TCP connection from %s", tcpConn.RemoteAddr().String()) default: // channel is full, discard the connection @@ -522,15 +522,21 @@ func (s *WsMuxTransport) handleSession(session *smux.Session, next chan struct{} } s.logger.Tracef("stream counter: %v, session counter: %v", atomic.LoadInt32(&s.streamCounter), atomic.LoadInt32(&s.sessionCounter)) - // +1 for Muxed connections counter - done <- struct{}{} - select { case <-s.ctx.Done(): session.Close() return case incomingConn := <-s.localChannel: + if time.Now().UnixMilli()-incomingConn.timeCreated > 3000 { // 3000ms + s.logger.Debugf("timeouted local connection: %d ms", time.Now().UnixMilli()-incomingConn.timeCreated) + incomingConn.conn.Close() + continue + } + + // +1 for Muxed connections counter + done <- struct{}{} + // +1 for stream counter atomic.AddInt32(&s.streamCounter, 1)