Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/Remove the handling case of the API rate limit #40

Merged
merged 6 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions agent/virus_total_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def process(self, message: msg.Message) -> None:
message: Message containing the file to scan.

Raises:
VirusTotalApiError: In case the Virus Total api encountered problems.
NameError: In case the scans were not defined.
"""
file_content = file.get_file_content(message)
Expand All @@ -66,12 +65,7 @@ def process(self, message: msg.Message) -> None:
self._process_response(response, target)

def _process_response(self, response: dict[str, Any], target: str | None) -> None:
try:
scans = virustotal.get_scans(response)
except virustotal.VirusTotalApiError:
logger.error("Virus Total API encountered some problems. Please try again.")
return None

scans = virustotal.get_scans(response)
try:
if scans is not None:
technical_detail = process_scans.get_technical_details(scans, target)
Expand Down
13 changes: 1 addition & 12 deletions agent/virustotal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class Error(Exception):
"""Custom Error."""


class VirusTotalApiError(Error):
"""VirtualTotalApiError."""


def scan_file_from_message(file_content: bytes, api_key: str) -> dict:
"""Method responsible for scanning a file through the Virus Total public API.
Args:
Expand Down Expand Up @@ -54,15 +50,8 @@ def get_scans(response: dict[str, Any]) -> dict[str, Any] | None:

Returns:
scans: Dictionary of the scans.

Raises:
VirusTotalApiError: In case the API request encountered problems.
"""
if response.get("response_code") == 204 and response.get("error") is not None:
raise VirusTotalApiError()
elif response.get("response_code") == 0 or "results" not in response:
raise VirusTotalApiError()
elif response["results"]["response_code"] == 1:
if response.get("results", {}).get("response_code") == 1:
return response["results"]["scans"]
else:
return None
21 changes: 7 additions & 14 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any
import requests_mock as rq_mock

import pytest
from ostorlab.agent.message import message as msg
from pytest_mock import plugin

Expand Down Expand Up @@ -41,7 +40,7 @@ def virustotal_url_valid_response(url: str, timeout: int) -> dict[str, Any]:
return response


def testVirusTotalAgent_whenVirusTotalApiReturnsValidResponse_noRaiseVirusTotalApiError(
def testVirusTotalAgent_whenVirusTotalApiReturnsValidResponse_noExceptionRaised(
mocker: plugin.MockerFixture,
agent_mock: list[msg.Message],
virustotal_agent: virus_total_agent.VirusTotalAgent,
Expand Down Expand Up @@ -85,13 +84,8 @@ def virustotal_valid_response(message: msg.Message) -> dict[str, Any]:
"virus_total_apis.PublicApi.get_file_report",
side_effect=virustotal_valid_response,
)
virustotal_agent.process(message)

try:
virustotal_agent.process(message)
except virustotal.VirusTotalApiError:
pytest.fail(
"Unexpected VirusTotalApiError because response is returned with status 200."
)
assert len(agent_mock) == 1
assert agent_mock[0].selector == "v3.report.vulnerability"
assert agent_mock[0].data["risk_rating"] == "HIGH"
Expand Down Expand Up @@ -130,9 +124,7 @@ def virustotal_invalid_response(message: msg.Message) -> dict[str, Any]:
"virus_total_apis.PublicApi.get_file_report",
side_effect=virustotal_invalid_response,
)
get_scans_mocker = mocker.patch(
"agent.virustotal.get_scans", side_effect=virustotal.VirusTotalApiError
)
get_scans_mocker = mocker.patch("agent.virustotal.get_scans")

virustotal_agent.process(message)

Expand Down Expand Up @@ -286,7 +278,7 @@ def testVirusTotalAgent_whenFileIsWhitelisted_agentShouldScanFile(
)


def testVirusTotalAgent_whenVirusTotalReachesApiRateLimit_raiseVirusTotalApiError(
def testVirusTotalAgent_whenVirusTotalReachesApiRateLimit_returnNone(
virustotal_agent: virus_total_agent.VirusTotalAgent,
message: msg.Message,
) -> None:
Expand All @@ -299,8 +291,9 @@ def testVirusTotalAgent_whenVirusTotalReachesApiRateLimit_raiseVirusTotalApiErro
"response_code": 204,
}

with pytest.raises(virustotal.VirusTotalApiError):
virustotal.get_scans(response)
scans = virustotal.get_scans(response)

assert scans is None


def testVirusTotalAgent_whenWhiteListTypesAreNotProvided_shouldNotCrash(
Expand Down
Loading