Skip to content

Commit

Permalink
switch to new attach/exec api (#2246)
Browse files Browse the repository at this point in the history
  • Loading branch information
zubenkoivan authored Aug 16, 2021
1 parent d056b04 commit c612e63
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 345 deletions.
237 changes: 96 additions & 141 deletions neuro-cli/src/neuro_cli/ael.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import codecs
import enum
import functools
import logging
import signal
import sys
Expand All @@ -23,10 +22,11 @@
from prompt_toolkit.shortcuts import PromptSession
from typing_extensions import NoReturn

from neuro_sdk import IllegalArgumentError, JobDescription, JobStatus, StdStream
from neuro_sdk import JobDescription, JobStatus, StdStream
from neuro_sdk.jobs import StdStreamError

from .const import EX_IOERR, EX_PLATFORMERROR
from .formatters.jobs import ExecStopProgress, JobStopProgress
from .formatters.jobs import JobStopProgress
from .root import Root
from .utils import AsyncExitStack

Expand Down Expand Up @@ -105,132 +105,100 @@ async def process_logs(
async def process_exec(
root: Root, job: str, cmd: str, tty: bool, *, cluster_name: Optional[str]
) -> NoReturn:
exec_id = await root.client.jobs.exec_create(
job, cmd, tty=tty, cluster_name=cluster_name
)
try:
if tty:
await _exec_tty(root, job, exec_id, cluster_name=cluster_name)
exit_code = await _exec_tty(root, job, cmd, cluster_name=cluster_name)
else:
await _exec_non_tty(root, job, exec_id, cluster_name=cluster_name)
exit_code = await _exec_non_tty(root, job, cmd, cluster_name=cluster_name)
finally:
root.soft_reset_tty()

info = await root.client.jobs.exec_inspect(job, exec_id, cluster_name=cluster_name)
with ExecStopProgress.create(
console=root.console, quiet=root.quiet, job_id=job
) as progress:
while info.running:
await asyncio.sleep(0.2)
info = await root.client.jobs.exec_inspect(
job, exec_id, cluster_name=cluster_name
)
if not progress(info.running):
sys.exit(EX_IOERR)
sys.exit(info.exit_code)
sys.exit(exit_code)


async def _exec_tty(
root: Root, job: str, exec_id: str, *, cluster_name: Optional[str]
) -> None:
root: Root, job: str, cmd: str, *, cluster_name: Optional[str]
) -> int:
loop = asyncio.get_event_loop()
helper = AttachHelper(quiet=True)

stdout = create_output()
h, w = stdout.get_size()

async with root.client.jobs.exec_start(
job, exec_id, cluster_name=cluster_name
async with root.client.jobs.exec(
job,
cmd,
tty=True,
stdin=True,
stdout=True,
stderr=False,
cluster_name=cluster_name,
) as stream:
try:
await root.client.jobs.exec_resize(
job, exec_id, w=w, h=h, cluster_name=cluster_name
)
except IllegalArgumentError:
pass
info = await root.client.jobs.exec_inspect(
job, exec_id, cluster_name=cluster_name
)
if not info.running:
# Exec session is finished
sys.exit(info.exit_code)

tasks = []
tasks.append(loop.create_task(_process_stdin_tty(stream, helper)))
tasks.append(
loop.create_task(_process_stdout_tty(root, stream, stdout, helper))
)
tasks.append(
loop.create_task(
_process_resizing(
functools.partial(
root.client.jobs.exec_resize,
job,
exec_id,
cluster_name=cluster_name,
),
stdout,
)
)
)
tasks.append(
loop.create_task(
_exec_watcher(root, job, exec_id, cluster_name=cluster_name)
)
status = await root.client.jobs.status(job)

if status.status is not JobStatus.RUNNING:
raise ValueError(f"Job {job!r} is not running")

await stream.resize(h=h, w=w)

resize_task = loop.create_task(_process_resizing(stream.resize, stdout))
input_task = loop.create_task(_process_stdin_tty(stream, helper))
output_task = loop.create_task(
_process_stdout_tty(root, stream, stdout, helper)
)

try:
tasks = [resize_task, input_task, output_task]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
for task in tasks:
await root.cancel_with_logging(task)
await root.cancel_with_logging(resize_task)
await root.cancel_with_logging(input_task)
return await _cancel_exec_output(root, output_task)


async def _exec_non_tty(
root: Root, job: str, exec_id: str, *, cluster_name: Optional[str]
) -> None:
root: Root, job: str, cmd: str, *, cluster_name: Optional[str]
) -> int:
loop = asyncio.get_event_loop()
helper = AttachHelper(quiet=True)

async with root.client.jobs.exec_start(
job, exec_id, cluster_name=cluster_name
async with root.client.jobs.exec(
job,
cmd,
tty=False,
stdin=True,
stdout=True,
stderr=True,
cluster_name=cluster_name,
) as stream:
info = await root.client.jobs.exec_inspect(
job, exec_id, cluster_name=cluster_name
)
if not info.running:
sys.exit(info.exit_code)
status = await root.client.jobs.status(job)

if status.status is not JobStatus.RUNNING:
raise ValueError(f"Job {job!r} is not running")

tasks = []
input_task = None
if root.tty:
tasks.append(loop.create_task(_process_stdin_non_tty(root, stream)))
tasks.append(loop.create_task(_process_stdout_non_tty(root, stream, helper)))
tasks.append(
loop.create_task(
_exec_watcher(root, job, exec_id, cluster_name=cluster_name)
)
)
input_task = loop.create_task(_process_stdin_non_tty(root, stream))
output_task = loop.create_task(_process_stdout_non_tty(root, stream, helper))

try:
tasks = [output_task]
if input_task:
tasks.append(input_task)
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
for task in tasks:
await root.cancel_with_logging(task)
if input_task:
await root.cancel_with_logging(input_task)
return await _cancel_exec_output(root, output_task)


async def _exec_watcher(
root: Root, job: str, exec_id: str, *, cluster_name: Optional[str]
) -> None:
while True:
try:
info = await root.client.jobs.exec_inspect(
job, exec_id, cluster_name=cluster_name
)
except Exception:
pass
else:
if not info.running:
return
await asyncio.sleep(5)
async def _cancel_exec_output(root: Root, output_task: "asyncio.Task[Any]") -> int:
if output_task.done():
ex = output_task.exception()
if ex and isinstance(ex, StdStreamError):
return ex.exit_code
await root.cancel_with_logging(output_task)
return EX_PLATFORMERROR


class RetryAttach(Exception):
Expand Down Expand Up @@ -372,16 +340,10 @@ async def _attach_tty(
logs_printer = loop.create_task(asyncio.sleep(0))

async with root.client.jobs.attach(
job, stdin=True, stdout=True, stderr=True, logs=True, cluster_name=cluster_name
job, tty=True, stdin=True, stdout=True, stderr=False, cluster_name=cluster_name
) as stream:
try:
await root.client.jobs.resize(job, w=w, h=h, cluster_name=cluster_name)
except IllegalArgumentError:
# Job may be finished at this moment.
# Need to check job's status and print logs
# for finished job
pass
status = await root.client.jobs.status(job)

if status.status is not JobStatus.RUNNING:
# Job is finished
await logs_printer
Expand All @@ -390,42 +352,24 @@ async def _attach_tty(
else:
sys.exit(status.history.exit_code)

tasks = []
tasks.append(loop.create_task(_process_stdin_tty(stream, helper)))
tasks.append(
loop.create_task(_process_stdout_tty(root, stream, stdout, helper))
)
tasks.append(
loop.create_task(
_process_resizing(
functools.partial(
root.client.jobs.resize, job, cluster_name=cluster_name
),
stdout,
)
)
await stream.resize(h=h, w=w)

resize_task = loop.create_task(_process_resizing(stream.resize, stdout))
input_task = loop.create_task(_process_stdin_tty(stream, helper))
output_task = loop.create_task(
_process_stdout_tty(root, stream, stdout, helper)
)
tasks.append(loop.create_task(_attach_watcher(root, job)))

try:
tasks = [resize_task, input_task, output_task]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
for task in tasks:
await root.cancel_with_logging(task)

await root.cancel_with_logging(resize_task)
await root.cancel_with_logging(input_task)
await _cancel_attach_output(root, output_task)
await root.cancel_with_logging(logs_printer)
return helper.action


async def _attach_watcher(root: Root, job: str) -> None:
while True:
try:
status = await root.client.jobs.status(job)
except Exception:
pass
else:
if status.status != JobStatus.RUNNING:
return
await asyncio.sleep(5)
return helper.action


async def _process_resizing(
Expand Down Expand Up @@ -565,27 +509,30 @@ async def _attach_non_tty(
logs_printer = loop.create_task(asyncio.sleep(0))

async with root.client.jobs.attach(
job, stdin=True, stdout=True, stderr=True, logs=True, cluster_name=cluster_name
job, stdin=True, stdout=True, stderr=True, cluster_name=cluster_name
) as stream:
status = await root.client.jobs.status(job)
if status.history.exit_code is not None:
# Wait for logs printing finish before exit
await logs_printer
sys.exit(status.history.exit_code)

tasks = []
input_task = None
if root.tty:
tasks.append(loop.create_task(_process_stdin_non_tty(root, stream)))
tasks.append(loop.create_task(_process_stdout_non_tty(root, stream, helper)))
tasks.append(loop.create_task(_process_ctrl_c(root, job, helper)))
tasks.append(loop.create_task(_attach_watcher(root, job)))
input_task = loop.create_task(_process_stdin_non_tty(root, stream))
output_task = loop.create_task(_process_stdout_non_tty(root, stream, helper))
ctrl_c_task = loop.create_task(_process_ctrl_c(root, job, helper))

try:
tasks = [output_task, ctrl_c_task]
if input_task:
tasks.append(input_task)
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
for task in tasks:
await root.cancel_with_logging(task)

if input_task:
await root.cancel_with_logging(input_task)
await _cancel_attach_output(root, output_task)
await root.cancel_with_logging(ctrl_c_task)
await root.cancel_with_logging(logs_printer)
return helper.action

Expand Down Expand Up @@ -734,3 +681,11 @@ def on_signal(signum: int, frame: Any) -> None:
return
finally:
signal.signal(signal.SIGINT, prev_signal)


async def _cancel_attach_output(root: Root, output_task: "asyncio.Task[Any]") -> None:
if output_task.done():
ex = output_task.exception()
if ex and isinstance(ex, StdStreamError):
return
await root.cancel_with_logging(output_task)
Loading

0 comments on commit c612e63

Please sign in to comment.