diff --git a/agent/virus_total_agent.py b/agent/virus_total_agent.py index c6fd0b5..966de76 100644 --- a/agent/virus_total_agent.py +++ b/agent/virus_total_agent.py @@ -44,7 +44,6 @@ def process(self, message: msg.Message) -> None: message: Message containing the file to scan. Raises: - VirusTotalApiError: In case the Virus Total api encountered problems. NameError: In case the scans were not defined. """ file_content = file.get_file_content(message) @@ -66,12 +65,7 @@ def process(self, message: msg.Message) -> None: self._process_response(response, target) def _process_response(self, response: dict[str, Any], target: str | None) -> None: - try: - scans = virustotal.get_scans(response) - except virustotal.VirusTotalApiError: - logger.error("Virus Total API encountered some problems. Please try again.") - return None - + scans = virustotal.get_scans(response) try: if scans is not None: technical_detail = process_scans.get_technical_details(scans, target) diff --git a/agent/virustotal.py b/agent/virustotal.py index 28e7ca6..ab99bc6 100644 --- a/agent/virustotal.py +++ b/agent/virustotal.py @@ -14,10 +14,6 @@ class Error(Exception): """Custom Error.""" -class VirusTotalApiError(Error): - """VirtualTotalApiError.""" - - def scan_file_from_message(file_content: bytes, api_key: str) -> dict: """Method responsible for scanning a file through the Virus Total public API. Args: @@ -54,15 +50,8 @@ def get_scans(response: dict[str, Any]) -> dict[str, Any] | None: Returns: scans: Dictionary of the scans. - - Raises: - VirusTotalApiError: In case the API request encountered problems. """ - if response.get("response_code") == 204 and response.get("error") is not None: - raise VirusTotalApiError() - elif response.get("response_code") == 0 or "results" not in response: - raise VirusTotalApiError() - elif response["results"]["response_code"] == 1: + if response.get("results", {}).get("response_code") == 1: return response["results"]["scans"] else: return None diff --git a/tests/test_agent.py b/tests/test_agent.py index bcc5221..d37f683 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -4,7 +4,6 @@ from typing import Any import requests_mock as rq_mock -import pytest from ostorlab.agent.message import message as msg from pytest_mock import plugin @@ -41,7 +40,7 @@ def virustotal_url_valid_response(url: str, timeout: int) -> dict[str, Any]: return response -def testVirusTotalAgent_whenVirusTotalApiReturnsValidResponse_noRaiseVirusTotalApiError( +def testVirusTotalAgent_whenVirusTotalApiReturnsValidResponse_noExceptionRaised( mocker: plugin.MockerFixture, agent_mock: list[msg.Message], virustotal_agent: virus_total_agent.VirusTotalAgent, @@ -85,13 +84,8 @@ def virustotal_valid_response(message: msg.Message) -> dict[str, Any]: "virus_total_apis.PublicApi.get_file_report", side_effect=virustotal_valid_response, ) + virustotal_agent.process(message) - try: - virustotal_agent.process(message) - except virustotal.VirusTotalApiError: - pytest.fail( - "Unexpected VirusTotalApiError because response is returned with status 200." - ) assert len(agent_mock) == 1 assert agent_mock[0].selector == "v3.report.vulnerability" assert agent_mock[0].data["risk_rating"] == "HIGH" @@ -130,9 +124,7 @@ def virustotal_invalid_response(message: msg.Message) -> dict[str, Any]: "virus_total_apis.PublicApi.get_file_report", side_effect=virustotal_invalid_response, ) - get_scans_mocker = mocker.patch( - "agent.virustotal.get_scans", side_effect=virustotal.VirusTotalApiError - ) + get_scans_mocker = mocker.patch("agent.virustotal.get_scans") virustotal_agent.process(message) @@ -286,7 +278,7 @@ def testVirusTotalAgent_whenFileIsWhitelisted_agentShouldScanFile( ) -def testVirusTotalAgent_whenVirusTotalReachesApiRateLimit_raiseVirusTotalApiError( +def testVirusTotalAgent_whenVirusTotalReachesApiRateLimit_returnNone( virustotal_agent: virus_total_agent.VirusTotalAgent, message: msg.Message, ) -> None: @@ -299,8 +291,9 @@ def testVirusTotalAgent_whenVirusTotalReachesApiRateLimit_raiseVirusTotalApiErro "response_code": 204, } - with pytest.raises(virustotal.VirusTotalApiError): - virustotal.get_scans(response) + scans = virustotal.get_scans(response) + + assert scans is None def testVirusTotalAgent_whenWhiteListTypesAreNotProvided_shouldNotCrash(