Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discv5 Protocol: Add support for banning nodes #769

Merged
merged 16 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 72 additions & 15 deletions eth/p2p/discoveryv5/protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ const
defaultResponseTimeout* = 4.seconds ## timeout for the response of a request-response
## call

## Ban durations for banned nodes in the routing table
NodeBanDurationInvalidResponse = 15.minutes

type
OptAddress* = object
ip*: Opt[IpAddress]
Expand All @@ -142,6 +145,7 @@ type
bindAddress: OptAddress ## UDP binding address
pendingRequests: Table[AESGCMNonce, PendingRequest]
routingTable*: RoutingTable
banNodes: bool
codec*: Codec
awaitedMessages: Table[(NodeId, RequestId), Future[Opt[Message]]]
refreshLoop: Future[void]
Expand All @@ -157,6 +161,7 @@ type
responseTimeout: Duration
rng*: ref HmacDrbgContext


PendingRequest = object
node: Node
message: seq[byte]
Expand Down Expand Up @@ -192,10 +197,13 @@ proc addNode*(d: Protocol, node: Node): bool =
##
## Returns true only when `Node` was added as a new entry to a bucket in the
## routing table.
if d.routingTable.addNode(node) == Added:
let r = d.routingTable.addNode(node)
if r == Added:
return true
else:
return false

if r == Banned:
debug "Banned node not added to routing table", nodeId = node.id
return false

proc addNode*(d: Protocol, r: Record): bool =
## Add `Node` from a `Record` to discovery routing table.
Expand Down Expand Up @@ -429,6 +437,30 @@ proc sendWhoareyou(d: Protocol, toId: NodeId, a: Address,
else:
debug "Node with this id already has ongoing handshake, ignoring packet"

proc replaceNode(d: Protocol, n: Node) =
if n.record notin d.bootstrapRecords:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)

proc banNode*(d: Protocol, n: Node, banPeriod: chronos.Duration) =
if n.record notin d.bootstrapRecords:
if d.banNodes:
d.routingTable.banNode(n.id, banPeriod) # banNode also replaces the node
else:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)

proc isBanned*(d: Protocol, nodeId: NodeId): bool =
d.banNodes and d.routingTable.isBanned(nodeId)

proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
discv5_network_bytes.inc(packet.len.int64, labelValues = [$Direction.In])

Expand All @@ -437,6 +469,10 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
let packet = decoded[]
case packet.flag
of OrdinaryMessage:
if d.isBanned(packet.srcId):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An improvement for the future here could be to ban without actually doing the decryption of the message (only the header). But they way the decodePacket call is currently designed this is not really possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I guess we could use the src-id in the authdata section of the packet header for ordinary messages and handshake messages.

trace "Ignoring received OrdinaryMessage from banned node", nodeId = packet.srcId
return

if packet.messageOpt.isSome():
let message = packet.messageOpt.get()
trace "Received message packet", srcId = packet.srcId, address = a,
Expand Down Expand Up @@ -464,6 +500,10 @@ proc receive*(d: Protocol, a: Address, packet: openArray[byte]) =
else:
debug "Timed out or unrequested whoareyou packet", address = a
of HandshakeMessage:
if d.isBanned(packet.srcIdHs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem above.

trace "Ignoring received HandshakeMessage from banned node", nodeId = packet.srcIdHs
return

trace "Received handshake message packet", srcId = packet.srcIdHs,
address = a, kind = packet.message.kind
d.handleMessage(packet.srcIdHs, a, packet.message, packet.node)
Expand Down Expand Up @@ -494,14 +534,7 @@ proc processClient(transp: DatagramTransport, raddr: TransportAddress):

proto.receive(Address(ip: raddr.toIpAddress(), port: raddr.port), buf)

proc replaceNode(d: Protocol, n: Node) =
if n.record notin d.bootstrapRecords:
d.routingTable.replaceNode(n)
else:
# For now we never remove bootstrap nodes. It might make sense to actually
# do so and to retry them only in case we drop to a really low amount of
# peers in the routing table.
debug "Message request to bootstrap node failed", enr = toURI(n.record)


# TODO: This could be improved to do the clean-up immediately in case a non
# whoareyou response does arrive, but we would need to store the AuthTag
Expand Down Expand Up @@ -546,9 +579,11 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
break
return ok(res)
else:
d.banNode(fromNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to find node message")
else:
d.replaceNode(fromNode)
discovery_message_requests_outgoing.inc(labelValues = ["no_response"])
return err("Nodes message not received in time")

Expand All @@ -574,6 +609,10 @@ proc ping*(d: Protocol, toNode: Node):
## Send a discovery ping message.
##
## Returns the received pong message or an error.

if d.isBanned(toNode.id):
return err("toNode is banned")

let reqId = d.sendMessage(toNode,
PingMessage(enrSeq: d.localNode.record.seqNum))
let resp = await d.waitMessage(toNode, reqId)
Expand All @@ -583,7 +622,7 @@ proc ping*(d: Protocol, toNode: Node):
d.routingTable.setJustSeen(toNode)
return ok(resp.get().pong)
else:
d.replaceNode(toNode)
d.banNode(toNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to ping message")
else:
Expand All @@ -597,22 +636,29 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]):
##
## Returns the received nodes or an error.
## Received ENRs are already validated and converted to `Node`.

if d.isBanned(toNode.id):
return err("toNode is banned")

let reqId = d.sendMessage(toNode, FindNodeMessage(distances: distances))
let nodes = await d.waitNodes(toNode, reqId)

if nodes.isOk:
let res = verifyNodesRecords(nodes.get(), toNode, findNodeResultLimit, distances)
d.routingTable.setJustSeen(toNode)
return ok(res)
return ok(res.filterIt(not d.isBanned(it.id)))
else:
d.replaceNode(toNode)
return err(nodes.error)

proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
Future[DiscResult[seq[byte]]] {.async: (raises: [CancelledError]).} =
## Send a discovery talkreq message.
##
## Returns the received talkresp message or an error.

if d.isBanned(toNode.id):
return err("toNode is banned")

let reqId = d.sendMessage(toNode,
TalkReqMessage(protocol: protocol, request: request))
let resp = await d.waitMessage(toNode, reqId)
Expand All @@ -622,7 +668,7 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
d.routingTable.setJustSeen(toNode)
return ok(resp.get().talkResp.response)
else:
d.replaceNode(toNode)
d.banNode(toNode, NodeBanDurationInvalidResponse)
discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"])
return err("Invalid response to talk request message")
else:
Expand Down Expand Up @@ -797,6 +843,12 @@ proc resolve*(d: Protocol, id: NodeId): Future[Opt[Node]] {.async: (raises: [Can
if id == d.localNode.id:
return Opt.some(d.localNode)

# No point in trying to resolve a banned node because it won't exist in the
# routing table and it will be filtered out of any respones in the lookup call
if d.isBanned(id):
debug "Not resolving banned node", nodeId = id
return Opt.none(Node)

let node = d.getNode(id)
if node.isSome():
let request = await d.findNode(node.get(), @[0'u16])
Expand Down Expand Up @@ -882,6 +934,9 @@ proc refreshLoop(d: Protocol) {.async: (raises: []).} =
trace "Discovered nodes in random target query", nodes = randomQuery.len
debug "Total nodes in discv5 routing table", total = d.routingTable.len()

# Remove the expired bans from routing table to limit memory usage
d.routingTable.cleanupExpiredBans()

await sleepAsync(refreshInterval)
except CancelledError:
trace "refreshLoop canceled"
Expand Down Expand Up @@ -985,6 +1040,7 @@ proc newProtocol*(
bindPort: Port,
bindIp = IPv4_any(),
enrAutoUpdate = false,
banNodes = false,
config = defaultDiscoveryConfig,
rng = newRng()):
Protocol =
Expand Down Expand Up @@ -1034,6 +1090,7 @@ proc newProtocol*(
enrAutoUpdate: enrAutoUpdate,
routingTable: RoutingTable.init(
node, config.bitsPerHop, config.tableIpLimits, rng),
banNodes: banNodes,
handshakeTimeout: config.handshakeTimeout,
responseTimeout: config.responseTimeout,
rng: rng)
Expand Down
2 changes: 1 addition & 1 deletion eth/p2p/discoveryv5/routing_table.nim
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func ipLimitDec(r: var RoutingTable, b: KBucket, n: Node) =
r.ipLimits.dec(ip)

func getNode*(r: RoutingTable, id: NodeId): Opt[Node]
proc replaceNode*(r: var RoutingTable, n: Node)
proc replaceNode*(r: var RoutingTable, n: Node) {.gcsafe.}

proc banNode*(r: var RoutingTable, nodeId: NodeId, period: chronos.Duration) =
## Ban a node from the routing table for the given period. The node is removed
Expand Down
6 changes: 4 additions & 2 deletions tests/p2p/discv5_test_helper.nim
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ proc initDiscoveryNode*(
address: Address,
bootstrapRecords: openArray[Record] = [],
localEnrFields: openArray[(string, seq[byte])] = [],
previousRecord = Opt.none(enr.Record)):
previousRecord = Opt.none(enr.Record),
banNodes = false):
discv5_protocol.Protocol =
# set bucketIpLimit to allow bucket split
let config = DiscoveryConfig.init(1000, 24, 5)
Expand All @@ -36,7 +37,8 @@ proc initDiscoveryNode*(
localEnrFields = localEnrFields,
previousRecord = previousRecord,
config = config,
rng = rng)
rng = rng,
banNodes = banNodes)

protocol.open()

Expand Down
113 changes: 113 additions & 0 deletions tests/p2p/test_discoveryv5.nim
Original file line number Diff line number Diff line change
Expand Up @@ -926,3 +926,116 @@ suite "Discovery v5 Tests":

await node1.closeWait()
await node2.closeWait()

asyncTest "Banned nodes are removed and cannot be added":
let
node = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302), banNodes = true)
targetNode = generateNode(PrivateKey.random(rng[]))

# add the node
check:
node.addNode(targetNode) == true
node.getNode(targetNode.id).isSome()

# banning the node should remove it from the routing table
node.banNode(targetNode, 1.minutes)
check node.getNode(targetNode.id).isNone()

# cannot add a banned node
check:
node.addNode(targetNode) == false
node.getNode(targetNode.id).isNone()

await node.closeWait()

asyncTest "FindNode filters out banned nodes":
let
mainNode = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)
testNode = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
@[mainNode.localNode.record], banNodes = true)

# Generate 100 random nodes and add to our main node's routing table
for i in 0 ..< 100:
discard mainNode.addSeenNode(generateNode(PrivateKey.random(rng[])))

let
neighbours = mainNode.neighbours(mainNode.localNode.id)
closest = neighbours[0]
closestDistance = logDistance(closest.id, mainNode.localNode.id)

block:
# the closest node is returned
let discovered = await testNode.findNode(mainNode.localNode, @[closestDistance])
check discovered.isOk
check closest in discovered[]

# ban the closest node
mainNode.banNode(closest, 1.minutes)

block:
# the banned node is not returned
let discovered = await testNode.findNode(mainNode.localNode, @[closestDistance])
check discovered.isOk
check closest notin discovered[]

await mainNode.closeWait()
await testNode.closeWait()

asyncTest "Cannot send messages to banned nodes":
let
node1 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
banNodes = true)
node2 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)

# ban node2 in node1's routing table
node1.banNode(node2.localNode, 1.minutes)

block:
let pong = await node1.ping(node2.localNode)
check:
pong.isErr()
pong.error() == "toNode is banned"

block:
let nodes = await node1.findNode(node2.localNode, @[0.uint16])
check:
nodes.isErr()
nodes.error() == "toNode is banned"

block:
let node = await node1.resolve(node2.localNode.id)
check node.isNone()

await node2.closeWait()
await node1.closeWait()

asyncTest "Ignore messages from banned nodes":
let
node1 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20302),
banNodes = true)
node2 = initDiscoveryNode(rng, PrivateKey.random(rng[]), localAddress(20301),
banNodes = true)

# ban node1 in node2's routing table
node2.banNode(node1.localNode, 1.minutes)

block:
let pong = await node1.ping(node2.localNode)
check:
pong.isErr()
pong.error() == "Pong message not received in time"

block:
let nodes = await node1.findNode(node2.localNode, @[0.uint16])
check:
nodes.isErr()
nodes.error() == "Nodes message not received in time"

block:
let node = await node1.resolve(node2.localNode.id)
check node.isNone()

await node2.closeWait()
await node1.closeWait()