Skip to content

Commit

Permalink
feat(cli): session pause and resume (#3633)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-alisafaee authored Oct 10, 2023
1 parent b31ade0 commit f4b6480
Show file tree
Hide file tree
Showing 13 changed files with 378 additions and 35 deletions.
14 changes: 14 additions & 0 deletions docs/_static/cheatsheet/cheatsheet.json
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,20 @@
"rp"
]
},
{
"command": "$ renku session pause <name>",
"description": "Pause the specified session.",
"target": [
"rp"
]
},
{
"command": "$ renku session resume <name>",
"description": "Resume the specified paused session.",
"target": [
"rp"
]
},
{
"command": "$ renku session stop <name>",
"description": "Stop the specified session.",
Expand Down
Binary file modified docs/_static/cheatsheet/cheatsheet.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/cheatsheet_hash
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ad86ac1d0614ccb692c96e893db4d20d cheatsheet.tex
5316163d742bdb6792ed8bcb35031f6c cheatsheet.tex
c70c179e07f04186ec05497564165f11 sdsc_cheatsheet.cls
2 changes: 1 addition & 1 deletion docs/cheatsheet_json_hash
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1ac51267cefdf4976c29c9d7657063b8 cheatsheet.json
1856fb451165d013777c7c4cdd56e575 cheatsheet.json
22 changes: 20 additions & 2 deletions renku/command/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

from renku.command.command_builder.command import Command
from renku.core.session.session import (
search_hibernating_session_providers,
search_session_providers,
search_sessions,
session_list,
session_open,
session_pause,
session_resume,
session_start,
session_stop,
ssh_setup,
Expand All @@ -37,6 +40,11 @@ def search_session_providers_command():
return Command().command(search_session_providers).require_migration().with_database(write=False)


def search_hibernating_session_providers_command():
"""Get all the session provider names that support hibernation and match a pattern."""
return Command().command(search_hibernating_session_providers).require_migration().with_database(write=False)


def session_list_command():
"""List all the running interactive sessions."""
return Command().command(session_list).with_database(write=False)
Expand All @@ -49,14 +57,24 @@ def session_start_command():

def session_stop_command():
"""Stop a running an interactive session."""
return Command().command(session_stop)
return Command().command(session_stop).with_database(write=False)


def session_open_command():
"""Open a running interactive session."""
return Command().command(session_open)
return Command().command(session_open).with_database(write=False)


def ssh_setup_command():
"""Setup SSH keys for SSH connections to sessions."""
return Command().command(ssh_setup)


def session_pause_command():
"""Pause a running interactive session."""
return Command().command(session_pause).with_database(write=False)


def session_resume_command():
"""Resume a paused session."""
return Command().command(session_resume).with_database(write=False)
8 changes: 7 additions & 1 deletion renku/core/plugin/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pluggy

from renku.domain_model.session import ISessionProvider
from renku.domain_model.session import IHibernatingSessionProvider, ISessionProvider

hookspec = pluggy.HookspecMarker("renku")

Expand All @@ -41,3 +41,9 @@ def get_supported_session_providers() -> List[ISessionProvider]:
providers = pm.hook.session_provider()

return sorted(providers, key=lambda p: p.priority)


def get_supported_hibernating_session_providers() -> List[IHibernatingSessionProvider]:
"""Returns the currently available interactive session providers that support hibernation."""
providers = get_supported_session_providers()
return [p for p in providers if isinstance(p, IHibernatingSessionProvider)]
82 changes: 74 additions & 8 deletions renku/core/session/renkulab.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
from renku.core.util.jwt import is_token_expired
from renku.core.util.ssh import SystemSSHConfig
from renku.domain_model.project_context import project_context
from renku.domain_model.session import ISessionProvider, Session, SessionStopStatus
from renku.domain_model.session import IHibernatingSessionProvider, Session, SessionStopStatus

if TYPE_CHECKING:
from renku.core.dataset.providers.models import ProviderParameter


class RenkulabSessionProvider(ISessionProvider):
class RenkulabSessionProvider(IHibernatingSessionProvider):
"""A session provider that uses the notebook service API to launch sessions."""

DEFAULT_TIMEOUT_SECONDS = 300
Expand Down Expand Up @@ -118,7 +118,7 @@ def _wait_for_session_status(
)
if res.status_code == 404 and status == "stopping":
return
if res.status_code == 200 and status != "stopping":
if res.status_code in [200, 204] and status != "stopping":
if res.json().get("status", {}).get("state") == status:
return
sleep(5)
Expand Down Expand Up @@ -210,9 +210,9 @@ def _remote_head_hexsha():

return remote.head

def _send_renku_request(self, req_type: str, *args, **kwargs):
res = getattr(requests, req_type)(*args, **kwargs)
if res.status_code == 401:
def _send_renku_request(self, verb: str, *args, **kwargs):
response = getattr(requests, verb)(*args, **kwargs)
if response.status_code == 401:
# NOTE: Check if logged in to KC but not the Renku UI
token = read_renku_token(endpoint=self._renku_url())
if token and not is_token_expired(token):
Expand All @@ -222,7 +222,7 @@ def _send_renku_request(self, req_type: str, *args, **kwargs):
raise errors.AuthenticationError(
"Please run the renku login command to authenticate with Renku or to refresh your expired credentials."
)
return res
return response

@staticmethod
def _project_name_from_full_project_name(project_name: str) -> str:
Expand Down Expand Up @@ -262,7 +262,7 @@ def find_image(self, image_name: str, config: Optional[Dict[str, Any]]) -> bool:
)

@hookimpl
def session_provider(self) -> ISessionProvider:
def session_provider(self) -> IHibernatingSessionProvider:
"""Supported session provider.
Returns:
Expand Down Expand Up @@ -511,3 +511,69 @@ def session_url(self, session_name: str) -> str:
def force_build_image(self, **kwargs) -> bool:
"""Whether we should force build the image directly or check for an existing image first."""
return self._force_build

def session_pause(self, project_name: str, session_name: Optional[str], **_) -> SessionStopStatus:
"""Pause all sessions (for the given project) or a specific interactive session."""

def pause(session_name: str):
result = self._send_renku_request(
"patch",
f"{self._notebooks_url()}/servers/{session_name}",
headers=self._auth_header(),
json={"state": "hibernated"},
)

self._wait_for_session_status(session_name, "hibernated")

return result

sessions = self.session_list(project_name=project_name)
n_sessions = len(sessions)

if n_sessions == 0:
return SessionStopStatus.NO_ACTIVE_SESSION

if session_name:
response = pause(session_name)
elif n_sessions == 1:
response = pause(sessions[0].name)
else:
return SessionStopStatus.NAME_NEEDED

return SessionStopStatus.SUCCESSFUL if response.status_code == 204 else SessionStopStatus.FAILED

def session_resume(self, project_name: str, session_name: Optional[str], **kwargs) -> bool:
"""Resume a paused session.
Args:
project_name(str): Renku project name.
session_name(Optional[str]): The unique id of the interactive session.
"""
sessions = self.session_list(project_name="")
system_config = SystemSSHConfig()
name = self._project_name_from_full_project_name(project_name)
ssh_prefix = f"{system_config.renku_host}-{name}-"

if not session_name:
if len(sessions) == 1:
session_name = sessions[0].name
else:
return False
else:
if session_name.startswith(ssh_prefix):
# NOTE: User passed in ssh connection name instead of session id by accident
session_name = session_name.replace(ssh_prefix, "", 1)

if not any(s.name == session_name for s in sessions):
return False

self._send_renku_request(
"patch",
f"{self._notebooks_url()}/servers/{session_name}",
headers=self._auth_header(),
json={"state": "running"},
)

self._wait_for_session_status(session_name, "running")

return True
111 changes: 109 additions & 2 deletions renku/core/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@

from renku.core import errors
from renku.core.config import get_value
from renku.core.plugin.session import get_supported_session_providers
from renku.core.plugin.session import get_supported_hibernating_session_providers, get_supported_session_providers
from renku.core.session.utils import get_image_repository_host, get_renku_project_name
from renku.core.util import communication
from renku.core.util.os import safe_read_yaml
from renku.core.util.ssh import SystemSSHConfig, generate_ssh_keys
from renku.domain_model.session import ISessionProvider, Session, SessionStopStatus
from renku.domain_model.session import IHibernatingSessionProvider, ISessionProvider, Session, SessionStopStatus


def _safe_get_provider(provider: str) -> ISessionProvider:
Expand Down Expand Up @@ -80,6 +80,22 @@ def search_session_providers(name: str) -> List[str]:
return [p.name for p in get_supported_session_providers() if p.name.lower().startswith(name)]


@validate_arguments(config=dict(arbitrary_types_allowed=True))
def search_hibernating_session_providers(name: str) -> List[str]:
"""Get all session providers that support hibernation and their name starts with the given name.
Args:
name(str): The name to search for.
Returns:
All session providers whose name starts with ``name``.
"""
from renku.core.plugin.session import get_supported_hibernating_session_providers

name = name.lower()
return [p.name for p in get_supported_hibernating_session_providers() if p.name.lower().startswith(name)]


@validate_arguments(config=dict(arbitrary_types_allowed=True))
def session_list(*, provider: Optional[str] = None) -> SessionList:
"""List interactive sessions.
Expand Down Expand Up @@ -358,3 +374,94 @@ def ssh_setup(existing_key: Optional[Path] = None, force: bool = False):
"This command does not add any public SSH keys to your project. "
"Keys have to be added manually or by using the 'renku session start' command with the '--ssh' flag."
)


@validate_arguments(config=dict(arbitrary_types_allowed=True))
def session_pause(session_name: Optional[str], provider: Optional[str] = None, **kwargs):
"""Pause an interactive session.
Args:
session_name(Optional[str]): Name of the session.
provider(Optional[str]): Name of the session provider to use.
"""

def pause(session_provider: IHibernatingSessionProvider) -> SessionStopStatus:
try:
return session_provider.session_pause(project_name=project_name, session_name=session_name)
except errors.RenkulabSessionGetUrlError:
if provider:
raise
return SessionStopStatus.FAILED

project_name = get_renku_project_name()

if provider:
session_provider = _safe_get_provider(provider)
if session_provider is None:
raise errors.ParameterError(f"Provider '{provider}' not found")
elif not isinstance(session_provider, IHibernatingSessionProvider):
raise errors.ParameterError(f"Provider '{provider}' doesn't support pausing sessions")
providers = [session_provider]
else:
providers = get_supported_hibernating_session_providers()

session_message = f"session {session_name}" if session_name else "session"
statues = []
warning_messages = []
with communication.busy(msg=f"Waiting for {session_message} to pause..."):
for session_provider in sorted(providers, key=lambda p: p.priority):
try:
status = pause(session_provider) # type: ignore
except errors.RenkuException as e:
warning_messages.append(f"Cannot pause sessions in provider '{session_provider.name}': {e}")
else:
statues.append(status)

# NOTE: The given session name was stopped; don't continue
if session_name and status == SessionStopStatus.SUCCESSFUL:
break

if warning_messages:
for message in warning_messages:
communication.warn(message)

if not statues:
return
elif all(s == SessionStopStatus.NO_ACTIVE_SESSION for s in statues):
raise errors.ParameterError("There are no running sessions.")
elif session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues):
raise errors.ParameterError(f"Could not find '{session_name}' among the running sessions.")
elif not session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues):
raise errors.ParameterError("Session name is missing")


@validate_arguments(config=dict(arbitrary_types_allowed=True))
def session_resume(session_name: Optional[str], provider: Optional[str] = None, **kwargs):
"""Resume a paused session.
Args:
session_name(Optional[str]): Name of the session.
provider(Optional[str]): Name of the session provider to use.
"""
project_name = get_renku_project_name()

if provider:
session_provider = _safe_get_provider(provider)
if session_provider is None:
raise errors.ParameterError(f"Provider '{provider}' not found")
elif not isinstance(session_provider, IHibernatingSessionProvider):
raise errors.ParameterError(f"Provider '{provider}' doesn't support pausing/resuming sessions")
providers = [session_provider]
else:
providers = get_supported_hibernating_session_providers()

session_message = f"session {session_name}" if session_name else "session"
with communication.busy(msg=f"Waiting for {session_message} to resume..."):
for session_provider in providers:
if session_provider.session_resume(project_name, session_name, **kwargs): # type: ignore
return

if session_name:
raise errors.ParameterError(f"Could not find '{session_name}' among the sessions.")
else:
raise errors.ParameterError("Session name is missing")
5 changes: 5 additions & 0 deletions renku/core/util/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def put(url, *, data=None, files=None, headers=None, params=None):
return _request("put", url=url, data=data, files=files, headers=headers, params=params)


def patch(url, *, json=None, files=None, headers=None, params=None):
"""Send a PATCH request."""
return _request("patch", url=url, json=json, files=files, headers=headers, params=params)


def _request(verb: str, url: str, *, allow_redirects=True, data=None, files=None, headers=None, json=None, params=None):
try:
with _retry() as session:
Expand Down
Loading

0 comments on commit f4b6480

Please sign in to comment.