Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add logout to funcx endpoint #909

Merged
merged 1 commit into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. A new scriv changelog fragment.
..
.. Uncomment the header that is right (remove the leading dots).
..
New Functionality
-----------------

- Added logout command for funcx-endpoint to revoke cached tokens
55 changes: 54 additions & 1 deletion funcx_endpoint/funcx_endpoint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from importlib.machinery import SourceFileLoader

import click
from click import ClickException

from funcx.sdk.login_manager import LoginManager

from .endpoint.endpoint import Endpoint
from .logging_config import setup_logging
Expand Down Expand Up @@ -169,6 +172,56 @@ def start_endpoint(*, name: str, endpoint_uuid: str | None):
)


@app.command(name="logout", help="Logout from all endpoints")
@click.option(
"--force",
is_flag=True,
help="Revokes tokens even with currently running endpoints",
)
def logout_endpoints(force: bool):
success, msg = _do_logout_endpoints(force=force)
if not success:
# Raising ClickException is apparently the way to do sys.exit(1)
# and return a non-zero value to the command line
# See https://click.palletsprojects.com/en/8.1.x/exceptions/
if not isinstance(msg, str) or msg is None:
# Generic unsuccessful if no reason was given
msg = "Logout unsuccessful"
raise ClickException(msg)


def _do_logout_endpoints(
force: bool, running_endpoints: dict | None = None
) -> tuple[bool, str | None]:
"""
Logout from all endpoints and remove cached authentication credentials

Returns True, None if logout was successful and tokens were found and revoked
Returns False, error_msg if token revocation was not done
"""

if running_endpoints is None:
running_endpoints = get_cli_endpoint().get_running_endpoints()
tokens_revoked = False
error_msg = None
if running_endpoints and not force:
running_list = ", ".join(running_endpoints.keys())
log.info(
"The following endpoints are currently running: "
+ running_list
+ "\nPlease use logout --force to proceed"
)
error_msg = "Not logging out with running endpoints without --force"
else:
tokens_revoked = LoginManager().logout()
if tokens_revoked:
log.info("Logout succeeded and all cached credentials were revoked")
else:
error_msg = "No cached tokens were found, already logged out?"
log.info(error_msg)
return tokens_revoked, error_msg


def _do_start_endpoint(
*,
name: str,
Expand Down Expand Up @@ -244,7 +297,7 @@ def restart_endpoint(*, name: str, endpoint_uuid: str | None):
def list_endpoints():
"""List all available endpoints"""
endpoint = get_cli_endpoint()
endpoint.list_endpoints()
endpoint.print_endpoint_table()


@app.command("delete")
Expand Down
56 changes: 46 additions & 10 deletions funcx_endpoint/funcx_endpoint/endpoint/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def init_endpoint(self):
except Exception as e:
print(f"[FuncX] Caught exception during registration {e}")

def check_endpoint_json(self, endpoint_json, endpoint_uuid):
@staticmethod
def check_endpoint_json(endpoint_json, endpoint_uuid):
if os.path.exists(endpoint_json):
with open(endpoint_json) as fp:
log.debug("Connection info loaded from prior registration record")
Expand Down Expand Up @@ -433,12 +434,27 @@ def pidfile_cleanup(self, filepath):
os.remove(filepath)
log.info(f"Endpoint <{self.name}> has been cleaned up.")

def list_endpoints(self):
table = texttable.Texttable()

headings = ["Endpoint Name", "Status", "Endpoint ID"]
table.header(headings)

def get_endpoints(self, status_filter=None):
"""
Gets a dictionary that contains information about all locally
known endpoints.

"status" can be one of:
["Running", "Disconnected", "Stopped"]

Example output:
{
"default": {
"status": "Running",
"id": "123abcde-a393-4456-8de5-123456789abc"
},
"my_test_ep": {
"status": "Disconnected",
"id": "xxxxxxxx-xxxx-1234-abcd-xxxxxxxxxxxx"
}
}
"""
endpoint_dict = {}
config_files = glob.glob(os.path.join(self.funcx_dir, "*", "config.py"))
for config_file in config_files:
endpoint_dir = os.path.dirname(config_file)
Expand All @@ -459,7 +475,27 @@ def list_endpoints(self):
else:
status = "Stopped"

table.add_row([endpoint_name, status, endpoint_id])
if status_filter is None or status_filter == status:
endpoint_dict[endpoint_name] = {
"status": status,
"id": endpoint_id,
}
return endpoint_dict

def get_running_endpoints(self):
return self.get_endpoints(status_filter="Running")

def print_endpoint_table(self):
"""
Converts locally configured endpoint list to a text based table
and prints the output.
For example format, see the texttable module
"""
endpoints = self.get_endpoints()
table = texttable.Texttable()
headings = ["Endpoint Name", "Status", "Endpoint ID"]
table.header(headings)

s = table.draw()
print(s)
for endpoint_name, endpoint_info in endpoints.items():
table.add_row([endpoint_name, endpoint_info["status"], endpoint_info["id"]])
print(table.draw())
2 changes: 1 addition & 1 deletion funcx_endpoint/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class FakeLoginManager:
def ensure_logged_in(self) -> None:
...

def logout(self) -> None:
def logout(self) -> bool:
...

def get_auth_client(self) -> globus_sdk.AuthClient:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from unittest.mock import Mock

import pytest
from click.testing import CliRunner

from funcx_endpoint.cli import app
import funcx.sdk.login_manager
from funcx_endpoint.cli import _do_logout_endpoints, app

runner = CliRunner()

Expand Down Expand Up @@ -30,3 +33,44 @@ def test_non_configured_endpoint(mocker):
result = runner.invoke(app, ["start", "newendpoint"])
assert "newendpoint" in result.stdout
assert "not configured" in result.stdout


def test_endpoint_logout(monkeypatch):
# not forced, and no running endpoints
logout_true = Mock(return_value=True)
logout_false = Mock(return_value=False)
monkeypatch.setattr(funcx.sdk.login_manager.LoginManager, "logout", logout_true)
success, msg = _do_logout_endpoints(
False,
running_endpoints={},
)
logout_true.assert_called_once()
assert success

logout_true.reset_mock()

# forced, and no running endpoints
success, msg = _do_logout_endpoints(
True,
running_endpoints={},
)
logout_true.assert_called_once()
assert success

one_running = {
"default": {"status": "Running", "id": "123abcde-a393-4456-8de5-123456789abc"}
}

monkeypatch.setattr(funcx.sdk.login_manager.LoginManager, "logout", logout_false)
# not forced, with running endpoint
success, msg = _do_logout_endpoints(False, running_endpoints=one_running)
logout_false.assert_not_called()
assert not success

logout_true.reset_mock()

monkeypatch.setattr(funcx.sdk.login_manager.LoginManager, "logout", logout_true)
# forced, with running endpoint
success, msg = _do_logout_endpoints(True, running_endpoints=one_running)
logout_true.assert_called_once()
assert success
14 changes: 10 additions & 4 deletions funcx_sdk/funcx/sdk/login_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,20 @@ def run_login_flow(

do_link_auth_flow(self._token_storage, scopes)

def logout(self) -> None:
def logout(self) -> bool:
"""
Returns True if at least one set of tokens were found and revoked.
"""
khk-globus marked this conversation as resolved.
Show resolved Hide resolved
auth_client = internal_auth_client()
for rs, tokendata in self._token_storage.get_by_resource_server().items():
tokens_revoked = False
for rs, token_data in self._token_storage.get_by_resource_server().items():
for tok_key in ("access_token", "refresh_token"):
token = tokendata[tok_key]
token = token_data[tok_key]
auth_client.oauth2_revoke_token(token)

self._token_storage.remove_tokens_for_resource_server(rs)
tokens_revoked = True

return tokens_revoked

def ensure_logged_in(self) -> None:
data = self._token_storage.get_by_resource_server()
Expand Down
2 changes: 1 addition & 1 deletion funcx_sdk/funcx/sdk/login_manager/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class LoginManagerProtocol(Protocol):
def ensure_logged_in(self) -> None:
...

def logout(self) -> None:
def logout(self) -> bool:
...

def get_auth_client(self) -> globus_sdk.AuthClient:
Expand Down