Skip to content

Commit

Permalink
dialParallel for dualstack support
Browse files Browse the repository at this point in the history
  • Loading branch information
cevatbarisyilmaz committed Nov 6, 2019
1 parent 4d67d5c commit 0c967d1
Showing 1 changed file with 147 additions and 12 deletions.
159 changes: 147 additions & 12 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ type Dialer struct {
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
// will cause the Control function to be called with "tcp4" or "tcp6".
Control func(network, address string, c syscall.RawConn) error

// Underlying dialer
d *net.Dialer
}

// DialContext connects to the address on the named network using
Expand Down Expand Up @@ -140,14 +143,37 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
addresses[i] = net.JoinHostPort(address, port)
}
}
t := net.Dialer{
Timeout: d.Timeout,
Deadline: d.Deadline,
LocalAddr: d.LocalAddr,
FallbackDelay: d.FallbackDelay,
KeepAlive: d.KeepAlive,
Control: d.Control,
var primaries, fallbacks []string
if d.dualStack() && network == "tcp" {
primaries, fallbacks = partition(addresses)
} else {
primaries = addresses
}

var c net.Conn
if len(fallbacks) > 0 {
c, err = d.dialParallel(ctx, network, primaries, fallbacks)
} else {
c, err = d.dialSerial(ctx, network, primaries)
}
if err != nil {
return nil, err
}
return c, nil
}

func (d *Dialer) resolver() Resolver {
if d.Resolver != nil {
return d.Resolver
}
return net.DefaultResolver
}

func (d *Dialer) dualStack() bool {
return d.FallbackDelay >= 0
}

func (d *Dialer) dialSerial(ctx context.Context, network string, addresses []string) (net.Conn, error) {
var firstErr error
for i, addr := range addresses {
saddr := simpleAddr{addr: addr, network: network}
Expand All @@ -171,7 +197,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
defer cancel()
}
c, err := t.DialContext(dialCtx, network, addr)
c, err := d.dialer().DialContext(dialCtx, network, addr)
if err == nil {
return c, nil
}
Expand All @@ -185,11 +211,99 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
return nil, firstErr
}

func (d *Dialer) resolver() Resolver {
if d.Resolver != nil {
return d.Resolver
// dialParallel races two copies of dialSerial, giving the first a
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first
// primary address.
func (d *Dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []string) (net.Conn, error) {
returned := make(chan struct{})
defer close(returned)

type dialResult struct {
net.Conn
error
primary bool
done bool
}
results := make(chan dialResult) // unbuffered

startRacer := func(ctx context.Context, primary bool) {
ras := primaries
if !primary {
ras = fallbacks
}
c, err := d.dialSerial(ctx, network, ras)
select {
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
if c != nil {
_ = c.Close()
}
}
}

var primary, fallback dialResult

// Start the main racer.
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
go startRacer(primaryCtx, true)

// Start the timer for the fallback racer.
fallbackTimer := time.NewTimer(d.fallbackDelay())
defer fallbackTimer.Stop()

for {
select {
case <-fallbackTimer.C:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
go startRacer(fallbackCtx, false)

case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if res.primary {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
return nil, primary.error
}
if res.primary && fallbackTimer.Stop() {
// If we were able to stop the timer, that means it
// was running (hadn't yet started the fallback), but
// we just got an error on the primary path, so start
// the fallback immediately (in 0 nanoseconds).
fallbackTimer.Reset(0)
}
}
}
return net.DefaultResolver
}

func (d *Dialer) fallbackDelay() time.Duration {
if d.FallbackDelay > 0 {
return d.FallbackDelay
} else {
return 300 * time.Millisecond
}
}

func (d *Dialer) dialer() *net.Dialer {
if d.d != nil {
return d.d
}
d.d = &net.Dialer{
Timeout: d.Timeout,
Deadline: d.Deadline,
LocalAddr: d.LocalAddr,
FallbackDelay: d.FallbackDelay,
KeepAlive: d.KeepAlive,
Control: d.Control,
}
return d.d
}

// partialDeadline returns the deadline to use for a single address,
Expand All @@ -215,3 +329,24 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
}
return now.Add(timeout), nil
}

// partition divides given address for dualstack usage
func partition(addresses []string) (primaries []string, fallbacks []string) {
var primaryLabel bool
for i, addr := range addresses {
label := isIPv4(addr)
if i == 0 || label == primaryLabel {
primaryLabel = label
primaries = append(primaries, addr)
} else {
fallbacks = append(fallbacks, addr)
}
}
return
}

// isIPv4 reports whether addr contains an IPv4 address.
func isIPv4(addr string) bool {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
return err == nil && tcpAddr.IP.To16() == nil
}

0 comments on commit 0c967d1

Please sign in to comment.