From 955b4a3e48d8354c482360e40496fe758e1c348d Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 4 May 2023 19:46:58 +0530 Subject: [PATCH 1/9] add addrDelay and dialRanker --- core/network/network.go | 10 ++ p2p/net/swarm/dial_ranker.go | 131 +++++++++++++++ p2p/net/swarm/dial_ranker_test.go | 270 ++++++++++++++++++++++++++++++ p2p/net/swarm/swarm.go | 20 +++ 4 files changed, 431 insertions(+) create mode 100644 p2p/net/swarm/dial_ranker.go create mode 100644 p2p/net/swarm/dial_ranker_test.go diff --git a/core/network/network.go b/core/network/network.go index 0beaac0f71..215b5373b3 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -184,3 +184,13 @@ type Dialer interface { Notify(Notifiee) StopNotify(Notifiee) } + +// AddrDelay provides an address along with the delay after which the address +// should be dialed +type AddrDelay struct { + Addr ma.Multiaddr + Delay time.Duration +} + +// DialRanker provides a schedule of dialing the provided addresses +type DialRanker func([]ma.Multiaddr) []AddrDelay diff --git a/p2p/net/swarm/dial_ranker.go b/p2p/net/swarm/dial_ranker.go new file mode 100644 index 0000000000..321781fdf8 --- /dev/null +++ b/p2p/net/swarm/dial_ranker.go @@ -0,0 +1,131 @@ +package swarm + +import ( + "time" + + "github.com/libp2p/go-libp2p/core/network" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +const ( + publicTCPDelay = 300 * time.Millisecond + privateTCPDelay = 30 * time.Millisecond + relayDelay = 500 * time.Millisecond +) + +func noDelayRanker(addrs []ma.Multiaddr) []network.AddrDelay { + res := make([]network.AddrDelay, len(addrs)) + for i, a := range addrs { + res[i] = network.AddrDelay{Addr: a, Delay: 0} + } + return res +} + +// defaultDialRanker is the default ranking logic. +// +// we consider private, public ip4, public ip6, relay addresses separately. +// +// In each group, if a quic address is present, we delay tcp addresses. +// +// private: 30 ms delay. +// public ip4: 300 ms delay. +// public ip6: 300 ms delay. +// +// If a quic-v1 address is present we don't dial quic or webtransport address on the same (ip,port) combination. +// If a tcp address is present we don't dial ws or wss address on the same (ip, port) combination. +// If direct addresses are present we delay all relay addresses by 500 millisecond +func defaultDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { + ip4 := make([]ma.Multiaddr, 0, len(addrs)) + ip6 := make([]ma.Multiaddr, 0, len(addrs)) + pvt := make([]ma.Multiaddr, 0, len(addrs)) + relay := make([]ma.Multiaddr, 0, len(addrs)) + + res := make([]network.AddrDelay, 0, len(addrs)) + for _, a := range addrs { + switch { + case !manet.IsPublicAddr(a): + pvt = append(pvt, a) + case isRelayAddr(a): + relay = append(relay, a) + case isProtocolAddr(a, ma.P_IP4): + ip4 = append(ip4, a) + case isProtocolAddr(a, ma.P_IP6): + ip6 = append(ip6, a) + default: + res = append(res, network.AddrDelay{Addr: a, Delay: 0}) + } + } + var roffset time.Duration = 0 + if len(ip4) > 0 || len(ip6) > 0 { + roffset = relayDelay + } + + res = append(res, getAddrDelay(pvt, privateTCPDelay, 0)...) + res = append(res, getAddrDelay(ip4, publicTCPDelay, 0)...) + res = append(res, getAddrDelay(ip6, publicTCPDelay, 0)...) + res = append(res, getAddrDelay(relay, publicTCPDelay, roffset)...) + return res +} + +func getAddrDelay(addrs []ma.Multiaddr, tcpDelay time.Duration, offset time.Duration) []network.AddrDelay { + var hasQuic, hasQuicV1 bool + quicV1Addr := make(map[string]struct{}) + tcpAddr := make(map[string]struct{}) + for _, a := range addrs { + switch { + case isProtocolAddr(a, ma.P_WEBTRANSPORT): + case isProtocolAddr(a, ma.P_QUIC): + hasQuic = true + case isProtocolAddr(a, ma.P_QUIC_V1): + hasQuicV1 = true + quicV1Addr[addrPort(a, ma.P_UDP)] = struct{}{} + case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS): + case isProtocolAddr(a, ma.P_TCP): + tcpAddr[addrPort(a, ma.P_TCP)] = struct{}{} + } + } + + res := make([]network.AddrDelay, 0, len(addrs)) + for _, a := range addrs { + delay := offset + switch { + case isProtocolAddr(a, ma.P_WEBTRANSPORT): + if hasQuicV1 { + if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok { + continue + } + } + case isProtocolAddr(a, ma.P_QUIC): + if hasQuicV1 { + if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok { + continue + } + } + case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS): + if _, ok := tcpAddr[addrPort(a, ma.P_TCP)]; ok { + continue + } + if hasQuic || hasQuicV1 { + delay += tcpDelay + } + case isProtocolAddr(a, ma.P_TCP): + if hasQuic || hasQuicV1 { + delay += tcpDelay + } + } + res = append(res, network.AddrDelay{Addr: a, Delay: delay}) + } + return res +} + +func addrPort(a ma.Multiaddr, p int) string { + c, _ := ma.SplitFirst(a) + port, _ := a.ValueForProtocol(p) + return c.Value() + ":" + port +} + +func isProtocolAddr(a ma.Multiaddr, p int) bool { + _, err := a.ValueForProtocol(p) + return err == nil +} diff --git a/p2p/net/swarm/dial_ranker_test.go b/p2p/net/swarm/dial_ranker_test.go new file mode 100644 index 0000000000..8edf8c2216 --- /dev/null +++ b/p2p/net/swarm/dial_ranker_test.go @@ -0,0 +1,270 @@ +package swarm + +import ( + "fmt" + "sort" + "testing" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/test" + ma "github.com/multiformats/go-multiaddr" +) + +func TestNoDelayRanker(t *testing.T) { + addrs := []ma.Multiaddr{ + ma.StringCast("/ip4/1.2.3.4/tcp/1"), + ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"), + } + addrDelays := noDelayRanker(addrs) + if len(addrs) != len(addrDelays) { + t.Errorf("addrDelay should have the same number of elements as addr") + } + + for _, a := range addrs { + for _, ad := range addrDelays { + if a.Equal(ad.Addr) { + if ad.Delay != 0 { + t.Errorf("expected 0 delay, got %s", ad.Delay) + } + } + } + } +} + +func TestDelayRankerTCPDelay(t *testing.T) { + pquicv1 := ma.StringCast("/ip4/192.168.0.100/udp/1/quic-v1") + ptcp := ma.StringCast("/ip4/192.168.0.100/tcp/1/") + + quic := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicv1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + tcp := ma.StringCast("/ip4/1.2.3.5/tcp/1/") + + tcp6 := ma.StringCast("/ip6/1::1/tcp/1") + quicv16 := ma.StringCast("/ip6/1::2/udp/1/quic-v1") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "quic prioritised over tcp", + addrs: []ma.Multiaddr{quic, tcp}, + output: []network.AddrDelay{ + {Addr: quic, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + }, + }, + { + name: "quic-v1 prioritised over tcp", + addrs: []ma.Multiaddr{quicv1, tcp}, + output: []network.AddrDelay{ + {Addr: quicv1, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + }, + }, + { + name: "ip6 treated separately", + addrs: []ma.Multiaddr{quicv16, tcp6, quic}, + output: []network.AddrDelay{ + {Addr: quicv16, Delay: 0}, + {Addr: quic, Delay: 0}, + {Addr: tcp6, Delay: publicTCPDelay}, + }, + }, + { + name: "private addrs treated separately", + addrs: []ma.Multiaddr{pquicv1, ptcp}, + output: []network.AddrDelay{ + {Addr: pquicv1, Delay: 0}, + {Addr: ptcp, Delay: privateTCPDelay}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Errorf("expected %+v got %+v", tc.output[i], res[i]) + } + } + }) + } +} + +func TestDelayRankerAddrDropped(t *testing.T) { + pquic := ma.StringCast("/ip4/192.168.0.100/udp/1/quic") + pquicv1 := ma.StringCast("/ip4/192.168.0.100/udp/1/quic-v1") + + quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicAddr2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + quicv1Addr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + wt := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport/") + wt2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic-v1/webtransport/") + + quic6 := ma.StringCast("/ip6/1::1/udp/1/quic") + quicv16 := ma.StringCast("/ip6/1::1/udp/1/quic-v1") + + tcp := ma.StringCast("/ip4/1.2.3.5/tcp/1/") + ws := ma.StringCast("/ip4/1.2.3.5/tcp/1/ws") + ws2 := ma.StringCast("/ip4/1.2.3.4/tcp/1/ws") + wss := ma.StringCast("/ip4/1.2.3.5/tcp/1/wss") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "quic dropped when quic-v1 present", + addrs: []ma.Multiaddr{quicAddr, quicv1Addr, quicAddr2}, + output: []network.AddrDelay{ + {Addr: quicv1Addr, Delay: 0}, + {Addr: quicAddr2, Delay: 0}, + }, + }, + { + name: "webtransport dropped when quicv1 present", + addrs: []ma.Multiaddr{quicv1Addr, wt, wt2, quicAddr}, + output: []network.AddrDelay{ + {Addr: quicv1Addr, Delay: 0}, + {Addr: wt2, Delay: 0}, + }, + }, + { + name: "ip6 quic dropped when quicv1 present", + addrs: []ma.Multiaddr{quicv16, quic6}, + output: []network.AddrDelay{ + {Addr: quicv16, Delay: 0}, + }, + }, + { + name: "web socket removed when tcp present", + addrs: []ma.Multiaddr{quicAddr, tcp, ws, wss, ws2}, + output: []network.AddrDelay{ + {Addr: quicAddr, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + {Addr: ws2, Delay: publicTCPDelay}, + }, + }, + { + name: "private quic dropped when quiv1 present", + addrs: []ma.Multiaddr{pquic, pquicv1}, + output: []network.AddrDelay{ + {Addr: pquicv1, Delay: 0}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Errorf("expected %+v got %+v", tc.output[i], res[i]) + } + } + }) + } +} + +func TestDelayRankerRelay(t *testing.T) { + quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicAddr2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + + pid := test.RandPeerIDFatal(t) + r1 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p-circuit/p2p/%s", pid)) + r2 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/udp/1/quic/p2p-circuit/p2p/%s", pid)) + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "relay address delayed", + addrs: []ma.Multiaddr{quicAddr, quicAddr2, r1, r2}, + output: []network.AddrDelay{ + {Addr: quicAddr, Delay: 0}, + {Addr: quicAddr2, Delay: 0}, + {Addr: r2, Delay: relayDelay}, + {Addr: r1, Delay: publicTCPDelay + relayDelay}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Errorf("expected %+v got %+v", tc.output[i], res[i]) + } + } + }) + } +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index cd19e726ed..3f4444f356 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -100,6 +100,23 @@ func WithResourceManager(m network.ResourceManager) Option { } } +// WithNoDialDelay configures swarm to dial all addresses for a peer without +// any delay +func WithNoDialDelay() Option { + return func(s *Swarm) error { + s.dialRanker = noDelayRanker + return nil + } +} + +// WithDialRanker configures swarm to use d as the DialRanker +func WithDialRanker(d network.DialRanker) Option { + return func(s *Swarm) error { + s.dialRanker = d + return nil + } +} + // Swarm is a connection muxer, allowing connections to other peers to // be opened and closed, while still using the same Chan for all // communication. The Chan sends/receives Messages, which note the @@ -163,6 +180,8 @@ type Swarm struct { bwc metrics.Reporter metricsTracer MetricsTracer + + dialRanker network.DialRanker } // NewSwarm constructs a Swarm. @@ -181,6 +200,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts dialTimeout: defaultDialTimeout, dialTimeoutLocal: defaultDialTimeoutLocal, maResolver: madns.DefaultResolver, + dialRanker: defaultDialRanker, } s.conns.m = make(map[peer.ID][]*Conn) From 4a0d1f0a71b4dedff9f0d59cf25424ab49b54a3a Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 5 May 2023 11:13:31 +0530 Subject: [PATCH 2/9] dial scheduler interim --- p2p/net/swarm/dial_scheduler.go | 113 +++++++++++++++++++++++++++ p2p/net/swarm/dial_worker.go | 132 ++++++++++++-------------------- 2 files changed, 162 insertions(+), 83 deletions(-) create mode 100644 p2p/net/swarm/dial_scheduler.go diff --git a/p2p/net/swarm/dial_scheduler.go b/p2p/net/swarm/dial_scheduler.go new file mode 100644 index 0000000000..b61ed07d2c --- /dev/null +++ b/p2p/net/swarm/dial_scheduler.go @@ -0,0 +1,113 @@ +package swarm + +import ( + "math" + "sort" + "time" + + "github.com/libp2p/go-libp2p/core/network" + ma "github.com/multiformats/go-multiaddr" +) + +type dialScheduler struct { + q []network.AddrDelay + pos map[ma.Multiaddr]int + ranker network.DialRanker + dialCh chan ma.Multiaddr + triggerCh chan struct{} + reqCh chan dialSchedule + timer *time.Timer + timerRunning bool + st time.Time +} + +type dialSchedule struct { + addrs []ma.Multiaddr + simConnect bool +} + +func newDialScheduler() *dialScheduler { + return &dialScheduler{ + dialCh: make(chan ma.Multiaddr, 1), + reqCh: make(chan dialSchedule, 1), + pos: make(map[ma.Multiaddr]int), + triggerCh: make(chan struct{}), + } +} + +func (ds *dialScheduler) triggerNext() { + select { + case ds.triggerCh <- struct{}{}: + default: + } +} + +func (ds *dialScheduler) start() { + go ds.loop() +} + +func (ds *dialScheduler) close() { + close(ds.reqCh) +} + +func (ds *dialScheduler) loop() { + ds.st = time.Now() + ds.timer = time.NewTimer(math.MaxInt64) + ds.timerRunning = true + trigger := false + for { + select { + case <-ds.timer.C: + var i int + for i = 0; i < len(ds.q); i++ { + if ds.q[i].Delay == ds.q[0].Delay { + ds.dialCh <- ds.q[i].Addr + delete(ds.pos, ds.q[i].Addr) + } + } + ds.q = ds.q[i:] + case req, ok := <-ds.reqCh: + if !ok { + return + } + var ranking []network.AddrDelay + if req.simConnect { + ranking = noDelayRanker(req.addrs) + } else { + ranking = defaultDialRanker(req.addrs) + } + for _, ad := range ranking { + pos, ok := ds.pos[ad.Addr] + if !ok { + ds.q = append(ds.q, ad) + } + if ds.q[pos].Delay < ad.Delay { + ds.q[pos].Delay = ad.Delay + } + } + sort.Slice(ds.q, func(i, j int) bool { return ds.q[i].Delay < ds.q[j].Delay }) + for i, a := range ds.q { + ds.pos[a.Addr] = i + } + case <-ds.triggerCh: + trigger = true + } + ds.resetTimer(trigger) + trigger = false + } +} + +func (ds *dialScheduler) resetTimer(trigger bool) { + if ds.timerRunning && !ds.timer.Stop() { + <-ds.timer.C + } + ds.timerRunning = false + if len(ds.q) > 0 { + if trigger { + ds.timer.Reset(-1) + } else { + ds.timer.Reset(time.Until(ds.st.Add(ds.q[0].Delay))) + } + ds.timerRunning = true + } +} diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index f805371cc6..f4f7492d95 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -8,7 +8,6 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" ) // ///////////////////////////////////////////////////////////////////////////////// @@ -38,6 +37,7 @@ type addrDial struct { conn *Conn err error requests []int + dialed bool } type dialWorker struct { @@ -56,6 +56,9 @@ type dialWorker struct { // ready when we have more addresses to dial (nextDial is not empty) triggerDial <-chan struct{} + // ds schedules dials for a request + ds *dialScheduler + // for testing wg sync.WaitGroup } @@ -68,6 +71,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { requests: make(map[int]*pendRequest), pending: make(map[ma.Multiaddr]*addrDial), resch: make(chan dialResult), + ds: newDialScheduler(), } } @@ -79,12 +83,14 @@ func (w *dialWorker) loop() { // used to signal readiness to dial and completion of the dial ready := make(chan struct{}) close(ready) - + currDials := 0 + w.ds.start() loop: for { select { case req, ok := <-w.reqch: if !ok { + w.ds.close() return } @@ -100,10 +106,6 @@ loop: continue loop } - // at this point, len(addrs) > 0 or else it would be error from addrsForDial - // ranke them to process in order - addrs = w.rankAddrs(addrs) - // create the pending request object pr := &pendRequest{ req: req, @@ -114,15 +116,10 @@ loop: pr.addrs[a] = struct{}{} } - // check if any of the addrs has been successfully dialed and accumulate - // errors from complete dials while collecting new addrs to dial/join - var todial []ma.Multiaddr - var tojoin []*addrDial - - for _, a := range addrs { + // check if any of the address have a completed dial already + for a := range pr.addrs { ad, ok := w.pending[a] - if !ok { - todial = append(todial, a) + if !ok || !ad.dialed { continue } @@ -138,13 +135,10 @@ loop: delete(pr.addrs, a) continue } - - // dial is still pending, add to the join list - tojoin = append(tojoin, ad) } - if len(todial) == 0 && len(tojoin) == 0 { - // all request applicable addrs have been dialed, we must have errored + // all request applicable addrs have been dialed, we must have errored + if len(pr.addrs) == 0 { req.resch <- dialResponse{err: pr.err} continue loop } @@ -153,45 +147,54 @@ loop: w.reqno++ w.requests[w.reqno] = pr - for _, ad := range tojoin { - if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { - if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { - ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + var toschedule []ma.Multiaddr + for a := range pr.addrs { + ad, ok := w.pending[a] + if !ok { + w.pending[a] = &addrDial{ + addr: a, + ctx: req.ctx, + requests: []int{w.reqno}, } + } else { + if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { + if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { + ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + } + } + ad.requests = append(ad.requests, w.reqno) } - ad.requests = append(ad.requests, w.reqno) + toschedule = append(toschedule, a) } - if len(todial) > 0 { - for _, a := range todial { - w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} - } - - w.nextDial = append(w.nextDial, todial...) - w.nextDial = w.rankAddrs(w.nextDial) - - // trigger a new dial now to account for the new addrs we added - w.triggerDial = ready + simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) + w.ds.schedule(toschedule, simConnect) + if currDials == 0 { + w.ds.triggerNext() } - case <-w.triggerDial: - for _, addr := range w.nextDial { - // spawn the dial - ad := w.pending[addr] - err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) - if err != nil { - w.dispatchError(ad, err) - } + case addr := <-w.ds.C: + // spawn the dial + ad := w.pending[addr] + if ad.dialed { + continue loop + } + ad.dialed = true + err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) + if err != nil { + w.dispatchError(ad, err) + } else { + currDials++ } - - w.nextDial = nil - w.triggerDial = nil case res := <-w.resch: if res.Conn != nil { w.connected = true } - + currDials-- + if currDials == 0 { + w.ds.triggerNext() + } ad := w.pending[res.Addr] if res.Conn != nil { @@ -274,40 +277,3 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { delete(w.pending, ad.addr) } } - -// ranks addresses in descending order of preference for dialing, with the following rules: -// NonRelay > Relay -// NonWS > WS -// Private > Public -// UDP > TCP -func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - addrTier := func(a ma.Multiaddr) (tier int) { - if isRelayAddr(a) { - tier |= 0b1000 - } - if isExpensiveAddr(a) { - tier |= 0b0100 - } - if !manet.IsPrivateAddr(a) { - tier |= 0b0010 - } - if isFdConsumingAddr(a) { - tier |= 0b0001 - } - - return tier - } - - tiers := make([][]ma.Multiaddr, 16) - for _, a := range addrs { - tier := addrTier(a) - tiers[tier] = append(tiers[tier], a) - } - - result := make([]ma.Multiaddr, 0, len(addrs)) - for _, tier := range tiers { - result = append(result, tier...) - } - - return result -} From da37478f0a80bb56d24c393d4300f60a0647d2f7 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 5 May 2023 15:28:25 +0530 Subject: [PATCH 3/9] scheduler done --- p2p/net/swarm/dial_scheduler.go | 150 ++++++++++++++++++-------------- p2p/net/swarm/dial_worker.go | 117 +++++++++++++------------ 2 files changed, 146 insertions(+), 121 deletions(-) diff --git a/p2p/net/swarm/dial_scheduler.go b/p2p/net/swarm/dial_scheduler.go index b61ed07d2c..bca29097b7 100644 --- a/p2p/net/swarm/dial_scheduler.go +++ b/p2p/net/swarm/dial_scheduler.go @@ -1,44 +1,46 @@ package swarm import ( + "context" "math" "sort" "time" - "github.com/libp2p/go-libp2p/core/network" ma "github.com/multiformats/go-multiaddr" ) type dialScheduler struct { - q []network.AddrDelay - pos map[ma.Multiaddr]int - ranker network.DialRanker - dialCh chan ma.Multiaddr - triggerCh chan struct{} - reqCh chan dialSchedule - timer *time.Timer - timerRunning bool - st time.Time + q []dialTask + tasks map[ma.Multiaddr]*taskState + reqCh chan dialTask + st time.Time } -type dialSchedule struct { - addrs []ma.Multiaddr - simConnect bool +type dialTask struct { + addr ma.Multiaddr + delay time.Duration + dialFunc func() + isSimConnect bool } -func newDialScheduler() *dialScheduler { - return &dialScheduler{ - dialCh: make(chan ma.Multiaddr, 1), - reqCh: make(chan dialSchedule, 1), - pos: make(map[ma.Multiaddr]int), - triggerCh: make(chan struct{}), - } +type taskStatus int + +const ( + scheduled taskStatus = iota + dialed + completed +) + +type taskState struct { + status taskStatus + delay time.Duration + isSimConnect bool } -func (ds *dialScheduler) triggerNext() { - select { - case ds.triggerCh <- struct{}{}: - default: +func newDialScheduler() *dialScheduler { + return &dialScheduler{ + reqCh: make(chan dialTask, 1), + tasks: make(map[ma.Multiaddr]*taskState), } } @@ -52,62 +54,82 @@ func (ds *dialScheduler) close() { func (ds *dialScheduler) loop() { ds.st = time.Now() - ds.timer = time.NewTimer(math.MaxInt64) - ds.timerRunning = true + timer := time.NewTimer(math.MaxInt64) + timerRunning := true trigger := false + ctx, cancel := context.WithCancel(context.Background()) + doneCh := make(chan ma.Multiaddr, 1) + currDials := 0 for { select { - case <-ds.timer.C: + case <-timer.C: var i int for i = 0; i < len(ds.q); i++ { - if ds.q[i].Delay == ds.q[0].Delay { - ds.dialCh <- ds.q[i].Addr - delete(ds.pos, ds.q[i].Addr) + if ds.q[i].delay != ds.q[0].delay { + break + } + st, ok := ds.tasks[ds.q[i].addr] + if !ok { + // shouldn't happen but for safety + log.Errorf("no dial scheduled for %s", ds.q[i].addr) + continue } + st.status = dialed + currDials++ + go func(task dialTask) { + task.dialFunc() + select { + case doneCh <- task.addr: + case <-ctx.Done(): + } + }(ds.q[i]) } ds.q = ds.q[i:] - case req, ok := <-ds.reqCh: + timerRunning = false + case task, ok := <-ds.reqCh: if !ok { + cancel() return } - var ranking []network.AddrDelay - if req.simConnect { - ranking = noDelayRanker(req.addrs) - } else { - ranking = defaultDialRanker(req.addrs) - } - for _, ad := range ranking { - pos, ok := ds.pos[ad.Addr] - if !ok { - ds.q = append(ds.q, ad) + st, ok := ds.tasks[task.addr] + if !ok { + ds.q = append(ds.q, task) + ds.tasks[task.addr] = &taskState{ + status: scheduled, + delay: task.delay, + isSimConnect: task.isSimConnect, } - if ds.q[pos].Delay < ad.Delay { - ds.q[pos].Delay = ad.Delay + } else if !st.isSimConnect && task.isSimConnect && st.status == scheduled { + st.isSimConnect = true + st.delay = task.delay + st.status = scheduled + for i, a := range ds.q { + if a.addr.Equal(task.addr) { + ds.q[i] = task + break + } } } - sort.Slice(ds.q, func(i, j int) bool { return ds.q[i].Delay < ds.q[j].Delay }) - for i, a := range ds.q { - ds.pos[a.Addr] = i + sort.Slice(ds.q, func(i, j int) bool { return ds.q[i].delay < ds.q[j].delay }) + case a := <-doneCh: + currDials-- + if currDials == 0 { + trigger = true } - case <-ds.triggerCh: - trigger = true + delete(ds.tasks, a) } - ds.resetTimer(trigger) - trigger = false - } -} - -func (ds *dialScheduler) resetTimer(trigger bool) { - if ds.timerRunning && !ds.timer.Stop() { - <-ds.timer.C - } - ds.timerRunning = false - if len(ds.q) > 0 { - if trigger { - ds.timer.Reset(-1) - } else { - ds.timer.Reset(time.Until(ds.st.Add(ds.q[0].Delay))) + if timerRunning && !timer.Stop() { + <-timer.C } - ds.timerRunning = true + timerRunning = false + if len(ds.q) > 0 { + if trigger { + timer.Reset(-1) + } else { + timer.Reset(time.Until(ds.st.Add(ds.q[0].delay))) + } + timerRunning = true + } + trigger = false } } diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index f4f7492d95..20ab6bd7d6 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -3,6 +3,7 @@ package swarm import ( "context" "sync" + "time" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -37,7 +38,6 @@ type addrDial struct { conn *Conn err error requests []int - dialed bool } type dialWorker struct { @@ -48,17 +48,10 @@ type dialWorker struct { requests map[int]*pendRequest pending map[ma.Multiaddr]*addrDial resch chan dialResult + ds *dialScheduler connected bool // true when a connection has been successfully established - nextDial []ma.Multiaddr - - // ready when we have more addresses to dial (nextDial is not empty) - triggerDial <-chan struct{} - - // ds schedules dials for a request - ds *dialScheduler - // for testing wg sync.WaitGroup } @@ -83,14 +76,12 @@ func (w *dialWorker) loop() { // used to signal readiness to dial and completion of the dial ready := make(chan struct{}) close(ready) - currDials := 0 w.ds.start() loop: for { select { case req, ok := <-w.reqch: if !ok { - w.ds.close() return } @@ -106,20 +97,32 @@ loop: continue loop } + // at this point, len(addrs) > 0 or else it would be error from addrsForDial + // ranke them to process in order + simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) + addrRanking := w.rankAddrs(addrs, simConnect) + addrDelay := make(map[ma.Multiaddr]time.Duration) + // create the pending request object pr := &pendRequest{ req: req, err: &DialError{Peer: w.peer}, addrs: make(map[ma.Multiaddr]struct{}), } - for _, a := range addrs { - pr.addrs[a] = struct{}{} + for _, adelay := range addrRanking { + pr.addrs[adelay.Addr] = struct{}{} + addrDelay[adelay.Addr] = adelay.Delay } - // check if any of the address have a completed dial already - for a := range pr.addrs { + // check if any of the addrs has been successfully dialed and accumulate + // errors from complete dials while collecting new addrs to dial/join + var todial []ma.Multiaddr + var tojoin []*addrDial + + for _, a := range addrs { ad, ok := w.pending[a] - if !ok || !ad.dialed { + if !ok { + todial = append(todial, a) continue } @@ -135,10 +138,13 @@ loop: delete(pr.addrs, a) continue } + + // dial is still pending, add to the join list + tojoin = append(tojoin, ad) } - // all request applicable addrs have been dialed, we must have errored - if len(pr.addrs) == 0 { + if len(todial) == 0 && len(tojoin) == 0 { + // all request applicable addrs have been dialed, we must have errored req.resch <- dialResponse{err: pr.err} continue loop } @@ -147,54 +153,44 @@ loop: w.reqno++ w.requests[w.reqno] = pr - var toschedule []ma.Multiaddr - for a := range pr.addrs { - ad, ok := w.pending[a] - if !ok { - w.pending[a] = &addrDial{ - addr: a, - ctx: req.ctx, - requests: []int{w.reqno}, - } - } else { - if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { - if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { - ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) - } - } - ad.requests = append(ad.requests, w.reqno) - } - toschedule = append(toschedule, a) + for _, ad := range tojoin { + ad.requests = append(ad.requests, w.reqno) } - simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) - w.ds.schedule(toschedule, simConnect) - if currDials == 0 { - w.ds.triggerNext() + if len(todial) > 0 { + for _, a := range todial { + w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} + tojoin = append(tojoin, w.pending[a]) + } } + for _, ad := range tojoin { + addr := ad.addr + delay := addrDelay[addr] + w.ds.reqCh <- dialTask{ + addr: addr, + delay: delay, + dialFunc: func() { + err := w.s.dialNextAddr(req.ctx, w.peer, addr, w.resch) + if err != nil { + select { + case w.resch <- dialResult{ + Conn: nil, + Addr: addr, + Err: err, + }: + case <-req.ctx.Done(): + } + } - case addr := <-w.ds.C: - // spawn the dial - ad := w.pending[addr] - if ad.dialed { - continue loop - } - ad.dialed = true - err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) - if err != nil { - w.dispatchError(ad, err) - } else { - currDials++ + }, + isSimConnect: simConnect, + } } - case res := <-w.resch: if res.Conn != nil { w.connected = true } - currDials-- - if currDials == 0 { - w.ds.triggerNext() - } + ad := w.pending[res.Addr] if res.Conn != nil { @@ -277,3 +273,10 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { delete(w.pending, ad.addr) } } + +func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr, isSimConnect bool) []network.AddrDelay { + if isSimConnect { + return noDelayRanker(addrs) + } + return w.s.dialRanker(addrs) +} From 2daa3242fe2ee9b51d86b992cd1c688d2a223ad9 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 5 May 2023 16:34:02 +0530 Subject: [PATCH 4/9] fix bug --- p2p/net/swarm/dial_scheduler.go | 37 +++++++++++++++++++++++---------- p2p/net/swarm/dial_worker.go | 13 ++++++------ 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/p2p/net/swarm/dial_scheduler.go b/p2p/net/swarm/dial_scheduler.go index bca29097b7..57e5323b7b 100644 --- a/p2p/net/swarm/dial_scheduler.go +++ b/p2p/net/swarm/dial_scheduler.go @@ -10,10 +10,11 @@ import ( ) type dialScheduler struct { - q []dialTask - tasks map[ma.Multiaddr]*taskState - reqCh chan dialTask - st time.Time + q []dialTask + tasks map[ma.Multiaddr]*taskState + reqCh chan dialTask + st time.Time + triggerCh chan struct{} } type dialTask struct { @@ -39,8 +40,9 @@ type taskState struct { func newDialScheduler() *dialScheduler { return &dialScheduler{ - reqCh: make(chan dialTask, 1), - tasks: make(map[ma.Multiaddr]*taskState), + reqCh: make(chan dialTask, 1), + tasks: make(map[ma.Multiaddr]*taskState), + triggerCh: make(chan struct{}), } } @@ -52,6 +54,13 @@ func (ds *dialScheduler) close() { close(ds.reqCh) } +func (ds *dialScheduler) maybeTrigger() { + select { + case ds.triggerCh <- struct{}{}: + default: + } +} + func (ds *dialScheduler) loop() { ds.st = time.Now() timer := time.NewTimer(math.MaxInt64) @@ -99,6 +108,10 @@ func (ds *dialScheduler) loop() { delay: task.delay, isSimConnect: task.isSimConnect, } + for x := range ds.tasks { + log.Errorf("state %s", x) + } + log.Errorf("\n") } else if !st.isSimConnect && task.isSimConnect && st.status == scheduled { st.isSimConnect = true st.delay = task.delay @@ -109,21 +122,23 @@ func (ds *dialScheduler) loop() { break } } + } else { + log.Errorf("dropping %s", task.addr) } sort.Slice(ds.q, func(i, j int) bool { return ds.q[i].delay < ds.q[j].delay }) case a := <-doneCh: currDials-- - if currDials == 0 { - trigger = true - } - delete(ds.tasks, a) + log.Errorf("completed %s", a) + case <-ds.triggerCh: + trigger = true + log.Errorf("triggering dials") } if timerRunning && !timer.Stop() { <-timer.C } timerRunning = false if len(ds.q) > 0 { - if trigger { + if trigger && currDials == 0 { timer.Reset(-1) } else { timer.Reset(time.Until(ds.st.Add(ds.q[0].delay))) diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 20ab6bd7d6..39be179788 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -125,7 +125,6 @@ loop: todial = append(todial, a) continue } - if ad.conn != nil { // dial to this addr was successful, complete the request req.resch <- dialResponse{conn: ad.conn} @@ -157,15 +156,14 @@ loop: ad.requests = append(ad.requests, w.reqno) } - if len(todial) > 0 { - for _, a := range todial { - w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} - tojoin = append(tojoin, w.pending[a]) - } + for _, a := range todial { + w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} + tojoin = append(tojoin, w.pending[a]) } for _, ad := range tojoin { addr := ad.addr delay := addrDelay[addr] + log.Errorf("messaging: %s %s", w.s.LocalPeer(), addr) w.ds.reqCh <- dialTask{ addr: addr, delay: delay, @@ -222,13 +220,14 @@ loop: } // it must be an error -- add backoff if applicable and dispatch - if res.Err != context.Canceled && !w.connected { + if res.Err != context.Canceled && res.Err != ErrDialBackoff && !w.connected { // we only add backoff if there has not been a successful connection // for consistency with the old dialer behavior. w.s.backf.AddBackoff(w.peer, res.Addr) } w.dispatchError(ad, res.Err) + w.ds.maybeTrigger() } } } From 67a20ab8d92a539ca52b94e0dafa56cf75486f69 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 5 May 2023 16:52:43 +0530 Subject: [PATCH 5/9] fix backoff test --- p2p/net/swarm/dial_scheduler.go | 11 ++--------- p2p/net/swarm/dial_worker.go | 2 +- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/p2p/net/swarm/dial_scheduler.go b/p2p/net/swarm/dial_scheduler.go index 57e5323b7b..521622aa76 100644 --- a/p2p/net/swarm/dial_scheduler.go +++ b/p2p/net/swarm/dial_scheduler.go @@ -108,11 +108,7 @@ func (ds *dialScheduler) loop() { delay: task.delay, isSimConnect: task.isSimConnect, } - for x := range ds.tasks { - log.Errorf("state %s", x) - } - log.Errorf("\n") - } else if !st.isSimConnect && task.isSimConnect && st.status == scheduled { + } else if !st.isSimConnect && task.isSimConnect && st.status != dialed { st.isSimConnect = true st.delay = task.delay st.status = scheduled @@ -122,16 +118,13 @@ func (ds *dialScheduler) loop() { break } } - } else { - log.Errorf("dropping %s", task.addr) } sort.Slice(ds.q, func(i, j int) bool { return ds.q[i].delay < ds.q[j].delay }) case a := <-doneCh: currDials-- - log.Errorf("completed %s", a) + delete(ds.tasks, a) case <-ds.triggerCh: trigger = true - log.Errorf("triggering dials") } if timerRunning && !timer.Stop() { <-timer.C diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 39be179788..a2e2c7a9f3 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -82,6 +82,7 @@ loop: select { case req, ok := <-w.reqch: if !ok { + w.ds.close() return } @@ -163,7 +164,6 @@ loop: for _, ad := range tojoin { addr := ad.addr delay := addrDelay[addr] - log.Errorf("messaging: %s %s", w.s.LocalPeer(), addr) w.ds.reqCh <- dialTask{ addr: addr, delay: delay, From e485f2807d82d4589991dfc570d0509d27edcbcc Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 5 May 2023 20:24:32 +0530 Subject: [PATCH 6/9] fix race condition --- p2p/net/swarm/dial_scheduler.go | 47 +++++-- p2p/net/swarm/dial_worker.go | 40 +++--- p2p/net/swarm/dial_worker_test.go | 219 ++++++++++++++++++++++++++++++ 3 files changed, 277 insertions(+), 29 deletions(-) diff --git a/p2p/net/swarm/dial_scheduler.go b/p2p/net/swarm/dial_scheduler.go index 521622aa76..331c965dd8 100644 --- a/p2p/net/swarm/dial_scheduler.go +++ b/p2p/net/swarm/dial_scheduler.go @@ -4,8 +4,10 @@ import ( "context" "math" "sort" + "sync/atomic" "time" + "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" ) @@ -15,6 +17,7 @@ type dialScheduler struct { reqCh chan dialTask st time.Time triggerCh chan struct{} + s *Swarm } type dialTask struct { @@ -22,6 +25,9 @@ type dialTask struct { delay time.Duration dialFunc func() isSimConnect bool + peer peer.ID + ctx context.Context + resCh chan dialResult } type taskStatus int @@ -38,11 +44,12 @@ type taskState struct { isSimConnect bool } -func newDialScheduler() *dialScheduler { +func newDialScheduler(s *Swarm) *dialScheduler { return &dialScheduler{ reqCh: make(chan dialTask, 1), tasks: make(map[ma.Multiaddr]*taskState), triggerCh: make(chan struct{}), + s: s, } } @@ -65,10 +72,10 @@ func (ds *dialScheduler) loop() { ds.st = time.Now() timer := time.NewTimer(math.MaxInt64) timerRunning := true - trigger := false ctx, cancel := context.WithCancel(context.Background()) - doneCh := make(chan ma.Multiaddr, 1) - currDials := 0 + doneCh := make(chan ma.Multiaddr) + var currDials atomic.Int32 + trigger := true for { select { case <-timer.C: @@ -84,13 +91,33 @@ func (ds *dialScheduler) loop() { continue } st.status = dialed - currDials++ + currDials.Add(-1) go func(task dialTask) { - task.dialFunc() + respCh := make(chan dialResult) + err := ds.s.dialNextAddr(task.ctx, task.peer, task.addr, respCh) + var r dialResult + if err != nil { + r = dialResult{ + Conn: nil, + Addr: task.addr, + Err: err, + } + } else { + select { + case r = <-respCh: + case <-ctx.Done(): + return + } + } + currDials.Add(1) select { - case doneCh <- task.addr: + case task.resCh <- r: case <-ctx.Done(): + return + case <-task.ctx.Done(): + return } + }(ds.q[i]) } ds.q = ds.q[i:] @@ -121,23 +148,23 @@ func (ds *dialScheduler) loop() { } sort.Slice(ds.q, func(i, j int) bool { return ds.q[i].delay < ds.q[j].delay }) case a := <-doneCh: - currDials-- delete(ds.tasks, a) case <-ds.triggerCh: trigger = true } + if timerRunning && !timer.Stop() { <-timer.C } timerRunning = false if len(ds.q) > 0 { - if trigger && currDials == 0 { + if trigger && currDials.Load() == 0 { timer.Reset(-1) } else { timer.Reset(time.Until(ds.st.Add(ds.q[0].delay))) } timerRunning = true + trigger = false } - trigger = false } } diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index a2e2c7a9f3..39bc28968b 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -38,6 +38,7 @@ type addrDial struct { conn *Conn err error requests []int + delay time.Duration } type dialWorker struct { @@ -64,7 +65,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { requests: make(map[int]*pendRequest), pending: make(map[ma.Multiaddr]*addrDial), resch: make(chan dialResult), - ds: newDialScheduler(), + ds: newDialScheduler(s), } } @@ -158,30 +159,31 @@ loop: } for _, a := range todial { - w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} - tojoin = append(tojoin, w.pending[a]) + w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}, delay: addrDelay[a]} + addr := a + delay := addrDelay[a] + w.ds.reqCh <- dialTask{ + addr: addr, + delay: delay, + peer: w.peer, + resCh: w.resch, + isSimConnect: simConnect, + ctx: req.ctx, + } } for _, ad := range tojoin { + if ad.delay == addrDelay[ad.addr] { + continue + } addr := ad.addr delay := addrDelay[addr] w.ds.reqCh <- dialTask{ - addr: addr, - delay: delay, - dialFunc: func() { - err := w.s.dialNextAddr(req.ctx, w.peer, addr, w.resch) - if err != nil { - select { - case w.resch <- dialResult{ - Conn: nil, - Addr: addr, - Err: err, - }: - case <-req.ctx.Done(): - } - } - - }, + addr: addr, + delay: delay, + peer: w.peer, + resCh: w.resch, isSimConnect: simConnect, + ctx: req.ctx, } } case res := <-w.resch: diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 2c441106b1..75b823b3f7 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -5,11 +5,13 @@ import ( "crypto/rand" "errors" "fmt" + "net" "sync" "testing" "time" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/sec" @@ -24,6 +26,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/transport/tcp" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) @@ -88,6 +91,19 @@ func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader { return u } +func makeTcpListener(t *testing.T) (net.Listener, ma.Multiaddr) { + t.Helper() + lst, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Error(err) + } + addr, err := manet.FromNetAddr(lst.Addr()) + if err != nil { + t.Error(err) + } + return lst, addr +} + func TestDialWorkerLoopBasic(t *testing.T) { s1 := makeSwarm(t) s2 := makeSwarm(t) @@ -342,3 +358,206 @@ func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) { close(reqch) worker.wg.Wait() } + +func TestDialWorkerLoopRanking(t *testing.T) { + s1 := makeSwarm(t) + s2 := makeSwarm(t) + defer s1.Close() + defer s2.Close() + + var quicAddr, tcpAddr ma.Multiaddr + for _, a := range s2.ListenAddresses() { + if _, err := a.ValueForProtocol(ma.P_QUIC); err == nil { + quicAddr = a + } + if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + tcpAddr = a + } + } + + tcpL1, silAddr1 := makeTcpListener(t) + ch1 := make(chan struct{}) + defer tcpL1.Close() + tcpL2, silAddr2 := makeTcpListener(t) + ch2 := make(chan struct{}) + defer tcpL2.Close() + tcpL3, silAddr3 := makeTcpListener(t) + ch3 := make(chan struct{}) + defer tcpL3.Close() + + acceptAndIgnore := func(ch chan struct{}, l net.Listener) func() { + return func() { + for { + _, err := l.Accept() + if err != nil { + break + } + ch <- struct{}{} + } + } + } + go acceptAndIgnore(ch1, tcpL1)() + go acceptAndIgnore(ch2, tcpL2)() + go acceptAndIgnore(ch3, tcpL3)() + + ranker := func(addrs []ma.Multiaddr) []network.AddrDelay { + res := make([]network.AddrDelay, 0) + for _, a := range addrs { + switch { + case a.Equal(silAddr1): + res = append(res, network.AddrDelay{Addr: a, Delay: 0}) + case a.Equal(silAddr2): + res = append(res, network.AddrDelay{Addr: a, Delay: 1 * time.Second}) + case a.Equal(tcpAddr): + res = append(res, network.AddrDelay{Addr: a, Delay: 2 * time.Second}) + case a.Equal(silAddr3): + res = append(res, network.AddrDelay{Addr: a, Delay: 3 * time.Second}) + default: + t.Errorf("unexpected address %s", a) + } + } + return res + } + + // should connect to quic with both tcp and quic address + s1.dialRanker = ranker + s2addrs := []ma.Multiaddr{tcpAddr, silAddr1, silAddr2, silAddr3} + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2addrs, peerstore.PermanentAddrTTL) + reqch := make(chan dialRequest) + resch := make(chan dialResponse) + worker1 := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker1.loop() + defer worker1.wg.Wait() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case <-ch1: + case <-time.After(1 * time.Second): + t.Fatal("expected dial to tcp1") + case <-resch: + t.Fatalf("didn't expect connection to succeed") + } + select { + case <-ch2: + case <-time.After(2 * time.Second): + t.Fatalf("expected dial to tcp2") + case <-resch: + t.Fatalf("didn't expect connection to succeed") + } + select { + case res := <-resch: + if !res.conn.RemoteMultiaddr().Equal(tcpAddr) { + log.Errorf("invalid connection address. expected %s got %s", tcpAddr, res.conn.RemoteMultiaddr()) + } + case <-time.After(2 * time.Second): + t.Fatalf("expected dial to succeed") + } + close(reqch) + s1.ClosePeer(s2.LocalPeer()) + s1.peers.ClearAddrs(s2.LocalPeer()) + select { + case <-ch3: + t.Errorf("didn't expect tcp call") + case <-time.After(2 * time.Second): + } + + quicFirstRanker := func(addrs []ma.Multiaddr) []network.AddrDelay { + m := make([]network.AddrDelay, 0) + for _, a := range addrs { + if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + m = append(m, network.AddrDelay{Addr: a, Delay: 500 * time.Millisecond}) + } else { + m = append(m, network.AddrDelay{Addr: a, Delay: 0}) + } + } + return m + } + + // tcp should connect after delay + s1.dialRanker = quicFirstRanker + s2.ListenClose(quicAddr) + s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{quicAddr, tcpAddr}, peerstore.PermanentAddrTTL) + reqch = make(chan dialRequest) + resch = make(chan dialResponse) + worker2 := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker2.loop() + defer worker2.wg.Wait() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + t.Fatalf("expected a delay before connecting %s", res.conn.LocalMultiaddr()) + case <-time.After(400 * time.Millisecond): + } + select { + case res := <-resch: + require.NoError(t, res.err) + if _, err := res.conn.LocalMultiaddr().ValueForProtocol(ma.P_TCP); err != nil { + t.Fatalf("expected tcp connection %s", res.conn.LocalMultiaddr()) + } + case <-time.After(1 * time.Second): + t.Fatal("dial didn't complete") + } + close(reqch) + s1.ClosePeer(s2.LocalPeer()) + s1.peers.ClearAddrs(s2.LocalPeer()) + s2.Listen(quicAddr) + log.Errorf("hello world") + // should dial tcp immediately if there's no quic address available + s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{tcpAddr}, peerstore.PermanentAddrTTL) + reqch = make(chan dialRequest) + resch = make(chan dialResponse) + worker3 := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker3.loop() + defer worker3.wg.Wait() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + require.NoError(t, res.err) + if _, err := res.conn.LocalMultiaddr().ValueForProtocol(ma.P_TCP); err != nil { + t.Fatalf("expected tcp connection, got: %s", res.conn.LocalMultiaddr()) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("dial didn't complete") + } + close(reqch) + s1.ClosePeer(s2.LocalPeer()) + s1.peers.ClearAddrs(s2.LocalPeer()) + + // should dial next immediately when one connection errors after timeout + quicFirstLargeDelayRanker := func(addrs []ma.Multiaddr) []network.AddrDelay { + m := make([]network.AddrDelay, 0) + for _, a := range addrs { + if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + m = append(m, network.AddrDelay{Addr: a, Delay: 10 * time.Second}) + } else { + m = append(m, network.AddrDelay{Addr: a, Delay: 0}) + } + } + return m + } + + s1.dialRanker = quicFirstLargeDelayRanker + s2.ListenClose(quicAddr) + s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{tcpAddr, quicAddr}, peerstore.PermanentAddrTTL) + reqch = make(chan dialRequest) + resch = make(chan dialResponse) + worker4 := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker4.loop() + defer worker4.wg.Wait() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + require.NoError(t, res.err) + if _, err := res.conn.LocalMultiaddr().ValueForProtocol(ma.P_TCP); err != nil { + t.Fatal("expected tcp connection") + } + case <-time.After(2 * time.Second): + t.Fatal("dial didn't complete") + } + close(reqch) + s1.ClosePeer(s2.LocalPeer()) + s1.peers.ClearAddrs(s2.LocalPeer()) +} From cab3d225d9a24ddf6595493b2e498121879d7fa8 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 5 May 2023 21:31:38 +0530 Subject: [PATCH 7/9] implement scheduler differently --- p2p/net/swarm/dial_worker.go | 130 ++++++++++++++++++++++-------- p2p/net/swarm/dial_worker_test.go | 2 +- 2 files changed, 99 insertions(+), 33 deletions(-) diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 39bc28968b..23484edcef 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -2,6 +2,8 @@ package swarm import ( "context" + "math" + "sort" "sync" "time" @@ -38,7 +40,7 @@ type addrDial struct { conn *Conn err error requests []int - delay time.Duration + dialed bool } type dialWorker struct { @@ -49,7 +51,6 @@ type dialWorker struct { requests map[int]*pendRequest pending map[ma.Multiaddr]*addrDial resch chan dialResult - ds *dialScheduler connected bool // true when a connection has been successfully established @@ -65,7 +66,6 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { requests: make(map[int]*pendRequest), pending: make(map[ma.Multiaddr]*addrDial), resch: make(chan dialResult), - ds: newDialScheduler(s), } } @@ -77,13 +77,30 @@ func (w *dialWorker) loop() { // used to signal readiness to dial and completion of the dial ready := make(chan struct{}) close(ready) - w.ds.start() + dq := dialQueue{} + currDials := 0 + timer := time.NewTimer(math.MaxInt64) + timerRunning := false + st := time.Now() + scheduleNext := func() { + if timerRunning && !timer.Stop() { + <-timer.C + } + timerRunning = false + if dq.Len() > 0 { + if currDials == 0 { + timer = time.NewTimer(-1) + } else { + timer = time.NewTimer(time.Until(st.Add(dq.Top().Delay))) + } + timerRunning = true + } + } loop: for { select { case req, ok := <-w.reqch: if !ok { - w.ds.close() return } @@ -99,6 +116,8 @@ loop: continue loop } + // at this point, len(addrs) > 0 or else it would be error from addrsForDial + // ranke them to process in order // at this point, len(addrs) > 0 or else it would be error from addrsForDial // ranke them to process in order simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) @@ -127,6 +146,7 @@ loop: todial = append(todial, a) continue } + if ad.conn != nil { // dial to this addr was successful, complete the request req.resch <- dialResponse{conn: ad.conn} @@ -155,42 +175,45 @@ loop: w.requests[w.reqno] = pr for _, ad := range tojoin { + if !ad.dialed { + if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { + if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { + ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + dq.Add(network.AddrDelay{Addr: ad.addr, Delay: addrDelay[ad.addr]}) + } + } + } ad.requests = append(ad.requests, w.reqno) } - for _, a := range todial { - w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}, delay: addrDelay[a]} - addr := a - delay := addrDelay[a] - w.ds.reqCh <- dialTask{ - addr: addr, - delay: delay, - peer: w.peer, - resCh: w.resch, - isSimConnect: simConnect, - ctx: req.ctx, + if len(todial) > 0 { + for _, a := range todial { + w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} + dq.Add(network.AddrDelay{Addr: a, Delay: addrDelay[a]}) } } - for _, ad := range tojoin { - if ad.delay == addrDelay[ad.addr] { - continue - } - addr := ad.addr - delay := addrDelay[addr] - w.ds.reqCh <- dialTask{ - addr: addr, - delay: delay, - peer: w.peer, - resCh: w.resch, - isSimConnect: simConnect, - ctx: req.ctx, + scheduleNext() + + case <-timer.C: + for _, adelay := range dq.NextBatch() { + // spawn the dial + ad := w.pending[adelay.Addr] + ad.dialed = true + err := w.s.dialNextAddr(ad.ctx, w.peer, ad.addr, w.resch) + if err != nil { + w.dispatchError(ad, err) + } else { + currDials++ } } + timerRunning = false + scheduleNext() + case res := <-w.resch: if res.Conn != nil { w.connected = true } - + currDials-- ad := w.pending[res.Addr] if res.Conn != nil { @@ -222,14 +245,14 @@ loop: } // it must be an error -- add backoff if applicable and dispatch - if res.Err != context.Canceled && res.Err != ErrDialBackoff && !w.connected { + if res.Err != context.Canceled && !w.connected { // we only add backoff if there has not been a successful connection // for consistency with the old dialer behavior. w.s.backf.AddBackoff(w.peer, res.Addr) } w.dispatchError(ad, res.Err) - w.ds.maybeTrigger() + scheduleNext() } } } @@ -281,3 +304,46 @@ func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr, isSimConnect bool) []networ } return w.s.dialRanker(addrs) } + +type dialQueue struct { + q []network.AddrDelay +} + +func (dq *dialQueue) Len() int { + return len(dq.q) +} + +func (dq *dialQueue) Top() network.AddrDelay { + return dq.q[0] +} + +func (dq *dialQueue) NextBatch() []network.AddrDelay { + if dq.Len() == 0 { + return nil + } + res := make([]network.AddrDelay, 0) + i := 0 + for i = 0; i < len(dq.q); i++ { + if dq.q[i].Delay != dq.q[0].Delay { + break + } + res = append(res, dq.q[i]) + } + dq.q = dq.q[i:] + return res +} + +func (dq *dialQueue) Add(adelay network.AddrDelay) { + updated := false + for i := 0; i < len(dq.q); i++ { + if dq.q[i].Addr.Equal(adelay.Addr) { + dq.q[i] = adelay + updated = true + break + } + } + if !updated { + dq.q = append(dq.q, adelay) + } + sort.Slice(dq.q, func(i, j int) bool { return dq.q[i].Delay < dq.q[j].Delay }) +} diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 75b823b3f7..4080ec8836 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -502,7 +502,7 @@ func TestDialWorkerLoopRanking(t *testing.T) { s1.ClosePeer(s2.LocalPeer()) s1.peers.ClearAddrs(s2.LocalPeer()) s2.Listen(quicAddr) - log.Errorf("hello world") + // should dial tcp immediately if there's no quic address available s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{tcpAddr}, peerstore.PermanentAddrTTL) reqch = make(chan dialRequest) From 4f5c805094668607c60d5f7b3e32c0e8aab3c900 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 5 May 2023 21:32:51 +0530 Subject: [PATCH 8/9] remove unused file --- p2p/net/swarm/dial_scheduler.go | 170 -------------------------------- 1 file changed, 170 deletions(-) delete mode 100644 p2p/net/swarm/dial_scheduler.go diff --git a/p2p/net/swarm/dial_scheduler.go b/p2p/net/swarm/dial_scheduler.go deleted file mode 100644 index 331c965dd8..0000000000 --- a/p2p/net/swarm/dial_scheduler.go +++ /dev/null @@ -1,170 +0,0 @@ -package swarm - -import ( - "context" - "math" - "sort" - "sync/atomic" - "time" - - "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" -) - -type dialScheduler struct { - q []dialTask - tasks map[ma.Multiaddr]*taskState - reqCh chan dialTask - st time.Time - triggerCh chan struct{} - s *Swarm -} - -type dialTask struct { - addr ma.Multiaddr - delay time.Duration - dialFunc func() - isSimConnect bool - peer peer.ID - ctx context.Context - resCh chan dialResult -} - -type taskStatus int - -const ( - scheduled taskStatus = iota - dialed - completed -) - -type taskState struct { - status taskStatus - delay time.Duration - isSimConnect bool -} - -func newDialScheduler(s *Swarm) *dialScheduler { - return &dialScheduler{ - reqCh: make(chan dialTask, 1), - tasks: make(map[ma.Multiaddr]*taskState), - triggerCh: make(chan struct{}), - s: s, - } -} - -func (ds *dialScheduler) start() { - go ds.loop() -} - -func (ds *dialScheduler) close() { - close(ds.reqCh) -} - -func (ds *dialScheduler) maybeTrigger() { - select { - case ds.triggerCh <- struct{}{}: - default: - } -} - -func (ds *dialScheduler) loop() { - ds.st = time.Now() - timer := time.NewTimer(math.MaxInt64) - timerRunning := true - ctx, cancel := context.WithCancel(context.Background()) - doneCh := make(chan ma.Multiaddr) - var currDials atomic.Int32 - trigger := true - for { - select { - case <-timer.C: - var i int - for i = 0; i < len(ds.q); i++ { - if ds.q[i].delay != ds.q[0].delay { - break - } - st, ok := ds.tasks[ds.q[i].addr] - if !ok { - // shouldn't happen but for safety - log.Errorf("no dial scheduled for %s", ds.q[i].addr) - continue - } - st.status = dialed - currDials.Add(-1) - go func(task dialTask) { - respCh := make(chan dialResult) - err := ds.s.dialNextAddr(task.ctx, task.peer, task.addr, respCh) - var r dialResult - if err != nil { - r = dialResult{ - Conn: nil, - Addr: task.addr, - Err: err, - } - } else { - select { - case r = <-respCh: - case <-ctx.Done(): - return - } - } - currDials.Add(1) - select { - case task.resCh <- r: - case <-ctx.Done(): - return - case <-task.ctx.Done(): - return - } - - }(ds.q[i]) - } - ds.q = ds.q[i:] - timerRunning = false - case task, ok := <-ds.reqCh: - if !ok { - cancel() - return - } - st, ok := ds.tasks[task.addr] - if !ok { - ds.q = append(ds.q, task) - ds.tasks[task.addr] = &taskState{ - status: scheduled, - delay: task.delay, - isSimConnect: task.isSimConnect, - } - } else if !st.isSimConnect && task.isSimConnect && st.status != dialed { - st.isSimConnect = true - st.delay = task.delay - st.status = scheduled - for i, a := range ds.q { - if a.addr.Equal(task.addr) { - ds.q[i] = task - break - } - } - } - sort.Slice(ds.q, func(i, j int) bool { return ds.q[i].delay < ds.q[j].delay }) - case a := <-doneCh: - delete(ds.tasks, a) - case <-ds.triggerCh: - trigger = true - } - - if timerRunning && !timer.Stop() { - <-timer.C - } - timerRunning = false - if len(ds.q) > 0 { - if trigger && currDials.Load() == 0 { - timer.Reset(-1) - } else { - timer.Reset(time.Until(ds.st.Add(ds.q[0].delay))) - } - timerRunning = true - trigger = false - } - } -} From 331ed97e976388f7533a473dc627d1e007730034 Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 5 May 2023 21:43:46 +0530 Subject: [PATCH 9/9] remove unused function --- p2p/net/swarm/swarm_dial.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 5423a199b7..256cff9dea 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -542,13 +542,6 @@ func isFdConsumingAddr(addr ma.Multiaddr) bool { return err1 == nil || err2 == nil } -func isExpensiveAddr(addr ma.Multiaddr) bool { - _, wsErr := addr.ValueForProtocol(ma.P_WS) - _, wssErr := addr.ValueForProtocol(ma.P_WSS) - _, wtErr := addr.ValueForProtocol(ma.P_WEBTRANSPORT) - return wsErr == nil || wssErr == nil || wtErr == nil -} - func isRelayAddr(addr ma.Multiaddr) bool { _, err := addr.ValueForProtocol(ma.P_CIRCUIT) return err == nil