diff --git a/lib/conn/conn.go b/lib/conn/conn.go index cf29acb3..4e511d19 100755 --- a/lib/conn/conn.go +++ b/lib/conn/conn.go @@ -14,6 +14,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "ehang.io/nps/lib/common" @@ -371,7 +372,10 @@ func CopyWaitGroup(conn1, conn2 net.Conn, crypt bool, snappy bool, rate *rate.Ra //if flow != nil { // flow.Add(in, out) //} - err := goroutine.CopyConnsPool.Invoke(goroutine.NewConns(connHandle, conn2, flow)) + wg := new(sync.WaitGroup) + wg.Add(1) + err := goroutine.CopyConnsPool.Invoke(goroutine.NewConns(connHandle, conn2, flow, wg)) + wg.Wait() if err != nil { logs.Error(err) } diff --git a/lib/goroutine/pool.go b/lib/goroutine/pool.go index ca91d6d5..60717e68 100644 --- a/lib/goroutine/pool.go +++ b/lib/goroutine/pool.go @@ -44,13 +44,15 @@ type Conns struct { conn1 io.ReadWriteCloser // mux connection conn2 net.Conn // outside connection flow *file.Flow + wg *sync.WaitGroup } -func NewConns(c1 io.ReadWriteCloser, c2 net.Conn, flow *file.Flow) Conns { +func NewConns(c1 io.ReadWriteCloser, c2 net.Conn, flow *file.Flow, wg *sync.WaitGroup) Conns { return Conns{ conn1: c1, conn2: c2, flow: flow, + wg: wg, } } @@ -67,6 +69,7 @@ func copyConns(group interface{}) { if conns.flow != nil { conns.flow.Add(in, out) } + conns.wg.Done() } var connCopyPool, _ = ants.NewPoolWithFunc(200000, copyConnGroup, ants.WithNonblocking(false))