From d4415eae17008b556a9f8fd52f62938c4e5d7012 Mon Sep 17 00:00:00 2001 From: jgomezve Date: Wed, 18 Dec 2024 14:41:50 -0500 Subject: [PATCH] Add testing & Fix black formatting --- .gitignore | 3 + .../eda/plugins/event_source/subscription.py | 22 ++---- requirements.txt | 3 +- tests/unit/event_source/__init__.py | 0 tests/unit/event_source/test_subscription.py | 79 +++++++++++++++++++ 5 files changed, 89 insertions(+), 18 deletions(-) create mode 100644 tests/unit/event_source/__init__.py create mode 100644 tests/unit/event_source/test_subscription.py diff --git a/.gitignore b/.gitignore index de73c3215..021b860d4 100644 --- a/.gitignore +++ b/.gitignore @@ -390,5 +390,8 @@ $RECYCLE.BIN/ # vsCode .vscode +# Python venv +venv/ + # Ansible Collection tarball cisco-aci-*.tar.gz \ No newline at end of file diff --git a/extensions/eda/plugins/event_source/subscription.py b/extensions/eda/plugins/event_source/subscription.py index 8780155ae..2450ddefb 100644 --- a/extensions/eda/plugins/event_source/subscription.py +++ b/extensions/eda/plugins/event_source/subscription.py @@ -55,9 +55,7 @@ def login(hostname: str, username: str, password: str) -> str: return token -def subscribe( - hostname: str, token: str, rf_timeout: int, sub_urls: list[str] -) -> list[str]: +def subscribe(hostname: str, token: str, rf_timeout: int, sub_urls: list[str]) -> list[str]: """ subscribe to a websocket @@ -70,9 +68,7 @@ def subscribe( sub_ids = [] for sub in sub_urls: - sub_url = ( - f"https://{hostname}{sub}&subscription=yes&refresh-timeout={rf_timeout}" - ) + sub_url = f"https://{hostname}{sub}&subscription=yes&refresh-timeout={rf_timeout}" cookie = {"APIC-cookie": token} sub_response = requests.get(sub_url, verify=False, cookies=cookie, timeout=60) if sub_response.ok: @@ -81,9 +77,7 @@ def subscribe( return sub_ids -async def refresh( - hostname: str, token: str, refresh_timeout: int, sub_ids: list[str] -) -> NoReturn: +async def refresh(hostname: str, token: str, refresh_timeout: int, sub_ids: list[str]) -> NoReturn: """ refresh subscriptions @@ -109,16 +103,10 @@ async def main(queue: asyncio.Queue, args: Dict[str, Any]): subscriptions = args.get("subscriptions") if "" in [hostname, username, password]: - print( - f"hostname, username and password can't be empty:{hostname}, {username}, *****" - ) + print(f"hostname, username and password can't be empty:{hostname}, {username}, *****") sys.exit(1) - if ( - not isinstance(subscriptions, list) - or subscriptions == [] - or subscriptions is None - ): + if not isinstance(subscriptions, list) or subscriptions == [] or subscriptions is None: print(f"subscriptions is empty or not a list: {subscriptions}") sys.exit(1) diff --git a/requirements.txt b/requirements.txt index ef4ffcb2e..a706c79d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ pyOpenSSL python_dateutil xmljson requests -websockets \ No newline at end of file +websockets +asyncmock \ No newline at end of file diff --git a/tests/unit/event_source/__init__.py b/tests/unit/event_source/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/event_source/test_subscription.py b/tests/unit/event_source/test_subscription.py new file mode 100644 index 000000000..d83f61c20 --- /dev/null +++ b/tests/unit/event_source/test_subscription.py @@ -0,0 +1,79 @@ +from unittest.mock import patch +from typing import Any, List +from asyncmock import AsyncMock +from extensions.eda.plugins.event_source.subscription import main as subscription_main +import pytest +import json +import asyncio + + +# Refresh mock method +def refresh_patch(hostname: str, token: str, rf_timeout: int, sub_urls: List[str]) -> None: + pass + + +# Login mock method +def login_patch(hostname: str, username: str, password: str) -> str: + return f"{hostname}{username}{password}" + + +# Subscribe mock method +def subscribe_patch(hostname, token, rf_timeout, sub_urls) -> List[str]: + return [f"{hostname}{token}{rf_timeout}{url}" for url in sub_urls] + + +# Mock iterator +class AsyncIterator: + def __init__(self) -> None: + self.count = 0 + + async def __anext__(self) -> str: + if self.count < 2: + self.count += 1 + return json.dumps({"eventid": f"00{self.count}"}) + else: + raise StopAsyncIteration + + +# Mock Async Websocket +class MockWebSocket(AsyncMock): # type: ignore[misc] + def __aiter__(self) -> AsyncIterator: + return AsyncIterator() + + async def close(self) -> None: + pass + + +# Mock AsyncQueue +class MockQueue(asyncio.Queue[Any]): + def __init__(self) -> None: + self.queue: list[Any] = [] + + async def put(self, item: Any) -> None: + self.queue.append(item) + + +def test_websocket_subscription() -> None: + + with patch( + "websockets.connect", + return_value=MockWebSocket(), + ), patch("unit.event_source.tmp_subscription.login", return_value=login_patch), patch( + "unit.event_source.tmp_subscription.subscribe", return_value=subscribe_patch + ), patch("unit.event_source.tmp_subscription.refresh", return_value=refresh_patch): + + my_queue = MockQueue() + asyncio.run( + subscription_main( + my_queue, + { + "hostname": "my-apic.com", + "username": "admin", + "password": "admin", + "subscriptions": ['/api/node/class/faultInst.json?query-target-filter=and(eq(faultInst.code,"F1386"))'], + }, + ) + ) + + assert my_queue.queue[0] == {"eventid": "001"} + assert len(my_queue.queue) == 2