Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy to agent virustotal #46

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[mypy]
files = agent, tests
check_untyped_defs = True
follow_imports_for_stubs = True
disallow_any_decorated = True
disallow_any_generics = True
disallow_incomplete_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
implicit_reexport = False
no_implicit_optional = True
show_error_codes = True
strict_equality = True
warn_incomplete_stub = True
warn_redundant_casts = True
warn_unreachable = True
warn_unused_ignores = True
disallow_any_unimported = False
warn_return_any = True
exclude = .*_pb2.py

[mypy-virus_total_apis]
ignore_missing_imports = True
4 changes: 3 additions & 1 deletion agent/file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Collection of functions to handle files."""
from typing import cast

import requests
import tenacity
from ostorlab.agent.message import message as m
Expand Down Expand Up @@ -41,7 +43,7 @@ def get_file_content(message: m.Message) -> bytes | None:
"""
content = message.data.get("content")
if content is not None and isinstance(content, bytes):
return content
return cast(bytes, content)
content_url: str | None = message.data.get("content_url")
if content_url is not None:
return _download_file(content_url)
Expand Down
5 changes: 3 additions & 2 deletions agent/process_scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

from ostorlab.agent.mixins import agent_report_vulnerability_mixin

from agent import markdown


Expand Down Expand Up @@ -30,9 +31,9 @@ def get_technical_details(scans: dict[str, Any], target: str | None) -> str:
Returns:
technical_detail : Markdown table of the scans results.
"""
scans = markdown.prepare_data_for_markdown_formatting(scans)
formatted_scans = markdown.prepare_data_for_markdown_formatting(scans)
technical_detail = ""
if target is not None:
technical_detail = f"Analysis of the target `{target}`:\n"
technical_detail += markdown.table_markdown(scans)
technical_detail += markdown.table_markdown(formatted_scans)
return technical_detail
10 changes: 7 additions & 3 deletions agent/virus_total_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""VirusTotal agent implementation : Agent responsible for scanning a file through the Virus Total DB."""
import hashlib
import ipaddress
import logging
from typing import Any
import hashlib
from typing import cast

import magic
from ostorlab.agent import agent, definitions as agent_definitions
Expand All @@ -16,7 +17,6 @@
from agent import process_scans
from agent import virustotal


logging.basicConfig(
format="%(message)s",
datefmt="[%X]",
Expand Down Expand Up @@ -47,7 +47,11 @@ def __init__(
agent_settings: Settings of running instance of the agent.
"""
super().__init__(agent_definition, agent_settings)
self.api_key = self.args.get("api_key")
api_key = self.args.get("api_key")
if api_key is None:
raise ValueError("Virustotal API Key is not set")
else:
self.api_key = cast(str, api_key)
self.whitelist_types = self.args.get("whitelist_types") or []

def process(self, message: msg.Message) -> None:
Expand Down
12 changes: 6 additions & 6 deletions agent/virustotal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module responsible for interacting with Virus Total public API."""
import hashlib
from typing import Any
import logging
from typing import Any, cast

import virus_total_apis

Expand All @@ -14,7 +14,7 @@ class Error(Exception):
"""Custom Error."""


def scan_file_from_message(file_content: bytes, api_key: str) -> dict:
def scan_file_from_message(file_content: bytes, api_key: str) -> dict[str, Any]:
"""Method responsible for scanning a file through the Virus Total public API.
Args:
file_content: Message containing the file to scan.
Expand All @@ -26,10 +26,10 @@ def scan_file_from_message(file_content: bytes, api_key: str) -> dict:
hash_hexa = file_md5_hash.hexdigest()
virustotal_client = virus_total_apis.PublicApi(api_key)
response = virustotal_client.get_file_report(hash_hexa)
return response
return cast(dict[str, Any], response)


def scan_url_from_message(target: str, api_key: str) -> dict:
def scan_url_from_message(target: str, api_key: str) -> dict[str, Any]:
"""Method responsible for scanning a file through the Virus Total public API.
Args:
target: url to scan.
Expand All @@ -39,7 +39,7 @@ def scan_url_from_message(target: str, api_key: str) -> dict:
"""
virustotal_client = virus_total_apis.PublicApi(api_key)
response = virustotal_client.get_url_report(target, timeout=TIMEOUT_REQUEST)
return response
return cast(dict[str, Any], response)


def get_scans(response: dict[str, Any]) -> dict[str, Any] | None:
Expand All @@ -52,6 +52,6 @@ def get_scans(response: dict[str, Any]) -> dict[str, Any] | None:
scans: Dictionary of the scans.
"""
if response.get("results", {}).get("response_code") == 1:
return response["results"]["scans"]
return cast(dict[str, Any], response["results"]["scans"])
else:
return None
6 changes: 3 additions & 3 deletions tests/virus_total_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import pathlib
import re
from typing import Any
import requests_mock as rq_mock

import pytest
import requests_mock as rq_mock
from ostorlab.agent.message import message as msg
from pytest_mock import plugin
import pytest

from agent import virus_total_agent
from agent import virustotal
Expand Down Expand Up @@ -275,7 +275,7 @@ def testVirusTotalAgent_whenFileIsWhitelisted_agentShouldScanFile(

assert virustotal_call.called is True
assert (
virustotal_call.last_request.query
virustotal_call.last_request.query # type: ignore
3asm marked this conversation as resolved.
Show resolved Hide resolved
== "apikey=some_api_key&resource=e29efc13355681a4aa23f0623c2316b9"
)

Expand Down
Loading