Skip to content

Commit

Permalink
improve logging of rate limits
Browse files Browse the repository at this point in the history
  • Loading branch information
arvidn committed Nov 20, 2024
1 parent 3bc2fbf commit 658475e
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 57 deletions.
4 changes: 2 additions & 2 deletions chia/_tests/core/server/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ async def get_block_path(full_node: FullNodeAPI):


class FakeRateLimiter:
def process_msg_and_check(self, msg, capa, capb):
return True
def process_msg_and_check(self, msg, capa, capb) -> tuple[bool, str]:
return (True, "")


class TestDos:
Expand Down
76 changes: 38 additions & 38 deletions chia/_tests/core/server/test_rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ async def test_too_many_messages(self):
r = RateLimiter(incoming=True)
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
for i in range(4999):
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(4999):
response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
Expand All @@ -47,11 +47,11 @@ async def test_too_many_messages(self):
r = RateLimiter(incoming=True)
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
for i in range(200):
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(200):
response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
Expand All @@ -63,39 +63,39 @@ async def test_large_message(self):
large_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 3 * 1024 * 1024))

r = RateLimiter(incoming=True)
assert r.process_msg_and_check(small_tx_message, rl_v2, rl_v2)
assert not r.process_msg_and_check(large_tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(small_tx_message, rl_v2, rl_v2)[0]
assert not r.process_msg_and_check(large_tx_message, rl_v2, rl_v2)[0]

small_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 5 * 1024))
large_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 600 * 1024))
r = RateLimiter(incoming=True)
assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2)
assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2)
assert not r.process_msg_and_check(large_vdf_message, rl_v2, rl_v2)
assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2)[0]
assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2)[0]
assert not r.process_msg_and_check(large_vdf_message, rl_v2, rl_v2)[0]

@pytest.mark.anyio
async def test_too_much_data(self):
# Too much data
r = RateLimiter(incoming=True)
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
for i in range(40):
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(300):
response = r.process_msg_and_check(tx_message, rl_v2, rl_v2)
response = r.process_msg_and_check(tx_message, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect

r = RateLimiter(incoming=True)
block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024))
for i in range(10):
assert r.process_msg_and_check(block_message, rl_v2, rl_v2)
assert r.process_msg_and_check(block_message, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(40):
response = r.process_msg_and_check(block_message, rl_v2, rl_v2)
response = r.process_msg_and_check(block_message, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
Expand All @@ -109,14 +109,14 @@ async def test_non_tx_aggregate_limits(self):
message_3 = make_msg(ProtocolMessageTypes.plot_sync_start, bytes([1] * 64))

for i in range(500):
assert r.process_msg_and_check(message_1, rl_v2, rl_v2)
assert r.process_msg_and_check(message_1, rl_v2, rl_v2)[0]

for i in range(500):
assert r.process_msg_and_check(message_2, rl_v2, rl_v2)
assert r.process_msg_and_check(message_2, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(500):
response = r.process_msg_and_check(message_3, rl_v2, rl_v2)
response = r.process_msg_and_check(message_3, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
Expand All @@ -127,11 +127,11 @@ async def test_non_tx_aggregate_limits(self):
message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 49 * 1024 * 1024))

for i in range(2):
assert r.process_msg_and_check(message_4, rl_v2, rl_v2)
assert r.process_msg_and_check(message_4, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(2):
response = r.process_msg_and_check(message_5, rl_v2, rl_v2)
response = r.process_msg_and_check(message_5, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
Expand All @@ -141,55 +141,55 @@ async def test_periodic_reset(self):
r = RateLimiter(True, 5)
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
for i in range(10):
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(300):
response = r.process_msg_and_check(tx_message, rl_v2, rl_v2)
response = r.process_msg_and_check(tx_message, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
assert not r.process_msg_and_check(tx_message, rl_v2, rl_v2)
assert not r.process_msg_and_check(tx_message, rl_v2, rl_v2)[0]
await asyncio.sleep(6)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)[0]

# Counts reset also
r = RateLimiter(True, 5)
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
for i in range(4999):
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(4999):
response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
await asyncio.sleep(6)
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)[0]

@pytest.mark.anyio
async def test_percentage_limits(self):
r = RateLimiter(True, 60, 40)
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
for i in range(50):
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(50):
response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect

r = RateLimiter(True, 60, 40)
block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024))
for i in range(5):
assert r.process_msg_and_check(block_message, rl_v2, rl_v2)
assert r.process_msg_and_check(block_message, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(5):
response = r.process_msg_and_check(block_message, rl_v2, rl_v2)
response = r.process_msg_and_check(block_message, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
Expand All @@ -201,13 +201,13 @@ async def test_percentage_limits(self):
message_3 = make_msg(ProtocolMessageTypes.plot_sync_start, bytes([1] * 32))

for i in range(180):
assert r.process_msg_and_check(message_1, rl_v2, rl_v2)
assert r.process_msg_and_check(message_1, rl_v2, rl_v2)[0]
for i in range(180):
assert r.process_msg_and_check(message_2, rl_v2, rl_v2)
assert r.process_msg_and_check(message_2, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(100):
response = r.process_msg_and_check(message_3, rl_v2, rl_v2)
response = r.process_msg_and_check(message_3, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
Expand All @@ -218,11 +218,11 @@ async def test_percentage_limits(self):
message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 24 * 1024 * 1024))

for i in range(2):
assert r.process_msg_and_check(message_4, rl_v2, rl_v2)
assert r.process_msg_and_check(message_4, rl_v2, rl_v2)[0]

saw_disconnect = False
for i in range(2):
response = r.process_msg_and_check(message_5, rl_v2, rl_v2)
response = r.process_msg_and_check(message_5, rl_v2, rl_v2)[0]
if not response:
saw_disconnect = True
assert saw_disconnect
Expand All @@ -237,7 +237,7 @@ async def test_too_many_outgoing_messages(self):
passed = 0
blocked = 0
for i in range(non_tx_freq):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2)[0]:
passed += 1
else:
blocked += 1
Expand All @@ -248,7 +248,7 @@ async def test_too_many_outgoing_messages(self):
# ensure that *another* message type is not blocked because of this

new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1]))
assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2)[0]

@pytest.mark.anyio
async def test_too_many_incoming_messages(self):
Expand All @@ -260,7 +260,7 @@ async def test_too_many_incoming_messages(self):
passed = 0
blocked = 0
for i in range(non_tx_freq):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2)[0]:
passed += 1
else:
blocked += 1
Expand All @@ -271,7 +271,7 @@ async def test_too_many_incoming_messages(self):
# ensure that other message types *are* blocked because of this

new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1]))
assert not r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2)
assert not r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2)[0]

@pytest.mark.parametrize(
"node_with_params",
Expand Down
40 changes: 31 additions & 9 deletions chia/server/rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, incoming: bool, reset_seconds: int = 60, percentage_of_limit:

def process_msg_and_check(
self, message: Message, our_capabilities: list[Capability], peer_capabilities: list[Capability]
) -> bool:
) -> tuple[bool, str]:
"""
Returns True if message can be processed successfully, false if a rate limit is passed.
"""
Expand All @@ -59,7 +59,7 @@ def process_msg_and_check(
message_type = ProtocolMessageTypes(message.type)
except Exception as e:
log.warning(f"Invalid message: {message.type}, {e}")
return True
return (True, "")

new_message_counts: int = self.message_counts[message_type] + 1
new_cumulative_size: int = self.message_cumulative_sizes[message_type] + len(message.data)
Expand All @@ -81,25 +81,47 @@ def process_msg_and_check(
new_non_tx_count = self.non_tx_message_counts + 1
new_non_tx_size = self.non_tx_cumulative_size + len(message.data)
if new_non_tx_count > non_tx_freq * proportion_of_limit:
return False
return (
False,
f"non-tx count: {new_non_tx_count} "
f"> {non_tx_freq * proportion_of_limit} "
f"(scale factor: {proportion_of_limit})",
)
if new_non_tx_size > non_tx_max_total_size * proportion_of_limit:
return False
return (
False,
f"non-tx size: {new_non_tx_size} "
f"> {non_tx_max_total_size * proportion_of_limit}"
f"(scale factor: {proportion_of_limit})",
)
else:
log.warning(f"Message type {message_type} not found in rate limits")
log.warning(
f"Message type {message_type} not found in rate limits " f"(scale factor: {proportion_of_limit})",
)

if limits.max_total_size is None:
limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size)
assert limits.max_total_size is not None

if new_message_counts > limits.frequency * proportion_of_limit:
return False
return (
False,
f"message count: {new_message_counts} "
f"> {limits.frequency * proportion_of_limit} "
f"(scale factor: {proportion_of_limit})",
)
if len(message.data) > limits.max_size:
return False
return (False, f"message size: {len(message.data)} > {limits.max_size}")
if new_cumulative_size > limits.max_total_size * proportion_of_limit:
return False
return (
False,
f"cumulative size: {new_cumulative_size} "
f"> {limits.max_total_size * proportion_of_limit} "
f"(scale factor: {proportion_of_limit})",
)

ret = True
return True
return (True, "")
finally:
if self.incoming or ret:
# now that we determined that it's OK to send the message, commit the
Expand Down
24 changes: 16 additions & 8 deletions chia/server/ws_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,17 +626,23 @@ async def _send_message(self, message: Message) -> None:
encoded: bytes = bytes(message)
size = len(encoded)
assert len(encoded) < (2 ** (LENGTH_BYTES * 8))
if not self.outbound_rate_limiter.process_msg_and_check(
accepted, limiter_msg = self.outbound_rate_limiter.process_msg_and_check(
message, self.local_capabilities, self.peer_capabilities
):
)
if not accepted:
if not is_localhost(self.peer_info.host):
message_type = ProtocolMessageTypes(message.type)
last_time = self.log_rate_limit_last_time[message_type]
now = time.monotonic()
self.log_rate_limit_last_time[message_type] = now
if now - last_time >= 60:
msg = f"Rate limiting ourselves. message type: {message_type.name}, peer: {self.peer_info.host}"
self.log.debug(msg)
if now - last_time >= 30:
self.log.info(
f"Rate limiting ourselves. Dropping outbound message: "
f"{message_type.name}, "
f"sz: {len(message.data) / 1000:0.2f} kB, "
f"peer: {self.peer_info.host}, "
f"{limiter_msg}"
)

# TODO: fix this special case. This function has rate limits which are too low.
if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.respond_peers:
Expand Down Expand Up @@ -696,13 +702,15 @@ async def _read_one_message(self) -> Optional[Message]:
message_type = ProtocolMessageTypes(full_message_loaded.type).name
except Exception:
message_type = "Unknown"
if not self.inbound_rate_limiter.process_msg_and_check(
accepted, limiter_msg = self.inbound_rate_limiter.process_msg_and_check(
full_message_loaded, self.local_capabilities, self.peer_capabilities
):
)
if not accepted:
if self.local_type == NodeType.FULL_NODE and not is_localhost(self.peer_info.host):
self.log.error(
f"Peer has been rate limited and will be disconnected: {self.peer_info.host}, "
f"message: {message_type}"
f"message: {message_type}, "
f"{limiter_msg}"
)
# Only full node disconnects peers, to prevent abuse and crashing timelords, farmers, etc
# TODO: stop dropping tasks on the floor
Expand Down

0 comments on commit 658475e

Please sign in to comment.