Skip to content

Commit

Permalink
feat: add queue position to the logs for tasks and jobs (#821)
Browse files Browse the repository at this point in the history
* feat: add logging for queue depth
  • Loading branch information
AbeCoull authored Jan 22, 2024
1 parent 4b5f690 commit 7df9aa7
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 7 deletions.
19 changes: 15 additions & 4 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def create(
aws_session: AwsSession | None = None,
tags: dict[str, str] | None = None,
logger: Logger = getLogger(__name__),
quiet: bool = False,
reservation_arn: str | None = None,
) -> AwsQuantumJob:
"""Creates a hybrid job by invoking the Braket CreateJob API.
Expand Down Expand Up @@ -176,6 +177,9 @@ def create(
while waiting for quantum task to be in a terminal state. Default is
`getLogger(__name__)`
quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
reservation_arn (str | None): the reservation window arn provided by Braket
Direct to reserve exclusive usage for the device to run the hybrid job on.
Default: None.
Expand Down Expand Up @@ -210,23 +214,26 @@ def create(
)

job_arn = aws_session.create_job(**create_job_kwargs)
job = AwsQuantumJob(job_arn, aws_session)
job = AwsQuantumJob(job_arn, aws_session, quiet)

if wait_until_complete:
print(f"Initializing Braket Job: {job_arn}")
job.logs(wait=True)

return job

def __init__(self, arn: str, aws_session: AwsSession | None = None):
def __init__(self, arn: str, aws_session: AwsSession | None = None, quiet: bool = False):
"""
Args:
arn (str): The ARN of the hybrid job.
aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services.
Default is `None`, in which case an `AwsSession` object will be created with the
region of the hybrid job.
quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
"""
self._arn: str = arn
self._quiet = quiet
if aws_session:
if not self._is_valid_aws_session_region_for_job_arn(aws_session, arn):
raise ValueError(
Expand Down Expand Up @@ -371,10 +378,11 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"]
has_streams = False
color_wrap = logs.ColorWrap()
previous_state = self.state()

while True:
time.sleep(poll_interval_seconds)

current_state = self.state()
has_streams = logs.flush_log_streams(
self._aws_session,
log_group,
Expand All @@ -384,14 +392,17 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
instance_count,
has_streams,
color_wrap,
[previous_state, current_state],
self.queue_position().queue_position if not self._quiet else None,
)
previous_state = current_state

if log_state == AwsQuantumJob.LogState.COMPLETE:
break

if log_state == AwsQuantumJob.LogState.JOB_COMPLETE:
log_state = AwsQuantumJob.LogState.COMPLETE
elif self.state() in AwsQuantumJob.TERMINAL_STATES:
elif current_state in AwsQuantumJob.TERMINAL_STATES:
log_state = AwsQuantumJob.LogState.JOB_COMPLETE

def metadata(self, use_cached_value: bool = False) -> dict[str, Any]:
Expand Down
14 changes: 14 additions & 0 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def create(
tags: dict[str, str] | None = None,
inputs: dict[str, float] | None = None,
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None,
quiet: bool = False,
reservation_arn: str | None = None,
*args,
**kwargs,
Expand Down Expand Up @@ -152,6 +153,9 @@ def create(
a `PulseSequence`.
Default: None.
quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
reservation_arn (str | None): The reservation ARN provided by Braket Direct
to reserve exclusive usage for the device to run the quantum task on.
Note: If you are creating tasks in a job that itself was created reservation ARN,
Expand Down Expand Up @@ -215,6 +219,7 @@ def create(
disable_qubit_rewiring,
inputs,
gate_definitions=gate_definitions,
quiet=quiet,
*args,
**kwargs,
)
Expand All @@ -226,6 +231,7 @@ def __init__(
poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL,
logger: Logger = getLogger(__name__),
quiet: bool = False,
):
"""
Args:
Expand All @@ -238,6 +244,8 @@ def __init__(
logger (Logger): Logger object with which to write logs, such as quantum task statuses
while waiting for quantum task to be in a terminal state. Default is
`getLogger(__name__)`
quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
Examples:
>>> task = AwsQuantumTask(arn='task_arn')
Expand All @@ -259,6 +267,7 @@ def __init__(
self._poll_interval_seconds = poll_interval_seconds

self._logger = logger
self._quiet = quiet

self._metadata: dict[str, Any] = {}
self._result: Union[
Expand Down Expand Up @@ -477,6 +486,11 @@ async def _wait_for_completion(
while (time.time() - start_time) < self._poll_timeout_seconds:
# Used cached metadata if cached status is terminal
task_status = self._update_status_if_nonterminal()
if not self._quiet and task_status == "QUEUED":
queue = self.queue_position()
self._logger.debug(
f"Task is in {queue.queue_type} queue position: {queue.queue_position}"
)
self._logger.debug(f"Task {self._arn}: task status {task_status}")
if task_status in AwsQuantumTask.RESULTS_READY_STATES:
return self._download_result()
Expand Down
7 changes: 6 additions & 1 deletion src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def hybrid_job(
aws_session: AwsSession | None = None,
tags: dict[str, str] | None = None,
logger: Logger = getLogger(__name__),
quiet: bool | None = None,
reservation_arn: str | None = None,
) -> Callable:
"""Defines a hybrid job by decorating the entry point function. The job will be created
Expand All @@ -71,7 +72,7 @@ def hybrid_job(
The job created will be a `LocalQuantumJob` when `local` is set to `True`, otherwise an
`AwsQuantumJob`. The following parameters will be ignored when running a job with
`local` set to `True`: `wait_until_complete`, `instance_config`, `distribution`,
`copy_checkpoints_from_job`, `stopping_condition`, `tags`, and `logger`.
`copy_checkpoints_from_job`, `stopping_condition`, `tags`, `logger`, and `quiet`.
Args:
device (str | None): Device ARN of the QPU device that receives priority quantum
Expand Down Expand Up @@ -153,6 +154,9 @@ def hybrid_job(
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default: `getLogger(__name__)`
quiet (bool | None): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
reservation_arn (str | None): the reservation window arn provided by Braket
Direct to reserve exclusive usage for the device to run the hybrid job on.
Default: None.
Expand Down Expand Up @@ -210,6 +214,7 @@ def job_wrapper(*args, **kwargs) -> Callable:
"output_data_config": output_data_config,
"aws_session": aws_session,
"tags": tags,
"quiet": quiet,
"reservation_arn": reservation_arn,
}
for key, value in optional_args.items():
Expand Down
13 changes: 11 additions & 2 deletions src/braket/jobs/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Support for reading logs
#
##############################################################################
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

from botocore.exceptions import ClientError

Expand Down Expand Up @@ -155,7 +155,7 @@ def log_stream(
yield ev


def flush_log_streams(
def flush_log_streams( # noqa C901
aws_session: AwsSession,
log_group: str,
stream_prefix: str,
Expand All @@ -164,6 +164,8 @@ def flush_log_streams(
stream_count: int,
has_streams: bool,
color_wrap: ColorWrap,
state: list[str],
queue_position: Optional[str] = None,
) -> bool:
"""Flushes log streams to stdout.
Expand All @@ -183,6 +185,9 @@ def flush_log_streams(
been found. This value is possibly updated and returned at the end of execution.
color_wrap (ColorWrap): An instance of ColorWrap to potentially color-wrap print statements
from different streams.
state (list[str]): The previous and current state of the job.
queue_position (Optional[str]): The current queue position. This is not passed in if the job
is ran with `quiet=True`
Returns:
bool: Returns 'True' if any streams have been flushed.
Expand Down Expand Up @@ -225,6 +230,10 @@ def flush_log_streams(
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
else:
positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1)
elif queue_position is not None and state[1] == "QUEUED":
print(f"Job queue position: {queue_position}", end="\n", flush=True)
elif state[0] != state[1] and state[1] == "RUNNING" and queue_position is not None:
print("Running:", end="\n", flush=True)
else:
print(".", end="", flush=True)
return has_streams
61 changes: 61 additions & 0 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _get_job_response(**kwargs):
"jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446",
"jobName": "job-test-20210628140446",
"outputDataConfig": {"s3Path": "s3://amazon-braket-jobs/job-path/data"},
"queueInfo": {"position": "1", "queue": "JOBS_QUEUE"},
"roleArn": "arn:aws:iam::875981177017:role/AmazonBraketJobRole",
"status": "RUNNING",
"stoppingCondition": {"maxRuntimeInSeconds": 1200},
Expand Down Expand Up @@ -720,6 +721,14 @@ def test_logs(
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
)
quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses
Expand All @@ -740,6 +749,48 @@ def test_logs(
)


def test_logs_queue_progress(
quantum_job,
generate_get_job_response,
log_events_responses,
log_stream_responses,
capsys,
):
queue_info = {"queue": "JOBS_QUEUE", "position": "1"}
quantum_job._aws_session.get_job.side_effect = (
generate_get_job_response(status="QUEUED", queue_info=queue_info),
generate_get_job_response(status="QUEUED", queue_info=queue_info),
generate_get_job_response(status="QUEUED", queue_info=queue_info),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
)
quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses
quantum_job._aws_session.get_log_events.side_effect = log_events_responses

quantum_job.logs(wait=True, poll_interval_seconds=0)

captured = capsys.readouterr()
assert captured.out == "\n".join(
(
f"Job queue position: {queue_info['position']}",
"Running:",
"",
"hi there #1",
"hi there #2",
"hi there #2a",
"hi there #3",
"",
)
)


@patch.dict("os.environ", {"JPY_PARENT_PID": "True"})
def test_logs_multiple_instances(
quantum_job,
Expand All @@ -753,6 +804,15 @@ def test_logs_multiple_instances(
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
)
log_stream_responses[-1]["logStreams"].append({"logStreamName": "stream-2"})
Expand Down Expand Up @@ -818,6 +878,7 @@ def get_log_events(log_group, log_stream, start_time, start_from_head, next_toke

def test_logs_error(quantum_job, generate_get_job_response, capsys):
quantum_job._aws_session.get_job.side_effect = (
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="COMPLETED"),
Expand Down
Loading

0 comments on commit 7df9aa7

Please sign in to comment.