Skip to content

Commit

Permalink
CLI override for Aerie host version check
Browse files Browse the repository at this point in the history
  • Loading branch information
cartermak committed Dec 3, 2024
1 parent 6a21b71 commit 6a3e8ae
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 13 deletions.
17 changes: 13 additions & 4 deletions src/aerie_cli/aerie_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
"2.18.0"
]

class AerieHostVersionError(RuntimeError):
pass


def process_gateway_response(resp: requests.Response) -> dict:
"""Throw a RuntimeError if the Gateway response is malformed or contains errors
Expand Down Expand Up @@ -261,9 +264,15 @@ def is_auth_enabled(self) -> bool:

return True

def authenticate(self, username: str, password: str = None):
def authenticate(self, username: str, password: str = None, override: bool = False):

self.check_aerie_version()
try:
self.check_aerie_version()
except AerieHostVersionError as e:
if override:
print("Warning: " + e.args[0])
else:
raise

resp = self.session.post(
self.gateway_url + "/auth/login",
Expand Down Expand Up @@ -295,13 +304,13 @@ def check_aerie_version(self) -> None:
except (RuntimeError, KeyError):
# If the Gateway responded, the route doesn't exist
if resp.text and "Aerie Gateway" in resp.text:
raise RuntimeError("Incompatible Aerie version: host version unknown")
raise AerieHostVersionError("Incompatible Aerie version: host version unknown")

# Otherwise, it could just be a failed connection
raise

if host_version not in COMPATIBLE_AERIE_VERSIONS:
raise RuntimeError(f"Incompatible Aerie version: {host_version}")
raise AerieHostVersionError(f"Incompatible Aerie version: {host_version}")


@define
Expand Down
5 changes: 3 additions & 2 deletions src/aerie_cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def activate_session(
),
role: str = typer.Option(
None, "--role", "-r", help="Specify a non-default role", metavar="ROLE"
)
),
override: bool = typer.Option(False, "--override", help="Override Aerie host version check")
):
"""
Activate a session with an Aerie host using a given configuration
Expand All @@ -102,7 +103,7 @@ def activate_session(

conf = PersistentConfigurationManager.get_configuration_by_name(name)

session = start_session_from_configuration(conf, username)
session = start_session_from_configuration(conf, username, override=override)

if role is not None:
if role in session.aerie_jwt.allowed_roles:
Expand Down
6 changes: 4 additions & 2 deletions src/aerie_cli/utils/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def start_session_from_configuration(
configuration: AerieHostConfiguration,
username: str = None,
password: str = None,
secret_post_vars: Dict[str, str] = None
secret_post_vars: Dict[str, str] = None,
override: bool = False
):
"""Start and authenticate an Aerie Host session, with prompts if necessary
Expand All @@ -136,6 +137,7 @@ def start_session_from_configuration(
username (str, optional): Aerie username.
password (str, optional): Aerie password.
secret_post_vars (Dict[str, str], optional): Optionally provide values for some or all secret post request variable values. Defaults to None.
override (bool, optional): Override Aerie host version check. Defaults to False.
Returns:
AerieHost:
Expand All @@ -162,6 +164,6 @@ def start_session_from_configuration(
if password is None and hs.is_auth_enabled():
password = typer.prompt("Aerie Password", hide_input=True)

hs.authenticate(username, password)
hs.authenticate(username, password, override)

return hs
46 changes: 41 additions & 5 deletions tests/unit_tests/test_aerie_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import pytest
import requests

from aerie_cli.aerie_host import AerieHost, COMPATIBLE_AERIE_VERSIONS
from aerie_cli.aerie_host import AerieHost, COMPATIBLE_AERIE_VERSIONS, AerieJWT


class MockJWT:
def __init__(self, *args, **kwargs):
self.default_role = 'viewer'

class MockResponse:
def __init__(self, json: Dict, text: str, ok: bool) -> None:
def __init__(self, json: Dict, text: str = None, ok: bool = True) -> None:
self.json_data = json
self.text = text
self.ok = ok
Expand Down Expand Up @@ -42,15 +46,47 @@ def test_check_aerie_version():
aerie_host.check_aerie_version()


def test_check_invalid_version():
aerie_host = get_mock_aerie_host(json={"version": "1.0.0"})
def test_authenticate_invalid_version(capsys, monkeypatch):
ah = AerieHost("", "")

def mock_get(*_, **__):
return MockResponse({"version": "1.0.0"})
def mock_post(*_, **__):
return MockResponse({"token": ""})
def mock_check_auth(*_, **__):
return True

monkeypatch.setattr(requests.Session, "get", mock_get)
monkeypatch.setattr(requests.Session, "post", mock_post)
monkeypatch.setattr(AerieHost, "check_auth", mock_check_auth)
monkeypatch.setattr(AerieJWT, "__init__", MockJWT.__init__)

with pytest.raises(RuntimeError) as e:
aerie_host.check_aerie_version()
ah.authenticate("")

assert "Incompatible Aerie version: 1.0.0" in str(e.value)


def test_authenticate_invalid_version_override(capsys, monkeypatch):
ah = AerieHost("", "")

def mock_get(*_, **__):
return MockResponse({"version": "1.0.0"})
def mock_post(*_, **__):
return MockResponse({"token": ""})
def mock_check_auth(*_, **__):
return True

monkeypatch.setattr(requests.Session, "get", mock_get)
monkeypatch.setattr(requests.Session, "post", mock_post)
monkeypatch.setattr(AerieHost, "check_auth", mock_check_auth)
monkeypatch.setattr(AerieJWT, "__init__", MockJWT.__init__)

ah.authenticate("", override=True)

assert capsys.readouterr().out == "Warning: Incompatible Aerie version: 1.0.0\n"


def test_no_version_endpoint():
aerie_host = get_mock_aerie_host(text="blah Aerie Gateway blah", ok=True)

Expand Down

0 comments on commit 6a3e8ae

Please sign in to comment.