diff --git a/client/client.go b/client/client.go index 19e6fa7..6efb4bc 100644 --- a/client/client.go +++ b/client/client.go @@ -88,7 +88,8 @@ func (client *WebSocksClient) Stop() { } func (client *WebSocksClient) HandleConn(conn *net.TCPConn) { - log.Println("new socks5 conn") + //debug log + //log.Println("new socks5 conn") lc, err := NewLocalConn(conn) if err != nil { @@ -105,7 +106,8 @@ func (client *WebSocksClient) HandleConn(conn *net.TCPConn) { return } - log.Printf("created #%v", muxConn) + //debug log + log.Printf("created new mux conn: %x %s", muxConn.ID, host) muxConn.Run(conn) return diff --git a/core/mux/client.go b/core/mux/client.go index c67500e..b415673 100644 --- a/core/mux/client.go +++ b/core/mux/client.go @@ -1,11 +1,9 @@ package mux -import "math/rand" - //NewMuxConn creates a new mux connection for client func (group *Group) NewMuxConn(host string) (conn *Conn, err error) { conn = &Conn{ - ID: rand.Uint32(), + ID: group.NextConnID(), wait: make(chan int), sendMessageID: new(uint32), group: group, @@ -24,6 +22,6 @@ func (group *Group) NewMuxConn(host string) (conn *Conn, err error) { return } - group.Conns = append(group.Conns, conn) + group.AddConn(conn) return } diff --git a/core/mux/conn.go b/core/mux/conn.go index d9ea4b7..ada08bd 100644 --- a/core/mux/conn.go +++ b/core/mux/conn.go @@ -1,6 +1,7 @@ package mux import ( + "errors" "io" "log" "net" @@ -8,6 +9,8 @@ import ( "sync/atomic" ) +var ErrConnClosed = errors.New("mux conn closed") + type Conn struct { ID uint32 @@ -17,11 +20,17 @@ type Conn struct { buf []byte wait chan int + closed bool + receiveMessageID uint32 sendMessageID *uint32 } func (conn *Conn) Write(p []byte) (n int, err error) { + if conn.closed { + return 0, ErrConnClosed + } + m := &Message{ Method: MessageMethodData, ConnID: conn.ID, @@ -38,6 +47,10 @@ func (conn *Conn) Write(p []byte) (n int, err error) { } func (conn *Conn) Read(p []byte) (n int, err error) { + if conn.closed { + return 0, ErrConnClosed + } + if len(conn.buf) == 0 { //log.Printf("%d buf is 0, waiting", conn.ID) <-conn.wait @@ -52,7 +65,13 @@ func (conn *Conn) Read(p []byte) (n int, err error) { } func (conn *Conn) HandleMessage(m *Message) (err error) { + if conn.closed { + return ErrConnClosed + } + + //debug log //log.Printf("handle message %d %d", m.ConnID, m.MessageID) + for { if conn.receiveMessageID == m.MessageID { conn.mutex.Lock() @@ -61,7 +80,8 @@ func (conn *Conn) HandleMessage(m *Message) (err error) { close(conn.wait) conn.wait = make(chan int) conn.mutex.Unlock() - log.Printf("handled message %d %d", m.ConnID, m.MessageID) + //debug log + //log.Printf("handled message %d %d", m.ConnID, m.MessageID) return } <-conn.wait @@ -79,14 +99,23 @@ func (conn *Conn) Run(c *net.TCPConn) { go func() { _, err := io.Copy(c, conn) if err != nil { + conn.Close() log.Printf(err.Error()) } }() _, err := io.Copy(conn, c) if err != nil { + conn.Close() log.Printf(err.Error()) } return } + +func (conn *Conn) Close() (err error) { + conn.group.DeleteConn(conn.ID) + //close(conn.wait) + conn.closed = true + return +} diff --git a/core/mux/group.go b/core/mux/group.go index 8580be5..e8a876d 100644 --- a/core/mux/group.go +++ b/core/mux/group.go @@ -2,7 +2,9 @@ package mux import ( "errors" + "fmt" "log" + "sync" "time" ) @@ -11,14 +13,19 @@ type Group struct { MuxWSs []*MuxWebSocket - Conns []*Conn + connMap map[uint32]*Conn + connMapMutex sync.RWMutex + + connID uint32 + connIDMutex sync.Mutex } //true: client group //false: server group func NewGroup(client bool) (group *Group) { group = &Group{ - client: client, + client: client, + connMap: make(map[uint32]*Conn), } return } @@ -41,22 +48,19 @@ func (group *Group) Handle(m *Message) { } //get conn and send message - //todo better way to find conn for { - t := time.Now() - for _, conn := range group.Conns { - if conn.ID == m.ConnID { - log.Printf("find conn id %x", conn.ID) - err := conn.HandleMessage(m) - if err != nil { - log.Println(err.Error()) - return - } - return - } + conn := group.GetConn(m.ConnID) + if conn == nil { + //debug log + err := errors.New(fmt.Sprintf("conn does not exist: %x", m.ConnID)) + log.Println(err.Error()) + log.Printf("%X %X %X %d", m.Method, m.ConnID, m.MessageID, m.Length) + return } - if time.Now().After(t.Add(time.Second * 3)) { - err := errors.New("conn does not exist") + + //this err should be nil or ErrConnClosed + err := conn.HandleMessage(m) + if err != nil { log.Println(err.Error()) return } @@ -64,6 +68,47 @@ func (group *Group) Handle(m *Message) { return } +func (group *Group) AddConn(conn *Conn) { + group.connMapMutex.Lock() + group.connMap[conn.ID] = conn + group.connMapMutex.Unlock() + return +} + +func (group *Group) DeleteConn(id uint32) { + group.connMapMutex.Lock() + delete(group.connMap, id) + group.connMapMutex.Unlock() + return +} + +func (group *Group) GetConn(id uint32) (conn *Conn) { + group.connMapMutex.RLock() + conn = group.connMap[id] + group.connMapMutex.RUnlock() + + if conn == nil { + t := time.Now() + for time.Now().Before(t.Add(time.Second)) { + group.connMapMutex.RLock() + conn = group.connMap[id] + group.connMapMutex.RUnlock() + if conn != nil { + return conn + } + } + } + return +} + +func (group *Group) NextConnID() (id uint32) { + group.connIDMutex.Lock() + group.connID++ + id = group.connID + group.connIDMutex.Unlock() + return +} + func (group *Group) AddMuxWS(muxWS *MuxWebSocket) (err error) { muxWS.group = group group.MuxWSs = append(group.MuxWSs, muxWS) @@ -74,15 +119,13 @@ func (group *Group) AddMuxWS(muxWS *MuxWebSocket) (err error) { func (group *Group) Listen(muxWS *MuxWebSocket) { go func() { for { - log.Println("ready to receive") m, err := muxWS.Receive() if err != nil { - log.Printf(err.Error()) + log.Println(err.Error()) return } go group.Handle(m) } - return }() } diff --git a/core/mux/server.go b/core/mux/server.go index 14ae369..2e8a5c7 100644 --- a/core/mux/server.go +++ b/core/mux/server.go @@ -10,7 +10,10 @@ func (group *Group) ServerHandleMessage(m *Message) (err error) { //accept new conn if m.Method == MessageMethodDial { host := string(m.Data) - log.Printf("start to dial %s", host) + + //debug log + //log.Printf("start to dial %s", host) + conn := &Conn{ ID: m.ConnID, wait: make(chan int), @@ -19,21 +22,24 @@ func (group *Group) ServerHandleMessage(m *Message) (err error) { } //add to group before receive data - group.Conns = append(group.Conns, conn) + group.AddConn(conn) tcpAddr, err := net.ResolveTCPAddr("tcp", host) if err != nil { + conn.Close() log.Printf(err.Error()) return err } tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) if err != nil { + conn.Close() log.Printf(err.Error()) return err } - log.Printf("Accepted mux conn %s", host) + //debug log + log.Printf("Accepted mux conn: %x, %s", conn.ID, host) conn.Run(tcpConn) return err diff --git a/core/mux/websocket.go b/core/mux/websocket.go index b9a84d0..cd06aaa 100644 --- a/core/mux/websocket.go +++ b/core/mux/websocket.go @@ -1,7 +1,9 @@ package mux import ( + "bytes" "io" + "log" "sync" "github.com/lzjluzijie/websocks/core" @@ -12,7 +14,8 @@ type MuxWebSocket struct { group *Group - mutex sync.Mutex + sMutex sync.Mutex + rMutex sync.Mutex } func NewMuxWebSocket(ws *core.WebSocket) (muxWS *MuxWebSocket) { @@ -23,33 +26,62 @@ func NewMuxWebSocket(ws *core.WebSocket) (muxWS *MuxWebSocket) { } func (muxWS *MuxWebSocket) Send(m *Message) (err error) { - muxWS.mutex.Lock() + muxWS.sMutex.Lock() _, err = io.Copy(muxWS, m) if err != nil { + e := muxWS.Close() + if e != nil { + log.Println(e.Error()) + } return } + muxWS.sMutex.Unlock() + //debug log //log.Printf("sent %#v", m) - muxWS.mutex.Unlock() return } func (muxWS *MuxWebSocket) Receive() (m *Message, err error) { + muxWS.rMutex.Lock() + h := make([]byte, 13) + _, err = muxWS.Read(h) if err != nil { + e := muxWS.Close() + if e != nil { + log.Println(e.Error()) + } return } + //debug log + //log.Printf("%d %x",n, h) + m = LoadMessage(h) - data := make([]byte, m.Length) + buf := &bytes.Buffer{} + r := io.LimitReader(muxWS, int64(m.Length)) - _, err = muxWS.Read(data) + _, err = io.Copy(buf, r) if err != nil { + e := muxWS.Close() + if e != nil { + log.Println(e.Error()) + } return } + muxWS.rMutex.Unlock() + + m.Data = buf.Bytes() - m.Data = data + ////debug log //log.Printf("received %#v", m) return } + +func (muxWS *MuxWebSocket) Close() (err error) { + muxWS.group.MuxWSs = nil + err = muxWS.WebSocket.Close() + return +} diff --git a/core/websocket.go b/core/websocket.go index f2f4ad2..f0b8cdd 100644 --- a/core/websocket.go +++ b/core/websocket.go @@ -2,12 +2,13 @@ package core import ( "errors" - "log" "time" "github.com/gorilla/websocket" ) +var ErrWebSocketClosed = errors.New("websocket closed") + type WebSocket struct { conn *websocket.Conn buf []byte @@ -29,11 +30,12 @@ func NewWebSocket(conn *websocket.Conn, stats *Stats) (ws *WebSocket) { func (ws *WebSocket) Read(p []byte) (n int, err error) { if ws.closed == true { - return 0, errors.New("websocket closed") + return 0, ErrWebSocketClosed } if len(ws.buf) == 0 { - log.Println("empty buf, waiting") + //debug log + //log.Println("empty buf, waiting") _, ws.buf, err = ws.conn.ReadMessage() if err != nil { return @@ -51,7 +53,7 @@ func (ws *WebSocket) Read(p []byte) (n int, err error) { func (ws *WebSocket) Write(p []byte) (n int, err error) { if ws.closed == true { - return 0, errors.New("websocket closed") + return 0, ErrWebSocketClosed } err = ws.conn.WriteMessage(websocket.BinaryMessage, p) @@ -68,7 +70,7 @@ func (ws *WebSocket) Write(p []byte) (n int, err error) { } func (ws *WebSocket) Close() (err error) { - ws.conn.Close() ws.closed = true + err = ws.conn.Close() return } diff --git a/websocks.go b/websocks.go index bedafc8..93c9a88 100644 --- a/websocks.go +++ b/websocks.go @@ -3,12 +3,11 @@ package main import ( "errors" "io/ioutil" + "log" "os" "os/exec" "runtime" - "log" - "github.com/lzjluzijie/websocks/client" "github.com/lzjluzijie/websocks/core" "github.com/lzjluzijie/websocks/server" @@ -23,7 +22,7 @@ func main() { todo better log todo better stats */ - Version: "0.15.0", + Version: "0.15.1", Usage: "A secure proxy based on WebSocket.", Description: "websocks.org", Author: "Halulu", @@ -208,6 +207,23 @@ func main() { }, } + ////pprof debug + //go func() { + // f, err := os.Create(fmt.Sprintf("%d.prof", time.Now().Unix())) + // if err != nil { + // panic(err) + // } + // + // err = pprof.StartCPUProfile(f) + // if err != nil { + // panic(err) + // } + // + // time.Sleep(time.Second * 30) + // pprof.StopCPUProfile() + // os.Exit(0) + //}() + err := app.Run(os.Args) if err != nil { log.Printf(err.Error())