Skip to content

Commit

Permalink
Merge pull request #2347 from guardicore/2261-move-agent-signals-to-c…
Browse files Browse the repository at this point in the history
…ommon

2261 move agent signals to common
  • Loading branch information
mssalvatore committed Sep 23, 2022
2 parents feb8288 + f7198ea commit a49ddf7
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 24 deletions.
1 change: 1 addition & 0 deletions monkey/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from . import types
from . import base_models
from .agent_registration_data import AgentRegistrationData
from .agent_signals import AgentSignals
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Optional

from common.base_models import InfectionMonkeyBaseModel
from .base_models import InfectionMonkeyBaseModel


class AgentSignals(InfectionMonkeyBaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import functools
import json
import logging
from datetime import datetime
from pprint import pformat
from typing import List, Optional, Sequence
from typing import List, Sequence

import requests

from common import AgentRegistrationData, OperatingSystem
from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
from common.agent_events import AbstractAgentEvent
Expand Down Expand Up @@ -189,15 +188,15 @@ def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSeriali

@handle_island_errors
@convert_json_error_to_island_api_error
def get_agent_signals(self, agent_id: str) -> Optional[datetime]:
def get_agent_signals(self, agent_id: str) -> AgentSignals:
url = f"{self._api_url}/agent-signals/{agent_id}"
response = requests.get( # noqa: DUO123
url,
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
return response.json()["terminate"]
return AgentSignals(**response.json())


class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, Sequence

from common import AgentRegistrationData, OperatingSystem
from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration
from common.agent_events import AbstractAgentEvent
from common.credentials import Credentials
Expand Down Expand Up @@ -133,7 +132,7 @@ def get_credentials_for_propagation(self) -> Sequence[Credentials]:
"""

@abstractmethod
def get_agent_signals(self, agent_id: str) -> Optional[datetime]:
def get_agent_signals(self, agent_id: str) -> AgentSignals:
"""
Gets an agent's signals from the island
Expand All @@ -142,5 +141,5 @@ def get_agent_signals(self, agent_id: str) -> Optional[datetime]:
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
:return: The relevant agent's terminate signal's timestamp
:return: The relevant agent's signals
"""
3 changes: 2 additions & 1 deletion monkey/infection_monkey/master/control_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def should_agent_stop(self) -> bool:
if not self._control_channel_server:
logger.error("Agent should stop because it can't connect to the C&C server.")
return True
return self._island_api_client.get_agent_signals(self._agent_id) is not None
agent_signals = self._island_api_client.get_agent_signals(self._agent_id)
return agent_signals.terminate is not None

@handle_island_api_errors
def get_config(self) -> AgentConfiguration:
Expand Down
1 change: 0 additions & 1 deletion monkey/monkey_island/cc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
from .node import Node
from common.types import AgentID
from .agent import Agent
from .agent_signals import AgentSignals
3 changes: 2 additions & 1 deletion monkey/monkey_island/cc/services/agent_signals_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from datetime import datetime
from typing import Optional

from common.agent_signals import AgentSignals
from common.types import AgentID
from monkey_island.cc.models import AgentSignals, Simulation
from monkey_island.cc.models import Simulation
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import requests
import requests_mock

from common import OperatingSystem
from common import AgentSignals, OperatingSystem
from common.agent_event_serializers import (
AgentEventSerializerRegistry,
PydanticAgentEventSerializer,
Expand Down Expand Up @@ -456,16 +456,17 @@ def test_island_api_client_get_agent_signals__status_code(
island_api_client.get_agent_signals(agent_id=AGENT_ID)


@pytest.mark.parametrize("expected_timestamp", [TIMESTAMP, None])
def test_island_api_client_get_agent_signals(island_api_client, expected_timestamp):
@pytest.mark.parametrize("timestamp", [TIMESTAMP, None])
def test_island_api_client_get_agent_signals(island_api_client, timestamp):
expected_agent_signals = AgentSignals(terminate=timestamp)
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)

m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": expected_timestamp})
actual_terminate_timestamp = island_api_client.get_agent_signals(agent_id=AGENT_ID)
m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": timestamp})
actual_agent_signals = island_api_client.get_agent_signals(agent_id=AGENT_ID)

assert actual_terminate_timestamp == expected_timestamp
assert actual_agent_signals == expected_agent_signals


def test_island_api_client_get_agent_signals__bad_json(island_api_client):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional
from unittest.mock import MagicMock

import pytest

from infection_monkey.i_control_channel import IslandCommunicationError
from common import AgentSignals
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.island_api_client import (
IIslandAPIClient,
IslandAPIConnectionError,
Expand Down Expand Up @@ -33,9 +35,17 @@ def control_channel(island_api_client) -> ControlChannel:
return ControlChannel(SERVER, AGENT_ID, island_api_client)


def test_control_channel__should_agent_stop(control_channel, island_api_client):
control_channel.should_agent_stop()
assert island_api_client.get_agent_signals.called_once()
@pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)])
def test_control_channel__should_agent_stop(
control_channel: IControlChannel,
island_api_client: IIslandAPIClient,
signal_time: Optional[int],
expected_should_stop: bool,
):
island_api_client.get_agent_signals = MagicMock(
return_value=AgentSignals(terminate=signal_time)
)
assert control_channel.should_agent_stop() is expected_should_stop


@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from tests.common import StubDIContainer

from monkey_island.cc.models import AgentSignals as Signals
from common.agent_signals import AgentSignals as Signals
from monkey_island.cc.repository import RetrievalError, StorageError
from monkey_island.cc.services import AgentSignalsService

Expand Down

0 comments on commit a49ddf7

Please sign in to comment.