diff --git a/proxy.go b/proxy.go index ef2c12a..b23a260 100644 --- a/proxy.go +++ b/proxy.go @@ -23,9 +23,9 @@ func NewProxyConn(proxyUrl string, protocol clientProtocol) (ProxyConn, error) { } switch u.Scheme { case "socks5": - return &Socks5Client{u, protocol}, nil + return NewSocks5Client(u, protocol), nil case "http": - return &HttpClient{u, protocol, nil}, nil + return NewHttpClient(u, protocol), nil default: return &DefaultClient{}, nil } @@ -62,15 +62,20 @@ type clientProtocol struct { type Socks5Client struct { proxyUrl *url.URL clientProtocol + forward proxy.Dialer +} + +func NewSocks5Client(proxyUrl *url.URL, protocol clientProtocol) *Socks5Client { + c := &Socks5Client{proxyUrl, protocol, nil} + if c.transport == "quic" { + c.forward = NewQuicDialer([]string{c.quicProtocol}) + } + return c } // Socks5 implementation of ProxyConn func (s5 *Socks5Client) Dial(network string, address string, timeout time.Duration) (net.Conn, error) { - var forward proxy.Dialer - if s5.transport == "quic" { - forward = NewQuicDialer([]string{s5.quicProtocol}) - } - d, err := proxy.FromURL(s5.proxyUrl, forward) + d, err := proxy.FromURL(s5.proxyUrl, s5.forward) if err != nil { return nil, err } @@ -84,6 +89,14 @@ type HttpClient struct { qd *QuicDialer } +func NewHttpClient(proxyUrl *url.URL, protocol clientProtocol) *HttpClient { + c := &HttpClient{proxyUrl, protocol, nil} + if c.transport == "quic" { + c.qd = NewQuicDialer([]string{c.quicProtocol}) + } + return c +} + func SetHTTPProxyBasicAuth(req *http.Request, username, password string) { auth := username + ":" + password authEncoded := base64.StdEncoding.EncodeToString([]byte(auth)) @@ -100,9 +113,6 @@ func (hc *HttpClient) Dial(network string, address string, timeout time.Duration SetHTTPProxyBasicAuth(req, hc.proxyUrl.User.Username(), password) var proxyConn net.Conn if hc.transport == "quic" { - if hc.qd == nil { - hc.qd = NewQuicDialer([]string{hc.quicProtocol}) - } proxyConn, err = hc.qd.Dial(network, hc.proxyUrl.Host) } else { proxyConn, err = net.DialTimeout("tcp", hc.proxyUrl.Host, timeout) diff --git a/quic.go b/quic.go index b6d5dfa..b6a6263 100644 --- a/quic.go +++ b/quic.go @@ -5,7 +5,7 @@ import ( "github.com/quic-go/quic-go" "golang.org/x/net/context" "net" - "sync" + "runtime" "sync/atomic" ) @@ -13,7 +13,6 @@ type QuicDialer struct { NextProtos []string streams atomic.Uint32 c quic.Connection - dialing sync.RWMutex } func NewQuicDialer(nextProtos []string) *QuicDialer { @@ -28,23 +27,25 @@ func (d *QuicDialer) DialContext(ctx context.Context, network, address string) ( now := d.streams.Add(1) if now > maxStreams { // wait for dialing - d.dialing.RLock() - d.dialing.RUnlock() + for { + if d.streams.Load() < maxStreams { + break + } + runtime.Gosched() + } return d.DialContext(ctx, network, address) } - if now == maxStreams+1 || now == 1 { - d.dialing.Lock() + if now == maxStreams || now == 1 { c, err := quic.DialAddr(ctx, address, &tls.Config{ InsecureSkipVerify: true, NextProtos: d.NextProtos, }, nil) if err != nil { - d.dialing.Unlock() + d.streams.Store(0) return nil, err } d.c = c d.streams.Store(1) - d.dialing.Unlock() } if d.c == nil { // still in initial dialing