Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
benyissa committed Nov 20, 2023
1 parent 00aee9c commit ee7ebfb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 46 deletions.
31 changes: 25 additions & 6 deletions tests/asteroid_agent_test.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
"""Unit tests for AsteroidAgent."""
from typing import Type
from agent import asteroid_agent
from agent import definitions
import datetime


from ostorlab.agent.message import message as m
import requests
from pytest_mock import plugin

from agent import asteroid_agent

seed: int = 0


def testAsteroidAgent_whenExploitCheckDetectVulnz_EmitsVulnerabilityReport(
exploit_instance_with_report: Type[definitions.Exploit],
asteroid_agent_instance: asteroid_agent.AsteroidAgent,
agent_mock: list[m.Message],
scan_message_domain_name: m.Message,
scan_message_ipv4_for_cve_2023_27997: m.Message,
mocker: plugin.MockerFixture,
) -> None:
"""Unit test for agent AsteroidAgent exploits check. case Exploit emits vulnerability report"""

asteroid_agent_instance.process(scan_message_domain_name)
def side_effect(*args, **kwargs): # type: ignore[no-untyped-def]
global seed
mock_response = mocker.Mock(spec=requests.Response)
if seed % 2 == 0:
elapsed = datetime.timedelta(microseconds=2500)
else:
elapsed = datetime.timedelta(microseconds=1)

mock_response.elapsed = elapsed
seed += 1
return mock_response

mocker.patch("requests.sessions.Session.post", side_effect=side_effect)

asteroid_agent_instance.process(scan_message_ipv4_for_cve_2023_27997)

assert len(agent_mock) == 1
assert agent_mock[0].selector == "v3.report.vulnerability"
54 changes: 15 additions & 39 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,13 @@
"""Pytest fixtures for agent Asteroid"""
import pathlib
import random
from typing import Type

import pytest
from ostorlab.agent import definitions as agent_definitions
from ostorlab.agent.message import message
from ostorlab.runtimes import definitions as runtime_definitions
from ostorlab.agent.mixins import agent_report_vulnerability_mixin as vuln_mixin
from ostorlab.agent.kb import kb
from agent import asteroid_agent
from agent import exploits_registry
from agent import definitions


@pytest.fixture()
def exploit_instance_with_report() -> Type[definitions.Exploit]:
@exploits_registry.register
class TestExploit(definitions.Exploit):
"""test class Exploit."""

def accept(self, target: definitions.Target) -> bool:
return True

def check(self, target: definitions.Target) -> list[definitions.Vulnerability]:
return [
definitions.Vulnerability(
technical_detail="test",
entry=kb.Entry(
title="test",
risk_rating="INFO",
short_description="test purposes",
description="test purposes",
recommendation="",
references={},
security_issue=False,
privacy_issue=False,
has_public_exploit=False,
targeted_by_malware=False,
targeted_by_ransomware=False,
targeted_by_nation_state=False,
),
risk_rating=vuln_mixin.RiskRating.HIGH,
)
]

return TestExploit
from agent import asteroid_agent


@pytest.fixture()
Expand Down Expand Up @@ -84,6 +46,20 @@ def scan_message_ipv4() -> message.Message:
return message.Message.from_data(selector, data=msg_data)


@pytest.fixture()
def scan_message_ipv4_for_cve_2023_27997() -> message.Message:
"""Creates a message of type v3.asset.ip.v4 to be used by the agent for testing purposes."""
selector = "v3.asset.ip.v4.port"
msg_data = {
"host": "91.135.170.42",
"mask": "32",
"version": 4,
"port": 8443,
"protocol": "https",
}
return message.Message.from_data(selector, data=msg_data)


@pytest.fixture()
def asteroid_agent_instance() -> asteroid_agent.AsteroidAgent:
with (pathlib.Path(__file__).parent.parent / "ostorlab.yaml").open() as yaml_o:
Expand Down
2 changes: 1 addition & 1 deletion tests/exploits_registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def testExploitsRegistry_importingAllExploits_registerAll() -> None:

registered_exploits = exploits_registry.ExploitsRegistry.values()

assert len(registered_exploits) == 3
assert len(registered_exploits) == 2


def testExploitsRegistry_allExploits_mustBeRegisteredOnce() -> None:
Expand Down

0 comments on commit ee7ebfb

Please sign in to comment.