diff --git a/client/client.go b/client/client.go index bdd974d..263c144 100644 --- a/client/client.go +++ b/client/client.go @@ -119,8 +119,5 @@ func (client *WebSocksClient) DialWebSocket(header map[string][]string) (ws *cor } ws = core.NewWebSocket(wsConn, client.Stats) - //client.connMutex.Lock() - //client.wsConns = append(client.wsConns, ws) - //client.connMutex.Unlock() return } diff --git a/core/mux/client.go b/core/mux/client.go index 6dfd009..8773445 100644 --- a/core/mux/client.go +++ b/core/mux/client.go @@ -10,13 +10,14 @@ func (group *Group) NewMuxConn(host string) (err error) { sendMessageID: new(uint32), } - mh := &MessageHead{ + m := &Message{ Method: MessageMethodDial, MessageID: 4294967295, ConnID: conn.ID, Length: uint32(len(host)), + Data: []byte(host), } - err = group.Send(mh, []byte(host)) + err = group.Send(m) return } diff --git a/core/mux/conn.go b/core/mux/conn.go index 79688f7..55f0cf6 100644 --- a/core/mux/conn.go +++ b/core/mux/conn.go @@ -21,14 +21,14 @@ type Conn struct { } func (conn *Conn) Write(p []byte) (n int, err error) { - mh := &MessageHead{ + mh := &Message{ Method: MessageMethodData, ConnID: conn.ID, MessageID: conn.SendMessageID(), Length: uint32(len(p)), } - err = conn.group.Send(mh, p) + err = conn.group.Send(mh) if err != nil { return 0, err } @@ -49,12 +49,12 @@ func (conn *Conn) Read(p []byte) (n int, err error) { return } -func (conn *Conn) HandleMessage(mh *MessageHead, data []byte) (err error) { +func (conn *Conn) HandleMessage(m *Message) (err error) { //log.Printf("handle message %d %d", mh.ConnID, mh.MessageID) for { - if conn.receiveMessageID == mh.MessageID { + if conn.receiveMessageID == m.MessageID { conn.mutex.Lock() - conn.buf = append(conn.buf, data...) + conn.buf = append(conn.buf, m.Data...) conn.receiveMessageID++ close(conn.wait) conn.wait = make(chan int) diff --git a/core/mux/group.go b/core/mux/group.go index c650e8b..85ea4df 100644 --- a/core/mux/group.go +++ b/core/mux/group.go @@ -2,9 +2,6 @@ package mux import ( "errors" - "log" - "math/rand" - "net" ) type Group struct { @@ -26,47 +23,22 @@ func NewGroup(client bool) (group *Group) { return } -func (group *Group) Send(mh *MessageHead, data []byte) (err error) { +func (group *Group) Send(m *Message) (err error) { //todo - err = group.MuxWSs[0].Send(mh, data) + err = group.MuxWSs[0].Send(m) return } -func (group *Group) Receive(mh *MessageHead, data []byte) (err error) { +func (group *Group) Receive(m *Message) (err error) { if !group.client { - //accept new conn - if mh.Method == MessageMethodDial { - host := string(data) - conn := &Conn{ - ID: rand.Uint32(), - wait: make(chan int), - sendMessageID: new(uint32), - } - - tcpAddr, err := net.ResolveTCPAddr("tcp", host) - if err != nil { - log.Printf(err.Error()) - return err - } - - tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) - if err != nil { - log.Printf(err.Error()) - return err - } - - log.Printf("Accepted mux conn %s", host) - - conn.Run(tcpConn) - return err - } + group.HandleMessage(m) } //get conn and send message //todo better way to find conn for _, conn := range group.Conns { - if conn.ID == mh.ConnID { - err = conn.HandleMessage(mh, data) + if conn.ID == m.ConnID { + err = conn.HandleMessage(m) if err != nil { return } @@ -77,11 +49,6 @@ func (group *Group) Receive(mh *MessageHead, data []byte) (err error) { return } -func (group *Group) Start() (err error) { - - return -} - func (group *Group) AddMuxWS(muxWS *MuxWebSocket) (err error) { muxWS.group = group group.MuxWSs = append(group.MuxWSs, muxWS) diff --git a/core/mux/message.go b/core/mux/message.go index b933660..39aca5c 100644 --- a/core/mux/message.go +++ b/core/mux/message.go @@ -1,13 +1,50 @@ package mux +import ( + "bytes" + "encoding/binary" + "io" +) + const ( MessageMethodData = iota MessageMethodDial ) -type MessageHead struct { +//MessageHeadLength = 13 +type Message struct { Method uint8 ConnID uint32 MessageID uint32 Length uint32 + Data []byte + + r io.Reader + buf []byte +} + +func (m *Message) Read(p []byte) (n int, err error) { + if m.r == nil { + h := make([]byte, 13) + h[0] = m.Method + binary.BigEndian.PutUint32(h[1:5], m.ConnID) + binary.BigEndian.PutUint32(h[5:9], m.MessageID) + binary.BigEndian.PutUint32(h[9:13], m.Length) + m.r = bytes.NewReader(append(h, m.Data...)) + } + + n, err = m.Read(p) + return len(p), nil +} + +func (m *Message) Write(p []byte) (n int, err error) { + m.buf = append(m.buf, p...) + if len(m.buf) >= 13 { + m.Method = m.buf[0] + m.ConnID = binary.BigEndian.Uint32(m.buf[1:5]) + m.MessageID = binary.BigEndian.Uint32(m.buf[5:9]) + m.Length = binary.BigEndian.Uint32(m.buf[9:13]) + m.Data = m.buf[13:] + } + return len(p), nil } diff --git a/core/mux/server.go b/core/mux/server.go new file mode 100644 index 0000000..1f52a89 --- /dev/null +++ b/core/mux/server.go @@ -0,0 +1,38 @@ +package mux + +import ( + "log" + "math/rand" + "net" +) + +//HandleMessage is a server group function +func (group *Group) HandleMessage(m *Message) (err error) { + //accept new conn + if m.Method == MessageMethodDial { + host := string(m.Data) + conn := &Conn{ + ID: rand.Uint32(), + wait: make(chan int), + sendMessageID: new(uint32), + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", host) + if err != nil { + log.Printf(err.Error()) + return err + } + + tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + log.Printf(err.Error()) + return err + } + + log.Printf("Accepted mux conn %s", host) + + conn.Run(tcpConn) + return err + } + return +} diff --git a/core/mux/websocket.go b/core/mux/websocket.go index 9fd094f..b30f7f6 100644 --- a/core/mux/websocket.go +++ b/core/mux/websocket.go @@ -1,7 +1,7 @@ package mux import ( - "encoding/binary" + "io" "log" "sync" @@ -23,13 +23,8 @@ func NewMuxWebSocket(ws *core.WebSocket) (muxWS *MuxWebSocket) { return } -func (muxWS *MuxWebSocket) Send(m *MessageHead, data []byte) (err error) { - err = binary.Write(muxWS, binary.BigEndian, m) - if err != nil { - return - } - - _, err = muxWS.Write(data) +func (muxWS *MuxWebSocket) Send(m *Message) (err error) { + _, err = io.Copy(muxWS, m) if err != nil { return } @@ -38,13 +33,8 @@ func (muxWS *MuxWebSocket) Send(m *MessageHead, data []byte) (err error) { return } -func (muxWS *MuxWebSocket) Receive(m *MessageHead, data []byte) (err error) { - err = binary.Read(muxWS, binary.BigEndian, m) - if err != nil { - return - } - - _, err = muxWS.Read(data) +func (muxWS *MuxWebSocket) Receive(m *Message) (err error) { + _, err = io.Copy(m, muxWS) if err != nil { return } @@ -56,16 +46,15 @@ func (muxWS *MuxWebSocket) Receive(m *MessageHead, data []byte) (err error) { func (muxWS *MuxWebSocket) Listen() { go func() { for { - mh := &MessageHead{} - data := make([]byte, 0) - err := muxWS.Receive(mh, data) + m := &Message{} + err := muxWS.Receive(m) if err != nil { //todo log.Printf(err.Error()) continue } - muxWS.group.Receive(mh, data) + muxWS.group.Receive(m) } return }() diff --git a/server/server.go b/server/server.go index 8ce915a..accd892 100644 --- a/server/server.go +++ b/server/server.go @@ -42,7 +42,6 @@ func (server *WebSocksServer) HandleWebSocket(w http.ResponseWriter, r *http.Req defer wsConn.Close() ws := core.NewWebSocket(wsConn, server.Stats) - //todo conns //mux //todo multiple clients