diff --git a/lib/bridge.go b/lib/bridge.go index 79f4366f..dcca5562 100755 --- a/lib/bridge.go +++ b/lib/bridge.go @@ -25,7 +25,7 @@ func (l *list) Len() int { func newList() *list { l := new(list) - l.connList = make(chan *Conn, 100) + l.connList = make(chan *Conn, 1000) return l } @@ -34,7 +34,8 @@ type Tunnel struct { listener *net.TCPListener //server端监听 signalList map[string]*list //通信 tunnelList map[string]*list //隧道 - sync.Mutex + lock sync.Mutex + tunnelLock sync.Mutex } func newTunnel(tunnelPort int) *Tunnel { @@ -113,7 +114,7 @@ func (s *Tunnel) typeDeal(typeVal string, c *Conn, cFlag string) error { //加到对应的list中 func (s *Tunnel) addList(m map[string]*list, c *Conn, cFlag string) { - s.Lock() + s.lock.Lock() if v, ok := m[cFlag]; ok { v.Add(c) } else { @@ -121,7 +122,7 @@ func (s *Tunnel) addList(m map[string]*list, c *Conn, cFlag string) { l.Add(c) m[cFlag] = l } - s.Unlock() + s.lock.Unlock() } //新建隧道 @@ -142,6 +143,7 @@ retry: //得到一个tcp隧道 func (s *Tunnel) GetTunnel(cFlag string, en, de int, crypt, mux bool) (c *Conn, err error) { + s.tunnelLock.Lock() if v, ok := s.tunnelList[cFlag]; !ok || v.Len() < 3 { //新建通道 go s.newChan(cFlag) } @@ -155,6 +157,7 @@ retry: goto retry } c.WriteConnInfo(en, de, crypt, mux) + s.tunnelLock.Unlock() return } diff --git a/lib/client.go b/lib/client.go index ac958694..fdb2dd3f 100755 --- a/lib/client.go +++ b/lib/client.go @@ -72,7 +72,7 @@ func (s *TRPClient) process(c *Conn) error { return err } case WORK_CHAN: //隧道模式,每次开启10个,加快连接速度 - for i := 0; i < 10; i++ { + for i := 0; i < 5; i++ { go s.dealChan() } case RES_MSG: @@ -85,17 +85,18 @@ func (s *TRPClient) process(c *Conn) error { } //隧道模式处理 -func (s *TRPClient) dealChan() error { +func (s *TRPClient) dealChan() { + var err error //创建一个tcp连接 conn, err := net.Dial("tcp", s.svrAddr) if err != nil { log.Println("connect to ", s.svrAddr, "error:", err) - return err + return } //验证 if _, err := conn.Write([]byte(getverifyval(s.vKey))); err != nil { log.Println("connect to ", s.svrAddr, "error:", err) - return err + return } //默认长连接保持 c := NewConn(conn) @@ -107,31 +108,24 @@ re: typeStr, host, en, de, crypt, mux, err := c.GetHostFromConn() if err != nil { log.Println("get host info error:", err) - return err + c.Close() + return } //与目标建立连接,超时时间为3 server, err := net.DialTimeout(typeStr, host, time.Second*3) if err != nil { log.Println("connect to ", host, "error:", err, mux) - if mux { - s.sendEof(conn, de, crypt) - goto re - } - return err + c.wFail() + goto end } - go relay(NewConn(server), c, de, crypt, mux) - relay(c, NewConn(server), en, crypt, mux) + c.wSuccess() + go relay(server, c.conn, de, crypt, mux) + relay(c.conn, server, en, crypt, mux) +end: if mux { goto re - } - return nil -} -func (s *TRPClient) sendEof(c net.Conn, de int, crypt bool) { - switch de { - case COMPRESS_SNAPY_DECODE: - NewSnappyConn(c, crypt).Write([]byte(IO_EOF)) - case COMPRESS_NONE_DECODE: - NewCryptConn(c, crypt).Write([]byte(IO_EOF)) + } else { + c.Close() } } diff --git a/lib/conn.go b/lib/conn.go index b3f1c5e4..bf95467d 100755 --- a/lib/conn.go +++ b/lib/conn.go @@ -35,9 +35,9 @@ func (s *CryptConn) Write(b []byte) (n int, err error) { if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil { return } - if b, err = GetLenBytes(b); err != nil { - return - } + } + if b, err = GetLenBytes(b); err != nil { + return } _, err = s.conn.Write(b) return @@ -46,29 +46,29 @@ func (s *CryptConn) Write(b []byte) (n int, err error) { //解密读 func (s *CryptConn) Read(b []byte) (n int, err error) { defer func() { - if string(b[:n]) == IO_EOF { + if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF { err = io.EOF n = 0 } }() + var lens int + var buf, bs []byte + c := NewConn(s.conn) + if lens, err = c.GetLen(); err != nil { + return + } + if buf, err = c.ReadLen(lens); err != nil { + return + } if s.crypt { - var lens int - var buf, bs []byte - c := NewConn(s.conn) - if lens, err = c.GetLen(); err != nil { - return - } - if buf, err = c.ReadLen(lens); err != nil { - return - } if bs, err = AesDecrypt(buf, []byte(cryptKey)); err != nil { return } - n = len(bs) - copy(b, bs) - return + } else { + bs = buf } - n, err = s.conn.Read(b) + n = len(bs) + copy(b, bs) return } @@ -105,7 +105,7 @@ func (s *SnappyConn) Write(b []byte) (n int, err error) { //snappy压缩读 包含解密 func (s *SnappyConn) Read(b []byte) (n int, err error) { defer func() { - if string(b[:n]) == IO_EOF { + if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF { err = io.EOF n = 0 } @@ -137,9 +137,12 @@ func NewConn(conn net.Conn) *Conn { } //读取指定长度内容 -func (s *Conn) ReadLen(len int) ([]byte, error) { - buf := make([]byte, len) - if n, err := io.ReadFull(s, buf); err != nil || n != len { +func (s *Conn) ReadLen(cLen int) ([]byte, error) { + if cLen > 65535 { + return nil, errors.New("长度错误") + } + buf := bufPool.Get().([]byte)[:cLen] + if n, err := io.ReadFull(s, buf); err != nil || n != cLen { return buf, errors.New("读取指定长度错误" + err.Error()) } return buf, nil @@ -316,6 +319,16 @@ func (s *Conn) wTest() (int, error) { return s.Write([]byte(TEST_FLAG)) } +//write test +func (s *Conn) wSuccess() (int, error) { + return s.Write([]byte(CONN_SUCCESS)) +} + +//write test +func (s *Conn) wFail() (int, error) { + return s.Write([]byte(CONN_ERROR)) +} + //获取长度+内容 func GetLenBytes(buf []byte) (b []byte, err error) { raw := bytes.NewBuffer([]byte{}) diff --git a/lib/file.go b/lib/file.go index a6a6a25e..2fed1a2d 100644 --- a/lib/file.go +++ b/lib/file.go @@ -4,6 +4,7 @@ import ( "encoding/csv" "encoding/json" "errors" + "github.com/astaxie/beego" "io/ioutil" "log" "os" @@ -33,9 +34,8 @@ type HostList struct { Target string //目标 } -func NewCsv(path string, bridge *Tunnel, runList map[string]interface{}) *Csv { +func NewCsv(bridge *Tunnel, runList map[string]interface{}) *Csv { c := new(Csv) - c.Path = path c.Bridge = bridge c.RunList = runList return c @@ -56,7 +56,7 @@ func (s *Csv) Init() { func (s *Csv) StoreTasksToCsv() { // 创建文件 - csvFile, err := os.Create(s.Path + "tasks.csv") + csvFile, err := os.Create(beego.AppPath + "/conf/tasks.csv") if err != nil { log.Fatalf(err.Error()) } @@ -87,7 +87,7 @@ func (s *Csv) StoreTasksToCsv() { func (s *Csv) LoadTaskFromCsv() { // 打开文件 - file, err := os.Open(s.Path + "tasks.csv") + file, err := os.Open(beego.AppPath + "/conf/tasks.csv") if err != nil { panic(err) } @@ -156,7 +156,7 @@ func (s *Csv) StoreHostToCsv() { func (s *Csv) LoadHostFromCsv() { // 打开文件 - file, err := os.Open(s.Path + "hosts.csv") + file, err := os.Open(beego.AppPath + "/conf/hosts.csv") if err != nil { panic(err) } diff --git a/lib/init.go b/lib/init.go index 38185027..78ceacf1 100644 --- a/lib/init.go +++ b/lib/init.go @@ -4,6 +4,8 @@ import ( "errors" "flag" "log" + "net/http" + _ "net/http/pprof" "reflect" "strings" "sync" @@ -35,9 +37,30 @@ func init() { RunList = make(map[string]interface{}) } +func InitClient() { + flag.Parse() + if *rpMode == "client" { + go func() { + http.ListenAndServe("0.0.0.0:8899", nil) + }() + JsonParse := NewJsonStruct() + if config, err = JsonParse.Load(*configPath); err != nil { + log.Println("配置文件加载失败") + } + stop := make(chan int) + for _, v := range strings.Split(*verifyKey, ",") { + log.Println("客户端启动,连接:", *serverAddr, " 验证令牌:", v) + go NewRPClient(*serverAddr, 1, v).Start() + } + <-stop + } +} func InitMode() { flag.Parse() if *rpMode == "client" { + go func() { + http.ListenAndServe("0.0.0.0:8899", nil) + }() JsonParse := NewJsonStruct() if config, err = JsonParse.Load(*configPath); err != nil { log.Println("配置文件加载失败") @@ -45,7 +68,7 @@ func InitMode() { stop := make(chan int) for _, v := range strings.Split(*verifyKey, ",") { log.Println("客户端启动,连接:", *serverAddr, " 验证令牌:", v) - go NewRPClient(*serverAddr, 3, v).Start() + go NewRPClient(*serverAddr, 1, v).Start() } <-stop } else { @@ -171,7 +194,7 @@ func DelTask(vKey string) error { func InitCsvDb() *Csv { var once sync.Once once.Do(func() { - CsvDb = NewCsv("./conf/", bridge, RunList) + CsvDb = NewCsv( bridge, RunList) CsvDb.Init() }) return CsvDb diff --git a/lib/socks5.go b/lib/socks5.go index 963441ec..8c151feb 100755 --- a/lib/socks5.go +++ b/lib/socks5.go @@ -147,23 +147,29 @@ func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) (proxyConn *Conn, ltype = CONN_TCP } _, err = client.WriteHost(ltype, addr) - return client, nil + var flag string + if flag, err = client.ReadFlag(); err == nil { + if flag != CONN_SUCCESS { + err = errors.New("conn failed") + } + } + return client, err } //conn func (s *Sock5ModeServer) handleConnect(c net.Conn) { proxyConn, err := s.doConnect(c, connectMethod) - if err != nil { - log.Println(err) - c.Close() - } else { - go relay(proxyConn, NewConn(c), s.config.CompressEncode, s.config.Crypt, s.config.Mux) - relay(NewConn(c), proxyConn, s.config.CompressDecode, s.config.Crypt, s.config.Mux) + defer func() { if s.config.Mux { s.bridge.ReturnTunnel(proxyConn, getverifyval(s.config.VerifyKey)) } + }() + if err != nil { + c.Close() + } else { + go relay(proxyConn.conn, c, s.config.CompressEncode, s.config.Crypt, s.config.Mux) + relay(c, proxyConn.conn, s.config.CompressDecode, s.config.Crypt, s.config.Mux) } - } // passive mode @@ -191,14 +197,16 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) { } proxyConn, err := s.doConnect(c, associateMethod) - if err != nil { - c.Close() - } else { - go relay(proxyConn, NewConn(c), s.config.CompressEncode, s.config.Crypt, s.config.Mux) - relay(NewConn(c), proxyConn, s.config.CompressDecode, s.config.Crypt, s.config.Mux) + defer func() { if s.config.Mux { s.bridge.ReturnTunnel(proxyConn, getverifyval(s.config.VerifyKey)) } + }() + if err != nil { + c.Close() + } else { + go relay(proxyConn.conn, c, s.config.CompressEncode, s.config.Crypt, s.config.Mux) + relay(c, proxyConn.conn, s.config.CompressDecode, s.config.Crypt, s.config.Mux) } } diff --git a/lib/tcp.go b/lib/tcp.go index 3631b5de..468c3391 100755 --- a/lib/tcp.go +++ b/lib/tcp.go @@ -20,6 +20,8 @@ const ( WORK_CHAN = "chan" RES_SIGN = "sign" RES_MSG = "msg0" + CONN_SUCCESS = "sucs" + CONN_ERROR = "fail" TEST_FLAG = "tst" CONN_TCP = "tcp" CONN_UDP = "udp" @@ -201,13 +203,19 @@ func (s *TunnelModeServer) dealClient(c *Conn, cnf *ServerConfig, addr string, m log.Println(err) return err } - if method == "CONNECT" { - fmt.Fprint(c, "HTTP/1.1 200 Connection established\r\n") - } else { - link.WriteTo(rb, cnf.CompressEncode, cnf.Crypt) + if flag, err := link.ReadFlag(); err == nil { + if flag == CONN_SUCCESS { + if method == "CONNECT" { + fmt.Fprint(c, "HTTP/1.1 200 Connection established\r\n") + } else { + link.WriteTo(rb, cnf.CompressEncode, cnf.Crypt) + } + go relay(link.conn, c.conn, cnf.CompressEncode, cnf.Crypt, cnf.Mux) + relay(c.conn, link.conn, cnf.CompressDecode, cnf.Crypt, cnf.Mux) + } else { + c.Close() + } } - go relay(link, c, cnf.CompressEncode, cnf.Crypt, cnf.Mux) - relay(c, link, cnf.CompressDecode, cnf.Crypt, cnf.Mux) return nil } @@ -283,6 +291,8 @@ func (s *WebServer) Start() { AddTask(t) beego.BConfig.WebConfig.Session.SessionOn = true log.Println("web管理启动,访问端口为", beego.AppConfig.String("httpport")) + beego.SetViewsPath(beego.AppPath + "/views/") + beego.SetStaticPath("/static/", beego.AppPath+"/static/") beego.Run() } diff --git a/lib/udp.go b/lib/udp.go index 2405b253..2d3baf79 100755 --- a/lib/udp.go +++ b/lib/udp.go @@ -55,22 +55,26 @@ func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) { return } conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt) - go func(addr *net.UDPAddr, conn *Conn) { - defer func() { - if s.config.Mux { - s.bridge.ReturnTunnel(conn, getverifyval(s.config.VerifyKey)) - } - }() - buf := make([]byte, 1024) - conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3))) - n, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt) - if err != nil || err == io.EOF { - conn.Close() - return + if flag, err := conn.ReadFlag(); err == nil { + if flag == CONN_SUCCESS { + go func(addr *net.UDPAddr, conn *Conn) { + defer func() { + if s.config.Mux { + s.bridge.ReturnTunnel(conn, getverifyval(s.config.VerifyKey)) + } + }() + buf := make([]byte, 1024) + conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3))) + n, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt) + if err != nil || err == io.EOF { + conn.Close() + return + } + s.listener.WriteToUDP(buf[:n], addr) + conn.Close() + }(addr, conn) } - s.listener.WriteToUDP(buf[:n], addr) - conn.Close() - }(addr, conn) + } } func (s *UdpModeServer) Close() error { diff --git a/lib/util.go b/lib/util.go index 17a0916a..82fc11b2 100755 --- a/lib/util.go +++ b/lib/util.go @@ -16,6 +16,7 @@ import ( "regexp" "strconv" "strings" + "sync" ) var ( @@ -27,7 +28,7 @@ const ( COMPRESS_NONE_DECODE COMPRESS_SNAPY_ENCODE COMPRESS_SNAPY_DECODE - IO_EOF = "EOF" + IO_EOF = "PROXYEOF" ) //error @@ -131,27 +132,27 @@ func replaceHost(resp []byte) []byte { } //copy -func relay(in, out *Conn, compressType int, crypt, mux bool) { +func relay(in, out net.Conn, compressType int, crypt, mux bool) { switch compressType { case COMPRESS_SNAPY_ENCODE: - copyBuffer(NewSnappyConn(in.conn, crypt), out) + copyBuffer(NewSnappyConn(in, crypt), out) if mux { - NewSnappyConn(in.conn, crypt).Write([]byte(IO_EOF)) out.Close() + NewSnappyConn(in, crypt).Write([]byte(IO_EOF)) } case COMPRESS_SNAPY_DECODE: - copyBuffer(in, NewSnappyConn(out.conn, crypt)) + copyBuffer(in, NewSnappyConn(out, crypt)) if mux { in.Close() } case COMPRESS_NONE_ENCODE: - copyBuffer(NewCryptConn(in.conn, crypt), out) + copyBuffer(NewCryptConn(in, crypt), out) if mux { - NewCryptConn(in.conn, crypt).Write([]byte(IO_EOF)) out.Close() + NewCryptConn(in, crypt).Write([]byte(IO_EOF)) } case COMPRESS_NONE_DECODE: - copyBuffer(in, NewCryptConn(out.conn, crypt)) + copyBuffer(in, NewCryptConn(out, crypt)) if mux { in.Close() } @@ -280,18 +281,15 @@ func GetIntNoerrByStr(str string) int { return i } +var bufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 65535) + }, +} // io.copy的优化版,读取buffer长度原为32*1024,与snappy不同,导致读取出的内容存在差异,不利于解密,特此修改 func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) { - // If the reader has a WriteTo method, use it to do the copy. - // Avoids an allocation and a copy. - if wt, ok := src.(io.WriterTo); ok { - return wt.WriteTo(dst) - } - // Similarly, if the writer has a ReadFrom method, use it to do the copy. - if rt, ok := dst.(io.ReaderFrom); ok { - return rt.ReadFrom(src) - } - buf := make([]byte, 65535) + //TODO 回收问题 + buf := bufPool.Get().([]byte) for { nr, er := src.Read(buf) if nr > 0 {