diff --git a/agent/definitions.py b/agent/definitions.py index 6c6970a6..45801a86 100644 --- a/agent/definitions.py +++ b/agent/definitions.py @@ -38,9 +38,11 @@ def vulnerability_description(self) -> str: """Vulnerability description""" pass - def is_target_valid(self) -> bool: - return False + @abc.abstractmethod + def accept(self) -> bool: + pass + @abc.abstractmethod def check(self) -> list[Vulnerability] | None: """Rule to detect specific vulnerability on a specific target. @@ -50,4 +52,4 @@ def check(self) -> list[Vulnerability] | None: Returns: List of identified vulnerabilities. """ - return None + pass diff --git a/agent/exploits/cve_2021_22941.py b/agent/exploits/cve_2021_22941.py index f35a8435..791ab9ce 100644 --- a/agent/exploits/cve_2021_22941.py +++ b/agent/exploits/cve_2021_22941.py @@ -1,29 +1,14 @@ """Agent Asteroid implementation for CVE-2021-22941""" -import logging -import warnings import requests from ostorlab.agent.kb import kb from ostorlab.agent.mixins import agent_report_vulnerability_mixin from requests import exceptions as requests_exceptions -from rich import logging as rich_logging from agent import definitions -warnings.filterwarnings("ignore") - DEFAULT_TIMEOUT = 90 -logging.basicConfig( - format="%(message)s", - datefmt="[%X]", - handlers=[rich_logging.RichHandler(rich_tracebacks=True)], - level="INFO", - force=True, -) -logger = logging.getLogger(__name__) - - def _encode_multipart_formdata(files: dict[str, str]) -> tuple[str, str]: boundary = "boundary" body = "".join( @@ -38,7 +23,7 @@ def _encode_multipart_formdata(files: dict[str, str]) -> tuple[str, str]: class Exploit(definitions.BaseExploit): """ - CVE: CVE-2021-22941 + CVE-2021-22941: Improper Access Control in Citrix ShareFile storage zones controller """ def __init__(self, target: str): @@ -59,11 +44,10 @@ def vulnerability_description(self) -> str: "allow an unauthenticated attacker to remotely compromise the storage zones controller." ) - def is_target_valid(self) -> bool: + def accept(self) -> bool: try: req = requests.get(self.target, verify=False, timeout=DEFAULT_TIMEOUT) except requests_exceptions.RequestException: - logger.error("Failed to reach target %s", self.target) return False return "ShareFile" in req.text @@ -94,7 +78,6 @@ def check(self) -> list[definitions.Vulnerability]: timeout=DEFAULT_TIMEOUT, ) except requests_exceptions.RequestException as exc: - logger.error("Failed to send payload, error message: %s", exc) return [] if payload not in req.text: return [] diff --git a/tests/exploits_test.py b/tests/exploits_test.py index bc4069f6..21d74710 100644 --- a/tests/exploits_test.py +++ b/tests/exploits_test.py @@ -7,6 +7,7 @@ def testCVE_2021_22941_whenVulnerable_reportFinding( requests_mock: req_mock.mocker.Mocker, ) -> None: + """Unit test for CVE-2021-22941, case when target is vulnerable""" exploit_instance = cve_2021_22941.Exploit("https://75.162.65.52") requests_mock.post(re.compile("https://75.162.65.52")) requests_mock.get(