diff --git a/src/ostorlab/cli/docker_requirements_checker.py b/src/ostorlab/cli/docker_requirements_checker.py index 3b1a5e909..a28ad4357 100644 --- a/src/ostorlab/cli/docker_requirements_checker.py +++ b/src/ostorlab/cli/docker_requirements_checker.py @@ -1,11 +1,18 @@ """Check if requirements for running docker are satisfied.""" -import docker import platform import sys + +import docker +import tenacity from docker import errors +from ostorlab import exceptions + _SUPPORTED_ARCH_TYPES = ["x86_64", "AMD64"] +RETRY_ATTEMPTS = 10 +WAIT_TIME = 2 + # The architecture is checked with a return value that's based on the kernel implementation of the uname(2) # system call. So it might be necesarry to handle the same arch with various strings e.g. linux returns x86_64 # or AMD64 on windows. @@ -97,7 +104,31 @@ def is_swarm_initialized() -> bool: def init_swarm() -> None: + """Initializes Docker Swarm. + + This function attempts to initialize Docker Swarm. If the initialization fails, + it retries 10 times with a 2-second delay between each attempt. + If it still fails after 10 attempts, it raises an OstorlabError. + + Raises: + OstorlabError: If the user does not have permission to run Docker, + or if the initialization fails after 10 attempts. + """ + if is_user_permitted() is False: + raise errors.DockerException("User does not have permission to run docker.") + try: + _init_swarm() + except errors.DockerException as e: + raise exceptions.OstorlabError("Error while initializing swarm.") from e + + +@tenacity.retry( + stop=tenacity.stop_after_attempt(RETRY_ATTEMPTS), + wait=tenacity.wait_fixed(WAIT_TIME), + retry=tenacity.retry_if_exception_type(errors.DockerException), + reraise=True, +) +def _init_swarm() -> None: """Initialize docker swarm""" - if is_user_permitted(): - docker_client = docker.from_env() - docker_client.swarm.init() + docker_client = docker.from_env() + docker_client.swarm.init() diff --git a/src/ostorlab/cli/scan/run/run.py b/src/ostorlab/cli/scan/run/run.py index ecec4c732..9401c13b7 100644 --- a/src/ostorlab/cli/scan/run/run.py +++ b/src/ostorlab/cli/scan/run/run.py @@ -4,18 +4,18 @@ - ostorlab scan run --agent=agent1 --agent=agent2 --title=test_scan [asset] [options].""" import io import logging -import requests from typing import List import click +import requests +from ostorlab import exceptions +from ostorlab.agent.schema import validator +from ostorlab.cli import console as cli_console from ostorlab.cli import install_agent from ostorlab.cli.scan import scan from ostorlab.runtimes import definitions from ostorlab.runtimes import runtime -from ostorlab.cli import console as cli_console -from ostorlab.agent.schema import validator - console = cli_console.Console() @@ -96,8 +96,12 @@ def run( runtime_instance: runtime.Runtime = ctx.obj["runtime"] # set list of log follow. runtime_instance.follow = follow - - if runtime_instance.can_run(agent_group_definition=agent_group): + try: + can_run_scan = runtime_instance.can_run(agent_group_definition=agent_group) + except exceptions.OstorlabError as e: + console.error(f"{e}") + return None + if can_run_scan is True: ctx.obj["agent_group_definition"] = agent_group ctx.obj["title"] = title if install is True: diff --git a/tests/cli/docker_requirements_checker_test.py b/tests/cli/docker_requirements_checker_test.py new file mode 100644 index 000000000..a75b3ac4b --- /dev/null +++ b/tests/cli/docker_requirements_checker_test.py @@ -0,0 +1,27 @@ +"""Tests for the docker_requirements_checker module.""" +import pytest +from pytest_mock import plugin +from requests_mock import mocker as req_mocker + +from ostorlab import exceptions +from ostorlab.cli import docker_requirements_checker + + +@pytest.mark.docker +def testRuntime_WhenCantInitSwarm_shouldRetry( + mocker: plugin.MockerFixture, requests_mock: req_mocker.Mocker +) -> None: + """Ensure the runtime retries to init swarm if it fails the first time.""" + mocker.patch("time.sleep") + requests_mock.get( + "http+docker://localhost/version", [{"json": {"ApiVersion": "1.35"}}] + ) + requests_mock.get("http+docker://localhost/v1.35/swarm", json={"ID": "1234"}) + mock_swarm_init = requests_mock.post( + "http+docker://localhost/v1.35/swarm/init", status_code=400 + ) + + with pytest.raises(exceptions.OstorlabError): + docker_requirements_checker.init_swarm() + + assert mock_swarm_init.call_count == 3 diff --git a/tests/runtimes/local/runtime_test.py b/tests/runtimes/local/runtime_test.py index 47a6cfbf9..b0654dc4d 100644 --- a/tests/runtimes/local/runtime_test.py +++ b/tests/runtimes/local/runtime_test.py @@ -2,11 +2,13 @@ from typing import Any import docker -from docker.models import services as services_model import pytest +from docker.models import services as services_model from pytest_mock import plugin +from requests_mock import mocker as req_mocker import ostorlab +from ostorlab import exceptions from ostorlab.assets import android_apk from ostorlab.runtimes import definitions from ostorlab.runtimes.local import runtime as local_runtime @@ -280,3 +282,41 @@ def testScanInLocalRuntime_whenScanIdIsPassed_shouldUseTheScanIdAsUniverseLabelI assert session.query(models.Scan).count() == 1 scan = session.query(models.Scan).first() assert scan.id != 42 + + +@pytest.mark.docker +def testRuntime_WhenCantInitSwarm_shouldShowUserFriendlyMessage( + mocker: plugin.MockerFixture, + requests_mock: req_mocker.Mocker, +) -> None: + """Ensure the runtime retries to init swarm if it fails the first time.""" + mocker.patch( + "ostorlab.cli.docker_requirements_checker.is_docker_working", return_value=True + ) + mocker.patch( + "ostorlab.cli.docker_requirements_checker.is_user_permitted", return_value=True + ) + mocker.patch( + "ostorlab.cli.docker_requirements_checker.is_sys_arch_supported", + return_value=True, + ) + mocker.patch( + "ostorlab.cli.docker_requirements_checker.is_swarm_initialized", + return_value=False, + ) + mocker.patch("time.sleep") + requests_mock.get( + "http+docker://localhost/version", [{"json": {"ApiVersion": "1.35"}}] + ) + requests_mock.get("http+docker://localhost/v1.35/swarm", json={"ID": "1234"}) + mock_swarm_init = requests_mock.post( + "http+docker://localhost/v1.35/swarm/init", status_code=400 + ) + local_runtime_instance = local_runtime.LocalRuntime(run_default_agents=False) + + with pytest.raises(exceptions.OstorlabError): + local_runtime_instance.can_run( + agent_group_definition=definitions.AgentGroupDefinition(agents=[]) + ) + + assert mock_swarm_init.call_count == 3