diff --git a/p2p/host/peerstore/pstoreds/addr_book.go b/p2p/host/peerstore/pstoreds/addr_book.go index 33ff1f5e2c..a16053b593 100644 --- a/p2p/host/peerstore/pstoreds/addr_book.go +++ b/p2p/host/peerstore/pstoreds/addr_book.go @@ -276,7 +276,7 @@ func (ab *dsAddrBook) AddAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duratio if ttl <= 0 { return } - addrs = cleanAddrs(addrs) + addrs = cleanAddrs(addrs, p) ab.setAddrs(p, addrs, ttl, ttlExtend, false) } @@ -302,7 +302,7 @@ func (ab *dsAddrBook) ConsumePeerRecord(recordEnvelope *record.Envelope, ttl tim return false, nil } - addrs := cleanAddrs(rec.Addrs) + addrs := cleanAddrs(rec.Addrs, rec.PeerID) err = ab.setAddrs(rec.PeerID, addrs, ttl, ttlExtend, true) if err != nil { return false, err @@ -385,7 +385,7 @@ func (ab *dsAddrBook) SetAddr(p peer.ID, addr ma.Multiaddr, ttl time.Duration) { // SetAddrs will add or update the TTLs of addresses in the AddrBook. func (ab *dsAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) { - addrs = cleanAddrs(addrs) + addrs = cleanAddrs(addrs, p) if ttl <= 0 { ab.deleteAddrs(p, addrs) return @@ -598,10 +598,17 @@ func (ab *dsAddrBook) deleteAddrs(p peer.ID, addrs []ma.Multiaddr) (err error) { return pr.flush(ab.ds) } -func cleanAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { +func cleanAddrs(addrs []ma.Multiaddr, pid peer.ID) []ma.Multiaddr { clean := make([]ma.Multiaddr, 0, len(addrs)) for _, addr := range addrs { + // Remove suffix of /p2p/peer-id from address + addr, addrPid := peer.SplitAddr(addr) if addr == nil { + log.Warnw("Was passed a nil multiaddr", "peer", pid) + continue + } + if addrPid != "" && addrPid != pid { + log.Warnf("Was passed p2p address with a different peerId. found: %s, expected: %s", addrPid, pid) continue } clean = append(clean, addr) diff --git a/p2p/host/peerstore/pstoremem/addr_book.go b/p2p/host/peerstore/pstoremem/addr_book.go index fc4fd0675a..67f9f91462 100644 --- a/p2p/host/peerstore/pstoremem/addr_book.go +++ b/p2p/host/peerstore/pstoremem/addr_book.go @@ -238,11 +238,16 @@ func (mab *memoryAddrBook) addAddrsUnlocked(s *addrSegment, p peer.ID, addrs []m exp := mab.clock.Now().Add(ttl) for _, addr := range addrs { + // Remove suffix of /p2p/peer-id from address + addr, addrPid := peer.SplitAddr(addr) if addr == nil { - log.Warnw("was passed nil multiaddr", "peer", p) + log.Warnw("Was passed nil multiaddr", "peer", p) + continue + } + if addrPid != "" && addrPid != p { + log.Warnf("Was passed p2p address with a different peerId. found: %s, expected: %s", addrPid, p) continue } - // find the highest TTL and Expiry time between // existing records and function args a, found := amap[string(addr.Bytes())] // won't allocate. @@ -283,10 +288,15 @@ func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Du exp := mab.clock.Now().Add(ttl) for _, addr := range addrs { + addr, addrPid := peer.SplitAddr(addr) if addr == nil { log.Warnw("was passed nil multiaddr", "peer", p) continue } + if addrPid != "" && addrPid != p { + log.Warnf("was passed p2p address with a different peerId, found: %s wanted: %s", addrPid, p) + continue + } aBytes := addr.Bytes() key := string(aBytes) diff --git a/p2p/host/peerstore/test/addr_book_suite.go b/p2p/host/peerstore/test/addr_book_suite.go index 62327e318a..27a74d96e9 100644 --- a/p2p/host/peerstore/test/addr_book_suite.go +++ b/p2p/host/peerstore/test/addr_book_suite.go @@ -141,6 +141,24 @@ func testAddAddress(ab pstore.AddrBook, clk *mockClock.Mock) func(*testing.T) { ab.AddAddrs("", addrs, time.Hour) AssertAddressesEqual(t, addrs, ab.Addrs("")) }) + + t.Run("add a /p2p address with valid peerid", func(t *testing.T) { + peerId := GeneratePeerIDs(1)[0] + addr := GenerateAddrs(1) + p2pAddr := addr[0].Encapsulate(Multiaddr("/p2p/" + peerId.String())) + ab.AddAddr(peerId, p2pAddr, time.Hour) + AssertAddressesEqual(t, addr, ab.Addrs(peerId)) + }) + + t.Run("add a /p2p address with invalid peerid", func(t *testing.T) { + pids := GeneratePeerIDs(2) + pid1 := pids[0] + pid2 := pids[1] + addr := GenerateAddrs(1) + p2pAddr := addr[0].Encapsulate(Multiaddr("/p2p/" + pid1.String())) + ab.AddAddr(pid2, p2pAddr, time.Hour) + AssertAddressesEqual(t, nil, ab.Addrs(pid2)) + }) } }