Skip to content

Commit

Permalink
fix: ignore becoming root on server search request (#355)
Browse files Browse the repository at this point in the history
* fix: ignore becoming root on server search request

* fix: remove unused test
  • Loading branch information
JurgenR authored Nov 5, 2024
1 parent 3f94f70 commit 17f1fa6
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 170 deletions.
43 changes: 9 additions & 34 deletions src/aioslsk/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@
class DistributedPeer:
"""Represents a distributed peer and its values in the distributed network"""
username: str
connection: Optional[PeerConnection] = None
"""Distributed connection (type=D) associated with the peer. A `None` value
indicates the peer represents the current client
"""
connection: PeerConnection
"""Distributed connection (type=D) associated with the peer"""
branch_level: Optional[int] = None
branch_root: Optional[str] = None

Expand Down Expand Up @@ -189,7 +187,7 @@ async def _set_parent(self, peer: DistributedPeer):
# be disconnected
distributed_connections = [
dpeer.connection for dpeer in self.distributed_peers
if dpeer in [self.parent, ] + self.children and dpeer.connection is not None
if dpeer in [self.parent, ] + self.children
]

disconnect_tasks = []
Expand All @@ -214,8 +212,7 @@ async def _check_if_new_parent(self, peer: DistributedPeer):
if not self.parent:
await self._set_parent(peer)
else:
if peer.connection:
await peer.connection.disconnect(reason=CloseReason.REQUESTED)
await peer.connection.disconnect(reason=CloseReason.REQUESTED)

async def _disconnect_children(self):
await asyncio.gather(
Expand All @@ -224,15 +221,11 @@ async def _disconnect_children(self):
)

async def _disconnect_child(self, peer: DistributedPeer):
if peer.connection:
await peer.connection.disconnect(CloseReason.REQUESTED)
await peer.connection.disconnect(CloseReason.REQUESTED)

async def _disconnect_parent(self):
if self.parent:
if self.parent.connection:
await self.parent.connection.disconnect(CloseReason.REQUESTED)
else:
await self._unset_parent()
await self.parent.connection.disconnect(CloseReason.REQUESTED)

async def _unset_parent(self):
logger.info("unset parent : %s", self.parent)
Expand Down Expand Up @@ -286,17 +279,15 @@ async def _check_if_new_child(self, peer: DistributedPeer):

if not self._accept_children:
logger.debug("not accepting children, rejecting peer as child : %s", peer)
if peer.connection: # Satisfy type checker
await peer.connection.disconnect(CloseReason.REQUESTED)
await peer.connection.disconnect(CloseReason.REQUESTED)
return

if len(self.children) >= self._max_children:
logger.debug(
"maximum amount of children reached (%d / %d), rejecting peer as child : %s",
len(self.children), self._max_children, peer
)
if peer.connection: # Satisfy type checker
await peer.connection.disconnect(CloseReason.REQUESTED)
await peer.connection.disconnect(CloseReason.REQUESTED)
return

await self._add_child(peer)
Expand Down Expand Up @@ -408,16 +399,6 @@ async def _on_server_search_request(self, message: ServerSearchRequest.Response,
if message.username == username:
return

if not self.parent:
# Set ourself as parent
parent = DistributedPeer(
username,
None,
branch_root=username,
branch_level=0
)
await self._set_parent(parent)

await self.send_messages_to_children(message)

@on_message(ResetDistributed.Response)
Expand Down Expand Up @@ -623,17 +604,11 @@ async def _on_state_changed(self, event: ConnectionStateChangedEvent):

elif isinstance(connection, ServerConnection):

# When the client itself is branch root it means that the server is
# our "parent" and must be unset here
if self.parent and self.parent.connection is None:
await self._unset_parent()

self._reset_server_values()

async def send_messages_to_children(self, *messages: Union[MessageDataclass, bytes]):
for child in self.children:
if child.connection:
child.connection.queue_messages(*messages)
child.connection.queue_messages(*messages)

async def stop(self) -> list[asyncio.Task]:
"""Cancels all pending tasks
Expand Down
15 changes: 10 additions & 5 deletions tests/e2e/mock/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ class Peer:

def __init__(
self, hostname: str, port: int, server: MockServer,
reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None):
reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None,
user: Optional[User] = None):

self.hostname: str = hostname
self.port: int = port
self.server = server

self.user: User = None
self.user: Optional[User] = user

self.reader: asyncio.StreamReader = reader
self.writer: asyncio.StreamWriter = writer
self.reader_loop = None
self.reader: Optional[asyncio.StreamReader] = reader
self.writer: Optional[asyncio.StreamWriter] = writer
self.reader_loop: Optional[asyncio.Task] = None

self.should_close: bool = False
self.last_ping: float = 0.0
Expand Down Expand Up @@ -113,3 +115,6 @@ async def send_message(self, message: MessageDataclass):
except Exception:
logger.exception(f"failed to send message {message}")
await self.disconnect()

def __repr__(self) -> str:
return f"Peer({self.hostname=}, {self.port=}, {self.user=})"
4 changes: 2 additions & 2 deletions tests/e2e/mock/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ async def set_upload_speed(self, username: str, uploads: int, speed: int):
)

async def send_search_request(self, username: str, sender: str, query: str, ticket: int):
"""This is a utility method for testing. To make a peer a root the
server has to send an initial search message to that user
"""This is a utility method for testing. This sends a
:class:`.ServerSearchRequest` to a specific user
:param username: Username to send the query to
:param query: The query to send
Expand Down
153 changes: 70 additions & 83 deletions tests/e2e/test_e2e_distributed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from aioslsk.client import SoulSeekClient
from .mock.server import MockServer
from .mock.distributed import ChainParentsStrategy
from .fixtures import mock_server, clients
from .utils import (
wait_until_clients_initialized,
wait_until_client_has_parent,
wait_until_client_has_no_parent,
wait_for_search_request,
wait_until_peer_has_parent,
)
Expand All @@ -29,65 +29,25 @@ async def set_upload_speed_for_clients(mock_server: MockServer, clients: list[So

class TestE2EDistributed:

@pytest.mark.asyncio
@pytest.mark.parametrize("clients", [2], indirect=True)
async def test_root_user(self, mock_server: MockServer, clients: list[SoulSeekClient]):
"""Tests when a user gets a search request directly from the server the
peer becomes root
"""
await wait_until_clients_initialized(mock_server, amount=len(clients))

client1, client2 = clients
client1_user = client1.settings.credentials.username
client2_user = client2.settings.credentials.username

# Send a search request to make client1 root
await mock_server.send_search_request(
username=client1_user,
sender=client2_user,
query='this should not match anything',
ticket=1
)

await wait_until_client_has_parent(client1)
await wait_until_peer_has_parent(mock_server, client1_user, 0, client1_user)

assert client1.distributed_network.parent is not None
assert client1.distributed_network.parent.username == client1_user
assert client1.distributed_network.parent.branch_root == client1_user
assert client1.distributed_network.parent.branch_level == 0

@pytest.mark.asyncio
@pytest.mark.parametrize("clients", [2], indirect=True)
async def test_level1_user(self, mock_server: MockServer, clients: list[SoulSeekClient]):
"""Tests when a user gets a search request directly from the server the
peer becomes root
"""
mock_server.set_distributed_strategy(ChainParentsStrategy)

await set_upload_speed_for_clients(mock_server, clients)
await wait_until_clients_initialized(mock_server, amount=len(clients))

client1, client2 = clients
client1_user = client1.settings.credentials.username
client2_user = client2.settings.credentials.username

# Send a search request to make client1 root
await mock_server.send_search_request(
username=client1_user,
sender=client2_user,
query='this should not match anything',
ticket=1
)

await wait_until_client_has_parent(client1)
await wait_until_peer_has_parent(mock_server, client1_user, 0, client1_user)

await mock_server.send_potential_parents(client2_user, [client1_user])

await wait_until_client_has_parent(client2)
await wait_until_peer_has_parent(mock_server, client2_user, 1, client1_user)

# Verify CLIENT 1
assert client1.distributed_network.parent is None

assert len(client1.distributed_network.children) == 1
assert client1.distributed_network.children[0].username == client2_user

Expand All @@ -100,10 +60,7 @@ async def test_level1_user(self, mock_server: MockServer, clients: list[SoulSeek
@pytest.mark.asyncio
@pytest.mark.parametrize("clients", [3], indirect=True)
async def test_level2_user(self, mock_server: MockServer, clients: list[SoulSeekClient]):
"""Tests when a user gets a search request directly from the server the
peer becomes root
"""
mock_server.set_distributed_strategy(ChainParentsStrategy)

await set_upload_speed_for_clients(mock_server, clients)
await wait_until_clients_initialized(mock_server, amount=len(clients))

Expand All @@ -112,17 +69,6 @@ async def test_level2_user(self, mock_server: MockServer, clients: list[SoulSeek
client2_user = client2.settings.credentials.username
client3_user = client3.settings.credentials.username

### Make CLIENT 1 root
await mock_server.send_search_request(
username=client1_user,
sender=client2_user,
query='this should not match anything',
ticket=1
)

await wait_until_client_has_parent(client1)
await wait_until_peer_has_parent(mock_server, client1_user, 0, client1_user)

### Make CLIENT 1 parent of CLIENT 2

await mock_server.send_potential_parents(client2_user, [client1_user])
Expand All @@ -138,10 +84,7 @@ async def test_level2_user(self, mock_server: MockServer, clients: list[SoulSeek
await wait_until_peer_has_parent(mock_server, client3_user, 2, client1_user)

# Verify CLIENT 1
assert client1.distributed_network.parent is not None
assert client1.distributed_network.parent.username == client1_user
assert client1.distributed_network.parent.branch_root == client1_user
assert client1.distributed_network.parent.branch_level == 0
assert client1.distributed_network.parent is None

assert len(client1.distributed_network.children) == 1
assert client1.distributed_network.children[0].username == client2_user
Expand All @@ -165,34 +108,19 @@ async def test_level2_user(self, mock_server: MockServer, clients: list[SoulSeek
assert len(client3.distributed_network.children) == 0

@pytest.mark.asyncio
@pytest.mark.parametrize("clients", [4], indirect=True)
@pytest.mark.parametrize("clients", [3], indirect=True)
async def test_level2_sendSearchRequest(self, mock_server: MockServer, clients: list[SoulSeekClient]):
"""Tests if clients on multiple levels in the network receive a search
request
"""
mock_server.set_distributed_strategy(ChainParentsStrategy)
await set_upload_speed_for_clients(mock_server, clients)
await wait_until_clients_initialized(mock_server, amount=len(clients))

client1, client2, client3, client4 = clients
client1, client2, client3 = clients
client1_user = client1.settings.credentials.username
client2_user = client2.settings.credentials.username
client3_user = client3.settings.credentials.username
client4_user = client4.settings.credentials.username

# Register mock event listeners

### Make CLIENT 1 root
await mock_server.send_search_request(
username=client1_user,
sender=client2_user,
query='this should not match anything',
ticket=1
)
await wait_until_client_has_parent(client1)
await wait_until_peer_has_parent(mock_server, client1_user, 0, client1_user)
# Remove the query made to make client1 root
client1.searches.received_searches.clear()
searching_user = 'user004'

### Make CLIENT 1 parent of CLIENT 2
await mock_server.send_potential_parents(client2_user, [client1_user])
Expand All @@ -205,13 +133,72 @@ async def test_level2_sendSearchRequest(self, mock_server: MockServer, clients:
await wait_until_peer_has_parent(mock_server, client3_user, 2, client1_user)

# Perform search
await client4.searches.search('bogus')
await mock_server.send_search_request(
username=client1_user,
sender=searching_user,
query='bogus',
ticket=1
)

await wait_for_search_request(client1)
await wait_for_search_request(client2)
await wait_for_search_request(client3)

for client in [client1, client2, client3]:
rec_search = client.searches.received_searches.pop()
assert rec_search.username == client4_user
assert rec_search.username == searching_user
assert rec_search.query == 'bogus'

@pytest.mark.asyncio
@pytest.mark.parametrize("clients", [2], indirect=True)
async def test_parent_disconnect(self, mock_server: MockServer, clients: list[SoulSeekClient]):

await set_upload_speed_for_clients(mock_server, clients)
await wait_until_clients_initialized(mock_server, amount=len(clients))

client1, client2 = clients
client1_user = client1.settings.credentials.username
client2_user = client2.settings.credentials.username

### Make CLIENT 1 parent of CLIENT 2
await mock_server.send_potential_parents(client2_user, [client1_user])
await wait_until_client_has_parent(client2)
await wait_until_peer_has_parent(mock_server, client2_user, 1, client1_user)

# Disconnect the parent
await client1.stop()

await wait_until_client_has_no_parent(client2)

# Verify CLIENT 1
assert len(client1.distributed_network.children) == 0

# Verify CLIENT 2
assert client2.distributed_network.parent is None

@pytest.mark.asyncio
@pytest.mark.parametrize("clients", [2], indirect=True)
async def test_child_disconnect(self, mock_server: MockServer, clients: list[SoulSeekClient]):

await set_upload_speed_for_clients(mock_server, clients)
await wait_until_clients_initialized(mock_server, amount=len(clients))

client1, client2 = clients
client1_user = client1.settings.credentials.username
client2_user = client2.settings.credentials.username

### Make CLIENT 1 parent of CLIENT 2
await mock_server.send_potential_parents(client2_user, [client1_user])
await wait_until_client_has_parent(client2)
await wait_until_peer_has_parent(mock_server, client2_user, 1, client1_user)

# Disconnect the child
await client2.stop()

await wait_until_client_has_no_parent(client2)

# Verify CLIENT 1
assert len(client1.distributed_network.children) == 0

# Verify CLIENT 2
assert client2.distributed_network.parent is None
Loading

0 comments on commit 17f1fa6

Please sign in to comment.