Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor RedisClientsManager and RedisClientSDK #5888

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from faststream.rabbit import ExchangeType, RabbitBroker, RabbitExchange, RabbitRouter
from pydantic import NonNegativeInt
from servicelib.logging_utils import log_catch, log_context
from servicelib.redis import RedisClientSDKHealthChecked
from servicelib.redis import RedisClientSDK
from settings_library.rabbit import RabbitSettings

from ._base_deferred_handler import (
Expand Down Expand Up @@ -116,7 +116,7 @@ class DeferredManager: # pylint:disable=too-many-instance-attributes
def __init__(
self,
rabbit_settings: RabbitSettings,
scheduler_redis_sdk: RedisClientSDKHealthChecked,
scheduler_redis_sdk: RedisClientSDK,
*,
globals_context: GlobalsContext,
max_workers: NonNegativeInt = _DEFAULT_DEFERRED_MANAGER_WORKER_SLOTS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import NonNegativeInt

from ..redis import RedisClientSDKHealthChecked
from ..redis import RedisClientSDK
from ..utils import logged_gather
from ._base_task_tracker import BaseTaskTracker
from ._models import TaskUID
Expand All @@ -18,36 +18,36 @@ def _get_key(task_uid: TaskUID) -> str:


class RedisTaskTracker(BaseTaskTracker):
def __init__(self, redis_sdk: RedisClientSDKHealthChecked) -> None:
self.redis_sdk = redis_sdk
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
self.redis_client_sdk = redis_client_sdk

async def get_new_unique_identifier(self) -> TaskUID:
candidate_already_exists = True
while candidate_already_exists:
candidate = f"{uuid4()}"
candidate_already_exists = (
await self.redis_sdk.redis.get(_get_key(candidate)) is not None
await self.redis_client_sdk.redis.get(_get_key(candidate)) is not None
)
return TaskUID(candidate)

async def _get_raw(self, redis_key: str) -> TaskScheduleModel | None:
found_data = await self.redis_sdk.redis.get(redis_key)
found_data = await self.redis_client_sdk.redis.get(redis_key)
return None if found_data is None else TaskScheduleModel.parse_raw(found_data)

async def get(self, task_uid: TaskUID) -> TaskScheduleModel | None:
return await self._get_raw(_get_key(task_uid))

async def save(self, task_uid: TaskUID, task_schedule: TaskScheduleModel) -> None:
await self.redis_sdk.redis.set(_get_key(task_uid), task_schedule.json())
await self.redis_client_sdk.redis.set(_get_key(task_uid), task_schedule.json())

async def remove(self, task_uid: TaskUID) -> None:
await self.redis_sdk.redis.delete(_get_key(task_uid))
await self.redis_client_sdk.redis.delete(_get_key(task_uid))

async def all(self) -> list[TaskScheduleModel]:
return await logged_gather(
*[
self._get_raw(x)
async for x in self.redis_sdk.redis.scan_iter(
async for x in self.redis_client_sdk.redis.scan_iter(
match=f"{_TASK_TRACKER_PREFIX}*"
)
],
Expand Down
107 changes: 56 additions & 51 deletions packages/service-library/src/servicelib/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
from .background_task import periodic_task, start_periodic_task, stop_periodic_task
from .logging_utils import log_catch, log_context
from .retry_policies import RedisRetryPolicyUponInitialization
from .utils import logged_gather

_DEFAULT_LOCK_TTL: Final[datetime.timedelta] = datetime.timedelta(seconds=10)
_DEFAULT_SOCKET_TIMEOUT: Final[datetime.timedelta] = datetime.timedelta(seconds=30)


_DEFAULT_DECODE_RESPONSES: Final[bool] = True
_DEFAULT_HEALTH_CHECK_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(
seconds=5
)


_logger = logging.getLogger(__name__)


Expand All @@ -44,7 +49,12 @@ class CouldNotConnectToRedisError(BaseRedisError):
@dataclass
class RedisClientSDK:
redis_dsn: str
decode_responses: bool = _DEFAULT_DECODE_RESPONSES
health_check_interval: datetime.timedelta = _DEFAULT_HEALTH_CHECK_INTERVAL

_client: aioredis.Redis = field(init=False)
_health_check_task: Task | None = None
_is_healthy: bool = False

@property
def redis(self) -> aioredis.Redis:
Expand All @@ -63,21 +73,32 @@ def __post_init__(self):
socket_timeout=_DEFAULT_SOCKET_TIMEOUT.total_seconds(),
socket_connect_timeout=_DEFAULT_SOCKET_TIMEOUT.total_seconds(),
encoding="utf-8",
decode_responses=True,
decode_responses=self.decode_responses,
)

@retry(**RedisRetryPolicyUponInitialization(_logger).kwargs)
async def setup(self) -> None:
if not await self._client.ping():
await self.shutdown()
raise CouldNotConnectToRedisError(dsn=self.redis_dsn)

self._is_healthy = True
self._health_check_task = start_periodic_task(
self._check_health,
interval=self.health_check_interval,
task_name=f"redis_service_health_check_{self.redis_dsn}",
)

_logger.info(
"Connection to %s succeeded with %s",
f"redis at {self.redis_dsn=}",
f"{self._client=}",
)

async def shutdown(self) -> None:
if self._health_check_task:
await stop_periodic_task(self._health_check_task)

# NOTE: redis-py does not yet completely fill all the needed types for mypy
await self._client.aclose(close_connection_pool=True) # type: ignore[attr-defined]

Expand All @@ -87,6 +108,21 @@ async def ping(self) -> bool:
return True
return False

async def _check_health(self) -> None:
self._is_healthy = await self.ping()

@property
def is_healthy(self) -> bool:
"""Returns the result of the last health check.
If redis becomes available, after being not available,
it will once more return ``True``

Returns:
``False``: if the service is no longer reachable
``True``: when service is reachable
"""
return self._is_healthy

@contextlib.asynccontextmanager
async def lock_context(
self,
Expand Down Expand Up @@ -169,49 +205,11 @@ async def lock_value(self, lock_name: str) -> str | None:
return output


class RedisClientSDKHealthChecked(RedisClientSDK):
"""
Provides access to ``is_healthy`` property, to be used for defining
health check handlers.
"""

def __init__(
self,
redis_dsn: str,
health_check_interval: datetime.timedelta = datetime.timedelta(seconds=5),
) -> None:
super().__init__(redis_dsn)
self.health_check_interval: datetime.timedelta = health_check_interval
self._health_check_task: Task | None = None
self._is_healthy: bool = True

@property
def is_healthy(self) -> bool:
"""Provides the status of Redis.
If redis becomes available, after being not available,
it will once more return ``True``

Returns:
``False``: if the service is no longer reachable
``True``: when service is reachable
"""
return self._is_healthy

async def _check_health(self) -> None:
self._is_healthy = await self.ping()

async def setup(self) -> None:
await super().setup()
self._health_check_task = start_periodic_task(
self._check_health,
interval=self.health_check_interval,
task_name="redis_service_health_check",
)

async def shutdown(self) -> None:
if self._health_check_task:
await stop_periodic_task(self._health_check_task)
await super().shutdown()
@dataclass(frozen=True)
class RedisManagerDBConfig:
database: RedisDatabase
decode_responses: bool = _DEFAULT_DECODE_RESPONSES
health_check_interval: datetime.timedelta = _DEFAULT_HEALTH_CHECK_INTERVAL


@dataclass
Expand All @@ -220,20 +218,27 @@ class RedisClientsManager:
Manages the lifetime of redis client sdk connections
"""

databases: set[RedisDatabase]
databases_configs: set[RedisManagerDBConfig]
settings: RedisSettings

_client_sdks: dict[RedisDatabase, RedisClientSDK] = field(default_factory=dict)

async def setup(self) -> None:
for db in self.databases:
self._client_sdks[db] = client_sdk = RedisClientSDK(
redis_dsn=self.settings.build_redis_dsn(db)
for config in self.databases_configs:
self._client_sdks[config.database] = RedisClientSDK(
redis_dsn=self.settings.build_redis_dsn(config.database),
decode_responses=config.decode_responses,
health_check_interval=config.health_check_interval,
)
await client_sdk.setup()

for client in self._client_sdks.values():
await client.setup()

async def shutdown(self) -> None:
await logged_gather(*(c.shutdown() for c in self._client_sdks.values()))
# NOTE: somehow using logged_gather is not an option
# doing so will make the shutdown procedure hang
for client in self._client_sdks.values():
await client.shutdown()

def client(self, database: RedisDatabase) -> RedisClientSDK:
return self._client_sdks[database]
12 changes: 6 additions & 6 deletions packages/service-library/src/servicelib/retry_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class PostgresRetryPolicyUponInitialization:
def __init__(self, logger: logging.Logger | None = None):
logger = logger or log

self.kwargs = dict(
wait=wait_fixed(self.WAIT_SECS),
stop=stop_after_attempt(self.ATTEMPTS_COUNT),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
self.kwargs = {
"wait": wait_fixed(self.WAIT_SECS),
"stop": stop_after_attempt(self.ATTEMPTS_COUNT),
"before_sleep": before_sleep_log(logger, logging.WARNING),
"reraise": True,
}


class RedisRetryPolicyUponInitialization(PostgresRetryPolicyUponInitialization):
Expand Down
14 changes: 7 additions & 7 deletions packages/service-library/tests/deferred_tasks/example_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StartContext,
TaskUID,
)
from servicelib.redis import RedisClientSDK, RedisClientSDKHealthChecked
from servicelib.redis import RedisClientSDK
from settings_library.rabbit import RabbitSettings
from settings_library.redis import RedisDatabase, RedisSettings

Expand Down Expand Up @@ -54,7 +54,7 @@ async def on_result(cls, result: str, context: DeferredContext) -> None:

class InMemoryLists:
def __init__(self, redis_settings: RedisSettings, port: int) -> None:
self.redis_sdk = RedisClientSDK(
self.redis_client_sdk = RedisClientSDK(
redis_settings.build_redis_dsn(RedisDatabase.DEFERRED_TASKS)
)
self.port = port
Expand All @@ -63,10 +63,10 @@ def _get_queue_name(self, queue_name: str) -> str:
return f"in_memory_lists::{queue_name}.{self.port}"

async def append_to(self, queue_name: str, value: Any) -> None:
await self.redis_sdk.redis.rpush(self._get_queue_name(queue_name), value) # type: ignore
await self.redis_client_sdk.redis.rpush(self._get_queue_name(queue_name), value) # type: ignore

async def get_all_from(self, queue_name: str) -> list:
return await self.redis_sdk.redis.lrange(
return await self.redis_client_sdk.redis.lrange(
self._get_queue_name(queue_name), 0, -1
) # type: ignore

Expand All @@ -79,18 +79,18 @@ def __init__(
in_memory_lists: InMemoryLists,
max_workers: NonNegativeInt,
) -> None:
self._redis_client = RedisClientSDKHealthChecked(
self._redis_client_sdk = RedisClientSDK(
redis_settings.build_redis_dsn(RedisDatabase.DEFERRED_TASKS)
)
self._manager = DeferredManager(
rabbit_settings,
self._redis_client,
self._redis_client_sdk,
globals_context={"in_memory_lists": in_memory_lists},
max_workers=max_workers,
)

async def setup(self) -> None:
await self._redis_client.setup()
await self._redis_client_sdk.setup()
await self._manager.setup()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from servicelib.deferred_tasks._models import TaskResultError, TaskUID
from servicelib.deferred_tasks._task_schedule import TaskState
from servicelib.redis import RedisClientSDKHealthChecked
from servicelib.redis import RedisClientSDK
from settings_library.rabbit import RabbitSettings
from settings_library.redis import RedisDatabase, RedisSettings
from tenacity._asyncio import AsyncRetrying
Expand All @@ -49,12 +49,10 @@ class MockKeys(StrAutoEnum):


@pytest.fixture
async def redis_sdk(
async def redis_client_sdk(
redis_service: RedisSettings,
) -> AsyncIterable[RedisClientSDKHealthChecked]:
sdk = RedisClientSDKHealthChecked(
redis_service.build_redis_dsn(RedisDatabase.DEFERRED_TASKS)
)
) -> AsyncIterable[RedisClientSDK]:
sdk = RedisClientSDK(redis_service.build_redis_dsn(RedisDatabase.DEFERRED_TASKS))
await sdk.setup()
yield sdk
await sdk.shutdown()
Expand All @@ -68,12 +66,12 @@ def mocked_deferred_globals() -> dict[str, Any]:
@pytest.fixture
async def deferred_manager(
rabbit_service: RabbitSettings,
redis_sdk: RedisClientSDKHealthChecked,
redis_client_sdk: RedisClientSDK,
mocked_deferred_globals: dict[str, Any],
) -> AsyncIterable[DeferredManager]:
manager = DeferredManager(
rabbit_service,
redis_sdk,
redis_client_sdk,
globals_context=mocked_deferred_globals,
max_workers=10,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from servicelib.deferred_tasks._models import TaskUID
from servicelib.deferred_tasks._redis_task_tracker import RedisTaskTracker
from servicelib.deferred_tasks._task_schedule import TaskScheduleModel, TaskState
from servicelib.redis import RedisClientSDKHealthChecked
from servicelib.redis import RedisClientSDK
from servicelib.utils import logged_gather

pytest_simcore_core_services_selection = [
Expand All @@ -33,7 +33,7 @@ def task_schedule() -> TaskScheduleModel:


async def test_task_tracker_workflow(
redis_client_sdk_deferred_tasks: RedisClientSDKHealthChecked,
redis_client_sdk_deferred_tasks: RedisClientSDK,
task_schedule: TaskScheduleModel,
):
task_tracker = RedisTaskTracker(redis_client_sdk_deferred_tasks)
Expand All @@ -51,7 +51,7 @@ async def test_task_tracker_workflow(

@pytest.mark.parametrize("count", [0, 1, 10, 100])
async def test_task_tracker_list_all_entries(
redis_client_sdk_deferred_tasks: RedisClientSDKHealthChecked,
redis_client_sdk_deferred_tasks: RedisClientSDK,
task_schedule: TaskScheduleModel,
count: int,
):
Expand Down
Loading
Loading