Skip to content
This repository has been archived by the owner on Aug 27, 2020. It is now read-only.

Commit

Permalink
rewrite mux
Browse files Browse the repository at this point in the history
  • Loading branch information
lzjluzijie committed Dec 1, 2018
1 parent cc7dc28 commit 7447db0
Show file tree
Hide file tree
Showing 14 changed files with 328 additions and 350 deletions.
39 changes: 26 additions & 13 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"net"
"time"

"github.com/lzjluzijie/websocks/core/mux"

"net/url"

"github.com/gorilla/websocket"
Expand All @@ -16,13 +18,12 @@ type WebSocksClient struct {
ListenAddr *net.TCPAddr

dialer *websocket.Dialer
//connMutex sync.Mutex
//wsConns []*core.WebSocket
muxWS *core.MuxWebSocket

//todo enable mux
Mux bool

muxGroup *mux.Group

//todo
//control
stopC chan int

Expand All @@ -39,12 +40,19 @@ func (client *WebSocksClient) Run() (err error) {
log.Printf("Start to listen at %s", client.ListenAddr.String())

if client.Mux {
err := client.OpenMux()
if err != nil {
return err
}

go client.ListenMuxWS(client.muxWS)
group := mux.NewGroup(true)
go func() {
//todo
for {
if len(group.MuxWSs) == 0 {
err := client.OpenMux()
if err != nil {
log.Printf(err.Error())
continue
}
}
}
}()
}

go func() {
Expand Down Expand Up @@ -83,13 +91,18 @@ func (client *WebSocksClient) HandleConn(conn *net.TCPConn) {
return
}

//todo mux
host := lc.Host

if client.Mux {
client.DialMuxConn(lc.Host, conn)
err = client.muxGroup.NewMuxConn(host)
if err != nil {
log.Printf(err.Error())
return
}
return
}

ws, err := client.DialWebSocket(core.NewHostHeader(lc.Host))
ws, err := client.DialWebSocket(core.NewHostHeader(host))
if err != nil {
log.Printf(err.Error())
return
Expand Down
48 changes: 4 additions & 44 deletions client/mux.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package client

import (
"net"

"github.com/lzjluzijie/websocks/core"
"github.com/lzjluzijie/websocks/core/mux"
)

func (client *WebSocksClient) OpenMux() (err error) {
wsConn, _, err := client.dialer.Dial(client.ServerURL.String(), map[string][]string{
"WebSocks-Mux": {"mux"},
"WebSocks-Mux": {"v0.15"},
})

if err != nil {
Expand All @@ -17,46 +16,7 @@ func (client *WebSocksClient) OpenMux() (err error) {

ws := core.NewWebSocket(wsConn, client.Stats)

muxWS := core.NewMuxWebSocket(ws)
client.muxWS = muxWS
muxWS := mux.NewMuxWebSocket(ws)
client.muxGroup.AddMuxWS(muxWS)
return
}

func (client *WebSocksClient) DialMuxConn(host string, conn *net.TCPConn) {
muxConn := core.CreateMuxConn(client.muxWS)

err := muxConn.DialMessage(host)
if err != nil {
//log.Printf(err.Error())
err = client.OpenMux()
if err != nil {
//log.Printf(err.Error())
}
return
}

muxConn.MuxWS.PutMuxConn(muxConn)

//log.Printf("dialed mux for %s", host)

muxConn.Run(conn)
return
}

func (client *WebSocksClient) ListenMuxWS(muxWS *core.MuxWebSocket) {
for {
m, err := muxWS.ReceiveMessage()
if err != nil {
//log.Printf(err.Error())
return
}

//get conn and send message
conn := muxWS.GetMuxConn(m.ConnID)
err = conn.HandleMessage(m)
if err != nil {
//log.Printf(err.Error())
continue
}
}
}
101 changes: 0 additions & 101 deletions core/mux.go

This file was deleted.

22 changes: 22 additions & 0 deletions core/mux/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package mux

import "math/rand"

//NewMuxConn creates a new mux connection for client
func (group *Group) NewMuxConn(host string) (err error) {
conn := &Conn{
ID: rand.Uint32(),
wait: make(chan int),
sendMessageID: new(uint32),
}

mh := &MessageHead{
Method: MessageMethodDial,
MessageID: 4294967295,
ConnID: conn.ID,
Length: uint32(len(host)),
}

err = group.Send(mh, []byte(host))
return
}
90 changes: 90 additions & 0 deletions core/mux/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package mux

import (
"io"
"net"
"sync"
"sync/atomic"
)

type Conn struct {
ID uint32

group *Group

mutex sync.Mutex
buf []byte
wait chan int

receiveMessageID uint32
sendMessageID *uint32
}

func (conn *Conn) Write(p []byte) (n int, err error) {
mh := &MessageHead{
Method: MessageMethodData,
ConnID: conn.ID,
MessageID: conn.SendMessageID(),
Length: uint32(len(p)),
}

err = conn.group.Send(mh, p)
if err != nil {
return 0, err
}
return len(p), nil
}

func (conn *Conn) Read(p []byte) (n int, err error) {
if len(conn.buf) == 0 {
//log.Printf("%d buf is 0, waiting", conn.ID)
<-conn.wait
}

conn.mutex.Lock()
//log.Printf("%d buf: %v", conn.buf)
n = copy(p, conn.buf)
conn.buf = conn.buf[n:]
conn.mutex.Unlock()
return
}

func (conn *Conn) HandleMessage(mh *MessageHead, data []byte) (err error) {
//log.Printf("handle message %d %d", mh.ConnID, mh.MessageID)
for {
if conn.receiveMessageID == mh.MessageID {
conn.mutex.Lock()
conn.buf = append(conn.buf, data...)
conn.receiveMessageID++
close(conn.wait)
conn.wait = make(chan int)
conn.mutex.Unlock()
//log.Printf("handled message %d %d", mh.ConnID, mh.MessageID)
return
}
<-conn.wait
}
return
}

func (conn *Conn) SendMessageID() (id uint32) {
id = atomic.LoadUint32(conn.sendMessageID)
atomic.AddUint32(conn.sendMessageID, 1)
return
}

func (conn *Conn) Run(c *net.TCPConn) {
go func() {
_, err := io.Copy(c, conn)
if err != nil {
//log.Printf(err.Error())
}
}()

_, err := io.Copy(conn, c)
if err != nil {
//log.Printf(err.Error())
}

return
}
Loading

0 comments on commit 7447db0

Please sign in to comment.