diff --git a/log/golog/golog.go b/log/golog/golog.go index 44921b9a5..e1481d169 100644 --- a/log/golog/golog.go +++ b/log/golog/golog.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "sync" + "sync/atomic" "time" terminal "golang.org/x/term" @@ -38,7 +39,7 @@ type Logger struct { timestamp bool quiet bool buf colorful.ColorBuffer - logLevel int + logLevel int32 } // Prefix struct define plain and color byte @@ -110,7 +111,7 @@ func New(out FdWriter) *Logger { func (l *Logger) SetLogLevel(level log.LogLevel) { l.mu.Lock() defer l.mu.Unlock() - l.logLevel = int(level) + atomic.StoreInt32(&l.logLevel, int32(level)) } func (l *Logger) SetOutput(w io.Writer) { @@ -295,7 +296,7 @@ func (l *Logger) Output(depth int, prefix Prefix, data string) error { // Fatal print fatal message to output and quit the application with status 1 func (l *Logger) Fatal(v ...interface{}) { - if l.logLevel <= 4 { + if atomic.LoadInt32(&l.logLevel) <= 4 { l.Output(1, FatalPrefix, fmt.Sprintln(v...)) } os.Exit(1) @@ -304,7 +305,7 @@ func (l *Logger) Fatal(v ...interface{}) { // Fatalf print formatted fatal message to output and quit the application // with status 1 func (l *Logger) Fatalf(format string, v ...interface{}) { - if l.logLevel <= 4 { + if atomic.LoadInt32(&l.logLevel) <= 4 { l.Output(1, FatalPrefix, fmt.Sprintf(format, v...)) } os.Exit(1) @@ -312,70 +313,70 @@ func (l *Logger) Fatalf(format string, v ...interface{}) { // Error print error message to output func (l *Logger) Error(v ...interface{}) { - if l.logLevel <= 3 { + if atomic.LoadInt32(&l.logLevel) <= 3 { l.Output(1, ErrorPrefix, fmt.Sprintln(v...)) } } // Errorf print formatted error message to output func (l *Logger) Errorf(format string, v ...interface{}) { - if l.logLevel <= 3 { + if atomic.LoadInt32(&l.logLevel) <= 3 { l.Output(1, ErrorPrefix, fmt.Sprintf(format, v...)) } } // Warn print warning message to output func (l *Logger) Warn(v ...interface{}) { - if l.logLevel <= 2 { + if atomic.LoadInt32(&l.logLevel) <= 2 { l.Output(1, WarnPrefix, fmt.Sprintln(v...)) } } // Warnf print formatted warning message to output func (l *Logger) Warnf(format string, v ...interface{}) { - if l.logLevel <= 2 { + if atomic.LoadInt32(&l.logLevel) <= 2 { l.Output(1, WarnPrefix, fmt.Sprintf(format, v...)) } } // Info print informational message to output func (l *Logger) Info(v ...interface{}) { - if l.logLevel <= 1 { + if atomic.LoadInt32(&l.logLevel) <= 1 { l.Output(1, InfoPrefix, fmt.Sprintln(v...)) } } // Infof print formatted informational message to output func (l *Logger) Infof(format string, v ...interface{}) { - if l.logLevel <= 1 { + if atomic.LoadInt32(&l.logLevel) <= 1 { l.Output(1, InfoPrefix, fmt.Sprintf(format, v...)) } } // Debug print debug message to output if debug output enabled func (l *Logger) Debug(v ...interface{}) { - if l.logLevel == 0 { + if atomic.LoadInt32(&l.logLevel) == 0 { l.Output(1, DebugPrefix, fmt.Sprintln(v...)) } } // Debugf print formatted debug message to output if debug output enabled func (l *Logger) Debugf(format string, v ...interface{}) { - if l.logLevel == 0 { + if atomic.LoadInt32(&l.logLevel) == 0 { l.Output(1, DebugPrefix, fmt.Sprintf(format, v...)) } } // Trace print trace message to output if debug output enabled func (l *Logger) Trace(v ...interface{}) { - if l.logLevel == 0 { + if atomic.LoadInt32(&l.logLevel) == 0 { l.Output(1, TracePrefix, fmt.Sprintln(v...)) } } // Tracef print formatted trace message to output if debug output enabled func (l *Logger) Tracef(format string, v ...interface{}) { - if l.logLevel == 0 { + if atomic.LoadInt32(&l.logLevel) == 0 { l.Output(1, TracePrefix, fmt.Sprintf(format, v...)) } } diff --git a/statistic/memory/memory.go b/statistic/memory/memory.go index 4de5b6ae6..d1b96115a 100644 --- a/statistic/memory/memory.go +++ b/statistic/memory/memory.go @@ -21,12 +21,11 @@ type User struct { recv uint64 lastSent uint64 lastRecv uint64 - speedLock sync.RWMutex sendSpeed uint64 recvSpeed uint64 hash string - ipTableLock sync.RWMutex - ipTable map[string]struct{} + ipTable sync.Map + ipNum int32 maxIPNum int limiterLock sync.RWMutex sendLimiter *rate.Limiter @@ -45,16 +44,15 @@ func (u *User) AddIP(ip string) bool { if u.maxIPNum <= 0 { return true } - u.ipTableLock.Lock() - defer u.ipTableLock.Unlock() - _, found := u.ipTable[ip] + _, found := u.ipTable.Load(ip) if found { return true } - if len(u.ipTable)+1 > u.maxIPNum { + if int(u.ipNum)+1 > u.maxIPNum { return false } - u.ipTable[ip] = struct{}{} + u.ipTable.Store(ip, true) + atomic.AddInt32(&u.ipNum, 1) return true } @@ -62,20 +60,17 @@ func (u *User) DelIP(ip string) bool { if u.maxIPNum <= 0 { return true } - u.ipTableLock.Lock() - defer u.ipTableLock.Unlock() - _, found := u.ipTable[ip] + _, found := u.ipTable.Load(ip) if !found { return false } - delete(u.ipTable, ip) + u.ipTable.Delete(ip) + atomic.AddInt32(&u.ipNum, -1) return true } func (u *User) GetIP() int { - u.ipTableLock.RLock() - defer u.ipTableLock.RUnlock() - return len(u.ipTable) + return int(u.ipNum) } func (u *User) SetIPLimit(n int) { @@ -87,8 +82,8 @@ func (u *User) GetIPLimit() int { } func (u *User) AddTraffic(sent, recv int) { - u.limiterLock.Lock() - defer u.limiterLock.Unlock() + u.limiterLock.RLock() + defer u.limiterLock.RUnlock() if u.sendLimiter != nil && sent >= 0 { u.sendLimiter.WaitN(u.ctx, sent) @@ -119,15 +114,13 @@ func (u *User) GetSpeedLimit() (send, recv int) { u.limiterLock.RLock() defer u.limiterLock.RUnlock() - sendLimit := 0 - recvLimit := 0 if u.sendLimiter != nil { - sendLimit = int(u.sendLimiter.Limit()) + send = int(u.sendLimiter.Limit()) } if u.recvLimiter != nil { - recvLimit = int(u.recvLimiter.Limit()) + recv = int(u.recvLimiter.Limit()) } - return sendLimit, recvLimit + return } func (u *User) Hash() string { @@ -152,83 +145,68 @@ func (u *User) ResetTraffic() (uint64, uint64) { } func (u *User) speedUpdater() { + ticker := time.NewTicker(time.Second) for { select { case <-u.ctx.Done(): return - case <-time.After(time.Second): - u.speedLock.Lock() + case <-ticker.C: sent, recv := u.GetTraffic() - u.sendSpeed = sent - u.lastSent - u.recvSpeed = recv - u.lastRecv - u.lastSent = sent - u.lastRecv = recv - u.speedLock.Unlock() + atomic.StoreUint64(&u.sendSpeed, sent-u.lastSent) + atomic.StoreUint64(&u.recvSpeed, recv-u.lastRecv) + atomic.StoreUint64(&u.lastSent, sent) + atomic.StoreUint64(&u.lastRecv, recv) } } } func (u *User) GetSpeed() (uint64, uint64) { - u.speedLock.RLock() - defer u.speedLock.RUnlock() - return u.sendSpeed, u.recvSpeed + return atomic.LoadUint64(&u.sendSpeed), atomic.LoadUint64(&u.recvSpeed) } type Authenticator struct { - sync.RWMutex - - users map[string]*User + users sync.Map ctx context.Context } func (a *Authenticator) AuthUser(hash string) (bool, statistic.User) { - a.RLock() - defer a.RUnlock() - if user, found := a.users[hash]; found { - return true, user + if user, found := a.users.Load(hash); found { + return true, user.(*User) } return false, nil } func (a *Authenticator) AddUser(hash string) error { - a.Lock() - defer a.Unlock() - if _, found := a.users[hash]; found { + if _, found := a.users.Load(hash); found { return common.NewError("hash " + hash + " is already exist") } ctx, cancel := context.WithCancel(a.ctx) meter := &User{ - hash: hash, - ctx: ctx, - cancel: cancel, - ipTable: make(map[string]struct{}), + hash: hash, + ctx: ctx, + cancel: cancel, } go meter.speedUpdater() - a.users[hash] = meter + a.users.Store(hash, meter) return nil } func (a *Authenticator) DelUser(hash string) error { - a.Lock() - defer a.Unlock() - meter, found := a.users[hash] + meter, found := a.users.Load(hash) if !found { return common.NewError("hash " + hash + " not found") } - meter.Close() - delete(a.users, hash) + meter.(*User).Close() + a.users.Delete(hash) return nil } func (a *Authenticator) ListUsers() []statistic.User { - a.RLock() - defer a.RUnlock() - result := make([]statistic.User, len(a.users)) - i := 0 - for _, u := range a.users { - result[i] = u - i++ - } + result := make([]statistic.User, 0) + a.users.Range(func(k, v interface{}) bool { + result = append(result, v.(*User)) + return true + }) return result } @@ -239,8 +217,7 @@ func (a *Authenticator) Close() error { func NewAuthenticator(ctx context.Context) (statistic.Authenticator, error) { cfg := config.FromContext(ctx, Name).(*Config) u := &Authenticator{ - ctx: ctx, - users: make(map[string]*User), + ctx: ctx, } for _, password := range cfg.Passwords { hash := common.SHA224String(password) diff --git a/tunnel/tls/server.go b/tunnel/tls/server.go index 79eac129c..1a600a84a 100644 --- a/tunnel/tls/server.go +++ b/tunnel/tls/server.go @@ -14,6 +14,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "time" "github.com/huandu/go-clone" @@ -48,7 +49,8 @@ type Server struct { ctx context.Context cancel context.CancelFunc underlay tunnel.Server - nextHTTP bool + nextHTTP int32 + setNextHTTPOnce sync.Once portOverrider map[string]int } @@ -76,7 +78,7 @@ func (s *Server) acceptLoop() { select { case <-s.ctx.Done(): default: - log.Fatal(common.NewError("transport accept error")) + log.Fatal(common.NewError("transport accept error" + err.Error())) } return } @@ -161,7 +163,7 @@ func (s *Server) acceptLoop() { Conn: rewindConn, } } else { - if !s.nextHTTP { + if atomic.LoadInt32(&s.nextHTTP) != 1 { // there is no websocket layer waiting for connections, redirect it log.Error("incoming http request, but no websocket server is listening") s.redir.Redirect(&redirector.Redirection{ @@ -182,7 +184,7 @@ func (s *Server) acceptLoop() { func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) { if _, ok := overlay.(*websocket.Tunnel); ok { - s.nextHTTP = true + atomic.StoreInt32(&s.nextHTTP, 1) log.Debug("next proto http") // websocket overlay select { diff --git a/tunnel/trojan/client.go b/tunnel/trojan/client.go index db171098f..778245d63 100644 --- a/tunnel/trojan/client.go +++ b/tunnel/trojan/client.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "net" + "sync" + "sync/atomic" "time" "github.com/p4gefau1t/trojan-go/api" @@ -27,11 +29,11 @@ const ( ) type OutboundConn struct { - metadata *tunnel.Metadata - sent uint64 - recv uint64 - user statistic.User - headerWritten bool + metadata *tunnel.Metadata + sent uint64 + recv uint64 + user statistic.User + headerWrittenOnce sync.Once net.Conn } @@ -39,8 +41,10 @@ func (c *OutboundConn) Metadata() *tunnel.Metadata { return c.metadata } -func (c *OutboundConn) WriteHeader(payload []byte) error { - if !c.headerWritten { +func (c *OutboundConn) WriteHeader(payload []byte) (bool, error) { + var err error + written := false + c.headerWrittenOnce.Do(func() { hash := c.user.Hash() buf := bytes.NewBuffer(make([]byte, 0, MaxPacketSize)) crlf := []byte{0x0d, 0x0a} @@ -51,36 +55,37 @@ func (c *OutboundConn) WriteHeader(payload []byte) error { if payload != nil { buf.Write(payload) } - _, err := c.Conn.Write(buf.Bytes()) - c.headerWritten = true - return err - } - return common.NewError("trojan header has been written") + _, err = c.Conn.Write(buf.Bytes()) + if err == nil { + written = true + } + }) + return written, err } func (c *OutboundConn) Write(p []byte) (int, error) { - if !c.headerWritten { - err := c.WriteHeader(p) - if err != nil { - return 0, common.NewError("trojan failed to flush header with payload").Base(err) - } + written, err := c.WriteHeader(p) + if err != nil { + return 0, common.NewError("trojan failed to flush header with payload").Base(err) + } + if written { return len(p), nil } n, err := c.Conn.Write(p) c.user.AddTraffic(n, 0) - c.sent += uint64(n) + atomic.AddUint64(&c.sent, uint64(n)) return n, err } func (c *OutboundConn) Read(p []byte) (int, error) { n, err := c.Conn.Read(p) c.user.AddTraffic(0, n) - c.recv += uint64(n) + atomic.AddUint64(&c.recv, uint64(n)) return n, err } func (c *OutboundConn) Close() error { - log.Info("connection to", c.metadata, "closed", "sent:", common.HumanFriendlyTraffic(c.sent), "recv:", common.HumanFriendlyTraffic(c.recv)) + log.Info("connection to", c.metadata, "closed", "sent:", common.HumanFriendlyTraffic(atomic.LoadUint64(&c.sent)), "recv:", common.HumanFriendlyTraffic(atomic.LoadUint64(&c.recv))) return c.Conn.Close() } diff --git a/tunnel/trojan/server.go b/tunnel/trojan/server.go index c6c3c8564..52e46cc38 100644 --- a/tunnel/trojan/server.go +++ b/tunnel/trojan/server.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "sync/atomic" "github.com/p4gefau1t/trojan-go/api" "github.com/p4gefau1t/trojan-go/common" @@ -36,20 +37,21 @@ func (c *InboundConn) Metadata() *tunnel.Metadata { func (c *InboundConn) Write(p []byte) (int, error) { n, err := c.Conn.Write(p) - c.sent += uint64(n) + atomic.AddUint64(&c.sent, uint64(n)) c.user.AddTraffic(n, 0) return n, err } func (c *InboundConn) Read(p []byte) (int, error) { n, err := c.Conn.Read(p) - c.recv += uint64(n) + atomic.AddUint64(&c.recv, uint64(n)) c.user.AddTraffic(0, n) return n, err } func (c *InboundConn) Close() error { - log.Info("user", c.hash, "from", c.Conn.RemoteAddr(), "tunneling to", c.metadata.Address, "closed", "sent:", common.HumanFriendlyTraffic(c.sent), "recv:", common.HumanFriendlyTraffic(c.recv)) + log.Info("user", c.hash, "from", c.Conn.RemoteAddr(), "tunneling to", c.metadata.Address, "closed", + "sent:", common.HumanFriendlyTraffic(atomic.LoadUint64(&c.sent)), "recv:", common.HumanFriendlyTraffic(atomic.LoadUint64(&c.recv))) c.user.DelIP(c.ip) return c.Conn.Close() }