diff --git a/agent/definitions.py b/agent/definitions.py new file mode 100644 index 00000000..49046af3 --- /dev/null +++ b/agent/definitions.py @@ -0,0 +1,55 @@ +"""Agent Asteriod definitions""" +import abc +import dataclasses + +from ostorlab.agent.kb import kb +from ostorlab.agent.mixins import agent_report_vulnerability_mixin as vuln_mixin + + +@dataclasses.dataclass +class Target: + scheme: str + host: str + port: int + + +@dataclasses.dataclass +class Vulnerability: + """Vulnerability entry with technical details, custom risk rating, DNA for unique identification and location.""" + + entry: kb.Entry + technical_detail: str + risk_rating: vuln_mixin.RiskRating + dna: str | None = None + vulnerability_location: vuln_mixin.VulnerabilityLocation | None = None + + +class Exploit(abc.ABC): + """Base Exploit""" + + @abc.abstractmethod + def accept(self, target: Target) -> bool: + """Rule: heuristically detect if a specific target is valid. + Args: + target: Target to verify + Returns: + True if the target is valid; false otherwise. + """ + pass + + @abc.abstractmethod + def check(self, target: Target) -> list[Vulnerability]: + """Rule to detect specific vulnerability on a specific target. + + Args: + target: target to scan + + Returns: + List of identified vulnerabilities. + """ + pass + + @property + def __key__(self) -> str: + """Unique key for the class, mainly useful for registering the exploits.""" + return self.__class__.__name__ diff --git a/agent/exploits_registry.py b/agent/exploits_registry.py new file mode 100644 index 00000000..952655df --- /dev/null +++ b/agent/exploits_registry.py @@ -0,0 +1,44 @@ +"""Register for exploits.""" +from collections import defaultdict +from typing import Type, Any + +from agent import definitions + + +class ExploitsRegistry: + """Registry class, This class provides a way to store and retrieve callables that generate lists of + `definitions.Exploit` objects from a given file name and bytes object. + """ + + registry: dict[Any, Any] = defaultdict(dict) + + @classmethod + def register_ref( + cls, + obj: definitions.Exploit, + key: str = "__key__", + ) -> definitions.Exploit: + cls.registry[cls.__name__][getattr(obj, key)] = obj + return obj + + @classmethod + def values( + cls, + ) -> list[Any]: + return list(cls.registry[cls.__name__].values()) + + +def register( + f: Type[definitions.Exploit], +) -> Type[definitions.Exploit]: + """ + To be used as a decorator on the exploit class + + Args: + f: The class which its object will be registered. + + Returns: + The input callable. + """ + ExploitsRegistry.register_ref(obj=f()) + return f diff --git a/tests/conftest.py b/tests/conftest.py index d0a7eb1d..f1919a4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,20 @@ -""" - Dummy conftest.py for template_agent. +"""Pytest fixtures for agent Asteroid""" +from typing import Type - If you don't know what this is for, just leave it empty. - Read more about conftest.py under: - - https://docs.pytest.org/en/stable/fixture.html - - https://docs.pytest.org/en/stable/writing_plugins.html -""" +import pytest -# import pytest +from agent import definitions + + +@pytest.fixture() +def exploit_instance() -> Type[definitions.Exploit]: + class TestExploit(definitions.Exploit): + """test class Exploit.""" + + def accept(self, target: definitions.Target) -> bool: + return False + + def check(self, target: definitions.Target) -> list[definitions.Vulnerability]: + return [] + + return TestExploit diff --git a/tests/exploits_registry_test.py b/tests/exploits_registry_test.py new file mode 100644 index 00000000..a8d491b3 --- /dev/null +++ b/tests/exploits_registry_test.py @@ -0,0 +1,14 @@ +"""Unit tests for the exploits' registry.""" +from typing import Type +from agent import definitions +from agent import exploits_registry + + +def testExploitsRegistry_whenRegisteringClassDirectly_shouldLoadClass( + exploit_instance: Type[definitions.Exploit], +) -> None: + """Ensure the exploits_registry registers instances of the exploits.""" + exploits_registry.register(exploit_instance) + registered_exploits = exploits_registry.ExploitsRegistry.values() + + assert len(registered_exploits) >= 1