From 721005686276ffa6138405b3556710a4f6087e7e Mon Sep 17 00:00:00 2001 From: fregie Date: Wed, 26 May 2021 09:10:05 +0800 Subject: [PATCH] Fix: data race when updating TLS keypair (#338) Co-authored-by: loyalsoldier <10487845+Loyalsoldier@users.noreply.github.com> --- go.mod | 1 + go.sum | 5 +++++ tunnel/tls/server.go | 14 +++++++++++--- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index fe3bbb8b7..45d04dd0c 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.16 require ( github.com/go-sql-driver/mysql v1.6.0 + github.com/huandu/go-clone v1.2.1 github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/refraction-networking/utls v0.0.0-20201210053706-2179f286686b github.com/shadowsocks/go-shadowsocks2 v0.1.5 diff --git a/go.sum b/go.sum index 503b45736..8151f941a 100644 --- a/go.sum +++ b/go.sum @@ -157,6 +157,10 @@ github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0m github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c= +github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U= +github.com/huandu/go-clone v1.2.1 h1:B5FcIXJGCZGBkiZkuC4fjmU/B9cxzuaZlvWWMjgMyHw= +github.com/huandu/go-clone v1.2.1/go.mod h1:bPJ9bAG8fjyAEBRFt6toaGUZcGFGL3f6g5u6yW+9W14= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jhump/protoreflect v1.8.2 h1:k2xE7wcUomeqwY0LDCYA16y4WWfyTcMx5mKhk0d4ua0= @@ -303,6 +307,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= diff --git a/tunnel/tls/server.go b/tunnel/tls/server.go index 25a1b6540..79eac129c 100644 --- a/tunnel/tls/server.go +++ b/tunnel/tls/server.go @@ -13,8 +13,11 @@ import ( "net/http" "os" "strings" + "sync" "time" + "github.com/huandu/go-clone" + "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/config" "github.com/p4gefau1t/trojan-go/log" @@ -33,6 +36,7 @@ type Server struct { alpn []string PreferServerCipher bool keyPair []tls.Certificate + keyPairLock sync.RWMutex httpResp []byte cipherSuite []uint16 sessionTicket bool @@ -84,6 +88,8 @@ func (s *Server) acceptLoop() { NextProtos: s.alpn, KeyLogWriter: s.keyLogger, GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + s.keyPairLock.RLock() + defer s.keyPairLock.RUnlock() sni := s.keyPair[0].Leaf.Subject.CommonName dnsNames := s.keyPair[0].Leaf.DNSNames if s.sni != "" { @@ -99,7 +105,8 @@ func (s *Server) acceptLoop() { if s.verifySNI && !matched { return nil, common.NewError("sni mismatched: " + hello.ServerName + ", expected: " + s.sni) } - return &s.keyPair[0], nil + keyPairCopied := clone.Clone(&s.keyPair[0]).(*tls.Certificate) + return keyPairCopied, nil }, } @@ -201,7 +208,7 @@ func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) { func (s *Server) checkKeyPairLoop(checkRate time.Duration, keyPath string, certPath string, password string) { var lastKeyBytes, lastCertBytes []byte for { - log.Debug("checking cert..") + log.Debug("checking cert...") keyBytes, err := ioutil.ReadFile(keyPath) if err != nil { log.Error(common.NewError("tls failed to check key").Base(err)) @@ -219,8 +226,9 @@ func (s *Server) checkKeyPairLoop(checkRate time.Duration, keyPath string, certP log.Error(common.NewError("tls failed to load new key pair").Base(err)) continue } - // TODO fix race + s.keyPairLock.Lock() s.keyPair = []tls.Certificate{*keyPair} + s.keyPairLock.Unlock() lastKeyBytes = keyBytes lastCertBytes = certBytes }