Skip to content
This repository has been archived by the owner on Aug 27, 2020. It is now read-only.

Commit

Permalink
manually merge develop
Browse files Browse the repository at this point in the history
  • Loading branch information
lzjluzijie committed Dec 3, 2018
1 parent 108e0ff commit 2cda345
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 43 deletions.
6 changes: 4 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ func (client *WebSocksClient) Stop() {
}

func (client *WebSocksClient) HandleConn(conn *net.TCPConn) {
log.Println("new socks5 conn")
//debug log
//log.Println("new socks5 conn")

lc, err := NewLocalConn(conn)
if err != nil {
Expand All @@ -105,7 +106,8 @@ func (client *WebSocksClient) HandleConn(conn *net.TCPConn) {
return
}

log.Printf("created #%v", muxConn)
//debug log
log.Printf("created new mux conn: %x %s", muxConn.ID, host)

muxConn.Run(conn)
return
Expand Down
6 changes: 2 additions & 4 deletions core/mux/client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package mux

import "math/rand"

//NewMuxConn creates a new mux connection for client
func (group *Group) NewMuxConn(host string) (conn *Conn, err error) {
conn = &Conn{
ID: rand.Uint32(),
ID: group.NextConnID(),
wait: make(chan int),
sendMessageID: new(uint32),
group: group,
Expand All @@ -24,6 +22,6 @@ func (group *Group) NewMuxConn(host string) (conn *Conn, err error) {
return
}

group.Conns = append(group.Conns, conn)
group.AddConn(conn)
return
}
31 changes: 30 additions & 1 deletion core/mux/conn.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package mux

import (
"errors"
"io"
"log"
"net"
"sync"
"sync/atomic"
)

var ErrConnClosed = errors.New("mux conn closed")

type Conn struct {
ID uint32

Expand All @@ -17,11 +20,17 @@ type Conn struct {
buf []byte
wait chan int

closed bool

receiveMessageID uint32
sendMessageID *uint32
}

func (conn *Conn) Write(p []byte) (n int, err error) {
if conn.closed {
return 0, ErrConnClosed
}

m := &Message{
Method: MessageMethodData,
ConnID: conn.ID,
Expand All @@ -38,6 +47,10 @@ func (conn *Conn) Write(p []byte) (n int, err error) {
}

func (conn *Conn) Read(p []byte) (n int, err error) {
if conn.closed {
return 0, ErrConnClosed
}

if len(conn.buf) == 0 {
//log.Printf("%d buf is 0, waiting", conn.ID)
<-conn.wait
Expand All @@ -52,7 +65,13 @@ func (conn *Conn) Read(p []byte) (n int, err error) {
}

func (conn *Conn) HandleMessage(m *Message) (err error) {
if conn.closed {
return ErrConnClosed
}

//debug log
//log.Printf("handle message %d %d", m.ConnID, m.MessageID)

for {
if conn.receiveMessageID == m.MessageID {
conn.mutex.Lock()
Expand All @@ -61,7 +80,8 @@ func (conn *Conn) HandleMessage(m *Message) (err error) {
close(conn.wait)
conn.wait = make(chan int)
conn.mutex.Unlock()
log.Printf("handled message %d %d", m.ConnID, m.MessageID)
//debug log
//log.Printf("handled message %d %d", m.ConnID, m.MessageID)
return
}
<-conn.wait
Expand All @@ -79,14 +99,23 @@ func (conn *Conn) Run(c *net.TCPConn) {
go func() {
_, err := io.Copy(c, conn)
if err != nil {
conn.Close()
log.Printf(err.Error())
}
}()

_, err := io.Copy(conn, c)
if err != nil {
conn.Close()
log.Printf(err.Error())
}

return
}

func (conn *Conn) Close() (err error) {
conn.group.DeleteConn(conn.ID)
//close(conn.wait)
conn.closed = true
return
}
81 changes: 62 additions & 19 deletions core/mux/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package mux

import (
"errors"
"fmt"
"log"
"sync"
"time"
)

Expand All @@ -11,14 +13,19 @@ type Group struct {

MuxWSs []*MuxWebSocket

Conns []*Conn
connMap map[uint32]*Conn
connMapMutex sync.RWMutex

connID uint32
connIDMutex sync.Mutex
}

//true: client group
//false: server group
func NewGroup(client bool) (group *Group) {
group = &Group{
client: client,
client: client,
connMap: make(map[uint32]*Conn),
}
return
}
Expand All @@ -41,29 +48,67 @@ func (group *Group) Handle(m *Message) {
}

//get conn and send message
//todo better way to find conn
for {
t := time.Now()
for _, conn := range group.Conns {
if conn.ID == m.ConnID {
log.Printf("find conn id %x", conn.ID)
err := conn.HandleMessage(m)
if err != nil {
log.Println(err.Error())
return
}
return
}
conn := group.GetConn(m.ConnID)
if conn == nil {
//debug log
err := errors.New(fmt.Sprintf("conn does not exist: %x", m.ConnID))
log.Println(err.Error())
log.Printf("%X %X %X %d", m.Method, m.ConnID, m.MessageID, m.Length)
return
}
if time.Now().After(t.Add(time.Second * 3)) {
err := errors.New("conn does not exist")

//this err should be nil or ErrConnClosed
err := conn.HandleMessage(m)
if err != nil {
log.Println(err.Error())
return
}
}
return
}

func (group *Group) AddConn(conn *Conn) {
group.connMapMutex.Lock()
group.connMap[conn.ID] = conn
group.connMapMutex.Unlock()
return
}

func (group *Group) DeleteConn(id uint32) {
group.connMapMutex.Lock()
delete(group.connMap, id)
group.connMapMutex.Unlock()
return
}

func (group *Group) GetConn(id uint32) (conn *Conn) {
group.connMapMutex.RLock()
conn = group.connMap[id]
group.connMapMutex.RUnlock()

if conn == nil {
t := time.Now()
for time.Now().Before(t.Add(time.Second)) {
group.connMapMutex.RLock()
conn = group.connMap[id]
group.connMapMutex.RUnlock()
if conn != nil {
return conn
}
}
}
return
}

func (group *Group) NextConnID() (id uint32) {
group.connIDMutex.Lock()
group.connID++
id = group.connID
group.connIDMutex.Unlock()
return
}

func (group *Group) AddMuxWS(muxWS *MuxWebSocket) (err error) {
muxWS.group = group
group.MuxWSs = append(group.MuxWSs, muxWS)
Expand All @@ -74,15 +119,13 @@ func (group *Group) AddMuxWS(muxWS *MuxWebSocket) (err error) {
func (group *Group) Listen(muxWS *MuxWebSocket) {
go func() {
for {
log.Println("ready to receive")
m, err := muxWS.Receive()
if err != nil {
log.Printf(err.Error())
log.Println(err.Error())
return
}

go group.Handle(m)
}
return
}()
}
12 changes: 9 additions & 3 deletions core/mux/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ func (group *Group) ServerHandleMessage(m *Message) (err error) {
//accept new conn
if m.Method == MessageMethodDial {
host := string(m.Data)
log.Printf("start to dial %s", host)

//debug log
//log.Printf("start to dial %s", host)

conn := &Conn{
ID: m.ConnID,
wait: make(chan int),
Expand All @@ -19,21 +22,24 @@ func (group *Group) ServerHandleMessage(m *Message) (err error) {
}

//add to group before receive data
group.Conns = append(group.Conns, conn)
group.AddConn(conn)

tcpAddr, err := net.ResolveTCPAddr("tcp", host)
if err != nil {
conn.Close()
log.Printf(err.Error())
return err
}

tcpConn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
conn.Close()
log.Printf(err.Error())
return err
}

log.Printf("Accepted mux conn %s", host)
//debug log
log.Printf("Accepted mux conn: %x, %s", conn.ID, host)

conn.Run(tcpConn)
return err
Expand Down
44 changes: 38 additions & 6 deletions core/mux/websocket.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package mux

import (
"bytes"
"io"
"log"
"sync"

"github.com/lzjluzijie/websocks/core"
Expand All @@ -12,7 +14,8 @@ type MuxWebSocket struct {

group *Group

mutex sync.Mutex
sMutex sync.Mutex
rMutex sync.Mutex
}

func NewMuxWebSocket(ws *core.WebSocket) (muxWS *MuxWebSocket) {
Expand All @@ -23,33 +26,62 @@ func NewMuxWebSocket(ws *core.WebSocket) (muxWS *MuxWebSocket) {
}

func (muxWS *MuxWebSocket) Send(m *Message) (err error) {
muxWS.mutex.Lock()
muxWS.sMutex.Lock()
_, err = io.Copy(muxWS, m)
if err != nil {
e := muxWS.Close()
if e != nil {
log.Println(e.Error())
}
return
}
muxWS.sMutex.Unlock()

//debug log
//log.Printf("sent %#v", m)
muxWS.mutex.Unlock()
return
}

func (muxWS *MuxWebSocket) Receive() (m *Message, err error) {
muxWS.rMutex.Lock()

h := make([]byte, 13)

_, err = muxWS.Read(h)
if err != nil {
e := muxWS.Close()
if e != nil {
log.Println(e.Error())
}
return
}

//debug log
//log.Printf("%d %x",n, h)

m = LoadMessage(h)
data := make([]byte, m.Length)
buf := &bytes.Buffer{}
r := io.LimitReader(muxWS, int64(m.Length))

_, err = muxWS.Read(data)
_, err = io.Copy(buf, r)
if err != nil {
e := muxWS.Close()
if e != nil {
log.Println(e.Error())
}
return
}
muxWS.rMutex.Unlock()

m.Data = buf.Bytes()

m.Data = data
////debug log
//log.Printf("received %#v", m)
return
}

func (muxWS *MuxWebSocket) Close() (err error) {
muxWS.group.MuxWSs = nil
err = muxWS.WebSocket.Close()
return
}
Loading

0 comments on commit 2cda345

Please sign in to comment.