Skip to content

Commit

Permalink
Merge pull request #7 from Ostorlab/implement_registry_pattern
Browse files Browse the repository at this point in the history
Feature / implement Registry pattern
  • Loading branch information
benyissa authored Nov 17, 2023
2 parents d268a9c + 4e66089 commit a9de23a
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 8 deletions.
57 changes: 57 additions & 0 deletions agent/definitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Agent Asteriod definitions"""
import abc
import dataclasses

from ostorlab.agent.kb import kb
from ostorlab.agent.mixins import agent_report_vulnerability_mixin as vuln_mixin


@dataclasses.dataclass
class Target:
scheme: str
host: str
port: int


@dataclasses.dataclass
class Vulnerability:
"""Vulnerability entry with technical details, custom risk rating, DNA for unique identification and location."""

entry: kb.Entry
technical_detail: str
risk_rating: vuln_mixin.RiskRating
dna: str | None = None
vulnerability_location: vuln_mixin.VulnerabilityLocation | None = None


class Exploit(abc.ABC):
"""Base Exploit"""

@abc.abstractmethod
def accept(self, target: Target) -> bool:
"""Rule: heuristically detect if a specific target is valid.
Args:
target: Target to verify
Returns:
True if the target is valid; otherwise False.
"""
pass

@abc.abstractmethod
def check(self, target: Target) -> list[Vulnerability]:
"""Rule to detect specific vulnerability on a specific target.
Args:
target: target to scan
Returns:
List of identified vulnerabilities.
"""
pass

@property
def __key__(self) -> str:
"""Unique key for the class, mainly useful for registering the exploits."""
return self.__class__.__name__
44 changes: 44 additions & 0 deletions agent/exploits_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Register for exploits."""
from collections import defaultdict
from typing import Type, Any

from agent import definitions


class ExploitsRegistry:
"""Registry class, This class provides a way to store and retrieve callables that generate lists of
`definitions.Exploit` objects from a given file name and bytes object.
"""

registry: dict[Any, Any] = defaultdict(dict)

@classmethod
def register_ref(
cls,
obj: definitions.Exploit,
key: str = "__key__",
) -> definitions.Exploit:
cls.registry[cls.__name__][getattr(obj, key)] = obj
return obj

@classmethod
def values(
cls,
) -> list[Any]:
return list(cls.registry[cls.__name__].values())


def register(
f: Type[definitions.Exploit],
) -> Type[definitions.Exploit]:
"""
To be used as a decorator on the exploit class
Args:
f: The class which its object will be registered.
Returns:
The input callable.
"""
ExploitsRegistry.register_ref(obj=f())
return f
26 changes: 18 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
"""
Dummy conftest.py for template_agent.
"""Pytest fixtures for agent Asteroid"""
from typing import Type

If you don't know what this is for, just leave it empty.
Read more about conftest.py under:
- https://docs.pytest.org/en/stable/fixture.html
- https://docs.pytest.org/en/stable/writing_plugins.html
"""
import pytest

# import pytest
from agent import definitions


@pytest.fixture()
def exploit_instance() -> Type[definitions.Exploit]:
class TestExploit(definitions.Exploit):
"""test class Exploit."""

def accept(self, target: definitions.Target) -> bool:
return False

def check(self, target: definitions.Target) -> list[definitions.Vulnerability]:
return []

return TestExploit
14 changes: 14 additions & 0 deletions tests/exploits_registry_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Unit tests for the exploits' registry."""
from typing import Type
from agent import definitions
from agent import exploits_registry


def testExploitsRegistry_whenRegisteringClassDirectly_shouldLoadClass(
exploit_instance: Type[definitions.Exploit],
) -> None:
"""Ensure the exploits_registry registers instances of the exploits."""
exploits_registry.register(exploit_instance)
registered_exploits = exploits_registry.ExploitsRegistry.values()

assert len(registered_exploits) >= 1

0 comments on commit a9de23a

Please sign in to comment.