Skip to content

Commit

Permalink
feat: job decorator (#729)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Beach <[email protected]>
  • Loading branch information
ajberdy and mbeach-aws authored Oct 16, 2023
1 parent c4f07f5 commit 71ef0ec
Show file tree
Hide file tree
Showing 24 changed files with 1,439 additions and 161 deletions.
49 changes: 49 additions & 0 deletions examples/hybrid_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from braket.aws import AwsDevice
from braket.circuits import Circuit, FreeParameter, Observable
from braket.devices import Devices
from braket.jobs import get_job_device_arn, hybrid_job
from braket.jobs.metrics import log_metric


@hybrid_job(device=Devices.Amazon.SV1, wait_until_complete=True)
def run_hybrid_job(num_tasks=1):
# declare AwsDevice within the hybrid job
device = AwsDevice(get_job_device_arn())

# create a parametric circuit
circ = Circuit()
circ.rx(0, FreeParameter("theta"))
circ.cnot(0, 1)
circ.expectation(observable=Observable.X(), target=0)

# initial parameter
theta = 0.0

for i in range(num_tasks):
# run task, specifying input parameter
task = device.run(circ, shots=100, inputs={"theta": theta})
exp_val = task.result().values[0]

# modify the parameter (e.g. gradient descent)
theta += exp_val

log_metric(metric_name="exp_val", value=exp_val, iteration_number=i)

return {"final_theta": theta, "final_exp_val": exp_val}


job = run_hybrid_job(num_tasks=5)
print(job.result())
56 changes: 56 additions & 0 deletions examples/hybrid_job_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


from braket.aws import AwsDevice, AwsQuantumJob
from braket.circuits import Circuit, FreeParameter, Observable
from braket.devices import Devices
from braket.jobs import get_job_device_arn, save_job_result
from braket.jobs.metrics import log_metric


def run_hybrid_job(num_tasks: int):
# use the device specified in the hybrid job
device = AwsDevice(get_job_device_arn())

# create a parametric circuit
circ = Circuit()
circ.rx(0, FreeParameter("theta"))
circ.cnot(0, 1)
circ.expectation(observable=Observable.X(), target=0)

# initial parameter
theta = 0.0

for i in range(num_tasks):
# run task, specifying input parameter
task = device.run(circ, shots=100, inputs={"theta": theta})
exp_val = task.result().values[0]

# modify the parameter (e.g. gradient descent)
theta += exp_val

log_metric(metric_name="exp_val", value=exp_val, iteration_number=i)

save_job_result({"final_theta": theta, "final_exp_val": exp_val})


if __name__ == "__main__":
job = AwsQuantumJob.create(
device=Devices.Amazon.SV1, # choose priority device
source_module="hybrid_job_script.py", # specify file or directory with code to run
entry_point="hybrid_job_script:run_hybrid_job", # specify function to run
hyperparameters={"num_tasks": 5},
wait_until_complete=True,
)
print(job.result())
43 changes: 0 additions & 43 deletions examples/job.py

This file was deleted.

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"backoff",
"boltons",
"boto3>=1.28.53",
"cloudpickle==2.2.1",
"nest-asyncio",
"networkx",
"numpy",
Expand Down
51 changes: 25 additions & 26 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from enum import Enum
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, Dict, List, Union
from typing import Any

import boto3
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -70,26 +70,27 @@ def create(
code_location: str = None,
role_arn: str = None,
wait_until_complete: bool = False,
hyperparameters: Dict[str, Any] = None,
input_data: Union[str, Dict, S3DataSourceConfig] = None,
hyperparameters: dict[str, Any] = None,
input_data: str | dict | S3DataSourceConfig = None,
instance_config: InstanceConfig = None,
distribution: str = None,
stopping_condition: StoppingCondition = None,
output_data_config: OutputDataConfig = None,
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
aws_session: AwsSession = None,
tags: Dict[str, str] = None,
tags: dict[str, str] = None,
logger: Logger = getLogger(__name__),
) -> AwsQuantumJob:
"""Creates a hybrid job by invoking the Braket CreateJob API.
Args:
device (str): ARN for the AWS device which is primarily accessed for the execution
of this hybrid job. Alternatively, a string of the format
"local:<provider>/<simulator>" for using a local simulator for the hybrid job.
This string will be available as the environment variable `AMZN_BRAKET_DEVICE_ARN`
inside the hybrid job container when using a Braket container.
device (str): Device ARN of the QPU device that receives priority quantum
task queueing once the hybrid job begins running. Each QPU has a separate hybrid
jobs queue so that only one hybrid job is running at a time. The device string is
accessible in the hybrid job instance as the environment variable
"AMZN_BRAKET_DEVICE_ARN". When using embedded simulators, you may provide the device
argument as a string of the form: "local:<provider>/<simulator_name>".
source_module (str): Path (absolute, relative or an S3 URI) to a python module to be
tarred and uploaded. If `source_module` is an S3 URI, it must point to a
Expand Down Expand Up @@ -120,22 +121,22 @@ def create(
This would tail the hybrid job logs as it waits. Otherwise `False`.
Default: `False`.
hyperparameters (Dict[str, Any]): Hyperparameters accessible to the hybrid job.
The hyperparameters are made accessible as a Dict[str, str] to the hybrid job.
hyperparameters (dict[str, Any]): Hyperparameters accessible to the hybrid job.
The hyperparameters are made accessible as a dict[str, str] to the hybrid job.
For convenience, this accepts other types for keys and values, but `str()`
is called to convert them before being passed on. Default: None.
input_data (Union[str, Dict, S3DataSourceConfig]): Information about the training
input_data (str | dict | S3DataSourceConfig): Information about the training
data. Dictionary maps channel names to local paths or S3 URIs. Contents found
at any local paths will be uploaded to S3 at
f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local
path, S3 URI, or S3DataSourceConfig is provided, it will be given a default
channel name "input".
Default: {}.
instance_config (InstanceConfig): Configuration of the instances to be used
to execute the hybrid job. Default: InstanceConfig(instanceType='ml.m5.large',
instanceCount=1, volumeSizeInGB=30).
instance_config (InstanceConfig): Configuration of the instance(s) for running the
classical code for the hybrid job. Default:
`InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`.
distribution (str): A str that specifies how the hybrid job should be distributed.
If set to "data_parallel", the hyperparameters for the hybrid job will be set
Expand Down Expand Up @@ -165,7 +166,7 @@ def create(
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()
tags (Dict[str, str]): Dict specifying the key-value pairs for tagging this hybrid job.
tags (dict[str, str]): Dict specifying the key-value pairs for tagging this hybrid job.
Default: {}.
logger (Logger): Logger object with which to write logs, such as quantum task statuses
Expand Down Expand Up @@ -385,7 +386,7 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
elif self.state() in AwsQuantumJob.TERMINAL_STATES:
log_state = AwsQuantumJob.LogState.JOB_COMPLETE

def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:
def metadata(self, use_cached_value: bool = False) -> dict[str, Any]:
"""Gets the hybrid job metadata defined in Amazon Braket.
Args:
Expand All @@ -394,7 +395,7 @@ def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:
`GetJob` is called to retrieve the metadata. If `False`, always calls
`GetJob`, which also updates the cached value. Default: `False`.
Returns:
Dict[str, Any]: Dict that specifies the hybrid job metadata defined in Amazon Braket.
dict[str, Any]: Dict that specifies the hybrid job metadata defined in Amazon Braket.
"""
if not use_cached_value or not self._metadata:
self._metadata = self._aws_session.get_job(self._arn)
Expand All @@ -404,7 +405,7 @@ def metrics(
self,
metric_type: MetricType = MetricType.TIMESTAMP,
statistic: MetricStatistic = MetricStatistic.MAX,
) -> Dict[str, List[Any]]:
) -> dict[str, list[Any]]:
"""Gets all the metrics data, where the keys are the column names, and the values are a list
containing the values in each row. For example, the table:
timestamp energy
Expand All @@ -421,7 +422,7 @@ def metrics(
when there is a conflict. Default: MetricStatistic.MAX.
Returns:
Dict[str, List[Any]] : The metrics data.
dict[str, list[Any]] : The metrics data.
"""
fetcher = CwlInsightsMetricsFetcher(self._aws_session)
metadata = self.metadata(True)
Expand Down Expand Up @@ -450,7 +451,7 @@ def result(
self,
poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Retrieves the hybrid job result persisted using save_job_result() function.
Args:
Expand All @@ -460,7 +461,7 @@ def result(
Default: 5 seconds.
Returns:
Dict[str, Any]: Dict specifying the job results.
dict[str, Any]: Dict specifying the job results.
Raises:
RuntimeError: if hybrid job is in a FAILED or CANCELLED state.
Expand All @@ -480,7 +481,7 @@ def result(
return AwsQuantumJob._read_and_deserialize_results(temp_dir, job_name)

@staticmethod
def _read_and_deserialize_results(temp_dir: str, job_name: str) -> Dict[str, Any]:
def _read_and_deserialize_results(temp_dir: str, job_name: str) -> dict[str, Any]:
return load_job_result(Path(temp_dir, job_name, AwsQuantumJob.RESULTS_FILENAME))

def download_result(
Expand Down Expand Up @@ -565,9 +566,7 @@ def __hash__(self) -> int:
return hash(self.arn)

@staticmethod
def _initialize_session(
session_value: AwsSession, device: AwsDevice, logger: Logger
) -> AwsSession:
def _initialize_session(session_value: AwsSession, device: str, logger: Logger) -> AwsSession:
aws_session = session_value or AwsSession()
if device.startswith("local:"):
return aws_session
Expand Down
36 changes: 36 additions & 0 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import os.path
import re
from functools import cache
from pathlib import Path
from typing import Any, NamedTuple, Optional

Expand Down Expand Up @@ -825,3 +826,38 @@ def copy_session(
# Preserve user_agent information
copied_session._braket_user_agents = self._braket_user_agents
return copied_session

@cache
def get_full_image_tag(self, image_uri: str) -> str:
"""
Get verbose image tag from image uri.
Args:
image_uri (str): Image uri to get tag for.
Returns:
str: Verbose image tag for given image.
"""
registry = image_uri.split(".")[0]
repository, tag = image_uri.split("/")[-1].split(":")

# get image digest of latest image
digest = self.ecr_client.batch_get_image(
registryId=registry,
repositoryName=repository,
imageIds=[{"imageTag": tag}],
)["images"][0]["imageId"]["imageDigest"]

# get all images matching digest (same image, different tags)
images = self.ecr_client.batch_get_image(
registryId=registry,
repositoryName=repository,
imageIds=[{"imageDigest": digest}],
)["images"]

# find the tag with the python version info
for image in images:
if re.search(r"py\d\d+", tag := image["imageId"]["imageTag"]):
return tag

raise ValueError("Full image tag missing.")
2 changes: 2 additions & 0 deletions src/braket/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from braket.jobs.data_persistence import ( # noqa: F401
load_job_checkpoint,
load_job_result,
save_job_checkpoint,
save_job_result,
)
Expand All @@ -31,4 +32,5 @@
get_job_name,
get_results_dir,
)
from braket.jobs.hybrid_job import hybrid_job # noqa: F401
from braket.jobs.image_uris import Framework, retrieve_image # noqa: F401
Loading

0 comments on commit 71ef0ec

Please sign in to comment.