diff --git a/core/client/client.go b/core/client/client.go index 572c0ff..a3278ad 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -38,7 +38,9 @@ type WebSocksClient struct { stopC chan int //statistics - CreatedAt time.Time + CreatedAt time.Time + Uploaded int64 + Downloaded int64 } func NewWebSocksClient(config *WebSocksClientConfig) (client *WebSocksClient) { @@ -112,7 +114,7 @@ func (client *WebSocksClient) Listen() (err error) { break } - go client.handleConn(conn) + go client.HandleConn(conn) } return nil } @@ -122,39 +124,23 @@ func (client *WebSocksClient) Stop() { return } -func (client *WebSocksClient) handleConn(conn *net.TCPConn) { - defer conn.Close() - - conn.SetLinger(0) - - err := handShake(conn) - if err != nil { - log.Debugf(err.Error()) - return - } - - _, host, err := getRequest(conn) +func (client *WebSocksClient) HandleConn(conn *net.TCPConn) { + lc, err := NewLocalConn(conn) if err != nil { - log.Debugf(err.Error()) - return - } - - _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) - if err != nil { - log.Debugf(err.Error()) + log.Debug(err.Error()) return } if client.Mux { - client.DialMuxConn(host, conn) + client.DialMuxConn(lc.Host, conn) } else { - client.DialWSConn(host, conn) + client.DialWSConn(lc.Host, lc) } return } -func (client *WebSocksClient) DialWSConn(host string, conn *net.TCPConn) { +func (client *WebSocksClient) DialWSConn(host string, conn io.ReadWriter) { wsConn, _, err := client.Dialer.Dial(client.ServerURL.String(), map[string][]string{ "WebSocks-Host": {host}, }) diff --git a/core/client/conn.go b/core/client/conn.go new file mode 100644 index 0000000..7d9c3f4 --- /dev/null +++ b/core/client/conn.go @@ -0,0 +1,73 @@ +package client + +import ( + "errors" + "net" + "sync/atomic" + "time" +) + +type LocalConn struct { + Host string + + conn *net.TCPConn + + //stats + createdAt time.Time + closed bool + readed uint64 + written uint64 +} + +func NewLocalConn(conn *net.TCPConn) (lc *LocalConn, err error) { + conn.SetLinger(0) + err = handShake(conn) + if err != nil { + return + } + + _, host, err := getRequest(conn) + if err != nil { + return + } + + _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) + if err != nil { + log.Debugf(err.Error()) + return + } + + lc = &LocalConn{ + Host: host, + + conn: conn, + createdAt: time.Now(), + } + return +} + +func (lc *LocalConn) Read(p []byte) (n int, err error) { + if lc.closed { + return 0, errors.New("local conn closed") + } + + n, err = lc.conn.Read(p) + if err != nil { + lc.closed = true + } + atomic.AddUint64(&lc.readed, uint64(n)) + return +} + +func (lc *LocalConn) Write(p []byte) (n int, err error) { + if lc.closed { + return 0, errors.New("local conn closed") + } + + n, err = lc.conn.Write(p) + if err != nil { + lc.closed = true + } + atomic.AddUint64(&lc.written, uint64(n)) + return +}