diff --git a/tests/asteroid_agent_test.py b/tests/asteroid_agent_test.py index 861001e..1373c96 100644 --- a/tests/asteroid_agent_test.py +++ b/tests/asteroid_agent_test.py @@ -1,11 +1,14 @@ """Unit tests for AsteroidAgent.""" -from typing import Type, Iterator +from typing import Callable, Iterator, Type +import requests +import requests_mock from ostorlab.agent.message import message as m +from pytest_mock import plugin +from requests_mock.adapter import ANY -from agent import asteroid_agent -from agent import definitions +from agent import asteroid_agent, definitions def testAsteroidAgent_whenExploitCheckDetectVulnz_EmitsVulnerabilityReport( @@ -23,18 +26,35 @@ def testAsteroidAgent_whenExploitCheckDetectVulnz_EmitsVulnerabilityReport( def testAsteroidAgent_whenTooManyRedirects_doesNotCrash( - exploit_instance_with_report: Iterator[Type[definitions.Exploit]], asteroid_agent_instance: asteroid_agent.AsteroidAgent, agent_mock: list[m.Message], + mocker: plugin.MockerFixture, + requests_mock: requests_mock.Mocker, ) -> None: """Ensure that the agent does not crash when there are too many redirects.""" + + def response_callback(request: requests.Request, context: Callable) -> str: + context.headers = {"Location": request.url} + context.status_code = 302 + return "" + + requests_mock.register_uri( + ANY, + ANY, + text=response_callback, + ) + + mock_var_bind = mocker.MagicMock() + mock_var_bind.__getitem__.return_value.prettyPrint.return_value = ( + "ArubaOS (MODEL: 7005), Version 8.5.0.0" + ) + mock_iterator = mocker.MagicMock() + mock_iterator.__next__.return_value = (None, None, None, [mock_var_bind]) + mocker.patch("pysnmp.hlapi.getCmd", return_value=mock_iterator) + msg = m.Message( selector="v3.asset.link", - data={"url": "https://expediaagents.com", "method": "GET"}, - raw=b"\n\x19https://expediaagents.com\x12\x03GET", + data={"url": "https://example.com", "method": "GET"}, + raw=b"\n\x19https://example.com\x12\x03GET", ) - asteroid_agent_instance.process(msg) - - assert len(agent_mock) == 1 - assert agent_mock[0].selector == "v3.report.vulnerability"