diff --git a/p2p/host/peerstore/pstoreds/protobook.go b/p2p/host/peerstore/pstoreds/protobook.go index f4349b6503..fd5f54cb20 100644 --- a/p2p/host/peerstore/pstoreds/protobook.go +++ b/p2p/host/peerstore/pstoreds/protobook.go @@ -102,6 +102,28 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, return res, nil } +func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { + s := pb.segments.get(p) + s.RLock() + defer s.RUnlock() + + pmap, err := pb.getProtocolMap(p) + if err != nil { + return err + } + + if len(pmap) == 0 { + // nothing to do. + return nil + } + + for _, proto := range protos { + delete(pmap, proto) + } + + return pb.meta.Put(p, "protocols", pmap) +} + func (pb *dsProtoBook) getProtocolMap(p peer.ID) (map[string]struct{}, error) { iprotomap, err := pb.meta.Get(p, "protocols") switch err { diff --git a/p2p/host/peerstore/pstoremem/protobook.go b/p2p/host/peerstore/pstoremem/protobook.go index a1ce2311da..7a0c86fe1f 100644 --- a/p2p/host/peerstore/pstoremem/protobook.go +++ b/p2p/host/peerstore/pstoremem/protobook.go @@ -112,6 +112,23 @@ func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) { return out, nil } +func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { + s := pb.segments.get(p) + s.RLock() + defer s.RUnlock() + + protomap, ok := s.protocols[p] + if !ok { + // nothing to remove. + return nil + } + + for _, proto := range protos { + delete(protomap, pb.internProtocol(proto)) + } + return nil +} + func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) { s := pb.segments.get(p) s.RLock() diff --git a/p2p/host/peerstore/test/peerstore_suite.go b/p2p/host/peerstore/test/peerstore_suite.go index 020da4caa0..fb4a813769 100644 --- a/p2p/host/peerstore/test/peerstore_suite.go +++ b/p2p/host/peerstore/test/peerstore_suite.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand" + "reflect" "sort" "testing" "time" @@ -240,7 +241,8 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { t.Fatal("got wrong supported array: ", supported) } - err = ps.SetProtocols(p1, "other") + protos = []string{"other", "yet another", "one more"} + err = ps.SetProtocols(p1, protos...) if err != nil { t.Fatal(err) } @@ -253,6 +255,29 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { if len(supported) != 0 { t.Fatal("none of those protocols should have been supported") } + + supported, err = ps.GetProtocols(p1) + if err != nil { + t.Fatal(err) + } + sort.Strings(supported) + sort.Strings(protos) + if !reflect.DeepEqual(supported, protos) { + t.Fatalf("expected previously set protos; expected: %v, have: %v", protos, supported) + } + + err = ps.RemoveProtocols(p1, protos[:2]...) + if err != nil { + t.Fatal(err) + } + + supported, err = ps.GetProtocols(p1) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(supported, protos[2:]) { + t.Fatal("expected only one protocol to remain") + } } }