From 6a3e8aef84c9879125ea4629569fefdc44a0da4a Mon Sep 17 00:00:00 2001 From: Carter Mak Date: Mon, 2 Dec 2024 18:30:08 -0800 Subject: [PATCH] CLI override for Aerie host version check --- src/aerie_cli/aerie_host.py | 17 ++++++++--- src/aerie_cli/app.py | 5 ++-- src/aerie_cli/utils/sessions.py | 6 ++-- tests/unit_tests/test_aerie_host.py | 46 +++++++++++++++++++++++++---- 4 files changed, 61 insertions(+), 13 deletions(-) diff --git a/src/aerie_cli/aerie_host.py b/src/aerie_cli/aerie_host.py index db399df3..e9409939 100644 --- a/src/aerie_cli/aerie_host.py +++ b/src/aerie_cli/aerie_host.py @@ -11,6 +11,9 @@ "2.18.0" ] +class AerieHostVersionError(RuntimeError): + pass + def process_gateway_response(resp: requests.Response) -> dict: """Throw a RuntimeError if the Gateway response is malformed or contains errors @@ -261,9 +264,15 @@ def is_auth_enabled(self) -> bool: return True - def authenticate(self, username: str, password: str = None): + def authenticate(self, username: str, password: str = None, override: bool = False): - self.check_aerie_version() + try: + self.check_aerie_version() + except AerieHostVersionError as e: + if override: + print("Warning: " + e.args[0]) + else: + raise resp = self.session.post( self.gateway_url + "/auth/login", @@ -295,13 +304,13 @@ def check_aerie_version(self) -> None: except (RuntimeError, KeyError): # If the Gateway responded, the route doesn't exist if resp.text and "Aerie Gateway" in resp.text: - raise RuntimeError("Incompatible Aerie version: host version unknown") + raise AerieHostVersionError("Incompatible Aerie version: host version unknown") # Otherwise, it could just be a failed connection raise if host_version not in COMPATIBLE_AERIE_VERSIONS: - raise RuntimeError(f"Incompatible Aerie version: {host_version}") + raise AerieHostVersionError(f"Incompatible Aerie version: {host_version}") @define diff --git a/src/aerie_cli/app.py b/src/aerie_cli/app.py index 6d188d18..0e24d60a 100644 --- a/src/aerie_cli/app.py +++ b/src/aerie_cli/app.py @@ -90,7 +90,8 @@ def activate_session( ), role: str = typer.Option( None, "--role", "-r", help="Specify a non-default role", metavar="ROLE" - ) + ), + override: bool = typer.Option(False, "--override", help="Override Aerie host version check") ): """ Activate a session with an Aerie host using a given configuration @@ -102,7 +103,7 @@ def activate_session( conf = PersistentConfigurationManager.get_configuration_by_name(name) - session = start_session_from_configuration(conf, username) + session = start_session_from_configuration(conf, username, override=override) if role is not None: if role in session.aerie_jwt.allowed_roles: diff --git a/src/aerie_cli/utils/sessions.py b/src/aerie_cli/utils/sessions.py index 6efd41f2..b1e90700 100644 --- a/src/aerie_cli/utils/sessions.py +++ b/src/aerie_cli/utils/sessions.py @@ -118,7 +118,8 @@ def start_session_from_configuration( configuration: AerieHostConfiguration, username: str = None, password: str = None, - secret_post_vars: Dict[str, str] = None + secret_post_vars: Dict[str, str] = None, + override: bool = False ): """Start and authenticate an Aerie Host session, with prompts if necessary @@ -136,6 +137,7 @@ def start_session_from_configuration( username (str, optional): Aerie username. password (str, optional): Aerie password. secret_post_vars (Dict[str, str], optional): Optionally provide values for some or all secret post request variable values. Defaults to None. + override (bool, optional): Override Aerie host version check. Defaults to False. Returns: AerieHost: @@ -162,6 +164,6 @@ def start_session_from_configuration( if password is None and hs.is_auth_enabled(): password = typer.prompt("Aerie Password", hide_input=True) - hs.authenticate(username, password) + hs.authenticate(username, password, override) return hs diff --git a/tests/unit_tests/test_aerie_host.py b/tests/unit_tests/test_aerie_host.py index 8fe7d2c1..fb28ed5b 100644 --- a/tests/unit_tests/test_aerie_host.py +++ b/tests/unit_tests/test_aerie_host.py @@ -2,11 +2,15 @@ import pytest import requests -from aerie_cli.aerie_host import AerieHost, COMPATIBLE_AERIE_VERSIONS +from aerie_cli.aerie_host import AerieHost, COMPATIBLE_AERIE_VERSIONS, AerieJWT +class MockJWT: + def __init__(self, *args, **kwargs): + self.default_role = 'viewer' + class MockResponse: - def __init__(self, json: Dict, text: str, ok: bool) -> None: + def __init__(self, json: Dict, text: str = None, ok: bool = True) -> None: self.json_data = json self.text = text self.ok = ok @@ -42,15 +46,47 @@ def test_check_aerie_version(): aerie_host.check_aerie_version() -def test_check_invalid_version(): - aerie_host = get_mock_aerie_host(json={"version": "1.0.0"}) +def test_authenticate_invalid_version(capsys, monkeypatch): + ah = AerieHost("", "") + + def mock_get(*_, **__): + return MockResponse({"version": "1.0.0"}) + def mock_post(*_, **__): + return MockResponse({"token": ""}) + def mock_check_auth(*_, **__): + return True + + monkeypatch.setattr(requests.Session, "get", mock_get) + monkeypatch.setattr(requests.Session, "post", mock_post) + monkeypatch.setattr(AerieHost, "check_auth", mock_check_auth) + monkeypatch.setattr(AerieJWT, "__init__", MockJWT.__init__) with pytest.raises(RuntimeError) as e: - aerie_host.check_aerie_version() + ah.authenticate("") assert "Incompatible Aerie version: 1.0.0" in str(e.value) +def test_authenticate_invalid_version_override(capsys, monkeypatch): + ah = AerieHost("", "") + + def mock_get(*_, **__): + return MockResponse({"version": "1.0.0"}) + def mock_post(*_, **__): + return MockResponse({"token": ""}) + def mock_check_auth(*_, **__): + return True + + monkeypatch.setattr(requests.Session, "get", mock_get) + monkeypatch.setattr(requests.Session, "post", mock_post) + monkeypatch.setattr(AerieHost, "check_auth", mock_check_auth) + monkeypatch.setattr(AerieJWT, "__init__", MockJWT.__init__) + + ah.authenticate("", override=True) + + assert capsys.readouterr().out == "Warning: Incompatible Aerie version: 1.0.0\n" + + def test_no_version_endpoint(): aerie_host = get_mock_aerie_host(text="blah Aerie Gateway blah", ok=True)