diff --git a/src/ostorlab/cli/scan/stop/stop.py b/src/ostorlab/cli/scan/stop/stop.py index 180772baa..7eab12998 100644 --- a/src/ostorlab/cli/scan/stop/stop.py +++ b/src/ostorlab/cli/scan/stop/stop.py @@ -3,6 +3,8 @@ Example of usage: - ostorlab scan list --source=source.""" +from typing import Tuple + import click from ostorlab.cli.scan import scan from ostorlab.cli import console as cli_console @@ -11,14 +13,36 @@ @scan.command() -@click.argument("scan_id", required=True) +@click.argument("scan_ids", nargs=-1, type=int, required=False) +@click.option( + "--all", + "-a", + "stop_all", + is_flag=True, + help="Stop all running scans", + default=False, +) @click.pass_context -def stop(ctx: click.core.Context, scan_id: int) -> None: - """Stop a scan.\n +def stop(ctx: click.core.Context, scan_ids: Tuple[int, ...], stop_all: bool) -> None: + """Stop one or multiple scans.\n Usage:\n - - ostorlab scan --runtime=local stop --id=id + - ostorlab scan --runtime=local stop 4 + - ostorlab scan --runtime=local stop 4 5 6 + - ostorlab scan --runtime=local stop --all """ + if len(scan_ids) == 0 and stop_all is False: + raise click.UsageError("Either provide scan IDs or use --all flag") runtime_instance = ctx.obj["runtime"] - with console.status("Stopping scan"): + if stop_all is True: + scans_list = runtime_instance.list() + ids_to_stop = [s.id for s in scans_list] + if len(ids_to_stop) == 0: + console.warning("No running scans found.") + return + else: + ids_to_stop = list(scan_ids) + + console.info(f"Stopping {len(ids_to_stop)} scan(s).") + for scan_id in ids_to_stop: runtime_instance.stop(scan_id=scan_id) diff --git a/src/ostorlab/runtimes/local/runtime.py b/src/ostorlab/runtimes/local/runtime.py index 0c567854a..dece11a37 100644 --- a/src/ostorlab/runtimes/local/runtime.py +++ b/src/ostorlab/runtimes/local/runtime.py @@ -300,7 +300,9 @@ def stop(self, scan_id: str) -> None: logger.info( "comparing %s and %s", service_labels.get("ostorlab.universe"), scan_id ) - if service_labels.get("ostorlab.universe") == scan_id: + if service_labels.get("ostorlab.universe") is not None and int( + service_labels.get("ostorlab.universe") + ) == int(scan_id): stopped_services.append(service) service.remove() @@ -309,7 +311,7 @@ def stop(self, scan_id: str) -> None: network_labels = network.attrs["Labels"] if ( network_labels is not None - and network_labels.get("ostorlab.universe") == scan_id + and int(network_labels.get("ostorlab.universe")) == scan_id ): logger.info("removing network %s", network_labels) stopped_network.append(network) @@ -318,7 +320,10 @@ def stop(self, scan_id: str) -> None: configs = self._docker_client.configs.list() for config in configs: config_labels = config.attrs["Spec"]["Labels"] - if config_labels.get("ostorlab.universe") == scan_id: + if ( + config_labels.get("ostorlab.universe") is not None + and int(config_labels.get("ostorlab.universe")) == scan_id + ): logger.info("removing config %s", config_labels) stopped_configs.append(config) config.remove() diff --git a/tests/cli/scan/stop/test_scan_stop.py b/tests/cli/scan/stop/test_scan_stop.py index 697baa9f3..d35bb0b8f 100644 --- a/tests/cli/scan/stop/test_scan_stop.py +++ b/tests/cli/scan/stop/test_scan_stop.py @@ -1,12 +1,14 @@ """Tests for scan stop command.""" +from unittest import mock + from click.testing import CliRunner -from ostorlab.cli import rootcli +from pytest_mock import plugin + from ostorlab.apis.runners import authenticated_runner +from ostorlab.cli import rootcli from ostorlab.runtimes.local import runtime as local_runtime -from unittest import mock - def testOstorlabScanStopCLI_whenRuntimeIsRemoteAndScanIdIsValid_stopsScan( httpx_mock, @@ -70,4 +72,106 @@ def testOstorlabScanStopCLI_whenRuntimeIsLocal_callsStopMethodWithProvidedId( runner.invoke(rootcli.rootcli, ["scan", "--runtime=local", "stop", "123456"]) - mock_scan_stop.assert_called_once_with(scan_id="123456") + mock_scan_stop.assert_called_once_with(scan_id=123456) + + +@mock.patch.object(local_runtime.LocalRuntime, "stop") +def testOstorlabScanStopCLI_whenMultipleScanIdsAreProvided_stopsAllProvidedScans( + mock_scan_stop: mock.Mock, mocker: plugin.MockerFixture +) -> None: + """Test ostorlab scan stop command with multiple scan ids. + Should call stop method for each provided scan id. + """ + + mock_scan_stop.return_value = None + mocker.patch("ostorlab.runtimes.local.LocalRuntime.__init__", return_value=None) + runner = CliRunner() + + result = runner.invoke( + rootcli.rootcli, ["scan", "--runtime=local", "stop", "1", "2", "3"] + ) + + assert result.exception is None + assert "Stopping 3 scan(s)" in result.output + assert mock_scan_stop.call_count == 3 + mock_scan_stop.assert_any_call(scan_id=1) + mock_scan_stop.assert_any_call(scan_id=2) + mock_scan_stop.assert_any_call(scan_id=3) + + +@mock.patch.object(local_runtime.LocalRuntime, "stop") +@mock.patch.object(local_runtime.LocalRuntime, "list") +def testOstorlabScanStopCLI_whenStopAllIsUsedAndScansExist_stopsAllScans( + mock_list_scans: mock.Mock, mock_scan_stop: mock.Mock, mocker: plugin.MockerFixture +) -> None: + """Test ostorlab scan stop command with --all flag. + Should stop all running scans. + """ + + mock_list_scans.return_value = [ + mock.Mock(id=101), + mock.Mock(id=102), + mock.Mock(id=103), + ] + mock_scan_stop.return_value = None + mocker.patch("ostorlab.runtimes.local.LocalRuntime.__init__", return_value=None) + runner = CliRunner() + + result = runner.invoke( + rootcli.rootcli, ["scan", "--runtime=local", "stop", "--all"] + ) + + assert result.exception is None + assert "Stopping 3 scan(s)" in result.output + assert mock_scan_stop.call_count == 3 + mock_scan_stop.assert_any_call(scan_id=101) + mock_scan_stop.assert_any_call(scan_id=102) + mock_scan_stop.assert_any_call(scan_id=103) + + +@mock.patch.object(local_runtime.LocalRuntime, "list") +def testOstorlabScanStopCLI_whenStopAllIsUsedAndNoScansExist_showsWarning( + mock_list_scans: mock.Mock, mocker: plugin.MockerFixture +) -> None: + """Test ostorlab scan stop command with --all flag. + Should show warning message when no scans are running. + """ + + mock_list_scans.return_value = [] + mocker.patch("ostorlab.runtimes.local.LocalRuntime.__init__", return_value=None) + runner = CliRunner() + + result = runner.invoke( + rootcli.rootcli, ["scan", "--runtime=local", "stop", "--all"] + ) + + assert result.exception is None + assert "No running scans found" in result.output + + +@mock.patch.object(local_runtime.LocalRuntime, "stop") +@mock.patch.object(local_runtime.LocalRuntime, "list") +def testOstorlabScanStopCLI_whenStopAllWithShorthandIsUsedAndScansExist_stopsAllScans( + mock_list_scans: mock.Mock, mock_scan_stop: mock.Mock, mocker: plugin.MockerFixture +) -> None: + """Test ostorlab scan stop command with --all flag. + Should stop all running scans. + """ + + mock_list_scans.return_value = [ + mock.Mock(id=101), + mock.Mock(id=102), + mock.Mock(id=103), + ] + mock_scan_stop.return_value = None + mocker.patch("ostorlab.runtimes.local.LocalRuntime.__init__", return_value=None) + runner = CliRunner() + + result = runner.invoke(rootcli.rootcli, ["scan", "--runtime=local", "stop", "-a"]) + + assert result.exception is None + assert "Stopping 3 scan(s)" in result.output + assert mock_scan_stop.call_count == 3 + mock_scan_stop.assert_any_call(scan_id=101) + mock_scan_stop.assert_any_call(scan_id=102) + mock_scan_stop.assert_any_call(scan_id=103)