Skip to content

Commit

Permalink
add logout to funcx endpoint
Browse files Browse the repository at this point in the history
modify erorr message

lint formatting
  • Loading branch information
LeiGlobus committed Aug 31, 2022
1 parent 3e23cdc commit 824eaac
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. A new scriv changelog fragment.
..
.. Uncomment the header that is right (remove the leading dots).
..
New Functionality
-----------------

- Added logout command for funcx-endpoint to revoke cached tokens
30 changes: 29 additions & 1 deletion funcx_endpoint/funcx_endpoint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import click

from funcx.sdk.login_manager.manager import LoginManager

from .endpoint.endpoint import Endpoint
from .logging_config import setup_logging

Expand Down Expand Up @@ -169,6 +171,32 @@ def start_endpoint(*, name: str, endpoint_uuid: str | None):
)


@app.command(name="logout", help="Logout from all endpoints")
@click.option(
"--force/--not-forced",
default=False,
help="--force revokes tokens even with currently running endpoints",
)
@common_options
def logout_endpoint(force: bool):
"""
Logout from all endpoints and remove cached authentication credentials
"""

running_endpoints = get_cli_endpoint().get_running_endpoints()
if running_endpoints and not force:
log.info(
"At least one endpoint is currently running.\n"
+ "Use the --force flag to proceed with logout"
)
else:
tokens_revoked = LoginManager().logout()
if tokens_revoked:
log.info("Logout succeeded and all cached credentials were revoked")
else:
log.info("No cached tokens were found, already logged out?")


def _do_start_endpoint(
*,
name: str,
Expand Down Expand Up @@ -244,7 +272,7 @@ def restart_endpoint(*, name: str, endpoint_uuid: str | None):
def list_endpoints():
"""List all available endpoints"""
endpoint = get_cli_endpoint()
endpoint.list_endpoints()
endpoint.print_endpoint_table()


@app.command("delete")
Expand Down
56 changes: 46 additions & 10 deletions funcx_endpoint/funcx_endpoint/endpoint/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def init_endpoint(self):
except Exception as e:
print(f"[FuncX] Caught exception during registration {e}")

def check_endpoint_json(self, endpoint_json, endpoint_uuid):
@staticmethod
def check_endpoint_json(endpoint_json, endpoint_uuid):
if os.path.exists(endpoint_json):
with open(endpoint_json) as fp:
log.debug("Connection info loaded from prior registration record")
Expand Down Expand Up @@ -433,12 +434,27 @@ def pidfile_cleanup(self, filepath):
os.remove(filepath)
log.info(f"Endpoint <{self.name}> has been cleaned up.")

def list_endpoints(self):
table = texttable.Texttable()

headings = ["Endpoint Name", "Status", "Endpoint ID"]
table.header(headings)

def get_endpoints(self, status_filter=None):
"""
Gets a dictionary that contains information about all locally
known endpoints.
"status" can be one of:
["Running", "Disconnected", "Stopped"]
Example output:
{
"default": {
"status": "Running",
"id": "123abcde-a393-4456-8de5-123456789abc"
},
"my_test_ep": {
"status": "Disconnected",
"id": "xxxxxxxx-xxxx-1234-abcd-xxxxxxxxxxxx"
}
}
"""
endpoint_dict = {}
config_files = glob.glob(os.path.join(self.funcx_dir, "*", "config.py"))
for config_file in config_files:
endpoint_dir = os.path.dirname(config_file)
Expand All @@ -459,7 +475,27 @@ def list_endpoints(self):
else:
status = "Stopped"

table.add_row([endpoint_name, status, endpoint_id])
if status_filter is None or status_filter == status:
endpoint_dict[endpoint_name] = {
"status": status,
"id": endpoint_id,
}
return endpoint_dict

def get_running_endpoints(self):
return self.get_endpoints(status_filter="Running")

def print_endpoint_table(self):
"""
Converts locally configured endpoint list to a text based table
and prints the output.
For example format, see the texttable module
"""
endpoints = self.get_endpoints()
table = texttable.Texttable()
headings = ["Endpoint Name", "Status", "Endpoint ID"]
table.header(headings)

s = table.draw()
print(s)
for endpoint_name, endpoint_info in endpoints.items():
table.add_row([endpoint_name, endpoint_info["status"], endpoint_info["id"]])
print(table.draw())
2 changes: 1 addition & 1 deletion funcx_endpoint/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ class FakeLoginManager:
def ensure_logged_in(self) -> None:
...

def logout(self) -> None:
def logout(self) -> bool:
...

def get_auth_client(self) -> globus_sdk.AuthClient:
Expand Down
14 changes: 10 additions & 4 deletions funcx_sdk/funcx/sdk/login_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,20 @@ def run_login_flow(

do_link_auth_flow(self._token_storage, scopes)

def logout(self) -> None:
def logout(self) -> bool:
"""
Returns True if at least one set of tokens were found and revoked.
"""
auth_client = internal_auth_client()
for rs, tokendata in self._token_storage.get_by_resource_server().items():
tokens_revoked = False
for rs, token_data in self._token_storage.get_by_resource_server().items():
for tok_key in ("access_token", "refresh_token"):
token = tokendata[tok_key]
token = token_data[tok_key]
auth_client.oauth2_revoke_token(token)

self._token_storage.remove_tokens_for_resource_server(rs)
tokens_revoked = True

return tokens_revoked

def ensure_logged_in(self) -> None:
data = self._token_storage.get_by_resource_server()
Expand Down
2 changes: 1 addition & 1 deletion funcx_sdk/funcx/sdk/login_manager/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class LoginManagerProtocol(Protocol):
def ensure_logged_in(self) -> None:
...

def logout(self) -> None:
def logout(self) -> bool:
...

def get_auth_client(self) -> globus_sdk.AuthClient:
Expand Down

0 comments on commit 824eaac

Please sign in to comment.