Skip to content

Commit

Permalink
feat: transfer management background task (#388)
Browse files Browse the repository at this point in the history
* feat: transfer management background task

* feat: clear queue when stopping the client

* feat: adjust delay based on management cycle duration

* feat: add missing constants

* feat: remove obsolute private method, correct function call

* feat: remove duplicate code
  • Loading branch information
JurgenR authored Dec 23, 2024
1 parent ba5e994 commit 5233123
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 41 deletions.
2 changes: 2 additions & 0 deletions src/aioslsk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@
POTENTIAL_PARENTS_CACHE_SIZE: int = 20
"""Maximum amount of potential parents stored"""
DEFAULT_COMMAND_TIMEOUT: float = 10
MIN_TRANSFER_MGMT_INTERVAL: float = 0.05
MAX_TRANSFER_MGMT_INTERVAL: float = 0.25
95 changes: 59 additions & 36 deletions src/aioslsk/transfer/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import logging
from operator import itemgetter
import os
import time
from typing import Optional, TYPE_CHECKING

from ..base_manager import BaseManager
from .cache import TransferNullCache, TransferCache
from ..constants import TRANSFER_REPLY_TIMEOUT
from ..constants import (
MAX_TRANSFER_MGMT_INTERVAL,
MIN_TRANSFER_MGMT_INTERVAL,
TRANSFER_REPLY_TIMEOUT,
)
from ..exceptions import (
AioSlskException,
ConnectionReadError,
Expand Down Expand Up @@ -89,12 +94,13 @@ def __init__(
user_manager: UserManager, shares_manager: SharesManager,
network: Network, cache: Optional[TransferCache] = None):

self.cache: TransferCache = cache if cache else TransferNullCache()
self._settings: Settings = settings
self._event_bus: EventBus = event_bus
self._user_manager: UserManager = user_manager
self._shares_manager: SharesManager = shares_manager
self._network: Network = network
self.cache: TransferCache = cache if cache else TransferNullCache()

self._ticket_generator = ticket_generator()

self._transfers: list[Transfer] = []
Expand All @@ -105,6 +111,13 @@ def __init__(
name='transfer-progress-task'
)

self._management_queue: asyncio.Queue = asyncio.Queue(maxsize=1)
self._management_task: BackgroundTask = BackgroundTask(
interval=MIN_TRANSFER_MGMT_INTERVAL,
task_coro=self._management_job,
name='transfer-management-task'
)

self._MESSAGE_MAP = build_message_map(self)

self.register_listeners()
Expand All @@ -129,9 +142,6 @@ async def read_cache(self):
for transfer in transfers:
# Analyze the current state of the stored transfers and set them to
# the correct state
# This needs to happen first: when calling _add_transfer the manager
# will be registering itself as listener. `manage_transfers` should
# only be called if everything is loaded
transfer.remotely_queued = False
if transfer.state.VALUE == TransferState.INITIALIZING:
await transfer.state.queue()
Expand All @@ -144,7 +154,7 @@ async def read_cache(self):
transfer.state = TransferState.init_from_state(state, transfer)
transfer.reset_time_vars()

await self._add_transfer(transfer)
await self.add(transfer)

def write_cache(self):
"""Write all currently stored transfers to the cache"""
Expand All @@ -157,6 +167,7 @@ async def store_data(self):
self.write_cache()

async def start(self):
self._management_task.start()
await self.start_progress_reporting_task()

async def stop(self) -> list[asyncio.Task]:
Expand All @@ -172,6 +183,11 @@ async def stop(self) -> list[asyncio.Task]:
if task := self._progress_reporting_task.cancel():
cancelled_tasks.append(task)

if task := self._management_task.cancel():
cancelled_tasks.append(task)

self._management_queue = asyncio.Queue()

return cancelled_tasks

async def start_progress_reporting_task(self):
Expand Down Expand Up @@ -303,8 +319,18 @@ async def add(self, transfer: Transfer) -> Transfer:
:return: either the transfer we have passed or the already existing
transfer
"""
transfer = await self._add_transfer(transfer)
await self.manage_transfers()
for queued_transfer in self._transfers:
if queued_transfer == transfer:
logger.info("skip adding transfer, already exists : %s", queued_transfer)
return queued_transfer

logger.info("adding transfer : %s", transfer)
transfer.state_listeners.append(self)
self._transfers.append(transfer)

self.request_management_cycle()
await self._event_bus.emit(TransferAddedEvent(transfer))

return transfer

async def remove(self, transfer: Transfer):
Expand All @@ -330,7 +356,7 @@ async def remove(self, transfer: Transfer):
self._transfers.remove(transfer)
await self._event_bus.emit(TransferRemovedEvent(transfer))

await self.manage_transfers()
self.request_management_cycle()

def get_uploads(self) -> list[Transfer]:
return [transfer for transfer in self._transfers if transfer.is_upload()]
Expand All @@ -346,11 +372,7 @@ def has_slots_free(self) -> bool:
return self.get_free_upload_slots() > 0

def get_free_upload_slots(self) -> int:
uploading_transfers = []
for transfer in self._transfers:
if transfer.is_upload() and transfer.is_processing():
uploading_transfers.append(transfer)

uploading_transfers = self.get_uploading()
available_slots = self.get_upload_slots() - len(uploading_transfers)
return max(0, available_slots)

Expand Down Expand Up @@ -482,12 +504,26 @@ async def manage_user_tracking(self):
for username in finished_users - unfinished_users:
await self._user_manager.untrack_user(username, TrackingFlag.TRANSFER)

async def manage_transfers(self):
async def _management_job(self) -> float:
await self._management_queue.get()

start = time.monotonic()
await self.manage_user_tracking()
self.manage_transfers()
duration = time.monotonic() - start

return min(MIN_TRANSFER_MGMT_INTERVAL + duration, MAX_TRANSFER_MGMT_INTERVAL)

def request_management_cycle(self):
try:
self._management_queue.put_nowait(None)
except asyncio.QueueFull:
pass

def manage_transfers(self):
"""This method analyzes the state of the current downloads/uploads and
starts them up in case there are free slots available
"""
await self.manage_user_tracking()

downloads, uploads = self._get_queued_transfers()
free_upload_slots = self.get_free_upload_slots()

Expand Down Expand Up @@ -593,19 +629,6 @@ def _prioritize_uploads(self, uploads: list[Transfer]) -> list[Transfer]:
ranking.sort(key=itemgetter(0))
return list(reversed([upload for _, upload in ranking]))

async def _add_transfer(self, transfer: Transfer) -> Transfer:
for queued_transfer in self._transfers:
if queued_transfer == transfer:
logger.info("skip adding transfer, already exists : %s", queued_transfer)
return queued_transfer

logger.info("adding transfer : %s", transfer)
transfer.state_listeners.append(self)
self._transfers.append(transfer)
await self._event_bus.emit(TransferAddedEvent(transfer))

return transfer

async def _prepare_download_path(self, transfer: Transfer):
if transfer.local_path is None:
download_path, file_path = self._shares_manager.calculate_download_path(transfer.remote_path)
Expand Down Expand Up @@ -657,7 +680,7 @@ async def _queue_remotely(self, transfer: Transfer):
else:
transfer.remotely_queued = True
transfer.reset_queue_attempts()
await self.manage_transfers()
self.request_management_cycle()

async def request_place_in_queue(self, transfer: Transfer) -> Optional[int]:
"""Requests the place in queue for the given transfer. The method will
Expand Down Expand Up @@ -1081,7 +1104,7 @@ async def _download_file(self, transfer: Transfer, connection: PeerConnection):
async def on_transfer_state_changed(
self, transfer: Transfer, old: TransferState.State, new: TransferState.State):

await self.manage_transfers()
self.request_management_cycle()

async def _on_message_received(self, event: MessageReceivedEvent):
message = event.message
Expand All @@ -1090,14 +1113,14 @@ async def _on_message_received(self, event: MessageReceivedEvent):

@on_message(AddUser.Response)
async def _on_add_user(self, message: AddUser.Response, connection: PeerConnection):
await self.manage_transfers()
self.request_management_cycle()

@on_message(GetUserStatus.Response)
async def _on_get_user_status(self, message: GetUserStatus.Response, connection: PeerConnection):
if message.status == UserStatus.OFFLINE.value:
self._reset_remotely_queued_flags(message.username)

await self.manage_transfers()
self.request_management_cycle()

@on_message(PeerTransferQueue.Request)
async def _on_peer_transfer_queue(self, message: PeerTransferQueue.Request, connection: PeerConnection):
Expand Down Expand Up @@ -1198,7 +1221,7 @@ async def _on_peer_initialized(self, event: PeerInitializedEvent):
await connection.disconnect(CloseReason.REQUESTED)

async def _on_session_initialized(self, event: SessionInitializedEvent):
await self.manage_transfers()
self.request_management_cycle()

@on_message(PeerTransferRequest.Request)
async def _on_peer_transfer_request(self, message: PeerTransferRequest.Request, connection: PeerConnection):
Expand Down Expand Up @@ -1409,7 +1432,7 @@ async def _on_peer_upload_failed(self, message: PeerUploadFailed.Request, connec
username = connection.username
if transfer := self.find_transfer(username, filename, TransferDirection.DOWNLOAD):
transfer.remotely_queued = False
await self.manage_transfers()
self.request_management_cycle()

else:
logger.warning(
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ async def wait_for_transfer_added(
else:
await asyncio.sleep(0.01)
else:
raise Exception(f"transfer {client} did not have a transfer added in {timeout}s")
raise Exception(
f"transfer {client} did not have a transfer added in {timeout}s "
f"(initial={initial_amount}, current={len(client.transfers.transfers)})")


async def wait_for_transfer_state(transfer: Transfer, state: TransferState.State, timeout: int = 15):
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/transfer/test_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,14 @@ def manager(tmpdir, user_manager: UserManager) -> TransferManager:
class TestTransferManager:

@pytest.mark.asyncio
async def test_whenAddTransfer_shouldAddTransferAndAddUser(self, manager: TransferManager):
async def test_whenAddTransfer_shouldAddTransfer(self, manager: TransferManager):
transfer = Transfer(DEFAULT_USERNAME, DEFAULT_FILENAME, TransferDirection.DOWNLOAD)
await manager.add(transfer)

assert transfer.state.VALUE == TransferState.VIRGIN
assert transfer in manager.transfers
manager._user_manager.track_user.assert_awaited_once_with(
DEFAULT_USERNAME, TrackingFlag.TRANSFER
)

assert manager._management_queue.qsize() == 1

@pytest.mark.asyncio
async def test_whenAddTransfer_alreadyExists_shouldNotAdd(self, manager: TransferManager):
Expand Down

0 comments on commit 5233123

Please sign in to comment.