diff --git a/.github/workflows/dependent-tests.yml b/.github/workflows/dependent-tests.yml index 9744ad777..e9da969b1 100644 --- a/.github/workflows/dependent-tests.yml +++ b/.github/workflows/dependent-tests.yml @@ -25,8 +25,8 @@ jobs: - name: Install dependencies run: | pip install --upgrade pip - pip install --upgrade git+https://github.com/aws/amazon-braket-schemas-python@main - pip install --upgrade git+https://github.com/aws/amazon-braket-default-simulator-python@main + pip install --upgrade git+https://github.com/aws/amazon-braket-schemas-python.git@main + pip install --upgrade git+https://github.com/aws/amazon-braket-default-simulator-python.git@main pip install -e . cd .. git clone https://github.com/aws/${{ matrix.dependent }}.git diff --git a/.gitignore b/.gitignore index b554563ef..090c26b8b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.swp *.idea *.DS_Store +.vscode/* build_files.tar.gz .ycm_extra_conf.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 35ecb71ba..5a2e9c062 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,19 +10,19 @@ ### Bug Fixes and Other Changes - * Pin Coverage 5.5 +- Pin Coverage 5.5 ## v1.9.4 (2021-10-04) ### Bug Fixes and Other Changes - * fixed a spelling nit +- fixed a spelling nit ## v1.9.3 (2021-10-01) ### Bug Fixes and Other Changes - * rigetti typo +- rigetti typo ## v1.9.2 (2021-09-30) @@ -30,316 +30,316 @@ ### Bug Fixes and Other Changes - * Have tasks that are failed output the failure reason from tas… +- Have tasks that are failed output the failure reason from tas… ## v1.9.0 (2021-09-09) ### Features - * Verbatim boxes +- Verbatim boxes ## v1.8.0 (2021-08-23) ### Features - * Calculate arbitrary observables when `shots=0` +- Calculate arbitrary observables when `shots=0` ### Bug Fixes and Other Changes - * Remove immutable default args +- Remove immutable default args ## v1.7.5 (2021-08-18) ### Bug Fixes and Other Changes - * Add test for local simulator device names +- Add test for local simulator device names ### Documentation Changes - * Add documentation for support +- Add documentation for support ### Testing and Release Infrastructure - * Update copyright notice +- Update copyright notice ## v1.7.4 (2021-08-06) ### Bug Fixes and Other Changes - * Flatten Tensor Products +- Flatten Tensor Products ## v1.7.3.post0 (2021-08-05) ### Documentation Changes - * Modify README.md to include update instructions +- Modify README.md to include update instructions ## v1.7.3 (2021-07-22) ### Bug Fixes and Other Changes - * Add json schema validation for dwave device schemas. +- Add json schema validation for dwave device schemas. ## v1.7.2 (2021-07-14) ### Bug Fixes and Other Changes - * add json validation for device schema in unit tests +- add json validation for device schema in unit tests ## v1.7.1 (2021-07-02) ### Bug Fixes and Other Changes - * Result Type syntax in IR - * Update test_circuit.py +- Result Type syntax in IR +- Update test_circuit.py ## v1.7.0 (2021-06-25) ### Features - * code Circuit.as_unitary() +- code Circuit.as_unitary() ### Bug Fixes and Other Changes - * allow integral number types that aren't type int +- allow integral number types that aren't type int ## v1.6.5 (2021-06-23) ### Bug Fixes and Other Changes - * Get qubit count without instantiating op - * Require qubit indices to be integers +- Get qubit count without instantiating op +- Require qubit indices to be integers ## v1.6.4 (2021-06-10) ### Bug Fixes and Other Changes - * fallback on empty dict for device level parameters +- fallback on empty dict for device level parameters ## v1.6.3 (2021-06-04) ### Bug Fixes and Other Changes - * use device data to create device level parameter data when creating a… +- use device data to create device level parameter data when creating a… ## v1.6.2 (2021-05-28) ### Bug Fixes and Other Changes - * exclude null values from device parameters for annealing tasks +- exclude null values from device parameters for annealing tasks ## v1.6.1 (2021-05-25) ### Bug Fixes and Other Changes - * copy the boto3 session using the default botocore session +- copy the boto3 session using the default botocore session ## v1.6.0.post0 (2021-05-24) ### Documentation Changes - * Add reference to the noise simulation example notebook +- Add reference to the noise simulation example notebook ## v1.6.0 (2021-05-24) ### Features - * Noise operators +- Noise operators ### Testing and Release Infrastructure - * Use GitHub source for tox tests +- Use GitHub source for tox tests ## v1.5.16 (2021-05-05) ### Bug Fixes and Other Changes - * Added /taskArn to id field in AwsQuantumTask __repr__ +- Added /taskArn to id field in AwsQuantumTask **repr** ### Documentation Changes - * Fix link, typos, order +- Fix link, typos, order ## v1.5.15 (2021-04-08) ### Bug Fixes and Other Changes - * stop manually managing waiting treads in quantum task batch requests +- stop manually managing waiting treads in quantum task batch requests ## v1.5.14 (2021-04-07) ### Bug Fixes and Other Changes - * roll back dwave change - * Dwave roll back - * use device data to create device level parameter data when creating a quantum annealing task +- roll back dwave change +- Dwave roll back +- use device data to create device level parameter data when creating a quantum annealing task ## v1.5.13 (2021-03-26) ### Bug Fixes and Other Changes - * check for task completion before entering async event loop - * remove unneeded get_quantum_task calls +- check for task completion before entering async event loop +- remove unneeded get_quantum_task calls ## v1.5.12 (2021-03-25) ### Bug Fixes and Other Changes - * Update user_agent for AwsSession +- Update user_agent for AwsSession ## v1.5.11 (2021-03-22) ### Bug Fixes and Other Changes - * Fix broken repository links +- Fix broken repository links ## v1.5.10.post2 (2021-03-19) ### Testing and Release Infrastructure - * Run unit tests for dependent packages +- Run unit tests for dependent packages ## v1.5.10.post1 (2021-03-16) ### Documentation Changes - * Remove STS calls from examples +- Remove STS calls from examples ## v1.5.10.post0 (2021-03-11) ### Testing and Release Infrastructure - * Add Python 3.9 +- Add Python 3.9 ## v1.5.10 (2021-03-03) ### Bug Fixes and Other Changes - * Don't return NotImplemented for boolean - * Use np.eye for identity - * AngledGate equality checks angles - * Unitary equality checks matrix - * Remove hardcoded device ARNs +- Don't return NotImplemented for boolean +- Use np.eye for identity +- AngledGate equality checks angles +- Unitary equality checks matrix +- Remove hardcoded device ARNs ### Documentation Changes - * Wording changes - * Add note about AWS region in README +- Wording changes +- Add note about AWS region in README ### Testing and Release Infrastructure - * Use main instead of PyPi for build dependencies - * very minor test changes +- Use main instead of PyPi for build dependencies +- very minor test changes ## v1.5.9.post0 (2021-02-22) ### Documentation Changes - * remove unneeded calls to sts from the README - * adjust s3_folder naming in README to clarify which bucket to use +- remove unneeded calls to sts from the README +- adjust s3_folder naming in README to clarify which bucket to use ## v1.5.9 (2021-02-06) ### Bug Fixes and Other Changes - * Search for unknown QPUs +- Search for unknown QPUs ## v1.5.8 (2021-01-29) ### Bug Fixes and Other Changes - * Remove redundant statement, boost coverage - * convert measurements to indices without allocating a high-dimens… +- Remove redundant statement, boost coverage +- convert measurements to indices without allocating a high-dimens… ### Testing and Release Infrastructure - * Raise coverage to 100% +- Raise coverage to 100% ## v1.5.7 (2021-01-27) ### Bug Fixes and Other Changes - * More scalable eigenvalue calculation +- More scalable eigenvalue calculation ## v1.5.6 (2021-01-21) ### Bug Fixes and Other Changes - * ensure AngledGate casts its angle argument to float so it can be… +- ensure AngledGate casts its angle argument to float so it can be… ## v1.5.5 (2021-01-15) ### Bug Fixes and Other Changes - * get correct event loop for task results after running a batch over multiple threads +- get correct event loop for task results after running a batch over multiple threads ## v1.5.4 (2021-01-12) ### Bug Fixes and Other Changes - * remove window check for polling-- revert to polling at all times - * update result_types to use hashing +- remove window check for polling-- revert to polling at all times +- update result_types to use hashing ### Testing and Release Infrastructure - * Enable Codecov +- Enable Codecov ## v1.5.3 (2020-12-31) ### Bug Fixes and Other Changes - * Update range of random qubit in test_qft_iqft_h +- Update range of random qubit in test_qft_iqft_h ## v1.5.2.post0 (2020-12-30) ### Testing and Release Infrastructure - * Add build badge - * Use GitHub Actions for CI +- Add build badge +- Use GitHub Actions for CI ## v1.5.2 (2020-12-22) ### Bug Fixes and Other Changes - * Get regions for QPUs instead of providers - * Do not search for simulators in wrong region +- Get regions for QPUs instead of providers +- Do not search for simulators in wrong region ## v1.5.1 (2020-12-10) ### Bug Fixes and Other Changes - * Use current region for simulators in get_devices +- Use current region for simulators in get_devices ## v1.5.0 (2020-12-04) ### Features - * Always accept identity observable factors +- Always accept identity observable factors ### Documentation Changes - * backticks for batching tasks - * add punctuation to aws_session.py - * fix backticks, missing periods in quantum task docs - * fix backticks, grammar for aws_device.py +- backticks for batching tasks +- add punctuation to aws_session.py +- fix backticks, missing periods in quantum task docs +- fix backticks, grammar for aws_device.py ## v1.4.1 (2020-12-04) ### Bug Fixes and Other Changes - * Correct integ test bucket +- Correct integ test bucket ## v1.4.0.post0 (2020-12-03) ### Documentation Changes - * Point README to developer guide +- Point README to developer guide ## v1.4.0 (2020-11-26) ### Features - * Enable retries when retrieving results from AwsQuantumTaskBatch. +- Enable retries when retrieving results from AwsQuantumTaskBatch. ## v1.3.1 (2020-11-25) @@ -347,99 +347,102 @@ ### Features - * Enable explicit qubit allocation - * Add support for batch execution +- Enable explicit qubit allocation +- Add support for batch execution ### Bug Fixes and Other Changes - * Correctly cache status +- Correctly cache status ## v1.2.0 (2020-11-02) ### Features - * support tags parameter for create method in AwsQuantumTask +- support tags parameter for create method in AwsQuantumTask ## v1.1.4.post0 (2020-10-30) ### Testing and Release Infrastructure - * update codeowners +- update codeowners ## v1.1.4 (2020-10-29) ### Bug Fixes and Other Changes - * Enable simultaneous measurement of observables with shared factors - * Add optimization to only poll during execution window +- Enable simultaneous measurement of observables with shared factors +- Add optimization to only poll during execution window ## v1.1.3 (2020-10-20) ### Bug Fixes and Other Changes - * add observable targets not in instructions to circuit qubit count and qubits +- add observable targets not in instructions to circuit qubit count and qubits ## v1.1.2.post1 (2020-10-15) ### Documentation Changes - * add sample notebooks link +- add sample notebooks link ## v1.1.2.post0 (2020-10-05) ### Testing and Release Infrastructure - * change check for s3 bucket exists - * change bucket creation setup for integ tests +- change check for s3 bucket exists +- change bucket creation setup for integ tests ## v1.1.2 (2020-10-02) ### Bug Fixes and Other Changes - * Add error for target qubit set size not equal to operator qubit size in instruction - * Add error message for running a circuit without instructions +- Add error for target qubit set size not equal to operator qubit size in instruction +- Add error message for running a circuit without instructions ### Documentation Changes - * Update docstring for measurement_counts +- Update docstring for measurement_counts ## v1.1.1.post2 (2020-09-29) ### Documentation Changes - * Add D-Wave Advantage_system1 arn +- Add D-Wave Advantage_system1 arn ## v1.1.1.post1 (2020-09-10) ### Testing and Release Infrastructure - * fix black formatting +- fix black formatting ## v1.1.1.post0 (2020-09-09) ### Testing and Release Infrastructure - * Add CHANGELOG.md +- Add CHANGELOG.md ## v1.1.1 (2020-09-09) ### Bug Fixes -* Add handling for solution_counts=[] for annealing result + +- Add handling for solution_counts=[] for annealing result ## v1.1.0 (2020-09-08) ### Features -* Add get_devices to search devices based on criteria + +- Add get_devices to search devices based on criteria ### Bug Fixes -* Call async_result() before calling result() -* Convert amplitude result to a complex number + +- Call async_result() before calling result() +- Convert amplitude result to a complex number ## v1.0.0.post1 (2020-08-14) ### Documentation -* add readthedocs link to README +- add readthedocs link to README ## v1.0.0 (2020-08-13) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 88373f012..3ae63fc25 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -60,7 +60,7 @@ Before sending us a pull request, please ensure that: ### Run the Unit Tests 1. Install tox using `pip install tox` -1. Install coverage using `pip install .[test]` +1. Install coverage using `pip install '.[test]'` 1. cd into the amazon-braket-sdk-python folder: `cd amazon-braket-sdk-python` or `cd /environment/amazon-braket-sdk-python` 1. Run the following tox command and verify that all unit tests pass: `tox -e unit-tests` diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..e44b0e7bc --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,14 @@ +include *.md +include *.yaml +include *.yml +include .coveragerc +include CODEOWNERS +include tox.ini +recursive-include src *.json +recursive-include bin *.py +recursive-include doc *.py +recursive-include doc *.rst +recursive-include doc *.txt +recursive-include doc Makefile +recursive-include examples *.py +recursive-include test *.py diff --git a/README.md b/README.md index 87a114672..dd3b1549f 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,24 @@ batch = device.run_batch(circuits, s3_folder, shots=100) print(batch.results()[0].measurement_counts) # The result of the first task in the batch ``` +### Running a hybrid job + +```python +from braket.aws import AwsQuantumJob + +job = AwsQuantumJob.create( + device="arn:aws:braket:::device/quantum-simulator/amazon/sv1", + source_module="job.py", + entry_point="job:run_job", + wait_until_complete=True, +) +print(job.result()) +``` +where `run_job` is a function in the file `job.py`. + + +The code sample imports the Amazon Braket framework, then creates a hybrid job with the entry point being the `run_job` function. The hybrid job creates quantum tasks against the SV1 AWS Simulator. The job runs synchronously, and prints logs until it completes. The complete example can be found in `../examples/job.py`. + ### Available Simulators Amazon Braket provides access to two types of simulators: fully managed simulators, available through the Amazon Braket service, and the local simulators that are part of the Amazon Braket SDK. @@ -202,6 +220,7 @@ After you create a profile, use the following command to set the `AWS_PROFILE` s ```bash export AWS_PROFILE=YOUR_PROFILE_NAME ``` +To run the integration tests for local jobs, you need to have Docker installed and running. To install Docker follow these instructions: [Install Docker](https://docs.docker.com/get-docker/) Run the tests: diff --git a/doc/examples-hybrid-jobs.rst b/doc/examples-hybrid-jobs.rst new file mode 100644 index 000000000..c1b6c6a9d --- /dev/null +++ b/doc/examples-hybrid-jobs.rst @@ -0,0 +1,33 @@ +################################ +Amazon Braket Hybrid Jobs +################################ + +Learn more about hybrid jobs on Amazon Braket. + +.. toctree:: + :maxdepth: 2 + +************************** +`Getting Started `_ +************************** + +This tutorial shows how to run your first Amazon Braket Hybrid Job. + +************************** +`Hyperparameter Tuning `_ +************************** + +This notebook demonstrates a typical quantum machine learning workflow, including uploading data, monitoring training, and tuning hyperparameters. + +************************** +`Using Pennylane with Braket Jobs `_ +************************** + +In this tutorial, we use PennyLane within Amazon Braket Hybrid Jobs to run the Quantum Approximate Optimization Algorithm (QAOA) on a Max-Cut problem. + +************************** +`Bring your own container `_ +************************** + +Amazon Braket has pre-configured containers for executing Amazon Braket Hybrid Jobs, which are sufficient for many use cases involving the Braket SDK and PennyLane. +However, if we want to use custom packages outside the scope of pre-configured containers, we have the ability to supply a custom-built container. In this tutorial, we show how to use Braket Hybrid Jobs to train a quantum machine learning model using BYOC (Bring Your Own Container). diff --git a/doc/examples.rst b/doc/examples.rst index e0126c2f1..cc65a5a8e 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -14,5 +14,6 @@ https://github.com/aws/amazon-braket-examples. examples-hybrid-quantum.rst examples-ml-pennylane.rst examples-quantum-annealing-dwave.rst + examples-hybrid-jobs.rst \ No newline at end of file diff --git a/examples/bell.py b/examples/bell.py index dfa81e521..7a2c43332 100644 --- a/examples/bell.py +++ b/examples/bell.py @@ -16,10 +16,7 @@ device = AwsDevice("arn:aws:braket:::device/quantum-simulator/amazon/sv1") -# Use the S3 bucket you created during onboarding -s3_folder = ("amazon-braket-Your-Bucket-Name", "folder-name") - # https://wikipedia.org/wiki/Bell_state bell = Circuit().h(0).cnot(0, 1) -task = device.run(bell, s3_folder, shots=100) +task = device.run(bell, shots=100) print(task.result().measurement_counts) diff --git a/examples/debug_bell.py b/examples/debug_bell.py index f1ff33fcd..cd492e45c 100644 --- a/examples/debug_bell.py +++ b/examples/debug_bell.py @@ -23,15 +23,11 @@ device = AwsDevice("arn:aws:braket:::device/quantum-simulator/amazon/sv1") -# Use the S3 bucket you created during onboarding -s3_folder = ("amazon-braket-Your-Bucket-Name", "folder-name") - bell = Circuit().h(0).cnot(0, 1) # pass in logger to device.run, enabling debugging logs to print to console logger.info( device.run( bell, - s3_folder, shots=100, poll_timeout_seconds=120, poll_interval_seconds=0.25, diff --git a/examples/job.py b/examples/job.py new file mode 100644 index 000000000..ab9cf8e88 --- /dev/null +++ b/examples/job.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import os + +from braket.aws import AwsDevice, AwsQuantumJob +from braket.circuits import Circuit +from braket.jobs import save_job_result + + +def run_job(): + device = AwsDevice(os.environ.get("AMZN_BRAKET_DEVICE_ARN")) + + bell = Circuit().h(0).cnot(0, 1) + num_tasks = 10 + results = [] + + for i in range(num_tasks): + task = device.run(bell, shots=100) + result = task.result().measurement_counts + results.append(result) + print(f"iter {i}: {result}") + + save_job_result({"results": results}) + + +if __name__ == "__main__": + job = AwsQuantumJob.create( + device="arn:aws:braket:::device/quantum-simulator/amazon/sv1", + source_module="job.py", + entry_point="job:run_job", + wait_until_complete=True, + ) + print(job.result()) diff --git a/setup.py b/setup.py index cd82e27e1..c7420189d 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ extras_require={ "test": [ "black", + "botocore", "flake8", "isort", "jsonschema==3.2.0", @@ -54,6 +55,7 @@ "tox", ] }, + include_package_data=True, url="https://github.com/aws/amazon-braket-sdk-python", author="Amazon Web Services", description=( diff --git a/src/braket/aws/__init__.py b/src/braket/aws/__init__.py index 08217d9a5..d0b3a3411 100644 --- a/src/braket/aws/__init__.py +++ b/src/braket/aws/__init__.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from braket.aws.aws_device import AwsDevice, AwsDeviceType # noqa: F401 +from braket.aws.aws_quantum_job import AwsQuantumJob # noqa: F401 from braket.aws.aws_quantum_task import AwsQuantumTask # noqa: F401 from braket.aws.aws_quantum_task_batch import AwsQuantumTaskBatch # noqa: F401 from braket.aws.aws_session import AwsSession # noqa: F401 diff --git a/src/braket/aws/aws_device.py b/src/braket/aws/aws_device.py index 0e694870b..59d461794 100644 --- a/src/braket/aws/aws_device.py +++ b/src/braket/aws/aws_device.py @@ -13,11 +13,11 @@ from __future__ import annotations +import os from enum import Enum from typing import List, Optional, Union -import boto3 -from botocore.config import Config +from botocore.errorfactory import ClientError from networkx import Graph, complete_graph, from_edgelist from braket.annealing.problem import Problem @@ -79,7 +79,7 @@ def __init__(self, arn: str, aws_session: Optional[AwsSession] = None): def run( self, task_specification: Union[Circuit, Problem], - s3_destination_folder: AwsSession.S3DestinationFolder, + s3_destination_folder: Optional[AwsSession.S3DestinationFolder] = None, shots: Optional[int] = None, poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, @@ -93,7 +93,10 @@ def run( Args: task_specification (Union[Circuit, Problem]): Specification of task (circuit or annealing problem) to run on device. - s3_destination_folder: The S3 location to save the task's results to. + s3_destination_folder (AwsSession.S3DestinationFolder, optional): The S3 location to + save the task's results to. Default is `/tasks` if evoked + outside of a Braket Job, `/jobs//tasks` if evoked inside of + a Braket Job. shots (int, optional): The number of times to run the circuit or annealing problem. Default is 1000 for QPUs and 0 for simulators. poll_timeout_seconds (float): The polling timeout for `AwsQuantumTask.result()`, @@ -141,7 +144,13 @@ def run( self._aws_session, self._arn, task_specification, - s3_destination_folder, + s3_destination_folder + or ( + AwsSession.parse_s3_uri(os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_URI")) + if "AMZN_BRAKET_TASK_RESULTS_S3_URI" in os.environ + else None + ) + or (self._aws_session.default_bucket(), "tasks"), shots if shots is not None else self._default_shots, poll_timeout_seconds=poll_timeout_seconds, poll_interval_seconds=poll_interval_seconds, @@ -152,7 +161,7 @@ def run( def run_batch( self, task_specifications: List[Union[Circuit, Problem]], - s3_destination_folder: AwsSession.S3DestinationFolder, + s3_destination_folder: Optional[AwsSession.S3DestinationFolder] = None, shots: Optional[int] = None, max_parallel: Optional[int] = None, max_connections: int = AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT, @@ -166,7 +175,10 @@ def run_batch( Args: task_specifications (List[Union[Circuit, Problem]]): List of circuits or annealing problems to run on device. - s3_destination_folder: The S3 location to save the tasks' results to. + s3_destination_folder (AwsSession.S3DestinationFolder, optional): The S3 location to + save the tasks' results to. Default is `/tasks` if evoked + outside of a Braket Job, `/jobs//tasks` if evoked inside of + a Braket Job. shots (int, optional): The number of times to run the circuit or annealing problem. Default is 1000 for QPUs and 0 for simulators. max_parallel (int, optional): The maximum number of tasks to run on AWS in parallel. @@ -190,10 +202,16 @@ def run_batch( `braket.aws.aws_quantum_task_batch.AwsQuantumTaskBatch` """ return AwsQuantumTaskBatch( - AwsDevice._copy_aws_session(self._aws_session, max_connections=max_connections), + AwsSession.copy_session(self._aws_session, max_connections=max_connections), self._arn, task_specifications, - s3_destination_folder, + s3_destination_folder + or ( + AwsSession.parse_s3_uri(os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_URI")) + if "AMZN_BRAKET_TASK_RESULTS_S3_URI" in os.environ + else None + ) + or (self._aws_session.default_bucket(), "tasks"), shots if shots is not None else self._default_shots, max_parallel=max_parallel if max_parallel is not None else self._default_max_parallel, max_workers=max_connections, @@ -210,22 +228,26 @@ def refresh_metadata(self) -> None: self._populate_properties(self._aws_session) def _get_session_and_initialize(self, session): - current_region = session.boto_session.region_name + current_region = session.region try: self._populate_properties(session) return session - except Exception: - if "qpu" not in self._arn: - raise ValueError(f"Simulator {self._arn} not found in {current_region}") + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + if "qpu" not in self._arn: + raise ValueError(f"Simulator '{self._arn}' not found in '{current_region}'") + else: + raise e # Search remaining regions for QPU for region in frozenset(AwsDevice.REGIONS) - {current_region}: - region_session = AwsDevice._copy_aws_session(session, region) + region_session = AwsSession.copy_session(session, region) try: self._populate_properties(region_session) return region_session - except Exception: - pass - raise ValueError(f"QPU {self._arn} not found") + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + raise e + raise ValueError(f"QPU '{self._arn}' not found") def _populate_properties(self, session): metadata = session.get_device(self._arn) @@ -315,27 +337,6 @@ def _default_shots(self): def _default_max_parallel(self): return AwsDevice.DEFAULT_MAX_PARALLEL - @staticmethod - def _copy_aws_session( - aws_session: AwsSession, - region: Optional[str] = None, - max_connections: Optional[int] = None, - ) -> AwsSession: - config = Config(max_pool_connections=max_connections) if max_connections else None - session_region = aws_session.boto_session.region_name - new_region = region or session_region - creds = aws_session.boto_session.get_credentials() - if creds.method == "explicit": - boto_session = boto3.Session( - aws_access_key_id=creds.access_key, - aws_secret_access_key=creds.secret_key, - aws_session_token=creds.token, - region_name=new_region, - ) - else: - boto_session = boto3.Session(region_name=new_region) - return AwsSession(boto_session=boto_session, config=config) - def __repr__(self): return "Device('name': {}, 'arn': {})".format(self.name, self.arn) @@ -397,7 +398,7 @@ def get_devices( session_for_region = ( aws_session if region == session_region - else AwsDevice._copy_aws_session(aws_session, region) + else AwsSession.copy_session(aws_session, region) ) # Simulators are only instantiated in the same region as the AWS session types_for_region = sorted( diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py new file mode 100644 index 000000000..48f9a6984 --- /dev/null +++ b/src/braket/aws/aws_quantum_job.py @@ -0,0 +1,561 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +import math +import tarfile +import tempfile +import time +from enum import Enum +from logging import Logger, getLogger +from pathlib import Path +from typing import Any, Dict, List, Union + +import boto3 +from botocore.exceptions import ClientError + +from braket.aws import AwsDevice +from braket.aws.aws_session import AwsSession +from braket.jobs import logs +from braket.jobs.config import ( + CheckpointConfig, + InstanceConfig, + OutputDataConfig, + S3DataSourceConfig, + StoppingCondition, +) +from braket.jobs.metrics_data.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher + +# TODO: Have added metric file in metrics folder, but have to decide on the name for keep +# for the files, since all those metrics are retrieved from the CW. +from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType +from braket.jobs.quantum_job import QuantumJob +from braket.jobs.quantum_job_creation import prepare_quantum_job +from braket.jobs.serialization import deserialize_values +from braket.jobs_data import PersistedJobData + + +class AwsQuantumJob(QuantumJob): + """Amazon Braket implementation of a quantum job.""" + + TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELLED"} + RESULTS_FILENAME = "results.json" + RESULTS_TAR_FILENAME = "model.tar.gz" + LOG_GROUP = "/aws/braket/jobs" + + class LogState(Enum): + TAILING = "tailing" + JOB_COMPLETE = "job_complete" + COMPLETE = "complete" + + @classmethod + def create( + cls, + device: str, + source_module: str, + entry_point: str = None, + image_uri: str = None, + job_name: str = None, + code_location: str = None, + role_arn: str = None, + wait_until_complete: bool = False, + hyperparameters: Dict[str, Any] = None, + input_data: Union[str, Dict, S3DataSourceConfig] = None, + instance_config: InstanceConfig = None, + stopping_condition: StoppingCondition = None, + output_data_config: OutputDataConfig = None, + copy_checkpoints_from_job: str = None, + checkpoint_config: CheckpointConfig = None, + aws_session: AwsSession = None, + tags: Dict[str, str] = None, + logger: Logger = getLogger(__name__), + ) -> AwsQuantumJob: + """Creates a job by invoking the Braket CreateJob API. + + Args: + device (str): ARN for the AWS device which is primarily + accessed for the execution of this job. + + source_module (str): Path (absolute, relative or an S3 URI) to a python module to be + tarred and uploaded. If `source_module` is an S3 URI, it must point to a + tar.gz file. Otherwise, source_module may be a file or directory. + + entry_point (str): A str that specifies the entry point of the job, relative to + the source module. The entry point must be in the format + `importable.module` or `importable.module:callable`. For example, + `source_module.submodule:start_here` indicates the `start_here` function + contained in `source_module.submodule`. If source_module is an S3 URI, + entry point must be given. Default: source_module's name + + image_uri (str): A str that specifies the ECR image to use for executing the job. + `image_uris.retrieve_image()` function may be used for retrieving the ECR image URIs + for the containers supported by Braket. Default = ``. + + job_name (str): A str that specifies the name with which the job is created. + Default: f'{image_uri_type}-{timestamp}'. + + code_location (str): The S3 prefix URI where custom code will be uploaded. + Default: f's3://{default_bucket_name}/jobs/{job_name}/script'. + + role_arn (str): A str providing the IAM role ARN used to execute the + script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`. + + wait_until_complete (bool): `True` if we should wait until the job completes. + This would tail the job logs as it waits. Otherwise `False`. Default: `False`. + + hyperparameters (Dict[str, Any]): Hyperparameters accessible to the job. + The hyperparameters are made accessible as a Dict[str, str] to the job. + For convenience, this accepts other types for keys and values, but `str()` + is called to convert them before being passed on. Default: None. + + input_data (Union[str, S3DataSourceConfig, dict]): Information about the training + data. Dictionary maps channel names to local paths or S3 URIs. Contents found + at any local paths will be uploaded to S3 at + f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local + path, S3 URI, or S3DataSourceConfig is provided, it will be given a default + channel name "input". + Default: {}. + + instance_config (InstanceConfig): Configuration of the instances to be used + to execute the job. Default: InstanceConfig(instanceType='ml.m5.large', + instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None). + + stopping_condition (StoppingCondition): The maximum length of time, in seconds, + and the maximum number of tasks that a job can run before being forcefully stopped. + Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). + + output_data_config (OutputDataConfig): Specifies the location for the output of the job. + Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', + kmsKeyId=None). + + copy_checkpoints_from_job (str): A str that specifies the job ARN whose checkpoint you + want to use in the current job. Specifying this value will copy over the checkpoint + data from `use_checkpoints_from_job`'s checkpoint_config s3Uri to the current job's + checkpoint_config s3Uri, making it available at checkpoint_config.localPath during + the job execution. Default: None + + checkpoint_config (CheckpointConfig): Configuration that specifies the location where + checkpoint data is stored. + Default: CheckpointConfig(localPath='/opt/jobs/checkpoints', + s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints'). + + aws_session (AwsSession): AwsSession for connecting to AWS Services. + Default: AwsSession() + + tags (Dict[str, str]): Dict specifying the key-value pairs for tagging this job. + Default: {}. + + logger (Logger): Logger object with which to write logs, such as task statuses + while waiting for task to be in a terminal state. Default is `getLogger(__name__)` + + Returns: + AwsQuantumJob: Job tracking the execution on Amazon Braket. + + Raises: + ValueError: Raises ValueError if the parameters are not valid. + """ + aws_session = AwsQuantumJob._initialize_session(aws_session, device, logger) + + create_job_kwargs = prepare_quantum_job( + device=device, + source_module=source_module, + entry_point=entry_point, + image_uri=image_uri, + job_name=job_name, + code_location=code_location, + role_arn=role_arn, + hyperparameters=hyperparameters, + input_data=input_data, + instance_config=instance_config, + stopping_condition=stopping_condition, + output_data_config=output_data_config, + copy_checkpoints_from_job=copy_checkpoints_from_job, + checkpoint_config=checkpoint_config, + aws_session=aws_session, + tags=tags, + ) + + job_arn = aws_session.create_job(**create_job_kwargs) + job = AwsQuantumJob(job_arn, aws_session) + + if wait_until_complete: + print(f"Initializing Braket Job: {job_arn}") + job.logs(wait=True) + + return job + + def __init__(self, arn: str, aws_session: AwsSession = None): + """ + Args: + arn (str): The ARN of the job. + aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services. + Default is `None`, in which case an `AwsSession` object will be created with the + region of the job. + """ + self._arn: str = arn + if aws_session: + if not self._is_valid_aws_session_region_for_job_arn(aws_session, arn): + raise ValueError( + "The aws session region does not match the region for the supplied arn." + ) + self._aws_session = aws_session + else: + self._aws_session = AwsQuantumJob._default_session_for_job_arn(arn) + self._metadata = {} + + @staticmethod + def _is_valid_aws_session_region_for_job_arn(aws_session: AwsSession, job_arn: str) -> bool: + """ + bool: `True` when the aws_session region matches the job_arn region; otherwise `False`. + """ + job_region = job_arn.split(":")[3] + return job_region == aws_session.braket_client.meta.region_name + + @staticmethod + def _default_session_for_job_arn(job_arn: str) -> AwsSession: + """Get an AwsSession for the Job ARN. The AWS session should be in the region of the job. + + Args: + job_arn (str): The ARN for the quantum job. + + Returns: + AwsSession: `AwsSession` object with default `boto_session` in job's region. + """ + job_region = job_arn.split(":")[3] + boto_session = boto3.Session(region_name=job_region) + return AwsSession(boto_session=boto_session) + + @property + def arn(self) -> str: + """str: The ARN (Amazon Resource Name) of the quantum job.""" + return self._arn + + @property + def name(self) -> str: + """str: The name of the quantum job.""" + return self._arn.partition("job/")[-1] + + def state(self, use_cached_value: bool = False) -> str: + """The state of the quantum job. + + Args: + use_cached_value (bool, optional): If `True`, uses the value most recently retrieved + value from the Amazon Braket `GetJob` operation. If `False`, calls the + `GetJob` operation to retrieve metadata, which also updates the cached + value. Default = `False`. + Returns: + str: The value of `status` in `metadata()`. This is the value of the `status` key + in the Amazon Braket `GetJob` operation. + + See Also: + `metadata()` + """ + return self.metadata(use_cached_value).get("status") + + def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: + """Display logs for a given job, optionally tailing them until job is complete. + + If the output is a tty or a Jupyter cell, it will be color-coded + based on which instance the log entry is from. + + Args: + wait (bool): `True` to keep looking for new log entries until the job completes; + otherwise `False`. Default: `False`. + + poll_interval_seconds (int): The interval of time, in seconds, between polling for + new log entries and job completion (default: 5). + + Raises: + exceptions.UnexpectedStatusException: If waiting and the training job fails. + """ + # The loop below implements a state machine that alternates between checking the job status + # and reading whatever is available in the logs at this point. Note, that if we were + # called with wait == False, we never check the job status. + # + # If wait == TRUE and job is not completed, the initial state is TAILING + # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is + # complete). + # + # The state table: + # + # STATE ACTIONS CONDITION NEW STATE + # ---------------- ---------------- ----------------- ---------------- + # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE + # Else TAILING + # JOB_COMPLETE Read logs, Pause Any COMPLETE + # COMPLETE Read logs, Exit N/A + # + # Notes: + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to + # Cloudwatch after the job was marked complete. + + job_already_completed = self.state() in AwsQuantumJob.TERMINAL_STATES + log_state = ( + AwsQuantumJob.LogState.TAILING + if wait and not job_already_completed + else AwsQuantumJob.LogState.COMPLETE + ) + + log_group = AwsQuantumJob.LOG_GROUP + stream_prefix = f"{self.name}/" + stream_names = [] # The list of log streams + positions = {} # The current position in each stream, map of stream name -> position + instance_count = 1 # currently only support a single instance + has_streams = False + color_wrap = logs.ColorWrap() + + while True: + time.sleep(poll_interval_seconds) + + has_streams = logs.flush_log_streams( + self._aws_session, + log_group, + stream_prefix, + stream_names, + positions, + instance_count, + has_streams, + color_wrap, + ) + + if log_state == AwsQuantumJob.LogState.COMPLETE: + break + + if log_state == AwsQuantumJob.LogState.JOB_COMPLETE: + log_state = AwsQuantumJob.LogState.COMPLETE + elif self.state() in AwsQuantumJob.TERMINAL_STATES: + log_state = AwsQuantumJob.LogState.JOB_COMPLETE + + def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: + """Gets the job metadata defined in Amazon Braket. + + Args: + use_cached_value (bool, optional): If `True`, uses the value most recently retrieved + from the Amazon Braket `GetJob` operation, if it exists; if does not exist, + `GetJob` is called to retrieve the metadata. If `False`, always calls + `GetJob`, which also updates the cached value. Default: `False`. + Returns: + Dict[str, Any]: Dict that specifies the job metadata defined in Amazon Braket. + """ + if not use_cached_value or not self._metadata: + self._metadata = self._aws_session.get_job(self._arn) + return self._metadata + + def metrics( + self, + metric_type: MetricType = MetricType.TIMESTAMP, + statistic: MetricStatistic = MetricStatistic.MAX, + ) -> Dict[str, List[Any]]: + """Gets all the metrics data, where the keys are the column names, and the values are a list + containing the values in each row. For example, the table: + timestamp energy + 0 0.1 + 1 0.2 + would be represented as: + { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } + values may be integers, floats, strings or None. + + Args: + metric_type (MetricType): The type of metrics to get. Default: MetricType.TIMESTAMP. + + statistic (MetricStatistic): The statistic to determine which metric value to use + when there is a conflict. Default: MetricStatistic.MAX. + + Returns: + Dict[str, List[Union[str, float, int]]] : The metrics data. + """ + fetcher = CwlInsightsMetricsFetcher(self._aws_session) + metadata = self.metadata(True) + job_name = metadata["jobName"] + job_start = None + job_end = None + if "startedAt" in metadata: + job_start = int(metadata["startedAt"].timestamp()) + if self.state() in AwsQuantumJob.TERMINAL_STATES and "endedAt" in metadata: + job_end = int(math.ceil(metadata["endedAt"].timestamp())) + return fetcher.get_metrics_for_job(job_name, metric_type, statistic, job_start, job_end) + + def cancel(self) -> str: + """Cancels the job. + + Returns: + str: Indicates the status of the job. + + Raises: + ClientError: If there are errors invoking the CancelJob API. + """ + cancellation_response = self._aws_session.cancel_job(self._arn) + return cancellation_response["cancellationStatus"] + + def result( + self, + poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT, + poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL, + ) -> Dict[str, Any]: + """Retrieves the job result persisted using save_job_result() function. + + Args: + poll_timeout_seconds (float): The polling timeout, in seconds, for `result()`. + Default: 10 days. + + poll_interval_seconds (float): The polling interval, in seconds, for `result()`. + Default: 5 seconds. + + + Returns: + Dict[str, Any]: Dict specifying the job results. + + Raises: + RuntimeError: if job is in a FAILED or CANCELLED state. + TimeoutError: if job execution exceeds the polling timeout period. + """ + + with tempfile.TemporaryDirectory() as temp_dir: + job_name = self.metadata(True)["jobName"] + + try: + self.download_result(temp_dir, poll_timeout_seconds, poll_interval_seconds) + except ClientError as e: + if e.response["Error"]["Code"] == "404": + return {} + else: + raise e + return AwsQuantumJob._read_and_deserialize_results(temp_dir, job_name) + + @staticmethod + def _read_and_deserialize_results(temp_dir, job_name): + try: + with open(f"{temp_dir}/{job_name}/{AwsQuantumJob.RESULTS_FILENAME}", "r") as f: + persisted_data = PersistedJobData.parse_raw(f.read()) + deserialized_data = deserialize_values( + persisted_data.dataDictionary, persisted_data.dataFormat + ) + return deserialized_data + except FileNotFoundError: + return {} + + def download_result( + self, + extract_to=None, + poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT, + poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL, + ) -> None: + """Downloads the results from the job output S3 bucket and extracts the tar.gz + bundle to the location specified by `extract_to`. If no location is specified, + the results are extracted to the current directory. + + Args: + extract_to (str): The directory to which the results are extracted. The results + are extracted to a folder titled with the job name within this directory. + Default= `Current working directory`. + + poll_timeout_seconds: (float): The polling timeout, in seconds, for `download_result()`. + Default: 10 days. + + poll_interval_seconds: (float): The polling interval, in seconds, for + `download_result()`.Default: 5 seconds. + + Raises: + RuntimeError: if job is in a FAILED or CANCELLED state. + TimeoutError: if job execution exceeds the polling timeout period. + """ + + extract_to = extract_to or Path.cwd() + + timeout_time = time.time() + poll_timeout_seconds + job_response = self.metadata(True) + + while time.time() < timeout_time: + job_response = self.metadata(True) + job_state = self.state() + + if job_state in AwsQuantumJob.TERMINAL_STATES: + output_s3_path = job_response["outputDataConfig"]["s3Path"] + output_s3_uri = f"{output_s3_path}/output/model.tar.gz" + AwsQuantumJob._attempt_results_download(self, output_s3_uri, output_s3_path) + AwsQuantumJob._extract_tar_file(f"{extract_to}/{self.name}") + return + else: + time.sleep(poll_interval_seconds) + + raise TimeoutError( + f"{job_response['jobName']}: Polling for job completion " + f"timed out after {poll_timeout_seconds} seconds." + ) + + def _attempt_results_download(self, output_bucket_uri, output_s3_path): + try: + self._aws_session.download_from_s3( + s3_uri=output_bucket_uri, filename=AwsQuantumJob.RESULTS_TAR_FILENAME + ) + except ClientError as e: + if e.response["Error"]["Code"] == "404": + exception_response = { + "Error": { + "Code": "404", + "Message": f"Error retrieving results, " + f"could not find results at '{output_s3_path}'", + } + } + raise ClientError(exception_response, "HeadObject") from e + else: + raise e + + @staticmethod + def _extract_tar_file(extract_path): + with tarfile.open(AwsQuantumJob.RESULTS_TAR_FILENAME, "r:gz") as tar: + tar.extractall(extract_path) + + def __repr__(self) -> str: + return f"AwsQuantumJob('arn':'{self.arn}')" + + def __eq__(self, other) -> bool: + if isinstance(other, AwsQuantumJob): + return self.arn == other.arn + return False + + def __hash__(self) -> int: + return hash(self.arn) + + @staticmethod + def _initialize_session(session_value, device, logger): + aws_session = session_value or AwsSession() + current_region = aws_session.region + + try: + aws_session.get_device(device) + return aws_session + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + if "qpu" not in device: + raise ValueError(f"Simulator '{device}' not found in '{current_region}'") + else: + raise e + + return AwsQuantumJob._find_device_session(aws_session, device, current_region, logger) + + @staticmethod + def _find_device_session(aws_session, device, original_region, logger): + for region in frozenset(AwsDevice.REGIONS) - {original_region}: + device_session = aws_session.copy_session(region=region) + try: + device_session.get_device(device) + logger.info( + f"Changed session region from '{original_region}' to '{device_session.region}'" + ) + return device_session + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + raise e + raise ValueError(f"QPU '{device}' not found.") diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index 2602ecf9f..7b7293b24 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -391,7 +391,7 @@ def __repr__(self) -> str: def __eq__(self, other) -> bool: if isinstance(other, AwsQuantumTask): return self.id == other.id - return NotImplemented + return False def __hash__(self) -> int: return hash(self.id) diff --git a/src/braket/aws/aws_session.py b/src/braket/aws/aws_session.py index ec154245c..2d7cf3a2f 100644 --- a/src/braket/aws/aws_session.py +++ b/src/braket/aws/aws_session.py @@ -11,11 +11,18 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import annotations + +import itertools +import os import os.path +import re +from pathlib import Path from typing import Any, Dict, List, NamedTuple, Optional import backoff import boto3 +from botocore.config import Config from botocore.exceptions import ClientError import braket._schemas as braket_schemas @@ -27,22 +34,82 @@ class AwsSession(object): S3DestinationFolder = NamedTuple("S3DestinationFolder", [("bucket", str), ("key", str)]) - def __init__(self, boto_session=None, braket_client=None, config=None): + def __init__(self, boto_session=None, braket_client=None, config=None, default_bucket=None): """ Args: boto_session: A boto3 session object. braket_client: A boto3 Braket client. config: A botocore Config object. """ + if ( + boto_session + and braket_client + and boto_session.region_name != braket_client.meta.region_name + ): + raise ValueError( + "Boto Session region and Braket Client region must match and currently " + f"they do not: Boto Session region is '{boto_session.region_name}', but " + f"Braket Client region is '{braket_client.meta.region_name}'." + ) - self.boto_session = boto_session or boto3.Session() self._config = config if braket_client: + self.boto_session = boto_session or boto3.Session( + region_name=braket_client.meta.region_name + ) self.braket_client = braket_client else: + self.boto_session = boto_session or boto3.Session() self.braket_client = self.boto_session.client("braket", config=self._config) + self._update_user_agent() + self._custom_default_bucket = bool(default_bucket) + self._default_bucket = default_bucket or os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET") + + self._iam = None + self._s3 = None + self._sts = None + self._logs = None + self._ecr = None + + @property + def region(self): + return self.boto_session.region_name + + @property + def account_id(self): + return self.sts_client.get_caller_identity()["Account"] + + @property + def iam_client(self): + if not self._iam: + self._iam = self.boto_session.client("iam", region_name=self.region) + return self._iam + + @property + def s3_client(self): + if not self._s3: + self._s3 = self.boto_session.client("s3", region_name=self.region) + return self._s3 + + @property + def sts_client(self): + if not self._sts: + self._sts = self.boto_session.client("sts", region_name=self.region) + return self._sts + + @property + def logs_client(self): + if not self._logs: + self._logs = self.boto_session.client("logs", region_name=self.region) + return self._logs + + @property + def ecr_client(self): + if not self._ecr: + self._ecr = self.boto_session.client("ecr", region_name=self.region) + return self._ecr def _update_user_agent(self): """ @@ -89,9 +156,26 @@ def create_quantum_task(self, **boto3_kwargs) -> str: Returns: str: The ARN of the quantum task. """ + # Add job token to request, if available. + job_token = os.getenv("AMZN_BRAKET_JOB_TOKEN") + if job_token: + boto3_kwargs.update({"jobToken": job_token}) response = self.braket_client.create_quantum_task(**boto3_kwargs) return response["quantumTaskArn"] + def create_job(self, **boto3_kwargs) -> str: + """ + Create a quantum job. + + Args: + **boto3_kwargs: Keyword arguments for the Amazon Braket `CreateJob` operation. + + Returns: + str: The ARN of the job. + """ + response = self.braket_client.create_job(**boto3_kwargs) + return response["jobArn"] + @staticmethod def _should_giveup(err): return not ( @@ -122,6 +206,60 @@ def get_quantum_task(self, arn: str) -> Dict[str, Any]: """ return self.braket_client.get_quantum_task(quantumTaskArn=arn) + def get_default_jobs_role(self) -> str: + """ + Returns the role ARN for the default jobs role created in the Amazon Braket Console. + It will pick the first role it finds with the `RoleName` prefix + `AmazonBraketJobsExecutionRole`. + + Returns: + (str): The ARN for the default IAM role for jobs execution created in the Amazon + Braket console. + + Raises: + RuntimeError: If no roles can be found with the prefix `AmazonBraketJobsExecutionRole`. + """ + roles_paginator = self.iam_client.get_paginator("list_roles") + for page in roles_paginator.paginate(): + for role in page.get("Roles", []): + if role["RoleName"].startswith("AmazonBraketJobsExecutionRole"): + return role["Arn"] + raise RuntimeError( + "No default jobs roles found. Please create a role using the " + "Amazon Braket console or supply a custom role." + ) + + @backoff.on_exception( + backoff.expo, + ClientError, + max_tries=3, + jitter=backoff.full_jitter, + giveup=_should_giveup.__func__, + ) + def get_job(self, arn: str) -> Dict[str, Any]: + """ + Gets the quantum job. + + Args: + arn (str): The ARN of the quantum job to get. + + Returns: + Dict[str, Any]: The response from the Amazon Braket `GetQuantumJob` operation. + """ + return self.braket_client.get_job(jobArn=arn) + + def cancel_job(self, arn: str) -> Dict[str, Any]: + """ + Cancel the quantum job. + + Args: + arn (str): The ARN of the quantum job to cancel. + + Returns: + Dict[str, Any]: The response from the Amazon Braket `CancelJob` operation. + """ + return self.braket_client.cancel_job(jobArn=arn) + def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str: """ Retrieve the S3 object body. @@ -137,6 +275,251 @@ def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str: obj = s3.Object(s3_bucket, s3_object_key) return obj.get()["Body"].read().decode("utf-8") + def upload_to_s3(self, filename: str, s3_uri: str) -> None: + """ + Upload file to S3 + + Args: + filename (str): local file to be uploaded. + s3_uri (str): The S3 URI where the file will be uploaded. + + Returns: + None + """ + bucket, key = self.parse_s3_uri(s3_uri) + self.s3_client.upload_file(filename, bucket, key) + + def upload_local_data(self, local_prefix: str, s3_prefix: str): + """ + Upload local data matching a prefix to a corresponding location in S3 + + Args: + local_prefix (str): a prefix designating files to be uploaded to S3. All files + beginning with local_prefix will be uploaded. + s3_prefix (str): the corresponding S3 prefix that will replace the local prefix + when the data is uploaded. This will be an S3 URI and should include the bucket + (i.e. 's3://my-bucket/my/prefix-') + + For example, local_prefix = "input", s3_prefix = "s3://my-bucket/dir/input" will upload: + * 'input.csv' to 's3://my-bucket/dir/input.csv' + * 'input-2.csv' to 's3://my-bucket/dir/input-2.csv' + * 'input/data.txt' to 's3://my-bucket/dir/input/data.txt' + * 'input-dir/data.csv' to 's3://my-bucket/dir/input-dir/data.csv' + but will not upload: + * 'my-input.csv' + * 'my-dir/input.csv' + To match all files within the directory "input" and upload them into + "s3://my-bucket/input", provide local_prefix = "input/" and + s3_prefix = "s3://my-bucket/input/" + """ + # support absolute paths + if Path(local_prefix).is_absolute(): + base_dir = Path(Path(local_prefix).anchor) + relative_prefix = str(Path(local_prefix).relative_to(base_dir)) + else: + base_dir = Path() + relative_prefix = str(local_prefix) + for file in itertools.chain( + # files that match the prefix + base_dir.glob(f"{relative_prefix}*"), + # files inside of directories that match the prefix + base_dir.glob(f"{relative_prefix}*/**/*"), + ): + if file.is_file(): + s3_uri = str(file.as_posix()).replace(str(Path(local_prefix).as_posix()), s3_prefix) + self.upload_to_s3(str(file), s3_uri) + + def download_from_s3(self, s3_uri: str, filename: str) -> None: + """ + Download file from S3 + + Args: + s3_uri (str): The S3 uri from where the file will be downloaded. + filename (str): filename to save the file to. + + Returns: + None + """ + bucket, key = self.parse_s3_uri(s3_uri) + self.s3_client.download_file(bucket, key, filename) + + def copy_s3_object(self, source_s3_uri: str, destination_s3_uri: str) -> None: + """ + Copy object from another location in s3. Does nothing if source and + destination URIs are the same. + + Args: + source_s3_uri (str): S3 URI pointing to the object to be copied. + destination_s3_uri (str): S3 URI where the object will be copied to. + """ + if source_s3_uri == destination_s3_uri: + return + + source_bucket, source_key = self.parse_s3_uri(source_s3_uri) + destination_bucket, destination_key = self.parse_s3_uri(destination_s3_uri) + + self.s3_client.copy( + { + "Bucket": source_bucket, + "Key": source_key, + }, + destination_bucket, + destination_key, + ) + + def copy_s3_directory(self, source_s3_path: str, destination_s3_path: str) -> None: + """ + Copy all objects from a specified directory in S3. Does nothing if source and + destination URIs are the same. Preserves nesting structure, will not overwrite + other files in the destination location unless they share a name with a file + being copied. + + Args: + source_s3_path (str): S3 URI pointing to the directory to be copied. + destination_s3_path (str): S3 URI where the contents of the source_s3_path + directory will be copied to. + """ + if source_s3_path == destination_s3_path: + return + + source_bucket, source_prefix = AwsSession.parse_s3_uri(source_s3_path) + destination_bucket, destination_prefix = AwsSession.parse_s3_uri(destination_s3_path) + + source_keys = self.list_keys(source_bucket, source_prefix) + + for key in source_keys: + self.s3_client.copy( + { + "Bucket": source_bucket, + "Key": key, + }, + destination_bucket, + key.replace(source_prefix, destination_prefix, 1), + ) + + def list_keys(self, bucket: str, prefix: str) -> List[str]: + """ + Lists keys matching prefix in bucket. + + Args: + bucket (str): Bucket to be queried. + prefix (str): The S3 path prefix to be matched + + Returns: + List[str]: A list of all keys matching the prefix in + the bucket. + """ + list_objects = self.s3_client.list_objects_v2( + Bucket=bucket, + Prefix=prefix, + ) + keys = [obj["Key"] for obj in list_objects["Contents"]] + while list_objects["IsTruncated"]: + list_objects = self.s3_client.list_objects_v2( + Bucket=bucket, + Prefix=prefix, + ContinuationToken=list_objects["NextContinuationToken"], + ) + keys += [obj["Key"] for obj in list_objects["Contents"]] + return keys + + def default_bucket(self): + """ + Returns the name of the default bucket of the AWS Session. In the following order + of priority, it will return either the parameter `default_bucket` set during + initialization of the AwsSession (if not None), the bucket being used by the + currently running Braket Job (if evoked inside of a Braket Job), or a default value of + "amazon-braket--. Except in the case of a user- + specified bucket name, this method will create the default bucket if it does not + exist. + + Returns: + str: Name of the default bucket. + """ + if self._default_bucket: + return self._default_bucket + default_bucket = f"amazon-braket-{self.region}-{self.account_id}" + + self._create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=self.region) + + self._default_bucket = default_bucket + return self._default_bucket + + def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): + """Creates an S3 Bucket if it does not exist. + Also swallows a few common exceptions that indicate that the bucket already exists or + that it is being created. + + Args: + bucket_name (str): Name of the S3 bucket to be created. + region (str): The region in which to create the bucket. + + Raises: + botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket + creation. + If the exception is due to the bucket already existing or + already being created, no exception is raised. + """ + try: + if region == "us-east-1": + # 'us-east-1' cannot be specified because it is the default region: + # https://github.com/boto/boto3/issues/125 + self.s3_client.create_bucket(Bucket=bucket_name) + else: + self.s3_client.create_bucket( + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} + ) + self.s3_client.put_public_access_block( + Bucket=bucket_name, + PublicAccessBlockConfiguration={ + "BlockPublicAcls": True, + "IgnorePublicAcls": True, + "BlockPublicPolicy": True, + "RestrictPublicBuckets": True, + }, + ) + self.s3_client.put_bucket_policy( + Bucket=bucket_name, + Policy=f"""{{ + "Version": "2012-10-17", + "Statement": [ + {{ + "Effect": "Allow", + "Principal": {{ + "Service": [ + "braket.amazonaws.com" + ] + }}, + "Action": "s3:*", + "Resource": [ + "arn:aws:s3:::{bucket_name}", + "arn:aws:s3:::{bucket_name}/*" + ] + }} + ] + }}""", + ) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + + if error_code == "BucketAlreadyOwnedByYou": + pass + elif error_code == "BucketAlreadyExists": + raise ValueError( + f"Provided default bucket '{bucket_name}' already exists " + f"for another account. Please supply alternative " + f"bucket name via AwsSession constructor `AwsSession()`." + ) from None + elif ( + error_code == "OperationAborted" and "conflicting conditional operation" in message + ): + # If this bucket is already being concurrently created, we don't need to create + # it again. + pass + else: + raise + def get_device(self, arn: str) -> Dict[str, Any]: """ Calls the Amazon Braket `get_device` API to @@ -190,3 +573,159 @@ def search_devices( continue results.append(result) return results + + @staticmethod + def is_s3_uri(string: str): + try: + AwsSession.parse_s3_uri(string) + except ValueError: + return False + return True + + @staticmethod + def parse_s3_uri(s3_uri: str) -> (str, str): + """ + Parse S3 URI to get bucket and key + + Args: + s3_uri (str): S3 URI. + + Returns: + (str, str): Bucket and Key tuple. + + Raises: + ValueError: Raises a ValueError if the provided string is not + a valid S3 URI. + """ + try: + # Object URL e.g. https://my-bucket.s3.us-west-2.amazonaws.com/my/key + # S3 URI e.g. s3://my-bucket/my/key + s3_uri_match = re.match("^https://([^./]+).[sS]3.[^/]+/(.*)$", s3_uri) or re.match( + "^[sS]3://([^./]+)/(.*)$", s3_uri + ) + assert s3_uri_match + bucket, key = s3_uri_match.groups() + assert bucket and key + return bucket, key + except (AssertionError, ValueError): + raise ValueError(f"Not a valid S3 uri: {s3_uri}") + + @staticmethod + def construct_s3_uri(bucket: str, *dirs: str): + """ + Args: + bucket (str): S3 URI. + *dirs (str): directories to be appended in the resulting S3 URI + + Returns: + str: S3 URI + + Raises: + ValueError: Raises a ValueError if the provided arguments are not + valid to generate an S3 URI + """ + if not dirs: + raise ValueError(f"Not a valid S3 location: s3://{bucket}") + return f"s3://{bucket}/{'/'.join(dirs)}" + + def describe_log_streams( + self, + log_group: str, + log_stream_prefix: str, + limit: int = None, + next_token: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Describes CloudWatch log streams in a log group with a given prefix. + + Args: + log_group (str): Name of the log group. + log_stream_prefix (str): Prefix for log streams to include. + limit (int, optional): Limit for number of log streams returned. + default is 50. + next_token (optional, str): The token for the next set of items to return. + Would have been received in a previous call. + + Returns: + dict: Dicionary containing logStreams and nextToken + """ + log_stream_args = { + "logGroupName": log_group, + "logStreamNamePrefix": log_stream_prefix, + "orderBy": "LogStreamName", + } + + if limit: + log_stream_args.update({"limit": limit}) + + if next_token: + log_stream_args.update({"nextToken": next_token}) + + return self.logs_client.describe_log_streams(**log_stream_args) + + def get_log_events( + self, + log_group: str, + log_stream: str, + start_time: int, + start_from_head: bool = True, + next_token: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Gets CloudWatch log events from a given log stream. + + Args: + log_group (str): Name of the log group. + log_stream (str): Name of the log stream. + start_time (int): Timestamp that indicates a start time to include log events. + start_from_head (bool): Bool indicating to return oldest events first. default + is True. + next_token (optional, str): The token for the next set of items to return. + Would have been received in a previous call. + + Returns: + dict: Dicionary containing events, nextForwardToken, and nextBackwardToken + """ + log_events_args = { + "logGroupName": log_group, + "logStreamName": log_stream, + "startTime": start_time, + "startFromHead": start_from_head, + } + + if next_token: + log_events_args.update({"nextToken": next_token}) + + return self.logs_client.get_log_events(**log_events_args) + + def copy_session( + self, + region: Optional[str] = None, + max_connections: Optional[int] = None, + ) -> AwsSession: + """ + Creates a new AwsSession based on the region. + + Args: + region (str): Name of the region. Default = `None`. + max_connections (int): The maximum number of connections in the + Boto3 connection pool. Default = `None`. + + Returns: + AwsSession: based on the region and boto config parameters. + """ + config = Config(max_pool_connections=max_connections) if max_connections else None + session_region = self.boto_session.region_name + new_region = region or session_region + creds = self.boto_session.get_credentials() + default_bucket = self._default_bucket if self._custom_default_bucket else None + if creds.method == "explicit": + boto_session = boto3.Session( + aws_access_key_id=creds.access_key, + aws_secret_access_key=creds.secret_key, + aws_session_token=creds.token, + region_name=new_region, + ) + else: + boto_session = boto3.Session(region_name=new_region) + return AwsSession(boto_session=boto_session, config=config, default_bucket=default_bucket) diff --git a/src/braket/jobs/__init__.py b/src/braket/jobs/__init__.py new file mode 100644 index 000000000..e58ba16eb --- /dev/null +++ b/src/braket/jobs/__init__.py @@ -0,0 +1,26 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from braket.jobs.config import ( # noqa: F401 + CheckpointConfig, + InstanceConfig, + OutputDataConfig, + S3DataSourceConfig, + StoppingCondition, +) +from braket.jobs.data_persistence import ( # noqa: F401 + load_job_checkpoint, + save_job_checkpoint, + save_job_result, +) +from braket.jobs.image_uris import Framework, retrieve_image # noqa: F401 diff --git a/src/braket/jobs/config.py b/src/braket/jobs/config.py new file mode 100644 index 000000000..bbe1f2cbd --- /dev/null +++ b/src/braket/jobs/config.py @@ -0,0 +1,81 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class CheckpointConfig: + """Configuration that specifies the location where checkpoint data is stored.""" + + localPath: str = "/opt/jobs/checkpoints" + s3Uri: Optional[str] = None + + +@dataclass +class InstanceConfig: + """Configuration of the instances used to execute the job.""" + + instanceType: str = "ml.m5.large" + volumeSizeInGb: int = 30 + + +@dataclass +class OutputDataConfig: + """Configuration that specifies the location for the output of the job.""" + + s3Path: Optional[str] = None + kmsKeyId = None + + +@dataclass +class StoppingCondition: + """Conditions that specify when the job should be forcefully stopped.""" + + maxRuntimeInSeconds: int = 5 * 24 * 60 * 60 + + +@dataclass +class DeviceConfig: + device: str + + +class S3DataSourceConfig: + """ + Data source for data that lives on S3 + Attributes: + config (dict[str, dict]): config passed to the Braket API + """ + + def __init__( + self, + s3_data, + content_type=None, + ): + """Create a definition for input data used by a Braket job. + + Args: + s3_data (str): Defines the location of s3 data to train on. + content_type (str): MIME type of the input data (default: None). + """ + self.config = { + "dataSource": { + "s3DataSource": { + "s3Uri": s3_data, + } + } + } + + if content_type is not None: + self.config["contentType"] = content_type diff --git a/src/braket/jobs/data_persistence.py b/src/braket/jobs/data_persistence.py new file mode 100644 index 000000000..f969a5984 --- /dev/null +++ b/src/braket/jobs/data_persistence.py @@ -0,0 +1,136 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import os +from typing import Any, Dict + +from braket.jobs.serialization import deserialize_values, serialize_values +from braket.jobs_data import PersistedJobData, PersistedJobDataFormat + + +def save_job_checkpoint( + checkpoint_data: Dict[str, Any], + checkpoint_file_suffix: str = "", + data_format: PersistedJobDataFormat = PersistedJobDataFormat.PLAINTEXT, +) -> None: + """ + Saves the specified `checkpoint_data` to the local output directory, specified by the container + environment variable `CHECKPOINT_DIR`, with the filename + `f"{job_name}(_{checkpoint_file_suffix}).json"`. The `job_name` refers to the name of the + current job and is retrieved from the container environment variable `JOB_NAME`. The + `checkpoint_data` values are serialized to the specified `data_format`. + + Note: This function for storing the checkpoints is only for use inside the job container + as it writes data to directories and references env variables set in the containers. + + + Args: + checkpoint_data (Dict[str, Any]): Dict that specifies the checkpoint data to be persisted. + checkpoint_file_suffix (str): str that specifies the file suffix to be used for + the checkpoint filename. The resulting filename + `f"{job_name}(_{checkpoint_file_suffix}).json"` is used to save the checkpoints. + Default: "" + data_format (PersistedJobDataFormat): The data format used to serialize the + values. Note that for `PICKLED` data formats, the values are base64 encoded + after serialization. Default: PersistedJobDataFormat.PLAINTEXT + + Raises: + ValueError: If the supplied `checkpoint_data` is `None` or empty. + """ + if not checkpoint_data: + raise ValueError("The checkpoint_data argument cannot be empty.") + checkpoint_directory = os.environ["AMZN_BRAKET_CHECKPOINT_DIR"] + job_name = os.environ["AMZN_BRAKET_JOB_NAME"] + checkpoint_file_path = ( + f"{checkpoint_directory}/{job_name}_{checkpoint_file_suffix}.json" + if checkpoint_file_suffix + else f"{checkpoint_directory}/{job_name}.json" + ) + with open(checkpoint_file_path, "w") as f: + serialized_data = serialize_values(checkpoint_data or {}, data_format) + persisted_data = PersistedJobData(dataDictionary=serialized_data, dataFormat=data_format) + f.write(persisted_data.json()) + + +def load_job_checkpoint(job_name: str, checkpoint_file_suffix: str = "") -> Dict[str, Any]: + """ + Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint + file that ends with the `checkpoint_file_suffix`. The `job_name` can refer to any job whose + checkpoint data you expect to be available in the file path specified by the `CHECKPOINT_DIR` + container environment variable. + + Note: This function for loading job checkpoints is only for use inside the job container + as it writes data to directories and references env variables set in the containers. + + + Args: + job_name (str): str that specifies the name of the job whose checkpoints + are to be loaded. + checkpoint_file_suffix (str): str specifying the file suffix that is used to + locate the checkpoint file to load. The resulting file name + `f"{job_name}(_{checkpoint_file_suffix}).json"` is used to locate the + checkpoint file. Default: "" + + Returns: + Dict[str, Any]: Dict that contains the checkpoint data persisted in the checkpoint file. + + Raises: + FileNotFoundError: If the file `f"{job_name}(_{checkpoint_file_suffix})"` could not be found + in the directory specified by the container environment variable `CHECKPOINT_DIR`. + ValueError: If the data stored in the checkpoint file can't be deserialized (possibly due to + corruption). + """ + checkpoint_directory = os.environ["AMZN_BRAKET_CHECKPOINT_DIR"] + checkpoint_file_path = ( + f"{checkpoint_directory}/{job_name}_{checkpoint_file_suffix}.json" + if checkpoint_file_suffix + else f"{checkpoint_directory}/{job_name}.json" + ) + with open(checkpoint_file_path, "r") as f: + persisted_data = PersistedJobData.parse_raw(f.read()) + deserialized_data = deserialize_values( + persisted_data.dataDictionary, persisted_data.dataFormat + ) + return deserialized_data + + +def save_job_result( + result_data: Dict[str, Any], + data_format: PersistedJobDataFormat = PersistedJobDataFormat.PLAINTEXT, +) -> None: + """ + Saves the `result_data` to the local output directory that is specified by the container + environment variable `OUTPUT_DIR`, with the filename 'results.json'. The `result_data` + values are serialized to the specified `data_format`. + + Note: This function for storing the results is only for use inside the job container + as it writes data to directories and references env variables set in the containers. + + + Args: + result_data (Dict[str, Any]): Dict that specifies the result data to be persisted. + data_format (PersistedJobDataFormat): The data format used to serialize the + values. Note that for `PICKLED` data formats, the values are base64 encoded + after serialization. Default: PersistedJobDataFormat.PLAINTEXT. + + Raises: + ValueError: If the supplied `result_data` is `None` or empty. + """ + if not result_data: + raise ValueError("The result_data argument cannot be empty.") + result_directory = os.environ["AMZN_BRAKET_JOB_RESULTS_DIR"] + result_path = f"{result_directory}/results.json" + with open(result_path, "w") as f: + serialized_data = serialize_values(result_data or {}, data_format) + persisted_data = PersistedJobData(dataDictionary=serialized_data, dataFormat=data_format) + f.write(persisted_data.json()) diff --git a/src/braket/jobs/image_uri_config/base.json b/src/braket/jobs/image_uri_config/base.json new file mode 100644 index 000000000..b941f5dc9 --- /dev/null +++ b/src/braket/jobs/image_uri_config/base.json @@ -0,0 +1,12 @@ +{ + "versions": { + "1.0": { + "registries": { + "us-east-1": "292282985366", + "us-west-1": "292282985366", + "us-west-2": "292282985366" + }, + "repository": "amazon-braket-base-jobs" + } + } +} diff --git a/src/braket/jobs/image_uri_config/pl_pytorch.json b/src/braket/jobs/image_uri_config/pl_pytorch.json new file mode 100644 index 000000000..d5b0943fa --- /dev/null +++ b/src/braket/jobs/image_uri_config/pl_pytorch.json @@ -0,0 +1,12 @@ +{ + "versions": { + "1.8.1": { + "registries": { + "us-east-1": "292282985366", + "us-west-1": "292282985366", + "us-west-2": "292282985366" + }, + "repository": "amazon-braket-pytorch-jobs" + } + } +} diff --git a/src/braket/jobs/image_uri_config/pl_tensorflow.json b/src/braket/jobs/image_uri_config/pl_tensorflow.json new file mode 100644 index 000000000..758b2107e --- /dev/null +++ b/src/braket/jobs/image_uri_config/pl_tensorflow.json @@ -0,0 +1,12 @@ +{ + "versions": { + "2.4.1": { + "registries": { + "us-east-1": "292282985366", + "us-west-1": "292282985366", + "us-west-2": "292282985366" + }, + "repository": "amazon-braket-tensorflow-jobs" + } + } +} diff --git a/src/braket/jobs/image_uris.py b/src/braket/jobs/image_uris.py new file mode 100644 index 000000000..0a9f27bbd --- /dev/null +++ b/src/braket/jobs/image_uris.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +import os +from enum import Enum +from typing import Dict + + +class Framework(str, Enum): + """Supported Frameworks for pre-built containers""" + + BASE = "BASE" + PL_TENSORFLOW = "PL_TENSORFLOW" + PL_PYTORCH = "PL_PYTORCH" + + +def retrieve_image(framework: Framework, region: str): + """Retrieves the ECR URI for the Docker image matching the specified arguments. + + Args: + framework (str): The name of the framework. + region (str): The AWS region for the Docker image. + + Returns: + str: The ECR URI for the corresponding Amazon Braket Docker image. + + Raises: + ValueError: If any of the supplied values are invalid or the combination of inputs + specified is not supported. + """ + # Validate framework + framework = Framework(framework) + config = _config_for_framework(framework) + framework_version = max(version for version in config["versions"]) + version_config = config["versions"][framework_version] + registry = _registry_for_region(version_config, region) + tag = f"{version_config['repository']}:{framework_version}-cpu-py37-ubuntu18.04" + return f"{registry}.dkr.ecr.{region}.amazonaws.com/{tag}" + + +def _config_for_framework(framework: Framework) -> Dict[str, str]: + """Loads the JSON config for the given framework. + + Args: + framework (Framework): The framework whose config needs to be loaded. + + Returns: + Dict[str, str]: Dict that contains the configuration for the specified framework. + """ + fname = os.path.join(os.path.dirname(__file__), "image_uri_config", f"{framework.lower()}.json") + with open(fname) as f: + return json.load(f) + + +def _registry_for_region(config: Dict[str, str], region: str) -> str: + """Retrieves the registry for the specified region from the configuration. + + Args: + config (Dict[str, str]): Dict containing the framework configuration. + region (str): str that specifies the region for which the registry is retrieved. + + Returns: + str: str that specifies the registry for the supplied region. + + Raises: + ValueError: If the supplied region is invalid or not supported. + """ + registry_config = config["registries"] + if region not in registry_config: + raise ValueError( + f"Unsupported region: {region}. You may need to upgrade your SDK version for newer " + f"regions. Supported region(s): {list(registry_config.keys())}" + ) + return registry_config[region] diff --git a/src/braket/jobs/local/__init__.py b/src/braket/jobs/local/__init__.py new file mode 100644 index 000000000..6fe353f7c --- /dev/null +++ b/src/braket/jobs/local/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from braket.jobs.local.local_job import LocalQuantumJob # noqa: F401 diff --git a/src/braket/jobs/local/local_job.py b/src/braket/jobs/local/local_job.py new file mode 100644 index 000000000..8b7394261 --- /dev/null +++ b/src/braket/jobs/local/local_job.py @@ -0,0 +1,258 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +import os +import time +from typing import Any, Dict, List, Union + +from braket.aws.aws_session import AwsSession +from braket.jobs.config import CheckpointConfig, OutputDataConfig, S3DataSourceConfig +from braket.jobs.image_uris import Framework, retrieve_image +from braket.jobs.local.local_job_container import _LocalJobContainer +from braket.jobs.local.local_job_container_setup import setup_container +from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType +from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser +from braket.jobs.quantum_job import QuantumJob +from braket.jobs.quantum_job_creation import prepare_quantum_job +from braket.jobs.serialization import deserialize_values +from braket.jobs_data import PersistedJobData + + +class LocalQuantumJob(QuantumJob): + """Amazon Braket implementation of a quantum job that runs locally.""" + + @classmethod + def create( + cls, + device: str, + source_module: str, + entry_point: str = None, + image_uri: str = None, + job_name: str = None, + code_location: str = None, + role_arn: str = None, + hyperparameters: Dict[str, Any] = None, + input_data: Union[str, Dict, S3DataSourceConfig] = None, + output_data_config: OutputDataConfig = None, + checkpoint_config: CheckpointConfig = None, + aws_session: AwsSession = None, + ) -> LocalQuantumJob: + """Creates and runs job by setting up and running the customer script in a local + docker container. + + Args: + device (str): ARN for the AWS device which is primarily + accessed for the execution of this job. + + source_module (str): Path (absolute, relative or an S3 URI) to a python module to be + tarred and uploaded. If `source_module` is an S3 URI, it must point to a + tar.gz file. Otherwise, source_module may be a file or directory. + + entry_point (str): A str that specifies the entry point of the job, relative to + the source module. The entry point must be in the format + `importable.module` or `importable.module:callable`. For example, + `source_module.submodule:start_here` indicates the `start_here` function + contained in `source_module.submodule`. If source_module is an S3 URI, + entry point must be given. Default: source_module's name + + image_uri (str): A str that specifies the ECR image to use for executing the job. + `image_uris.retrieve_image()` function may be used for retrieving the ECR image URIs + for the containers supported by Braket. Default = ``. + + job_name (str): A str that specifies the name with which the job is created. + Default: f'{image_uri_type}-{timestamp}'. + + code_location (str): The S3 prefix URI where custom code will be uploaded. + Default: f's3://{default_bucket_name}/jobs/{job_name}/script'. + + role_arn (str): This field is currently not used for local jobs. Local jobs will use + the current role's credentials. This may be subject to change. + + hyperparameters (Dict[str, Any]): Hyperparameters accessible to the job. + The hyperparameters are made accessible as a Dict[str, str] to the job. + For convenience, this accepts other types for keys and values, but `str()` + is called to convert them before being passed on. Default: None. + + input_data (Union[str, S3DataSourceConfig, dict]): Information about the training + data. Dictionary maps channel names to local paths or S3 URIs. Contents found + at any local paths will be uploaded to S3 at + f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local + path, S3 URI, or S3DataSourceConfig is provided, it will be given a default + channel name "input". + Default: {}. + + output_data_config (OutputDataConfig): Specifies the location for the output of the job. + Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', + kmsKeyId=None). + + checkpoint_config (CheckpointConfig): Configuration that specifies the location where + checkpoint data is stored. + Default: CheckpointConfig(localPath='/opt/jobs/checkpoints', + s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints'). + + aws_session (AwsSession): AwsSession for connecting to AWS Services. + Default: AwsSession() + + Returns: + LocalQuantumJob: The representation of a local Braket Job. + """ + create_job_kwargs = prepare_quantum_job( + device=device, + source_module=source_module, + entry_point=entry_point, + image_uri=image_uri, + job_name=job_name, + code_location=code_location, + role_arn=role_arn, + hyperparameters=hyperparameters, + input_data=input_data, + output_data_config=output_data_config, + checkpoint_config=checkpoint_config, + aws_session=aws_session, + ) + + job_name = create_job_kwargs["jobName"] + if os.path.isdir(job_name): + raise ValueError( + f"A local directory called {job_name} already exists. " + f"Please use a different job name." + ) + + session = aws_session or AwsSession() + algorithm_specification = create_job_kwargs["algorithmSpecification"] + if "containerImage" in algorithm_specification: + image_uri = algorithm_specification["containerImage"]["uri"] + else: + image_uri = retrieve_image(Framework.BASE, session.region) + + with _LocalJobContainer(image_uri) as container: + env_variables = setup_container(container, session, **create_job_kwargs) + container.run_local_job(env_variables) + container.copy_from("/opt/ml/model", job_name) + with open(os.path.join(job_name, "log.txt"), "w") as log_file: + log_file.write(container.run_log) + if "checkpointConfig" in create_job_kwargs: + checkpoint_config = create_job_kwargs["checkpointConfig"] + if "localPath" in checkpoint_config: + checkpoint_path = checkpoint_config["localPath"] + container.copy_from(checkpoint_path, os.path.join(job_name, "checkpoints")) + run_log = container.run_log + return LocalQuantumJob(f"local:job/{job_name}", run_log) + + def __init__(self, arn: str, run_log: str = None): + """ + Args: + arn (str): The ARN of the job. + run_log (str, Optional): The container output log of running the job with the given arn. + """ + if not arn.startswith("local:job/"): + raise ValueError(f"Arn {arn} is not a valid local job arn") + self._arn = arn + self._run_log = run_log + self._name = arn.partition("job/")[-1] + if not run_log and not os.path.isdir(self.name): + raise ValueError(f"Unable to find local job results for {self.name}") + + @property + def arn(self) -> str: + """str: The ARN (Amazon Resource Name) of the quantum job.""" + return self._arn + + @property + def name(self) -> str: + """str: The name of the quantum job.""" + return self._name + + @property + def run_log(self) -> str: + """str: The container output log from running the job.""" + if not self._run_log: + try: + with open(os.path.join(self.name, "log.txt"), "r") as log_file: + self._run_log = log_file.read() + except FileNotFoundError: + raise ValueError(f"Unable to find logs in the local job directory {self.name}.") + return self._run_log + + def state(self, use_cached_value: bool = False) -> str: + """The state of the quantum job.""" + return "COMPLETED" + + def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: + """When running the quantum job in local mode, the metadata is not available.""" + pass + + def cancel(self) -> str: + """When running the quantum job in local mode, the cancelling a running is not possible.""" + pass + + def download_result( + self, + extract_to=None, + poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT, + poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL, + ) -> None: + """When running the quantum job in local mode, results are automatically stored locally.""" + pass + + def result( + self, + poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT, + poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL, + ) -> Dict[str, Any]: + """Retrieves the job result persisted using save_job_result() function.""" + try: + with open(os.path.join(self.name, "results.json"), "r") as f: + persisted_data = PersistedJobData.parse_raw(f.read()) + deserialized_data = deserialize_values( + persisted_data.dataDictionary, persisted_data.dataFormat + ) + return deserialized_data + except FileNotFoundError: + raise ValueError(f"Unable to find results in the local job directory {self.name}.") + + def metrics( + self, + metric_type: MetricType = MetricType.TIMESTAMP, + statistic: MetricStatistic = MetricStatistic.MAX, + ) -> Dict[str, List[Any]]: + """Gets all the metrics data, where the keys are the column names, and the values are a list + containing the values in each row. For example, the table: + timestamp energy + 0 0.1 + 1 0.2 + would be represented as: + { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } + values may be integers, floats, strings or None. + + Args: + metric_type (MetricType): The type of metrics to get. Default: MetricType.TIMESTAMP. + + statistic (MetricStatistic): The statistic to determine which metric value to use + when there is a conflict. Default: MetricStatistic.MAX. + + Returns: + Dict[str, List[Union[str, float, int]]] : The metrics data. + """ + parser = LogMetricsParser() + current_time = str(time.time()) + for line in self.run_log.splitlines(): + if line.startswith("Metrics -"): + parser.parse_log_message(current_time, line) + return parser.get_parsed_metrics(metric_type, statistic) + + def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: + """Display container logs for a given job""" + return print(self.run_log) diff --git a/src/braket/jobs/local/local_job_container.py b/src/braket/jobs/local/local_job_container.py new file mode 100644 index 000000000..fc76e8d2b --- /dev/null +++ b/src/braket/jobs/local/local_job_container.py @@ -0,0 +1,245 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import base64 +import re +import subprocess +from logging import Logger, getLogger +from pathlib import Path +from typing import Dict, List + +from braket.aws.aws_session import AwsSession + + +class _LocalJobContainer(object): + """Uses docker CLI to run Braket Jobs on a local docker container.""" + + ECR_URI_PATTERN = r"^((\d+)\.dkr\.ecr\.([^.]+)\.[^/]*)/([^:]*):(.*)$" + CONTAINER_CODE_PATH = "/opt/ml/code/" + + def __init__( + self, image_uri: str, aws_session: AwsSession = None, logger: Logger = getLogger(__name__) + ): + """Represents and provides functions for interacting with a Braket Jobs docker container. + + The function "end_session" must be called when the container is no longer needed. + Args: + image_uri (str): The URI of the container image to run. + aws_session (AwsSession, Optional): AwsSession for connecting to AWS Services. + Default: AwsSession() + logger (Logger): Logger object with which to write logs. + Default: `getLogger(__name__)` + """ + self._aws_session = aws_session or AwsSession() + self.image_uri = image_uri + self.run_log = None + self._container_name = None + self._logger = logger + + def __enter__(self): + """Creates and starts the local docker container.""" + self._container_name = self._start_container(self.image_uri) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stops and removes the local docker container.""" + self._end_session() + + @staticmethod + def _envs_to_list(environment_variables: Dict[str, str]) -> List[str]: + """Converts a dictionary environment variables to a list of parameters that can be + passed to the container exec/run commands to ensure those env variables are available + in the container. + + Args: + environment_variables (Dict[str, str]): A dictionary of environment variables and + their values. + Returns: + List[str]: The list of parameters to use when running a job that will include the + provided environment variables as part of the runtime. + """ + env_list = [] + for key in environment_variables: + env_list.append("-e") + env_list.append(f"{key}={environment_variables[key]}") + return env_list + + @staticmethod + def _check_output_formatted(command: List[str]) -> str: + """This is a wrapper around the subprocess.check_output command that decodes the output + to UTF-8 encoding. + + Args: + command(List[str]): The command to run. + + Returns: + (str): The UTF-8 encoded output of running the command. + """ + output = subprocess.check_output(command) + return output.decode("utf-8").strip() + + def _login_to_ecr(self, account_id: str, ecr_url: str) -> None: + """Logs in docker to an ECR repository using the client AWS credentials. + + Args: + account_id(str): The customer account ID. + ecr_url(str): The URL of the ECR repo to log into. + """ + ecr_client = self._aws_session.ecr_client + authorization_data_result = ecr_client.get_authorization_token(registryIds=[account_id]) + if not authorization_data_result: + raise ValueError( + "Unable to get permissions to access to log in to docker. " + "Please pull down the container before proceeding." + ) + authorization_data = authorization_data_result["authorizationData"][0] + raw_token = base64.b64decode(authorization_data["authorizationToken"]) + token = raw_token.decode("utf-8").strip("AWS:") + subprocess.run(["docker", "login", "-u", "AWS", "-p", token, ecr_url]) + + def _pull_image(self, image_uri: str) -> None: + """Pulls an image from ECR. + + Args: + image_uri(str): The URI of the ECR image to pull. + """ + ecr_pattern = re.compile(self.ECR_URI_PATTERN) + ecr_pattern_match = ecr_pattern.match(image_uri) + if not ecr_pattern_match: + raise ValueError( + f"The URL {image_uri} is not available locally and does not seem to " + f"be a valid AWS ECR URL." + "Please pull down the container, or specify a valid ECR URL, " + "before proceeding." + ) + ecr_url = ecr_pattern_match.group(1) + account_id = ecr_pattern_match.group(2) + self._login_to_ecr(account_id, ecr_url) + self._logger.warning("Pulling docker container image. This may take a while.") + subprocess.run(["docker", "pull", image_uri]) + + def _start_container(self, image_uri: str) -> str: + """Runs a docker container in a busy loop so that it will accept further commands. The + call to this function must be matched with end_session to stop the container. + + Args: + image_uri(str): The URI of the ECR image to run. + + Returns: + (str): The name of the running container, which can be used to execute further commands. + """ + image_name = self._check_output_formatted(["docker", "images", "-q", image_uri]) + if not image_name: + self._pull_image(image_uri) + image_name = self._check_output_formatted(["docker", "images", "-q", image_uri]) + if not image_name: + raise ValueError( + f"The URL {image_uri} is not available locally and can not be pulled from ECR." + " Please pull down the container before proceeding." + ) + return self._check_output_formatted( + ["docker", "run", "-d", "--rm", image_name, "tail", "-f", "/dev/null"] + ) + + def makedir(self, dir_path: str) -> None: + """Creates a directory path in the container. + + Args: + dir_path(str): The directory path to create. + + Raises: + subprocess.CalledProcessError: If unable to make the directory. + """ + try: + subprocess.check_output( + ["docker", "exec", self._container_name, "mkdir", "-p", dir_path] + ) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8").strip() + self._logger.error(output) + raise e + + def copy_to(self, source: str, destination: str) -> None: + """Copies a local file or directory to the container. + + Args: + source(str): The local file or directory to copy. + destination(str): The path to the file or directory where the source should be copied. + + Raises: + subprocess.CalledProcessError: If unable to copy. + """ + dirname = str(Path(destination).parent) + try: + subprocess.check_output( + ["docker", "exec", self._container_name, "mkdir", "-p", dirname] + ) + subprocess.check_output( + ["docker", "cp", source, f"{self._container_name}:{destination}"] + ) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8").strip() + self._logger.error(output) + raise e + + def copy_from(self, source: str, destination: str) -> None: + """Copies a file or directory from the container locally. + + Args: + source(str): The container file or directory to copy. + destination(str): The path to the file or directory where the source should be copied. + + Raises: + subprocess.CalledProcessError: If unable to copy. + """ + try: + subprocess.check_output( + ["docker", "cp", f"{self._container_name}:{source}", destination] + ) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8").strip() + self._logger.error(output) + raise e + + def run_local_job(self, environment_variables: Dict[str, str]) -> None: + """Runs a Braket job in a local container. + + Args: + environment_variables (Dict[str, str]): The environment variables to make available + as part of running the job. + """ + start_program_name = self._check_output_formatted( + ["docker", "exec", self._container_name, "printenv", "SAGEMAKER_PROGRAM"] + ) + if not start_program_name: + raise ValueError( + "Start program not found. " + "The specified container is not setup to run Braket Jobs. " + "Please see setup instructions for creating your own containers." + ) + + command = ["docker", "exec", "-w", self.CONTAINER_CODE_PATH] + command.extend(self._envs_to_list(environment_variables)) + command.append(self._container_name) + command.append("python") + command.append(start_program_name) + + try: + self.run_log = self._check_output_formatted(command) + print(self.run_log) + except subprocess.CalledProcessError as e: + self.run_log = e.output.decode("utf-8").strip() + self._logger.error(self.run_log) + + def _end_session(self): + """Stops and removes the local container.""" + subprocess.run(["docker", "stop", self._container_name]) diff --git a/src/braket/jobs/local/local_job_container_setup.py b/src/braket/jobs/local/local_job_container_setup.py new file mode 100644 index 000000000..0f7eac021 --- /dev/null +++ b/src/braket/jobs/local/local_job_container_setup.py @@ -0,0 +1,274 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +import tempfile +from logging import Logger, getLogger +from pathlib import Path +from typing import Any, Dict, Iterable + +from braket.aws.aws_session import AwsSession +from braket.jobs.local.local_job_container import _LocalJobContainer + + +def setup_container( + container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs +) -> Dict[str, str]: + """Sets up a container with prerequisites for running a Braket Job. The prerequisites are + based on the options the customer has chosen for the job. Similarly, any environment variables + that are needed during runtime will be returned by this function. + + Args: + container(_LocalJobContainer): The container that will run the braket job. + aws_session (AwsSession): AwsSession for connecting to AWS Services. + **creation_kwargs: Keyword arguments for the boto3 Amazon Braket `CreateJob` operation. + + Returns: + (Dict[str, str]): A dictionary of environment variables that reflect Braket Jobs options + requested by the customer. + """ + logger = getLogger(__name__) + _create_expected_paths(container, **creation_kwargs) + run_environment_variables = {} + run_environment_variables.update(_get_env_credentials(aws_session, logger)) + run_environment_variables.update( + _get_env_script_mode_config(creation_kwargs["algorithmSpecification"]["scriptModeConfig"]) + ) + run_environment_variables.update(_get_env_additional_lib()) + run_environment_variables.update(_get_env_default_vars(aws_session, **creation_kwargs)) + if _copy_hyperparameters(container, **creation_kwargs): + run_environment_variables.update(_get_env_hyperparameters()) + if _copy_input_data_list(container, aws_session, **creation_kwargs): + run_environment_variables.update(_get_env_input_data()) + return run_environment_variables + + +def _create_expected_paths(container: _LocalJobContainer, **creation_kwargs) -> None: + """Creates the basic paths required for Braket Jobs to run. + + Args: + container(_LocalJobContainer): The container that will run the braket job. + **creation_kwargs: Keyword arguments for the boto3 Amazon Braket `CreateJob` operation. + """ + container.makedir("/opt/ml/model") + container.makedir(creation_kwargs["checkpointConfig"]["localPath"]) + + +def _get_env_credentials(aws_session: AwsSession, logger: Logger) -> Dict[str, str]: + """Gets the account credentials from boto so they can be added as environment variables to + the running container. + + Args: + aws_session (AwsSession): AwsSession for connecting to AWS Services. + logger (Logger): Logger object with which to write logs. Default is `getLogger(__name__)` + + Returns: + (Dict[str, str]): The set of key/value pairs that should be added as environment variables + to the running container. + """ + credentials = aws_session.boto_session.get_credentials() + if credentials.token is None: + logger.info("Using the long-lived AWS credentials found in session") + return { + "AWS_ACCESS_KEY_ID": str(credentials.access_key), + "AWS_SECRET_ACCESS_KEY": str(credentials.secret_key), + } + logger.warning( + "Using the short-lived AWS credentials found in session. They might expire while running." + ) + return { + "AWS_ACCESS_KEY_ID": str(credentials.access_key), + "AWS_SECRET_ACCESS_KEY": str(credentials.secret_key), + "AWS_SESSION_TOKEN": str(credentials.token), + } + + +def _get_env_script_mode_config(script_mode_config: Dict[str, str]) -> Dict[str, str]: + """Gets the environment variables related to the customer script mode config. + + Args: + script_mode_config (Dict[str, str]): The values for scriptModeConfig in the boto3 input + parameters for running a Braket Job. + + Returns: + (Dict[str, str]): The set of key/value pairs that should be added as environment variables + to the running container. + """ + result = { + "AMZN_BRAKET_SCRIPT_S3_URI": script_mode_config["s3Uri"], + "AMZN_BRAKET_SCRIPT_ENTRY_POINT": script_mode_config["entryPoint"], + } + if "compressionType" in script_mode_config: + result["AMZN_BRAKET_SCRIPT_COMPRESSION_TYPE"] = script_mode_config["compressionType"] + return result + + +def _get_env_additional_lib() -> Dict[str, str]: + """For preview, we have some libraries that are not available publicly (yet). The container + will install these libraries if we set this env variable. + + Returns: + (Dict[str, str]): The set of key/value pairs that should be added as environment variables + to the running container. + """ + return { + "AMZN_BRAKET_IMAGE_SETUP_SCRIPT": "s3://amazon-braket-external-assets-preview-us-west-2/" + "HybridJobsAccess/scripts/setup-container.sh", + } + + +def _get_env_default_vars(aws_session: AwsSession, **creation_kwargs) -> Dict[str, str]: + """This function gets the remaining 'simple' env variables, that don't require any + additional logic to determine what they are or when they should be added as env variables. + + Returns: + (Dict[str, str]): The set of key/value pairs that should be added as environment variables + to the running container. + """ + job_name = creation_kwargs["jobName"] + bucket, location = AwsSession.parse_s3_uri(creation_kwargs["outputDataConfig"]["s3Path"]) + return { + "AWS_DEFAULT_REGION": aws_session.region, + "AMZN_BRAKET_JOB_NAME": job_name, + "AMZN_BRAKET_DEVICE_ARN": creation_kwargs["deviceConfig"]["device"], + "AMZN_BRAKET_JOB_RESULTS_DIR": "/opt/braket/model", + "AMZN_BRAKET_CHECKPOINT_DIR": creation_kwargs["checkpointConfig"]["localPath"], + "AMZN_BRAKET_OUT_S3_BUCKET": bucket, + "AMZN_BRAKET_TASK_RESULTS_S3_URI": f"s3://{bucket}/jobs/{job_name}/tasks", + "AMZN_BRAKET_JOB_RESULTS_S3_PATH": str(Path(location, job_name, "output").as_posix()), + } + + +def _get_env_hyperparameters() -> Dict[str, str]: + """Gets the env variable for hyperparameters. This should only be added if the customer has + provided hyperpameters to the job. + + Returns: + (Dict[str, str]): The set of key/value pairs that should be added as environment variables + to the running container. + """ + return { + "AMZN_BRAKET_HP_FILE": "/opt/braket/input/config/hyperparameters.json", + } + + +def _get_env_input_data() -> Dict[str, str]: + """Gets the env variable for input data. This should only be added if the customer has + provided input data to the job. + + Returns: + (Dict[str, str]): The set of key/value pairs that should be added as environment variables + to the running container. + """ + return { + "AMZN_BRAKET_INPUT_DIR": "/opt/braket/input/data", + } + + +def _copy_hyperparameters(container: _LocalJobContainer, **creation_kwargs) -> bool: + """If hyperpameters are present, this function will store them as a JSON object in the + container in the appropriate location on disk. + + Args: + container(_LocalJobContainer): The container to save hyperparameters to. + **creation_kwargs: Keyword arguments for the boto3 Amazon Braket `CreateJob` operation. + + Returns: + (bool): True if any hyperparameters were copied to the container. + """ + if "hyperParameters" not in creation_kwargs: + return False + hyperparameters = creation_kwargs["hyperParameters"] + with tempfile.TemporaryDirectory() as temp_dir: + file_path = Path(temp_dir, "hyperparameters.json") + with open(file_path, "w") as write_file: + json.dump(hyperparameters, write_file) + container.copy_to(str(file_path), "/opt/ml/input/config/hyperparameters.json") + return True + + +def _download_input_data( + aws_session: AwsSession, + download_dir: str, + input_data: Dict[str, Any], +) -> None: + """Downloads input data for a job. + + Args: + aws_session (AwsSession): AwsSession for connecting to AWS Services. + download_dir (str): The directory path to download to. + input_data (Dict[str, Any]): One of the input data in the boto3 input parameters for + running a Braket Job. + """ + # If s3 prefix is the full name of a directory and all keys are inside + # that directory, the contents of said directory will be copied into a + # directory with the same name as the channel. This behavior is the same + # whether or not s3 prefix ends with a "/". Moreover, if s3 prefix ends + # with a "/", this is certainly the behavior to expect, since it can only + # match a directory. + # If s3 prefix matches any files exactly, or matches as a prefix of any + # files or directories, then all files and directories matching s3 prefix + # will be copied into a directory with the same name as the channel. + channel_name = input_data["channelName"] + s3_uri_prefix = input_data["dataSource"]["s3DataSource"]["s3Uri"] + bucket, prefix = AwsSession.parse_s3_uri(s3_uri_prefix) + s3_keys = aws_session.list_keys(bucket, prefix) + top_level = prefix if _is_dir(prefix, s3_keys) else str(Path(prefix).parent) + found_item = False + try: + Path(download_dir, channel_name).mkdir() + except FileExistsError: + raise ValueError(f"Duplicate channel names not allowed for input data: {channel_name}") + for s3_key in s3_keys: + relative_key = Path(s3_key).relative_to(top_level) + download_path = Path(download_dir, channel_name, relative_key) + if not s3_key.endswith("/"): + download_path.parent.mkdir(parents=True, exist_ok=True) + aws_session.download_from_s3( + AwsSession.construct_s3_uri(bucket, s3_key), str(download_path) + ) + found_item = True + if not found_item: + raise RuntimeError(f"No data found for channel '{channel_name}'") + + +def _is_dir(prefix: str, keys: Iterable[str]) -> bool: + """determine whether the prefix refers to a directory""" + if prefix.endswith("/"): + return True + return all(key.startswith(f"{prefix}/") for key in keys) + + +def _copy_input_data_list( + container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs +) -> bool: + """If the input data list is not empty, this function will download the input files and + store them in the container. + + Args: + container(_LocalJobContainer): The container to save input data to. + aws_session (AwsSession): AwsSession for connecting to AWS Services. + **creation_kwargs: Keyword arguments for the boto3 Amazon Braket `CreateJob` operation. + + Returns: + (bool): True if any input data was copied to the container. + """ + if "inputDataConfig" not in creation_kwargs: + return False + + input_data_list = creation_kwargs["inputDataConfig"] + with tempfile.TemporaryDirectory() as temp_dir: + for input_data in input_data_list: + _download_input_data(aws_session, temp_dir, input_data) + container.copy_to(temp_dir, "/opt/ml/input/data/") + return bool(input_data_list) diff --git a/src/braket/jobs/logs.py b/src/braket/jobs/logs.py new file mode 100644 index 000000000..7f5ff65df --- /dev/null +++ b/src/braket/jobs/logs.py @@ -0,0 +1,231 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import collections +import os +import sys + +############################################################################## +# +# Support for reading logs +# +############################################################################## +from typing import Dict, List + +from botocore.exceptions import ClientError + + +class ColorWrap(object): + """A callable that prints text in a different color depending on the instance. + Up to 5 if the standard output is a terminal or a Jupyter notebook cell. + """ + + # For what color each number represents, see + # https://misc.flogisoft.com/bash/tip_colors_and_formatting#colors + _stream_colors = [34, 35, 32, 36, 33] + + def __init__(self, force=False): + """Initialize the class. + + Args: + force (bool): If True, the render output is colorized wherever the + output is. Default: False. + """ + self.colorize = force or sys.stdout.isatty() or os.environ.get("JPY_PARENT_PID", None) + + def __call__(self, index, s): + """Prints the string, colorized or not, depending on the environment. + + Args: + index (int): The instance number. + s (str): The string to print. + """ + if self.colorize: + self._color_wrap(index, s) + else: + print(s) + + def _color_wrap(self, index, s): + """Prints the string in a color determined by the index. + + Args: + index (int): The instance number. + s (str): The string to print (color-wrapped). + """ + print(f"\x1b[{self._stream_colors[index % len(self._stream_colors)]}m{s}\x1b[0m") + + +# Position is a tuple that includes the last read timestamp and the number of items that were read +# at that time. This is used to figure out which event to start with on the next read. +Position = collections.namedtuple("Position", ["timestamp", "skip"]) + + +def multi_stream_iter(aws_session, log_group, streams, positions): + """Iterates over the available events coming from a set of log streams. + Log streams are in a single log group interleaving the events from each stream, + so they yield in timestamp order. + + Args: + aws_session (AwsSession): The AwsSession for interfacing with CloudWatch. + + log_group (str): The name of the log group. + + streams (list of str): A list of the log stream names. The the stream number is + the position of the stream in this list. + + positions: (list of Positions): A list of (timestamp, skip) pairs which represent + the last record read from each stream. + + Yields: + A tuple of (stream number, cloudwatch log event). + """ + event_iters = [ + log_stream(aws_session, log_group, s, positions[s].timestamp, positions[s].skip) + for s in streams + ] + events = [] + for s in event_iters: + try: + events.append(next(s)) + except StopIteration: + events.append(None) + + while any(events): + i = events.index(min(events, key=lambda x: x["timestamp"] if x else float("inf"))) + yield i, events[i] + try: + events[i] = next(event_iters[i]) + except StopIteration: + events[i] = None + + +def log_stream(aws_session, log_group, stream_name, start_time=0, skip=0): + """A generator for log items in a single stream. + This yields all the items that are available at the current moment. + + Args: + aws_session (AwsSession): The AwsSession for interfacing with CloudWatch. + + log_group (str): The name of the log group. + + stream_name (str): The name of the specific stream. + + start_time (int): The time stamp value to start reading the logs from. Default: 0. + + skip (int): The number of log entries to skip at the start. Default: 0 (This is for + when there are multiple entries at the same timestamp.) + + Yields: + Dict: A CloudWatch log event with the following key-value pairs: + 'timestamp' (int): The time of the event. + 'message' (str): The log event data. + 'ingestionTime' (int): The time the event was ingested. + """ + + next_token = None + + event_count = 1 + while event_count > 0: + response = aws_session.get_log_events( + log_group, + stream_name, + start_time, + start_from_head=True, + next_token=next_token, + ) + next_token = response["nextForwardToken"] + events = response["events"] + event_count = len(events) + if event_count > skip: + events = events[skip:] + skip = 0 + else: + skip = skip - event_count + events = [] + for ev in events: + yield ev + + +def flush_log_streams( + aws_session, + log_group: str, + stream_prefix: str, + stream_names: List[str], + positions: Dict[str, Position], + stream_count: int, + has_streams: bool, + color_wrap: ColorWrap, +): + """Flushes log streams to stdout. + + Args: + aws_session (AwsSession): The AwsSession for interfacing with CloudWatch. + log_group (str): The name of the log group. + stream_prefix (str): The prefix for log streams to flush. + stream_names (List[str]): A list of the log stream names. The position of the stream in + this list is the stream number. If incomplete, the function will check for remaining + streams and mutate this list to add stream names when available, up to the + `stream_count` limit. + positions: (dict of Positions): A dict mapping stream numbers to (timestamp, skip) pairs + which represent the last record read from each stream. The function will update this + list after being called to represent the new last record read from each stream. + stream_count (int): The number of streams expected. + has_streams (bool): Whether the function has already been called once all streams have + been found. This value is possibly updated and returned at the end of execution. + color_wrap (ColorWrap): An instance of ColorWrap to potentially color-wrap print statements + from different streams. + + Yields: + A tuple of (stream number, cloudwatch log event). + """ + if len(stream_names) < stream_count: + # Log streams are created whenever a container starts writing to stdout/err, + # so this list may be dynamic until we have a stream for every instance. + try: + streams = aws_session.describe_log_streams( + log_group, + stream_prefix, + limit=stream_count, + ) + # stream_names = [...] wouldn't modify the list by reference. + new_streams = [ + s["logStreamName"] + for s in streams["logStreams"] + if s["logStreamName"] not in stream_names + ] + stream_names.extend(new_streams) + positions.update( + [(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions] + ) + except ClientError as e: + # On the very first training job run on an account, there's no + # log group until the container starts logging, so ignore any + # errors thrown about that until logging begins. + err = e.response.get("Error", {}) + if err.get("Code") != "ResourceNotFoundException": + raise + + if len(stream_names) > 0: + if not has_streams: + print() + has_streams = True + for idx, event in multi_stream_iter(aws_session, log_group, stream_names, positions): + color_wrap(idx, event["message"]) + ts, count = positions[stream_names[idx]] + if event["timestamp"] == ts: + positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1) + else: + positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1) + else: + print(".", end="", flush=True) + return has_streams diff --git a/src/braket/jobs/metrics.py b/src/braket/jobs/metrics.py new file mode 100644 index 000000000..cd8626282 --- /dev/null +++ b/src/braket/jobs/metrics.py @@ -0,0 +1,43 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import time +from typing import Optional, Union + + +def log_metric( + metric_name: str, + value: Union[float, int], + timestamp: Optional[float] = None, + iteration_number: Optional[int] = None, +) -> None: + """ + Records Braket Job metrics. + + Args: + metric_name (str) : The name of the metric. + + value (Union[float, int]) : The value of the metric. + + timestamp (Optional[float]) : The time the metric data was received, expressed + as the number of seconds + since the epoch. Default: Current system time. + + iteration_number (Optional[int]) : The iteration number of the metric. + """ + logged_timestamp = timestamp or time.time() + metric_list = [f"Metrics - timestamp={logged_timestamp}; {metric_name}={value};"] + if iteration_number is not None: + metric_list.append(f" iteration_number={iteration_number};") + metric_line = "".join(metric_list) + print(metric_line) diff --git a/src/braket/jobs/metrics_data/__init__.py b/src/braket/jobs/metrics_data/__init__.py new file mode 100644 index 000000000..273b68004 --- /dev/null +++ b/src/braket/jobs/metrics_data/__init__.py @@ -0,0 +1,17 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from braket.jobs.metrics_data.cwl_metrics_fetcher import CwlMetricsFetcher # noqa: F401 +from braket.jobs.metrics_data.definitions import MetricPeriod, MetricStatistic # noqa: F401 +from braket.jobs.metrics_data.exceptions import MetricsRetrievalError # noqa: F401 +from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser # noqa: F401 diff --git a/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py new file mode 100644 index 000000000..251d17ded --- /dev/null +++ b/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py @@ -0,0 +1,185 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import time +from logging import Logger, getLogger +from typing import Any, Dict, List, Optional, Union + +from braket.aws.aws_session import AwsSession +from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType +from braket.jobs.metrics_data.exceptions import MetricsRetrievalError +from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser + + +class CwlInsightsMetricsFetcher(object): + LOG_GROUP_NAME = "/aws/braket/jobs" + QUERY_DEFAULT_JOB_DURATION = 3 * 60 * 60 + + def __init__( + self, + aws_session: AwsSession, + poll_timeout_seconds: float = 10, + poll_interval_seconds: float = 1, + logger: Logger = getLogger(__name__), + ): + """ + Args: + aws_session (AwsSession): AwsSession to connect to AWS with. + poll_timeout_seconds (float): The polling timeout for retrieving the metrics, + in seconds. Default: 10 seconds. + poll_interval_seconds (float): The interval of time, in seconds, between polling + for results. Default: 1 second. + logger (Logger): Logger object with which to write logs, such as task statuses + while waiting for a task to be in a terminal state. Default is `getLogger(__name__)` + """ + self._poll_timeout_seconds = poll_timeout_seconds + self._poll_interval_seconds = poll_interval_seconds + self._logger = logger + self._logs_client = aws_session.logs_client + + @staticmethod + def _get_element_from_log_line( + element_name: str, log_line: List[Dict[str, Any]] + ) -> Optional[str]: + """ + Finds and returns an element of a log line from CloudWatch Insights results. + + Args: + element_name (str): The element to find. + log_line (List[Dict[str, Any]]): An iterator for RegEx matches on a log line. + + Returns: + Optional[str] : The value of the element with the element name, or None if no such + element is found. + """ + return next( + (element["value"] for element in log_line if element["field"] == element_name), None + ) + + def _get_metrics_results_sync(self, query_id: str) -> List[Any]: + """ + Waits for the CloudWatch Insights query to complete and then returns all the results. + + Args: + query_id (str): CloudWatch Insights query ID. + + Returns: + List[Any]: The results from CloudWatch insights 'GetQueryResults' operation. + """ + timeout_time = time.time() + self._poll_timeout_seconds + while time.time() < timeout_time: + response = self._logs_client.get_query_results(queryId=query_id) + query_status = response["status"] + if query_status in ["Failed", "Cancelled"]: + raise MetricsRetrievalError(f"Query {query_id} failed with status {query_status}.") + elif query_status == "Complete": + return response["results"] + else: + time.sleep(self._poll_interval_seconds) + self._logger.warning(f"Timed out waiting for query {query_id}.") + return [] + + def _parse_log_line(self, result_entry: List[Dict[str, Any]], parser: LogMetricsParser) -> None: + """ + Parses the single entry from CloudWatch Insights results and adds any metrics it finds + to 'all_metrics' along with the timestamp for the entry. + + Args: + result_entry (List[Dict[str, Any]]): A structured result from calling CloudWatch + Insights to get logs that contain metrics. A single entry contains the message + (the actual line logged to output), the timestamp (generated by CloudWatch Logs), + and other metadata that we (currently) do not use. + parser (LogMetricsParser) : The CWL metrics parser. + """ + message = self._get_element_from_log_line("@message", result_entry) + if message: + timestamp = self._get_element_from_log_line("@timestamp", result_entry) + parser.parse_log_message(timestamp, message) + + def _parse_log_query_results( + self, results: List[Any], metric_type: MetricType, statistic: MetricStatistic + ) -> Dict[str, List[Union[str, float, int]]]: + """ + Parses CloudWatch Insights results and returns all found metrics. + + Args: + results (List[Any]): A structured result from calling CloudWatch Insights to get + logs that contain metrics. + metric_type (MetricType): The type of metrics to get. + statistic (MetricStatistic): The statistic to determine which metric value to use + when there is a conflict. + + Returns: + Dict[str, List[Union[str, float, int]]] : The metrics data. + """ + parser = LogMetricsParser() + for result in results: + self._parse_log_line(result, parser) + return parser.get_parsed_metrics(metric_type, statistic) + + def get_metrics_for_job( + self, + job_name: str, + metric_type: MetricType = MetricType.TIMESTAMP, + statistic: MetricStatistic = MetricStatistic.MAX, + job_start_time: int = None, + job_end_time: int = None, + ) -> Dict[str, List[Union[str, float, int]]]: + """ + Synchronously retrieves all the algorithm metrics logged by a given Job. + + Args: + job_name (str): The name of the Job. The name must be exact to ensure only the relevant + metrics are retrieved. + metric_type (MetricType): The type of metrics to get. Default is MetricType.TIMESTAMP. + statistic (MetricStatistic): The statistic to determine which metric value to use + when there is a conflict. Default is MetricStatistic.MAX. + job_start_time (int): The time when the job started. + Default: 3 hours before job_end_time. + job_end_time (int): If the job is complete, this should be the time at which the + job finished. Default: current time. + + Returns: + Dict[str, List[Union[str, float, int]]] : The metrics data, where the keys + are the column names and the values are a list containing the values in each row. + For example, the table: + timestamp energy + 0 0.1 + 1 0.2 + would be represented as: + { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } + The values may be integers, floats, strings or None. + """ + query_end_time = job_end_time or int(time.time()) + query_start_time = job_start_time or query_end_time - self.QUERY_DEFAULT_JOB_DURATION + + # The job name needs to be unique to prevent jobs with similar names from being conflated. + query = ( + f"fields @timestamp, @message " + f"| filter @logStream like /^{job_name}\\// " + f"| filter @message like /^Metrics - /" + ) + + response = self._logs_client.start_query( + logGroupName=self.LOG_GROUP_NAME, + startTime=query_start_time, + endTime=query_end_time, + queryString=query, + limit=10000, + ) + + query_id = response["queryId"] + + results = self._get_metrics_results_sync(query_id) + + return self._parse_log_query_results(results, metric_type, statistic) diff --git a/src/braket/jobs/metrics_data/cwl_metrics_fetcher.py b/src/braket/jobs/metrics_data/cwl_metrics_fetcher.py new file mode 100644 index 000000000..632376162 --- /dev/null +++ b/src/braket/jobs/metrics_data/cwl_metrics_fetcher.py @@ -0,0 +1,164 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import time +from logging import Logger, getLogger +from typing import Dict, List, Union + +from braket.aws.aws_session import AwsSession +from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType +from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser + + +class CwlMetricsFetcher(object): + LOG_GROUP_NAME = "/aws/braket/jobs" + + def __init__( + self, + aws_session: AwsSession, + poll_timeout_seconds: float = 10, + logger: Logger = getLogger(__name__), + ): + """ + Args: + aws_session (AwsSession): AwsSession to connect to AWS with. + poll_timeout_seconds (float): The polling timeout for retrieving the metrics, + in seconds. Default: 10 seconds. + logger (Logger): Logger object with which to write logs, such as task statuses + while waiting for task to be in a terminal state. Default is `getLogger(__name__)` + """ + self._poll_timeout_seconds = poll_timeout_seconds + self._logger = logger + self._logs_client = aws_session.logs_client + + @staticmethod + def _is_metrics_message(message): + """ + Returns true if a given message is designated as containing Metrics. + + Args: + message (str): The message to check. + + Returns: + True if the given message is designated as containing Metrics; False otherwise. + """ + if message: + return "Metrics -" in message + return False + + def _parse_metrics_from_log_stream( + self, + stream_name: str, + timeout_time: float, + parser: LogMetricsParser, + ) -> None: + """ + Synchronously retrieves the algorithm metrics logged in a given job log stream. + + Args: + stream_name (str): The name of the log stream. + timeout_time (float) : We stop getting metrics if the current time is beyond + the timeout time. + parser (LogMetricsParser) : The CWL metrics parser. + + Returns: + None + """ + kwargs = { + "logGroupName": self.LOG_GROUP_NAME, + "logStreamName": stream_name, + "startFromHead": True, + "limit": 10000, + } + + previous_token = None + while time.time() < timeout_time: + response = self._logs_client.get_log_events(**kwargs) + for event in response.get("events"): + message = event.get("message") + if self._is_metrics_message(message): + parser.parse_log_message(event.get("timestamp"), message) + next_token = response.get("nextForwardToken") + if not next_token or next_token == previous_token: + return + previous_token = next_token + kwargs["nextToken"] = next_token + self._logger.warning("Timed out waiting for all metrics. Data may be incomplete.") + + def _get_log_streams_for_job(self, job_name: str, timeout_time: float) -> List[str]: + """ + Retrieves the list of log streams relevant to a job. + + Args: + job_name (str): The name of the job. + timeout_time (float) : Metrics cease getting streamed if the current time exceeds + the timeout time. + Returns: + List[str] : A list of log stream names for the given job. + """ + kwargs = { + "logGroupName": self.LOG_GROUP_NAME, + "logStreamNamePrefix": job_name + "/algo-", + } + log_streams = [] + while time.time() < timeout_time: + response = self._logs_client.describe_log_streams(**kwargs) + streams = response.get("logStreams") + if streams: + for stream in streams: + name = stream.get("logStreamName") + if name: + log_streams.append(name) + next_token = response.get("nextToken") + if not next_token: + return log_streams + kwargs["nextToken"] = next_token + self._logger.warning("Timed out waiting for all metrics. Data may be incomplete.") + return log_streams + + def get_metrics_for_job( + self, + job_name: str, + metric_type: MetricType = MetricType.TIMESTAMP, + statistic: MetricStatistic = MetricStatistic.MAX, + ) -> Dict[str, List[Union[str, float, int]]]: + """ + Synchronously retrieves all the algorithm metrics logged by a given Job. + + Args: + job_name (str): The name of the Job. The name must be exact to ensure only the relevant + metrics are retrieved. + metric_type (MetricType): The type of metrics to get. Default is MetricType.TIMESTAMP. + statistic (MetricStatistic): The statistic to determine which metric value to use + when there is a conflict. Default is MetricStatistic.MAX. + + Returns: + Dict[str, List[Union[str, float, int]]] : The metrics data, where the keys + are the column names and the values are a list containing the values in each row. + For example, the table: + timestamp energy + 0 0.1 + 1 0.2 + would be represented as: + { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } + values may be integers, floats, strings or None. + """ + timeout_time = time.time() + self._poll_timeout_seconds + + parser = LogMetricsParser() + + log_streams = self._get_log_streams_for_job(job_name, timeout_time) + for log_stream in log_streams: + self._parse_metrics_from_log_stream(log_stream, timeout_time, parser) + + return parser.get_parsed_metrics(metric_type, statistic) diff --git a/src/braket/jobs/metrics_data/definitions.py b/src/braket/jobs/metrics_data/definitions.py new file mode 100644 index 000000000..a3e77a783 --- /dev/null +++ b/src/braket/jobs/metrics_data/definitions.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from enum import Enum, unique + + +@unique +class MetricPeriod(Enum): + """Period over which the cloudwatch metric is aggregated.""" + + ONE_MINUTE: int = 60 + + +@unique +class MetricStatistic(Enum): + """Metric data aggregation to use over the specified period.""" + + MIN: str = "Min" + MAX: str = "Max" + + +@unique +class MetricType(Enum): + """Metric type.""" + + TIMESTAMP: str = "Timestamp" + ITERATION_NUMBER: str = "IterationNumber" diff --git a/src/braket/jobs/metrics_data/exceptions.py b/src/braket/jobs/metrics_data/exceptions.py new file mode 100644 index 000000000..677a3a447 --- /dev/null +++ b/src/braket/jobs/metrics_data/exceptions.py @@ -0,0 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + + +class MetricsRetrievalError(Exception): + """Raised when retrieving metrics fails.""" + + pass diff --git a/src/braket/jobs/metrics_data/log_metrics_parser.py b/src/braket/jobs/metrics_data/log_metrics_parser.py new file mode 100644 index 000000000..76ef319b4 --- /dev/null +++ b/src/braket/jobs/metrics_data/log_metrics_parser.py @@ -0,0 +1,198 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import re +from logging import Logger, getLogger +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType + + +class LogMetricsParser(object): + """ + This class is used to parse metrics from log lines, and return them in a more + convenient format. + """ + + METRICS_DEFINITIONS = re.compile(r"(\w+)\s*=\s*([^;]+)\s*;") + TIMESTAMP = "timestamp" + ITERATION_NUMBER = "iteration_number" + + def __init__( + self, + logger: Logger = getLogger(__name__), + ): + self._logger = logger + self.all_metrics = [] + + @staticmethod + def _get_value( + current_value: Optional[Union[str, float, int]], + new_value: Union[str, float, int], + statistic: MetricStatistic, + ) -> Union[str, float, int]: + """ + Gets the value based on a statistic. + + Args: + current_value (Optional[Union[str, float, int]]): The current value. + + new_value: (Union[str, float, int]) The new value. + + statistic (MetricStatistic): The statistic to determine which value to use. + + Returns: + Union[str, float, int]: the value. + """ + if current_value is None: + return new_value + if statistic == MetricStatistic.MAX: + return max(current_value, new_value) + return min(current_value, new_value) + + def _get_metrics_from_log_line_matches( + self, all_matches: Iterator + ) -> Dict[str, Union[str, float, int]]: + """ + Converts matches from a RegEx to a set of metrics. + + Args: + all_matches (Iterator): An iterator for RegEx matches on a log line. + + Returns: + Dict[str, Union[str, float, int]]: The set of metrics found by the RegEx. The result + is in the format { : }. This implies that multiple metrics + with the same name are deduped to the last instance of that metric. + """ + metrics = {} + for match in all_matches: + subgroup = match.groups() + value = subgroup[1] + try: + metrics[subgroup[0]] = float(value) + except ValueError: + self._logger.warning(f"Unable to convert value {value} to a float.") + return metrics + + def parse_log_message(self, timestamp: str, message: str) -> None: + """ + Parses a line from logs, adding all the metrics that have been logged + on that line. The timestamp is also added to match the corresponding values. + + Args: + timestamp (str): A formatted string representing the timestamp for any found metrics. + + message (str): A log line from a log. + """ + if not message: + return + all_matches = self.METRICS_DEFINITIONS.finditer(message) + parsed_metrics = self._get_metrics_from_log_line_matches(all_matches) + if not parsed_metrics: + return + if timestamp and self.TIMESTAMP not in parsed_metrics: + parsed_metrics[self.TIMESTAMP] = timestamp + self.all_metrics.append(parsed_metrics) + + def get_columns_and_pivot_indices( + self, pivot: str + ) -> Tuple[Dict[str, List[Union[str, float, int]]], Dict[int, int]]: + """ + Parses the metrics to find all the metrics that have the pivot column. The values of + the pivot column are assigned a row index, so that all metrics with the same pivot value + are stored in the same row. + Args: + pivot (str): The name of the pivot column. Must be TIMESTAMP or ITERATION_NUMBER. + + Returns: + Tuple[Dict[str, List[Any]], Dict[int, int]]: + The Dict[str, List[Any]] the result table with all the metrics values initialized + to None + The Dict[int, int] is the list of pivot indices, where the value of a pivot column + is mapped to a row index. + """ + row_count = 0 + pivot_indices: dict[int, int] = {} + table: dict[str, list[Optional[Union[str, float, int]]]] = {} + for metric in self.all_metrics: + if pivot in metric: + if metric[pivot] not in pivot_indices: + pivot_indices[metric[pivot]] = row_count + row_count += 1 + for column_name in metric: + table[column_name] = [None] + for column_name in table: + table[column_name] = [None] * row_count + return table, pivot_indices + + def get_metric_data_with_pivot( + self, pivot: str, statistic: MetricStatistic + ) -> Dict[str, List[Union[str, float, int]]]: + """ + Gets the metric data for a given pivot column name. Metrics without the pivot column + are not included in the results. Metrics that have the same value in the pivot column + are returned in the same row. If the a metric has multiple values for the pivot value, + the statistic is used to determine which value is returned. + For example, for the metrics: + "iteration_number" : 0, "metricA" : 2, "metricB" : 1, + "iteration_number" : 0, "metricA" : 1, + "no_pivot_column" : 0, "metricA" : 0, + "iteration_number" : 1, "metricA" : 2, + + The result with iteration_number as the pivot, statistic of MIN the result will be: + iteration_number metricA metricB + 0 1 1 + 1 2 None + + Args: + pivot (str): The name of the pivot column. Must be TIMESTAMP or ITERATION_NUMBER. + statistic (MetricStatistic): The statistic to determine which value to use. + + Returns: + Dict[str, List[Union[str, float, int]]] : The metrics data. + """ + table, pivot_indices = self.get_columns_and_pivot_indices(pivot) + for metric in self.all_metrics: + if pivot in metric: + row = pivot_indices[metric[pivot]] + for column_name in metric: + table[column_name][row] = self._get_value( + table[column_name][row], metric[column_name], statistic + ) + return table + + def get_parsed_metrics( + self, metric_type: MetricType, statistic: MetricStatistic + ) -> Dict[str, List[Union[str, float, int]]]: + """ + Gets all the metrics data, where the keys are the column names and the values are a list + containing the values in each row. For example, the table: + timestamp energy + 0 0.1 + 1 0.2 + would be represented as: + { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } + values may be integers, floats, strings or None. + + Args: + metric_type (MetricType): The type of metrics to get. + + statistic (MetricStatistic): The statistic to determine which metric value to use + when there is a conflict. + + Returns: + Dict[str, List[Union[str, float, int]]] : The metrics data. + """ + if metric_type == MetricType.ITERATION_NUMBER: + return self.get_metric_data_with_pivot(self.ITERATION_NUMBER, statistic) + return self.get_metric_data_with_pivot(self.TIMESTAMP, statistic) diff --git a/src/braket/jobs/quantum_job.py b/src/braket/jobs/quantum_job.py new file mode 100644 index 000000000..d17fa637e --- /dev/null +++ b/src/braket/jobs/quantum_job.py @@ -0,0 +1,186 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType + + +class QuantumJob(ABC): + DEFAULT_RESULTS_POLL_TIMEOUT = 864000 + DEFAULT_RESULTS_POLL_INTERVAL = 5 + + @property + @abstractmethod + def arn(self) -> str: + """str: The ARN (Amazon Resource Name) of the quantum job.""" + + @property + @abstractmethod + def name(self) -> str: + """str: The name of the quantum job.""" + + @abstractmethod + def state(self, use_cached_value: bool = False) -> str: + """The state of the quantum job. + + Args: + use_cached_value (bool, optional): If `True`, uses the value most recently retrieved + value from the Amazon Braket `GetJob` operation. If `False`, calls the + `GetJob` operation to retrieve metadata, which also updates the cached + value. Default = `False`. + Returns: + str: The value of `status` in `metadata()`. This is the value of the `status` key + in the Amazon Braket `GetJob` operation. + + See Also: + `metadata()` + """ + + @abstractmethod + def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: + """Display logs for a given job, optionally tailing them until job is complete. + + If the output is a tty or a Jupyter cell, it will be color-coded + based on which instance the log entry is from. + + Args: + wait (bool): `True` to keep looking for new log entries until the job completes; + otherwise `False`. Default: `False`. + + poll_interval_seconds (int): The interval of time, in seconds, between polling for + new log entries and job completion (default: 5). + + Raises: + RuntimeError: If waiting and the job fails. + """ + # The loop below implements a state machine that alternates between checking the job status + # and reading whatever is available in the logs at this point. Note, that if we were + # called with wait == False, we never check the job status. + # + # If wait == TRUE and job is not completed, the initial state is TAILING + # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is + # complete). + # + # The state table: + # + # STATE ACTIONS CONDITION NEW STATE + # ---------------- ---------------- ----------------- ---------------- + # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE + # Else TAILING + # JOB_COMPLETE Read logs, Pause Any COMPLETE + # COMPLETE Read logs, Exit N/A + # + # Notes: + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to + # Cloudwatch after the job was marked complete. + + @abstractmethod + def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: + """Gets the job metadata defined in Amazon Braket. + + Args: + use_cached_value (bool, optional): If `True`, uses the value most recently retrieved + from the Amazon Braket `GetJob` operation, if it exists; if does not exist, + `GetJob` is called to retrieve the metadata. If `False`, always calls + `GetJob`, which also updates the cached value. Default: `False`. + Returns: + Dict[str, Any]: Dict that specifies the job metadata defined in Amazon Braket. + """ + + @abstractmethod + def metrics( + self, + metric_type: MetricType = MetricType.TIMESTAMP, + statistic: MetricStatistic = MetricStatistic.MAX, + ) -> Dict[str, List[Any]]: + """Gets all the metrics data, where the keys are the column names, and the values are a list + containing the values in each row. For example, the table: + timestamp energy + 0 0.1 + 1 0.2 + would be represented as: + { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } + values may be integers, floats, strings or None. + + Args: + metric_type (MetricType): The type of metrics to get. Default: MetricType.TIMESTAMP. + + statistic (MetricStatistic): The statistic to determine which metric value to use + when there is a conflict. Default: MetricStatistic.MAX. + + Returns: + Dict[str, List[Union[str, float, int]]] : The metrics data. + """ + + @abstractmethod + def cancel(self) -> str: + """Cancels the job. + + Returns: + str: Indicates the status of the job. + + Raises: + ClientError: If there are errors invoking the CancelJob API. + """ + + @abstractmethod + def result( + self, + poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT, + poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL, + ) -> Dict[str, Any]: + """Retrieves the job result persisted using save_job_result() function. + + Args: + poll_timeout_seconds (float): The polling timeout, in seconds, for `result()`. + Default: 10 days. + + poll_interval_seconds (float): The polling interval, in seconds, for `result()`. + Default: 5 seconds. + + + Returns: + Dict[str, Any]: Dict specifying the job results. + + Raises: + RuntimeError: if job is in a FAILED or CANCELLED state. + TimeoutError: if job execution exceeds the polling timeout period. + """ + + @abstractmethod + def download_result( + self, + extract_to=None, + poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT, + poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL, + ) -> None: + """Downloads the results from the job output S3 bucket and extracts the tar.gz + bundle to the location specified by `extract_to`. If no location is specified, + the results are extracted to the current directory. + + Args: + extract_to (str): The directory to which the results are extracted. The results + are extracted to a folder titled with the job name within this directory. + Default= `Current working directory`. + + poll_timeout_seconds: (float): The polling timeout, in seconds, for `download_result()`. + Default: 10 days. + + poll_interval_seconds: (float): The polling interval, in seconds, for + `download_result()`.Default: 5 seconds. + + Raises: + RuntimeError: if job is in a FAILED or CANCELLED state. + TimeoutError: if job execution exceeds the polling timeout period. + """ diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py new file mode 100644 index 000000000..db159eccf --- /dev/null +++ b/src/braket/jobs/quantum_job_creation.py @@ -0,0 +1,409 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import annotations + +import importlib.util +import re +import sys +import tarfile +import tempfile +import time +from dataclasses import asdict +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from braket.aws.aws_session import AwsSession +from braket.jobs.config import ( + CheckpointConfig, + DeviceConfig, + InstanceConfig, + OutputDataConfig, + S3DataSourceConfig, + StoppingCondition, +) + + +def prepare_quantum_job( + device: str, + source_module: str, + entry_point: str = None, + image_uri: str = None, + job_name: str = None, + code_location: str = None, + role_arn: str = None, + hyperparameters: Dict[str, Any] = None, + input_data: Union[str, Dict, S3DataSourceConfig] = None, + instance_config: InstanceConfig = None, + stopping_condition: StoppingCondition = None, + output_data_config: OutputDataConfig = None, + copy_checkpoints_from_job: str = None, + checkpoint_config: CheckpointConfig = None, + aws_session: AwsSession = None, + tags: Dict[str, str] = None, +): + """Creates a job by invoking the Braket CreateJob API. + + Args: + device (str): ARN for the AWS device which is primarily + accessed for the execution of this job. + + source_module (str): Path (absolute, relative or an S3 URI) to a python module to be + tarred and uploaded. If `source_module` is an S3 URI, it must point to a + tar.gz file. Otherwise, source_module may be a file or directory. + + entry_point (str): A str that specifies the entry point of the job, relative to + the source module. The entry point must be in the format + `importable.module` or `importable.module:callable`. For example, + `source_module.submodule:start_here` indicates the `start_here` function + contained in `source_module.submodule`. If source_module is an S3 URI, + entry point must be given. Default: source_module's name + + image_uri (str): A str that specifies the ECR image to use for executing the job. + `image_uris.retrieve_image()` function may be used for retrieving the ECR image URIs + for the containers supported by Braket. Default = ``. + + job_name (str): A str that specifies the name with which the job is created. + Default: f'{image_uri_type}-{timestamp}'. + + code_location (str): The S3 prefix URI where custom code will be uploaded. + Default: f's3://{default_bucket_name}/jobs/{job_name}/script'. + + role_arn (str): A str providing the IAM role ARN used to execute the + script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`. + + hyperparameters (Dict[str, Any]): Hyperparameters accessible to the job. + The hyperparameters are made accessible as a Dict[str, str] to the job. + For convenience, this accepts other types for keys and values, but `str()` + is called to convert them before being passed on. Default: None. + + input_data (Union[str, S3DataSourceConfig, dict]): Information about the training + data. Dictionary maps channel names to local paths or S3 URIs. Contents found + at any local paths will be uploaded to S3 at + f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local + path, S3 URI, or S3DataSourceConfig is provided, it will be given a default + channel name "input". + Default: {}. + + instance_config (InstanceConfig): Configuration of the instances to be used + to execute the job. Default: InstanceConfig(instanceType='ml.m5.large', + instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None). + + stopping_condition (StoppingCondition): The maximum length of time, in seconds, + and the maximum number of tasks that a job can run before being forcefully stopped. + Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). + + output_data_config (OutputDataConfig): Specifies the location for the output of the job. + Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', + kmsKeyId=None). + + copy_checkpoints_from_job (str): A str that specifies the job ARN whose checkpoint you + want to use in the current job. Specifying this value will copy over the checkpoint + data from `use_checkpoints_from_job`'s checkpoint_config s3Uri to the current job's + checkpoint_config s3Uri, making it available at checkpoint_config.localPath during + the job execution. Default: None + + checkpoint_config (CheckpointConfig): Configuration that specifies the location where + checkpoint data is stored. + Default: CheckpointConfig(localPath='/opt/jobs/checkpoints', + s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints'). + + aws_session (AwsSession): AwsSession for connecting to AWS Services. + Default: AwsSession() + + tags (Dict[str, str]): Dict specifying the key-value pairs for tagging this job. + Default: {}. + + Returns: + AwsQuantumJob: Job tracking the execution on Amazon Braket. + + Raises: + ValueError: Raises ValueError if the parameters are not valid. + """ + param_datatype_map = { + "instance_config": (instance_config, InstanceConfig), + "stopping_condition": (stopping_condition, StoppingCondition), + "output_data_config": (output_data_config, OutputDataConfig), + "checkpoint_config": (checkpoint_config, CheckpointConfig), + } + + _validate_params(param_datatype_map) + aws_session = aws_session or AwsSession() + device_config = DeviceConfig(device) + job_name = job_name or _generate_default_job_name(image_uri) + role_arn = role_arn or aws_session.get_default_jobs_role() + hyperparameters = hyperparameters or {} + input_data = input_data or {} + tags = tags or {} + default_bucket = aws_session.default_bucket() + input_data_list = _process_input_data(input_data, job_name, aws_session) + instance_config = instance_config or InstanceConfig() + stopping_condition = stopping_condition or StoppingCondition() + output_data_config = output_data_config or OutputDataConfig() + checkpoint_config = checkpoint_config or CheckpointConfig() + code_location = code_location or AwsSession.construct_s3_uri( + default_bucket, + "jobs", + job_name, + "script", + ) + if AwsSession.is_s3_uri(source_module): + _process_s3_source_module(source_module, entry_point, aws_session, code_location) + else: + # if entry point is None, it will be set to default here + entry_point = _process_local_source_module( + source_module, entry_point, aws_session, code_location + ) + algorithm_specification = { + "scriptModeConfig": { + "entryPoint": entry_point, + "s3Uri": f"{code_location}/source.tar.gz", + "compressionType": "GZIP", + } + } + if image_uri: + algorithm_specification["containerImage"] = {"uri": image_uri} + if not output_data_config.s3Path: + output_data_config.s3Path = AwsSession.construct_s3_uri( + default_bucket, + "jobs", + job_name, + "data", + ) + if not checkpoint_config.s3Uri: + checkpoint_config.s3Uri = AwsSession.construct_s3_uri( + default_bucket, + "jobs", + job_name, + "checkpoints", + ) + if copy_checkpoints_from_job: + checkpoints_to_copy = aws_session.get_job(copy_checkpoints_from_job)["checkpointConfig"][ + "s3Uri" + ] + aws_session.copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri) + + create_job_kwargs = { + "jobName": job_name, + "roleArn": role_arn, + "algorithmSpecification": algorithm_specification, + "inputDataConfig": input_data_list, + "instanceConfig": asdict(instance_config), + "outputDataConfig": asdict(output_data_config), + "checkpointConfig": asdict(checkpoint_config), + "deviceConfig": asdict(device_config), + "hyperParameters": hyperparameters, + "stoppingCondition": asdict(stopping_condition), + "tags": tags, + } + + return create_job_kwargs + + +def _generate_default_job_name(image_uri: Optional[str]) -> str: + """ + Generate default job name using the image uri and a timestamp + Args: + image_uri (str, optional): URI for the image container. + + Returns: + str: Job name. + """ + if not image_uri: + job_type = "-default" + else: + job_type_match = re.search("/amazon-braket-(.*)-jobs:", image_uri) or re.search( + "/amazon-braket-([^:/]*)", image_uri + ) + job_type = f"-{job_type_match.groups()[0]}" if job_type_match else "" + + return f"braket-job{job_type}-{time.time() * 1000:.0f}" + + +def _process_s3_source_module( + source_module: str, entry_point: str, aws_session: AwsSession, code_location: str +) -> None: + """ + Check that the source module is an S3 URI of the correct type and that entry point is + provided. + + Args: + source_module (str): S3 URI pointing to the tarred source module. + entry_point (str): Entry point for the job. + aws_session (AwsSession): AwsSession to copy source module to code location. + code_location (str): S3 URI pointing to the location where the code will be + copied to. + """ + if entry_point is None: + raise ValueError("If source_module is an S3 URI, entry_point must be provided.") + if not source_module.lower().endswith(".tar.gz"): + raise ValueError( + "If source_module is an S3 URI, it must point to a tar.gz file. " + f"Not a valid S3 URI for parameter `source_module`: {source_module}" + ) + aws_session.copy_s3_object(source_module, f"{code_location}/source.tar.gz") + + +def _process_local_source_module( + source_module: str, entry_point: str, aws_session: AwsSession, code_location: str +) -> str: + """ + Check that entry point is valid with respect to source module, or provide a default + value if entry point is not given. Tar and upload source module to code location in S3. + Args: + source_module (str): Local path pointing to the source module. + entry_point (str): Entry point relative to the source module. + aws_session (AwsSession): AwsSession for uploading tarred source module. + code_location (str): S3 URI pointing to the location where the code will + be uploaded to. + + Returns: + str: Entry point. + """ + try: + # raises FileNotFoundError if not found + abs_path_source_module = Path(source_module).resolve(strict=True) + except FileNotFoundError: + raise ValueError(f"Source module not found: {source_module}") + + entry_point = entry_point or abs_path_source_module.stem + _validate_entry_point(abs_path_source_module, entry_point) + _tar_and_upload_to_code_location(abs_path_source_module, aws_session, code_location) + return entry_point + + +def _validate_entry_point(source_module_path: Path, entry_point: str) -> None: + """ + Confirm that a valid entry point relative to source module is given. + + Args: + source_module_path (Path): Path to source module. + entry_point (str): Entry point relative to source module. + """ + importable, _, _method = entry_point.partition(":") + sys.path.append(str(source_module_path.parent)) + try: + # second argument allows relative imports + module = importlib.util.find_spec(importable, source_module_path.stem) + assert module is not None + # if entry point is nested (ie contains '.'), parent modules are imported + except (ModuleNotFoundError, AssertionError): + raise ValueError(f"Entry point module was not found: {importable}") + finally: + sys.path.pop() + + +def _tar_and_upload_to_code_location( + source_module_path: Path, aws_session: AwsSession, code_location: str +) -> None: + """ + Tar and upload source module to code location. + + Args: + source_module_path (Path): Path to source module. + aws_session (AwsSession): AwsSession for uploading source module. + code_location (str): S3 URI pointing to the location where the tarred + source module will be uploaded to. + """ + with tempfile.TemporaryDirectory() as temp_dir: + with tarfile.open(f"{temp_dir}/source.tar.gz", "w:gz", dereference=True) as tar: + tar.add(source_module_path, arcname=source_module_path.name) + aws_session.upload_to_s3(f"{temp_dir}/source.tar.gz", f"{code_location}/source.tar.gz") + + +def _validate_params(dict_arr: Dict[str, Tuple[any, any]]) -> None: + """ + Validate that config parameters are of the right type. + + Args: + dict_arr (Dict[str, Tuple[any, any]]): dict mapping parameter names to + a tuple containing the provided value and expected type. + """ + for parameter_name, value_tuple in dict_arr.items(): + user_input, expected_datatype = value_tuple + + if user_input and not isinstance(user_input, expected_datatype): + raise ValueError( + f"'{parameter_name}' should be of '{expected_datatype}' " + f"but user provided {type(user_input)}." + ) + + +def _process_input_data( + input_data: Union[str, Dict, S3DataSourceConfig], job_name: str, aws_session: AwsSession +) -> List[Dict[str, Any]]: + """ + Convert input data into a list of dicts compatible with the Braket API. + Args: + input_data (Union[str, Dict, S3DataSourceConfig]): Either a channel definition or a + dictionary mapping channel names to channel definitions, where a channel definition + can be an S3DataSourceConfig or a str corresponding to a local prefix or S3 prefix. + job_name (str): Job name. + aws_session (AwsSession): AwsSession for possibly uploading local data. + + Returns: + List[Dict[str, Any]]: A list of channel configs. + """ + if not isinstance(input_data, dict): + input_data = {"input": input_data} + for channel_name, data in input_data.items(): + if not isinstance(data, S3DataSourceConfig): + input_data[channel_name] = _process_channel(data, job_name, aws_session, channel_name) + return _convert_input_to_config(input_data) + + +def _process_channel( + location: str, job_name: str, aws_session: AwsSession, channel_name: str +) -> S3DataSourceConfig: + """ + Convert a location to an S3DataSourceConfig, uploading local data to S3, if necessary. + Args: + location (str): Local prefix or S3 prefix. + job_name (str): Job name. + aws_session (AwsSession): AwsSession to be used for uploading local data. + channel_name (str): Name of the channel. + + Returns: + S3DataSourceConfig: S3DataSourceConfig for the channel. + """ + if AwsSession.is_s3_uri(location): + return S3DataSourceConfig(location) + else: + # local prefix "path/to/prefix" will be mapped to + # s3://bucket/jobs/job-name/data/input/prefix + location_name = Path(location).name + s3_prefix = AwsSession.construct_s3_uri( + aws_session.default_bucket(), "jobs", job_name, "data", channel_name, location_name + ) + aws_session.upload_local_data(location, s3_prefix) + return S3DataSourceConfig(s3_prefix) + + +def _convert_input_to_config(input_data: Dict[str, S3DataSourceConfig]) -> List[Dict[str, Any]]: + """ + Convert a dictionary mapping channel names to S3DataSourceConfigs into a list of channel + configs compatible with the Braket API. + + Args: + input_data (Dict[str, S3DataSourceConfig]): A dictionary mapping channel names to + S3DataSourceConfig objects. + + Returns: + List[Dict[str, Any]]: A list of channel configs. + """ + return [ + { + "channelName": channel_name, + **data_config.config, + } + for channel_name, data_config in input_data.items() + ] diff --git a/src/braket/jobs/serialization.py b/src/braket/jobs/serialization.py new file mode 100644 index 000000000..f8c854d03 --- /dev/null +++ b/src/braket/jobs/serialization.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import codecs +import pickle +from typing import Any, Dict + +from braket.jobs_data import PersistedJobDataFormat + + +def serialize_values( + data_dictionary: Dict[str, Any], data_format: PersistedJobDataFormat +) -> Dict[str, Any]: + """ + Serializes the `data_dictionary` values to the format specified by `data_format`. + + Args: + data_dictionary (Dict[str, Any]): Dict whose values are to be serialized. + data_format (PersistedJobDataFormat): The data format used to serialize the + values. Note that for `PICKLED` data formats, the values are base64 encoded + after serialization, so that they represent valid UTF-8 text and are compatible + with `PersistedJobData.json()`. + + Returns: + Dict[str, Any]: Dict with same keys as `data_dictionary` and values serialized to + the specified `data_format`. + """ + return ( + { + k: codecs.encode(pickle.dumps(v, protocol=4), "base64").decode() + for k, v in data_dictionary.items() + } + if data_format == PersistedJobDataFormat.PICKLED_V4 + else data_dictionary + ) + + +def deserialize_values( + data_dictionary: Dict[str, Any], data_format: PersistedJobDataFormat +) -> Dict[str, Any]: + """ + Deserializes the `data_dictionary` values from the format specified by `data_format`. + + Args: + data_dictionary (Dict[str, Any]): Dict whose values are to be deserialized. + data_format (PersistedJobDataFormat): The data format that the `data_dictionary` values + are currently serialized with. + + Returns: + Dict[str, Any]: Dict with same keys as `data_dictionary` and values deserialized from + the specified `data_format` to plaintext. + """ + return ( + {k: pickle.loads(codecs.decode(v.encode(), "base64")) for k, v in data_dictionary.items()} + if data_format == PersistedJobDataFormat.PICKLED_V4 + else data_dictionary + ) diff --git a/test/integ_tests/job_test_script.py b/test/integ_tests/job_test_script.py new file mode 100644 index 000000000..7071ffd62 --- /dev/null +++ b/test/integ_tests/job_test_script.py @@ -0,0 +1,53 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +import os + +from braket.aws import AwsDevice +from braket.circuits import Circuit +from braket.jobs import save_job_checkpoint, save_job_result +from braket.jobs_data import PersistedJobDataFormat + + +def start_here(): + hp_file = os.environ["AMZN_BRAKET_HP_FILE"] + with open(hp_file, "r") as f: + hyperparameters = json.load(f) + + if hyperparameters["test_case"] == "completed": + completed_job_script() + else: + failed_job_script() + + +def failed_job_script(): + print("Test job started!!!!!") + assert 0 + + +def completed_job_script(): + print("Test job started!!!!!") + + # Use the device declared in the Orchestration Script + device = AwsDevice(os.environ["AMZN_BRAKET_DEVICE_ARN"]) + + bell = Circuit().h(0).cnot(0, 1) + for count in range(5): + task = device.run(bell, shots=100) + print(task.result().measurement_counts) + save_job_result({"converged": True, "energy": -0.2}) + save_job_checkpoint({"some_data": "abc"}, checkpoint_file_suffix="plain_data") + save_job_checkpoint({"some_data": "abc"}, data_format=PersistedJobDataFormat.PICKLED_V4) + + print("Test job completed!!!!!") diff --git a/test/integ_tests/test_create_local_quantum_job.py b/test/integ_tests/test_create_local_quantum_job.py new file mode 100644 index 000000000..838a01329 --- /dev/null +++ b/test/integ_tests/test_create_local_quantum_job.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +import os +import re +import tempfile +from pathlib import Path + +import pytest + +from braket.jobs.local import LocalQuantumJob + + +def test_completed_local_job(aws_session, capsys): + """Asserts the job is completed with the respective files and folders for logs, + results and checkpoints. Validate the results are what we expect. Also, + assert that logs contains all the necessary steps for setup and running + the job is displayed to the user. + """ + absolute_source_module = str(Path("test/integ_tests/job_test_script.py").resolve()) + current_dir = Path.cwd() + + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + job = LocalQuantumJob.create( + "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + source_module=absolute_source_module, + entry_point="job_test_script:start_here", + hyperparameters={"test_case": "completed"}, + aws_session=aws_session, + ) + + job_name = job.name + pattern = f"^local:job/{job_name}$" + re.match(pattern=pattern, string=job.arn) + + assert job.state() == "COMPLETED" + assert Path(job_name).is_dir() + + # Check results match the expectations. + assert Path(f"{job_name}/results.json").exists() + assert job.result() == {"converged": True, "energy": -0.2} + + # Validate checkpoint files and data + assert Path(f"{job_name}/checkpoints/{job_name}.json").exists() + assert Path(f"{job_name}/checkpoints/{job_name}_plain_data.json").exists() + + for file_name, expected_data in [ + ( + f"{job_name}/checkpoints/{job_name}_plain_data.json", + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": {"some_data": "abc"}, + "dataFormat": "plaintext", + }, + ), + ( + f"{job_name}/checkpoints/{job_name}.json", + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": {"some_data": "gASVBwAAAAAAAACMA2FiY5Qu\n"}, + "dataFormat": "pickled_v4", + }, + ), + ]: + with open(file_name, "r") as f: + assert json.loads(f.read()) == expected_data + + # Capture logs + assert Path(f"{job_name}/log.txt").exists() + job.logs() + log_data, errors = capsys.readouterr() + + logs_to_validate = [ + "Beginning Setup", + "Running Code As Process", + "Test job started!!!!!", + "Test job completed!!!!!", + "Code Run Finished", + ] + + for data in logs_to_validate: + assert data in log_data + + os.chdir(current_dir) + + +def test_failed_local_job(aws_session, capsys): + """Asserts the job is failed with the output, checkpoints not created in bucket + and only logs are populated. Validate the calling result function raises + the ValueError. Also, check if the logs displays the required error message. + """ + absolute_source_module = str(Path("test/integ_tests/job_test_script.py").resolve()) + current_dir = Path.cwd() + + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + job = LocalQuantumJob.create( + "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + source_module=absolute_source_module, + entry_point="job_test_script:start_here", + hyperparameters={"test_case": "failed"}, + aws_session=aws_session, + ) + + job_name = job.name + pattern = f"^local:job/{job_name}$" + re.match(pattern=pattern, string=job.arn) + + assert Path(job_name).is_dir() + + # Check no files are populated in checkpoints folder. + assert not any(Path(f"{job_name}/checkpoints").iterdir()) + + # Check results match the expectations. + error_message = f"Unable to find results in the local job directory {job_name}." + with pytest.raises(ValueError, match=error_message): + job.result() + + assert Path(f"{job_name}/log.txt").exists() + job.logs() + log_data, errors = capsys.readouterr() + + logs_to_validate = [ + "Beginning Setup", + "Running Code As Process", + "Test job started!!!!!", + "Code Run Finished", + ] + + for data in logs_to_validate: + assert data in log_data + + os.chdir(current_dir) diff --git a/test/integ_tests/test_create_quantum_job.py b/test/integ_tests/test_create_quantum_job.py new file mode 100644 index 000000000..b88405e60 --- /dev/null +++ b/test/integ_tests/test_create_quantum_job.py @@ -0,0 +1,180 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +import os.path +import re +import tempfile +from pathlib import Path + +from braket.aws.aws_quantum_job import AwsQuantumJob + + +def test_failed_quantum_job(aws_session, capsys): + """Asserts the job is failed with the output, checkpoints, + tasks not created in bucket and only input is uploaded to s3. Validate the + results/download results have the response raising RuntimeError. Also, + check if the logs displays the Assertion Error. + """ + + job = AwsQuantumJob.create( + "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + source_module="test/integ_tests/job_test_script.py", + entry_point="job_test_script:start_here", + aws_session=aws_session, + wait_until_complete=True, + hyperparameters={"test_case": "failed"}, + ) + + job_name = job.name + pattern = f"^arn:aws:braket:{aws_session.region}:\\d12:job/{job_name}$" + re.match(pattern=pattern, string=job.arn) + + # Check job is in failed state. + assert job.state() == "FAILED" + + # Check whether the respective folder with files are created for script, + # output, tasks and checkpoints. + keys = aws_session.list_keys( + bucket=f"amazon-braket-{aws_session.region}-{aws_session.account_id}", + prefix=f"jobs/{job_name}", + ) + assert keys == [f"jobs/{job_name}/script/source.tar.gz"] + + # no results saved + assert job.result() == {} + + job.logs() + log_data, errors = capsys.readouterr() + assert errors == "" + logs_to_validate = [ + "Invoking script with the following command:", + "/usr/local/bin/python3.7 braket_container.py", + "Running Code As Process", + "Test job started!!!!!", + "AssertionError", + "Code Run Finished", + '"user_entry_point": "braket_container.py"', + ] + + for data in logs_to_validate: + assert data in log_data + + assert job.metadata()["failureReason"] == ( + "AlgorithmError: Job at job_test_script:start_here exited with exit code: 1" + ) + + +def test_completed_quantum_job(aws_session, capsys): + """Asserts the job is completed with the output, checkpoints, tasks and + script folder created in S3 for respective job. Validate the results are + downloaded and results are what we expect. Also, assert that logs contains all the + necessary steps for setup and running the job and is displayed to the user. + """ + + job = AwsQuantumJob.create( + "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + source_module="test/integ_tests/job_test_script.py", + entry_point="job_test_script:start_here", + wait_until_complete=True, + aws_session=aws_session, + hyperparameters={"test_case": "completed"}, + ) + + job_name = job.name + pattern = f"^arn:aws:braket:{aws_session.region}:\\d12:job/{job_name}$" + re.match(pattern=pattern, string=job.arn) + + # check job is in completed state. + assert job.state() == "COMPLETED" + + # Check whether the respective folder with files are created for script, + # output, tasks and checkpoints. + s3_bucket = f"amazon-braket-{aws_session.region}-{aws_session.account_id}" + keys = aws_session.list_keys( + bucket=s3_bucket, + prefix=f"jobs/{job_name}", + ) + for expected_key in [ + f"jobs/{job_name}/script/source.tar.gz", + f"jobs/{job_name}/data/output/model.tar.gz", + f"jobs/{job_name}/tasks/[^/]*/results.json", + f"jobs/{job_name}/checkpoints/{job_name}_plain_data.json", + f"jobs/{job_name}/checkpoints/{job_name}.json", + ]: + assert any(re.match(expected_key, key) for key in keys) + + # Check if checkpoint is uploaded in requested format. + for s3_key, expected_data in [ + ( + f"jobs/{job_name}/checkpoints/{job_name}_plain_data.json", + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": {"some_data": "abc"}, + "dataFormat": "plaintext", + }, + ), + ( + f"jobs/{job_name}/checkpoints/{job_name}.json", + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": {"some_data": "gASVBwAAAAAAAACMA2FiY5Qu\n"}, + "dataFormat": "pickled_v4", + }, + ), + ]: + assert ( + json.loads( + aws_session.retrieve_s3_object_body(s3_bucket=s3_bucket, s3_object_key=s3_key) + ) + == expected_data + ) + + # Check downloaded results exists in the file system after the call. + downloaded_result = f"{job_name}/{AwsQuantumJob.RESULTS_FILENAME}" + current_dir = Path.cwd() + + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + job.download_result() + assert ( + Path(AwsQuantumJob.RESULTS_TAR_FILENAME).exists() and Path(downloaded_result).exists() + ) + + # Check results match the expectations. + assert job.result() == {"converged": True, "energy": -0.2} + os.chdir(current_dir) + + # Check the logs and validate it contains required output. + job.logs(wait=True) + log_data, errors = capsys.readouterr() + assert errors == "" + logs_to_validate = [ + "Invoking script with the following command:", + "/usr/local/bin/python3.7 braket_container.py", + "Running Code As Process", + "Test job started!!!!!", + "Test job completed!!!!!", + "Code Run Finished", + '"user_entry_point": "braket_container.py"', + "Reporting training SUCCESS", + ] + + for data in logs_to_validate: + assert data in log_data diff --git a/test/integ_tests/test_device_creation.py b/test/integ_tests/test_device_creation.py index a480e12cd..63b7a3716 100644 --- a/test/integ_tests/test_device_creation.py +++ b/test/integ_tests/test_device_creation.py @@ -16,7 +16,7 @@ from braket.aws import AwsDevice DWAVE_ARN = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" -RIGETTI_ARN = "arn:aws:braket:::device/qpu/rigetti/Aspen-8" +RIGETTI_ARN = "arn:aws:braket:::device/qpu/rigetti/Aspen-10" IONQ_ARN = "arn:aws:braket:::device/qpu/ionq/ionQdevice" SIMULATOR_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1" diff --git a/test/unit_tests/braket/aws/common_test_utils.py b/test/unit_tests/braket/aws/common_test_utils.py index 846bf165a..8eaaa367c 100644 --- a/test/unit_tests/braket/aws/common_test_utils.py +++ b/test/unit_tests/braket/aws/common_test_utils.py @@ -17,7 +17,7 @@ from braket.aws import AwsQuantumTaskBatch DWAVE_ARN = "arn:aws:braket:::device/qpu/d-wave/Advantage_system1" -RIGETTI_ARN = "arn:aws:braket:::device/qpu/rigetti/Aspen-9" +RIGETTI_ARN = "arn:aws:braket:::device/qpu/rigetti/Aspen-10" IONQ_ARN = "arn:aws:braket:::device/qpu/ionq/ionQdevice" SV1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1" TN1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/tn1" @@ -131,11 +131,12 @@ class MockS3: def run_and_assert( aws_quantum_task_mock, device, + default_s3_folder, default_shots, default_poll_timeout, default_poll_interval, circuit, - s3_destination_folder, + s3_destination_folder, # Treated as positional arg shots, # Treated as positional arg poll_timeout_seconds, # Treated as positional arg poll_interval_seconds, # Treated as positional arg @@ -146,6 +147,8 @@ def run_and_assert( aws_quantum_task_mock.return_value = task_mock run_args = [] + if s3_destination_folder is not None: + run_args.append(s3_destination_folder) if shots is not None: run_args.append(shots) if poll_timeout_seconds is not None: @@ -155,13 +158,15 @@ def run_and_assert( run_args += extra_args if extra_args else [] run_kwargs = extra_kwargs or {} - task = device.run(circuit, s3_destination_folder, *run_args, **run_kwargs) + task = device.run(circuit, *run_args, **run_kwargs) assert task == task_mock create_args, create_kwargs = _create_task_args_and_kwargs( + default_s3_folder, default_shots, default_poll_timeout, default_poll_interval, + s3_destination_folder, shots, poll_timeout_seconds, poll_interval_seconds, @@ -170,12 +175,7 @@ def run_and_assert( ) aws_quantum_task_mock.assert_called_with( - device._aws_session, - device.arn, - circuit, - s3_destination_folder, - *create_args, - **create_kwargs + device._aws_session, device.arn, circuit, *create_args, **create_kwargs ) @@ -183,6 +183,7 @@ def run_batch_and_assert( aws_quantum_task_mock, aws_session_mock, device, + default_s3_folder, default_shots, default_poll_timeout, default_poll_interval, @@ -203,6 +204,8 @@ def run_batch_and_assert( aws_session_mock.return_value = new_session_mock run_args = [] + if s3_destination_folder is not None: + run_args.append(s3_destination_folder) if shots is not None: run_args.append(shots) if max_parallel is not None: @@ -216,13 +219,15 @@ def run_batch_and_assert( run_args += extra_args if extra_args else [] run_kwargs = extra_kwargs or {} - batch = device.run_batch(circuits, s3_destination_folder, *run_args, **run_kwargs) + batch = device.run_batch(circuits, *run_args, **run_kwargs) assert batch.tasks == [task_mock for _ in range(len(circuits))] create_args, create_kwargs = _create_task_args_and_kwargs( + default_s3_folder, default_shots, default_poll_timeout, default_poll_interval, + s3_destination_folder, shots, poll_timeout_seconds, poll_interval_seconds, @@ -235,26 +240,26 @@ def run_batch_and_assert( # aws_session_mock.call_args.kwargs syntax is newer than Python 3.7 assert aws_session_mock.call_args[1]["config"].max_pool_connections == max_pool_connections aws_quantum_task_mock.assert_called_with( - new_session_mock, - device.arn, - circuits[0], - s3_destination_folder, - *create_args, - **create_kwargs + new_session_mock, device.arn, circuits[0], *create_args, **create_kwargs ) def _create_task_args_and_kwargs( + default_s3_folder, default_shots, default_poll_timeout, default_poll_interval, + s3_folder, shots, poll_timeout_seconds, poll_interval_seconds, extra_args, extra_kwargs, ): - create_args = [shots if shots is not None else default_shots] + create_args = [ + s3_folder if s3_folder is not None else default_s3_folder, + shots if shots is not None else default_shots, + ] create_args += extra_args if extra_args else [] create_kwargs = extra_kwargs or {} create_kwargs.update( diff --git a/test/unit_tests/braket/aws/test_aws_device.py b/test/unit_tests/braket/aws/test_aws_device.py index b4cd59ed5..911ad932c 100644 --- a/test/unit_tests/braket/aws/test_aws_device.py +++ b/test/unit_tests/braket/aws/test_aws_device.py @@ -10,10 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. - +import os from unittest.mock import Mock, patch import pytest +from botocore.exceptions import ClientError from common_test_utils import ( DWAVE_ARN, IONQ_ARN, @@ -73,7 +74,7 @@ def test_mock_rigetti_schema_1(): MOCK_GATE_MODEL_QPU_1 = { - "deviceName": "Aspen-9", + "deviceName": "Aspen-10", "deviceType": "QPU", "providerName": "provider1", "deviceStatus": "OFFLINE", @@ -231,6 +232,11 @@ def test_gate_model_sim_schema(): "deviceCapabilities": MOCK_GATE_MODEL_SIMULATOR_CAPABILITIES.json(), } +MOCK_DEFAULT_S3_DESTINATION_FOLDER = ( + "amazon-braket-us-test-1-00000000", + "tasks", +) + @pytest.fixture def arn(): @@ -254,23 +260,6 @@ def boto_session(): return _boto_session -@pytest.fixture -def aws_explicit_session(): - _boto_session = Mock() - _boto_session.region_name = RIGETTI_REGION - - creds = Mock() - creds.access_key = "access key" - creds.secret_key = "secret key" - creds.token = "token" - creds.method = "explicit" - _boto_session.get_credentials.return_value = creds - - _aws_session = Mock() - _aws_session.boto_session = _boto_session - return _aws_session - - @pytest.fixture def aws_session(): _boto_session = Mock() @@ -282,6 +271,11 @@ def aws_session(): _aws_session = Mock() _aws_session.boto_session = _boto_session + _aws_session._default_bucket = MOCK_DEFAULT_S3_DESTINATION_FOLDER[0] + _aws_session.default_bucket.return_value = _aws_session._default_bucket + _aws_session._custom_default_bucket = False + _aws_session.account_id = "00000000" + _aws_session.region = RIGETTI_REGION return _aws_session @@ -320,43 +314,41 @@ def test_device_simulator_no_aws_session(aws_session_init, aws_session): aws_session.get_device.assert_called_with(arn) -@patch("boto3.Session") -def test_copy_session(boto_session_init, aws_session): - boto_session_init.return_value = Mock() - AwsDevice._copy_aws_session(aws_session, RIGETTI_REGION) - boto_session_init.assert_called_with(region_name=RIGETTI_REGION) - - -@patch("boto3.Session") -def test_copy_explicit_session(boto_session_init, aws_explicit_session): - boto_session_init.return_value = Mock() - AwsDevice._copy_aws_session(aws_explicit_session, RIGETTI_REGION) - boto_session_init.assert_called_with( - aws_access_key_id="access key", - aws_secret_access_key="secret key", - aws_session_token="token", - region_name=RIGETTI_REGION, - ) - - -@patch("braket.aws.aws_device.AwsDevice._copy_aws_session") +@patch("braket.aws.aws_device.AwsSession.copy_session") @patch("braket.aws.aws_device.AwsSession") @pytest.mark.parametrize( "get_device_side_effect", [ [MOCK_GATE_MODEL_QPU_1], - [ValueError(), MOCK_GATE_MODEL_QPU_1], + [ + ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ), + MOCK_GATE_MODEL_QPU_1, + ], ], ) def test_device_qpu_no_aws_session( - aws_session_init, mock_copy_aws_session, get_device_side_effect, aws_session + aws_session_init, mock_copy_session, get_device_side_effect, aws_session ): arn = RIGETTI_ARN mock_session = Mock() mock_session.get_device.side_effect = get_device_side_effect - aws_session.get_device.side_effect = ValueError() + aws_session.get_device.side_effect = ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ) aws_session_init.return_value = aws_session - mock_copy_aws_session.return_value = mock_session + mock_copy_session.return_value = mock_session device = AwsDevice(arn) _assert_device_fields(device, MOCK_GATE_MODEL_QPU_CAPABILITIES_1, MOCK_GATE_MODEL_QPU_1) @@ -394,29 +386,118 @@ def test_repr(arn): assert repr(device) == expected -@pytest.mark.xfail(raises=ValueError) def test_device_simulator_not_found(): mock_session = Mock() - mock_session.get_device.side_effect = ValueError() - AwsDevice("arn:aws:braket:::device/simulator/a/b", mock_session) + mock_session.region = "test-region-1" + mock_session.get_device.side_effect = ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + "Message": ( + "Braket device 'arn:aws:braket:::device/quantum-simulator/amazon/tn1' " + "not found in us-west-1. You can find a list of all supported device " + "ARNs and the regions in which they are available in the documentation: " + "https://docs.aws.amazon.com/braket/latest/developerguide/braket-devices.html" + ), + } + }, + "getDevice", + ) + simulator_not_found = ( + "Simulator 'arn:aws:braket:::device/simulator/a/b' not found in 'test-region-1'" + ) + with pytest.raises(ValueError, match=simulator_not_found): + AwsDevice("arn:aws:braket:::device/simulator/a/b", mock_session) -@pytest.mark.xfail(raises=ValueError) -@patch("braket.aws.aws_device.AwsDevice._copy_aws_session") -def test_device_qpu_not_found(mock_copy_aws_session): +@patch("braket.aws.aws_device.AwsSession.copy_session") +def test_device_qpu_not_found(mock_copy_session): + mock_session = Mock() + mock_session.get_device.side_effect = ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + "Message": ( + "Braket device 'arn:aws:braket:::device/quantum-simulator/amazon/tn1' " + "not found in us-west-1. You can find a list of all supported device " + "ARNs and the regions in which they are available in the documentation: " + "https://docs.aws.amazon.com/braket/latest/developerguide/braket-devices.html" + ), + } + }, + "getDevice", + ) + mock_copy_session.return_value = mock_session + qpu_not_found = "QPU 'arn:aws:braket:::device/qpu/a/b' not found" + with pytest.raises(ValueError, match=qpu_not_found): + AwsDevice("arn:aws:braket:::device/qpu/a/b", mock_session) + + +@patch("braket.aws.aws_device.AwsSession.copy_session") +def test_device_qpu_exception(mock_copy_session): + mock_session = Mock() + mock_session.get_device.side_effect = ( + ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + "Message": ( + "Braket device 'arn:aws:braket:::device/quantum-simulator/amazon/tn1' " + "not found in us-west-1. You can find a list of all supported device " + "ARNs and the regions in which they are available in the documentation: " + "https://docs.aws.amazon.com/braket/latest/developerguide/braket-" + "devices.html" + ), + } + }, + "getDevice", + ), + ClientError( + { + "Error": { + "Code": "OtherException", + "Message": "Some other message", + } + }, + "getDevice", + ), + ) + mock_copy_session.return_value = mock_session + qpu_exception = ( + "An error occurred \\(OtherException\\) when calling the " + "getDevice operation: Some other message" + ) + with pytest.raises(ClientError, match=qpu_exception): + AwsDevice("arn:aws:braket:::device/qpu/a/b", mock_session) + + +@patch("braket.aws.aws_device.AwsSession.copy_session") +def test_device_non_qpu_region_error(mock_copy_session): mock_session = Mock() - mock_session.get_device.side_effect = ValueError() - mock_copy_aws_session.return_value = mock_session - AwsDevice("arn:aws:braket:::device/qpu/a/b", mock_session) + mock_session.get_device.side_effect = ClientError( + { + "Error": { + "Code": "ExpiredTokenError", + "Message": ("Some other error that isn't ResourceNotFoundException"), + } + }, + "getDevice", + ) + mock_copy_session.return_value = mock_session + expired_token = ( + "An error occurred \\(ExpiredTokenError\\) when calling the getDevice operation: " + "Some other error that isn't ResourceNotFoundException" + ) + with pytest.raises(ClientError, match=expired_token): + AwsDevice("arn:aws:braket:::device/qpu/a/b", mock_session) @patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") -def test_run_no_extra(aws_quantum_task_mock, device, circuit, s3_destination_folder): +def test_run_no_extra(aws_quantum_task_mock, device, circuit): _run_and_assert( aws_quantum_task_mock, device, circuit, - s3_destination_folder, ) @@ -460,6 +541,7 @@ def test_run_with_qpu_no_shots(aws_quantum_task_mock, device, circuit, s3_destin run_and_assert( aws_quantum_task_mock, device(RIGETTI_ARN), + MOCK_DEFAULT_S3_DESTINATION_FOLDER, AwsDevice.DEFAULT_SHOTS_QPU, AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, @@ -473,6 +555,27 @@ def test_run_with_qpu_no_shots(aws_quantum_task_mock, device, circuit, s3_destin ) +@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") +def test_default_bucket_not_called(aws_quantum_task_mock, device, circuit, s3_destination_folder): + device = device(RIGETTI_ARN) + run_and_assert( + aws_quantum_task_mock, + device, + MOCK_DEFAULT_S3_DESTINATION_FOLDER, + AwsDevice.DEFAULT_SHOTS_QPU, + AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, + AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, + circuit, + s3_destination_folder, + None, + None, + None, + None, + None, + ) + device._aws_session.default_bucket.assert_not_called() + + @patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") def test_run_with_shots_poll_timeout_kwargs( aws_quantum_task_mock, device, circuit, s3_destination_folder @@ -505,21 +608,28 @@ def test_run_with_positional_args_and_kwargs( ) -@patch("braket.aws.aws_device.AwsSession") +@patch.dict( + os.environ, + {"AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://env_bucket/env/path"}, +) @patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") -def test_run_batch_no_extra( - aws_quantum_task_mock, aws_session_mock, device, circuit, s3_destination_folder -): +def test_run_env_variables(aws_quantum_task_mock, device, circuit): + device("foo:bar").run(circuit) + assert aws_quantum_task_mock.call_args_list[0][0][3] == ("env_bucket", "env/path") + + +@patch("braket.aws.aws_session.AwsSession") +@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") +def test_run_batch_no_extra(aws_quantum_task_mock, aws_session_mock, device, circuit): _run_batch_and_assert( aws_quantum_task_mock, aws_session_mock, device, [circuit for _ in range(10)], - s3_destination_folder, ) -@patch("braket.aws.aws_device.AwsSession") +@patch("braket.aws.aws_session.AwsSession") @patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") def test_run_batch_with_shots( aws_quantum_task_mock, aws_session_mock, device, circuit, s3_destination_folder @@ -534,7 +644,7 @@ def test_run_batch_with_shots( ) -@patch("braket.aws.aws_device.AwsSession") +@patch("braket.aws.aws_session.AwsSession") @patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") def test_run_batch_with_max_parallel_and_kwargs( aws_quantum_task_mock, aws_session_mock, device, circuit, s3_destination_folder @@ -552,11 +662,21 @@ def test_run_batch_with_max_parallel_and_kwargs( ) +@patch.dict( + os.environ, + {"AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://env_bucket/env/path"}, +) +@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") +def test_run_batch_env_variables(aws_quantum_task_mock, device, circuit): + device("foo:bar").run_batch([circuit]) + assert aws_quantum_task_mock.call_args_list[0][0][3] == ("env_bucket", "env/path") + + def _run_and_assert( aws_quantum_task_mock, device_factory, circuit, - s3_destination_folder, + s3_destination_folder=None, # Treated as positional arg shots=None, # Treated as positional arg poll_timeout_seconds=None, # Treated as positional arg poll_interval_seconds=None, # Treated as positional arg @@ -566,6 +686,7 @@ def _run_and_assert( run_and_assert( aws_quantum_task_mock, device_factory("foo_bar"), + MOCK_DEFAULT_S3_DESTINATION_FOLDER, AwsDevice.DEFAULT_SHOTS_SIMULATOR, AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, @@ -584,7 +705,7 @@ def _run_batch_and_assert( aws_session_mock, device_factory, circuits, - s3_destination_folder, + s3_destination_folder=None, # Treated as positional arg shots=None, # Treated as positional arg max_parallel=None, # Treated as positional arg max_connections=None, # Treated as positional arg @@ -597,6 +718,7 @@ def _run_batch_and_assert( aws_quantum_task_mock, aws_session_mock, device_factory("foo_bar"), + MOCK_DEFAULT_S3_DESTINATION_FOLDER, AwsDevice.DEFAULT_SHOTS_SIMULATOR, AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, @@ -622,8 +744,8 @@ def _assert_device_fields(device, expected_properties, expected_device_data): assert device.topology_graph.edges == device._construct_topology_graph().edges -@patch("braket.aws.aws_device.AwsDevice._copy_aws_session") -def test_get_devices(mock_copy_aws_session, aws_session): +@patch("braket.aws.aws_device.AwsSession.copy_session") +def test_get_devices(mock_copy_session, aws_session): aws_session.search_devices.side_effect = [ # us-west-1 [ @@ -679,7 +801,7 @@ def test_get_devices(mock_copy_aws_session, aws_session): MOCK_GATE_MODEL_QPU_2, ValueError("should not be reachable"), ] - mock_copy_aws_session.return_value = session_for_region + mock_copy_session.return_value = session_for_region # Search order: us-east-1, us-west-1, us-west-2 results = AwsDevice.get_devices( arns=[SV1_ARN, DWAVE_ARN, IONQ_ARN], @@ -690,8 +812,8 @@ def test_get_devices(mock_copy_aws_session, aws_session): assert [result.name for result in results] == ["Advantage_system1.1", "Blah", "SV1"] -@patch("braket.aws.aws_device.AwsDevice._copy_aws_session") -def test_get_devices_simulators_only(mock_copy_aws_session, aws_session): +@patch("braket.aws.aws_device.AwsSession.copy_session") +def test_get_devices_simulators_only(mock_copy_session, aws_session): aws_session.search_devices.side_effect = [ [ { @@ -711,7 +833,7 @@ def test_get_devices_simulators_only(mock_copy_aws_session, aws_session): session_for_region = Mock() session_for_region.search_devices.side_effect = ValueError("should not be reachable") session_for_region.get_device.side_effect = ValueError("should not be reachable") - mock_copy_aws_session.return_value = session_for_region + mock_copy_session.return_value = session_for_region results = AwsDevice.get_devices( arns=[SV1_ARN, TN1_ARN], types=["SIMULATOR"], diff --git a/test/unit_tests/braket/aws/test_aws_quantum_job.py b/test/unit_tests/braket/aws/test_aws_quantum_job.py new file mode 100644 index 000000000..2265fd3c0 --- /dev/null +++ b/test/unit_tests/braket/aws/test_aws_quantum_job.py @@ -0,0 +1,910 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import datetime +import json +import logging +import os +import tarfile +import tempfile +from unittest.mock import Mock, patch + +import pytest +from botocore.exceptions import ClientError + +from braket.aws import AwsQuantumJob, AwsSession + + +@pytest.fixture +def aws_session(quantum_job_arn, job_region): + _aws_session = Mock(spec=AwsSession) + _aws_session.create_job.return_value = quantum_job_arn + _aws_session.default_bucket.return_value = "default-bucket-name" + _aws_session.get_default_jobs_role.return_value = "default-role-arn" + _aws_session.construct_s3_uri.side_effect = ( + lambda bucket, *dirs: f"s3://{bucket}/{'/'.join(dirs)}" + ) + + def fake_copy_session(region): + _aws_session.region = region + return _aws_session + + _aws_session.copy_session.side_effect = fake_copy_session + _aws_session.list_keys.return_value = ["job-path/output/model.tar.gz"] + _aws_session.region = "us-test-1" + + _braket_client_mock = Mock(meta=Mock(region_name=job_region)) + _aws_session.braket_client = _braket_client_mock + return _aws_session + + +@pytest.fixture +def generate_get_job_response(): + def _get_job_response(**kwargs): + response = { + "ResponseMetadata": { + "RequestId": "d223b1a0-ee5c-4c75-afa7-3c29d5338b62", + "HTTPStatusCode": 200, + }, + "algorithmSpecification": { + "scriptModeConfig": { + "entryPoint": "my_file:start_here", + "s3Uri": "s3://amazon-braket-jobs/job-path/my_file.py", + } + }, + "checkpointConfig": { + "localPath": "/opt/omega/checkpoints", + "s3Uri": "s3://amazon-braket-jobs/job-path/checkpoints", + }, + "createdAt": datetime.datetime(2021, 6, 28, 21, 4, 51), + "deviceConfig": { + "device": "arn:aws:braket:::device/qpu/rigetti/Aspen-10", + }, + "hyperParameters": { + "foo": "bar", + }, + "inputDataConfig": [ + { + "channelName": "training_input", + "dataSource": { + "s3DataSource": { + "s3Uri": "s3://amazon-braket-jobs/job-path/input", + } + }, + } + ], + "instanceConfig": { + "instanceCount": 1, + "instanceType": "ml.m5.large", + "volumeSizeInGb": 1, + }, + "jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446", + "jobName": "job-test-20210628140446", + "outputDataConfig": {"s3Path": "s3://amazon-braket-jobs/job-path/data"}, + "roleArn": "arn:aws:iam::875981177017:role/AmazonBraketJobRole", + "status": "RUNNING", + "stoppingCondition": {"maxRuntimeInSeconds": 1200}, + } + response.update(kwargs) + + return response + + return _get_job_response + + +@pytest.fixture +def generate_cancel_job_response(): + def _cancel_job_response(**kwargs): + response = { + "ResponseMetadata": { + "RequestId": "857b0893-2073-4ad6-b828-744af8400dfe", + "HTTPStatusCode": 200, + }, + "cancellationStatus": "CANCELLING", + "jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446", + } + response.update(kwargs) + return response + + return _cancel_job_response + + +@pytest.fixture +def quantum_job_name(): + return "job-test-20210628140446" + + +@pytest.fixture +def job_region(): + return "us-west-2" + + +@pytest.fixture +def quantum_job_arn(quantum_job_name, job_region): + return f"arn:aws:braket:{job_region}:875981177017:job/{quantum_job_name}" + + +@pytest.fixture +def quantum_job(quantum_job_arn, aws_session): + return AwsQuantumJob(quantum_job_arn, aws_session) + + +def test_equality(quantum_job_arn, aws_session, job_region): + new_aws_session = Mock(braket_client=Mock(meta=Mock(region_name=job_region))) + quantum_job_1 = AwsQuantumJob(quantum_job_arn, aws_session) + quantum_job_2 = AwsQuantumJob(quantum_job_arn, aws_session) + quantum_job_3 = AwsQuantumJob(quantum_job_arn, new_aws_session) + other_quantum_job = AwsQuantumJob( + "arn:aws:braket:us-west-2:875981177017:job/other-job", aws_session + ) + non_quantum_job = quantum_job_1.arn + + assert quantum_job_1 == quantum_job_2 + assert quantum_job_1 == quantum_job_3 + assert quantum_job_1 is not quantum_job_2 + assert quantum_job_1 is not quantum_job_3 + assert quantum_job_1 is quantum_job_1 + assert quantum_job_1 != other_quantum_job + assert quantum_job_1 != non_quantum_job + + +def test_hash(quantum_job): + assert hash(quantum_job) == hash(quantum_job.arn) + + +@pytest.mark.parametrize( + "arn, expected_region", + [ + ("arn:aws:braket:us-west-2:875981177017:job/job-name", "us-west-2"), + ("arn:aws:braket:us-west-1:1234567890:job/job-name", "us-west-1"), + ], +) +@patch("braket.aws.aws_quantum_job.boto3.Session") +@patch("braket.aws.aws_quantum_job.AwsSession") +def test_quantum_job_constructor_default_session( + aws_session_mock, mock_session, arn, expected_region +): + mock_boto_session = Mock() + aws_session_mock.return_value = Mock() + mock_session.return_value = mock_boto_session + job = AwsQuantumJob(arn) + mock_session.assert_called_with(region_name=expected_region) + aws_session_mock.assert_called_with(boto_session=mock_boto_session) + assert job.arn == arn + assert job._aws_session == aws_session_mock.return_value + + +@pytest.mark.xfail(raises=ValueError) +def test_quantum_job_constructor_invalid_region(aws_session): + arn = "arn:aws:braket:unknown-region:875981177017:job/quantum_job_name" + AwsQuantumJob(arn, aws_session) + + +@patch("braket.aws.aws_quantum_job.boto3.Session") +def test_quantum_job_constructor_explicit_session(mock_session, quantum_job_arn, job_region): + aws_session_mock = Mock(braket_client=Mock(meta=Mock(region_name=job_region))) + job = AwsQuantumJob(quantum_job_arn, aws_session_mock) + assert job._aws_session == aws_session_mock + assert job.arn == quantum_job_arn + mock_session.assert_not_called() + + +def test_metadata(quantum_job, aws_session, generate_get_job_response, quantum_job_arn): + get_job_response_running = generate_get_job_response(status="RUNNING") + aws_session.get_job.return_value = get_job_response_running + assert quantum_job.metadata() == get_job_response_running + aws_session.get_job.assert_called_with(quantum_job_arn) + + get_job_response_completed = generate_get_job_response(status="COMPLETED") + aws_session.get_job.return_value = get_job_response_completed + assert quantum_job.metadata() == get_job_response_completed + aws_session.get_job.assert_called_with(quantum_job_arn) + assert aws_session.get_job.call_count == 2 + + +def test_metadata_caching(quantum_job, aws_session, generate_get_job_response, quantum_job_arn): + get_job_response_running = generate_get_job_response(status="RUNNING") + aws_session.get_job.return_value = get_job_response_running + assert quantum_job.metadata(True) == get_job_response_running + + get_job_response_completed = generate_get_job_response(status="COMPLETED") + aws_session.get_job.return_value = get_job_response_completed + assert quantum_job.metadata(True) == get_job_response_running + aws_session.get_job.assert_called_with(quantum_job_arn) + assert aws_session.get_job.call_count == 1 + + +def test_state(quantum_job, aws_session, generate_get_job_response, quantum_job_arn): + state_1 = "RUNNING" + get_job_response_running = generate_get_job_response(status=state_1) + aws_session.get_job.return_value = get_job_response_running + assert quantum_job.state() == state_1 + aws_session.get_job.assert_called_with(quantum_job_arn) + + state_2 = "COMPLETED" + get_job_response_completed = generate_get_job_response(status=state_2) + aws_session.get_job.return_value = get_job_response_completed + assert quantum_job.state() == state_2 + aws_session.get_job.assert_called_with(quantum_job_arn) + assert aws_session.get_job.call_count == 2 + + +def test_state_caching(quantum_job, aws_session, generate_get_job_response, quantum_job_arn): + state_1 = "RUNNING" + get_job_response_running = generate_get_job_response(status=state_1) + aws_session.get_job.return_value = get_job_response_running + assert quantum_job.state(True) == state_1 + + state_2 = "COMPLETED" + get_job_response_completed = generate_get_job_response(status=state_2) + aws_session.get_job.return_value = get_job_response_completed + assert quantum_job.state(True) == state_1 + aws_session.get_job.assert_called_with(quantum_job_arn) + assert aws_session.get_job.call_count == 1 + + +@pytest.fixture() +def result_setup(quantum_job_name): + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + file_path = "results.json" + + with open(file_path, "w") as write_file: + write_file.write( + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": {"converged": True, "energy": -0.2}, + "dataFormat": "plaintext", + } + ) + ) + + with tarfile.open("model.tar.gz", "w:gz") as tar: + tar.add(file_path, arcname=os.path.basename(file_path)) + + yield + + result_dir = f"{os.getcwd()}/{quantum_job_name}" + + if os.path.exists(result_dir): + os.remove(f"{result_dir}/results.json") + os.rmdir(f"{result_dir}/") + + if os.path.isfile("model.tar.gz"): + os.remove("model.tar.gz") + + os.chdir("..") + + +@pytest.mark.parametrize("state", AwsQuantumJob.TERMINAL_STATES) +def test_results_when_job_is_completed( + quantum_job, aws_session, generate_get_job_response, result_setup, state +): + expected_saved_data = {"converged": True, "energy": -0.2} + + get_job_response_completed = generate_get_job_response(status=state) + quantum_job._aws_session.get_job.return_value = get_job_response_completed + actual_data = quantum_job.result() + + job_metadata = quantum_job.metadata(True) + s3_path = job_metadata["outputDataConfig"]["s3Path"] + + output_bucket_uri = f"{s3_path}/output/model.tar.gz" + quantum_job._aws_session.download_from_s3.assert_called_with( + s3_uri=output_bucket_uri, filename="model.tar.gz" + ) + assert actual_data == expected_saved_data + + +def test_download_result_when_job_is_running( + quantum_job, aws_session, generate_get_job_response, result_setup +): + poll_timeout_seconds, poll_interval_seconds, state = 1, 0.5, "RUNNING" + get_job_response_completed = generate_get_job_response(status=state) + aws_session.get_job.return_value = get_job_response_completed + job_metadata = quantum_job.metadata(True) + + with pytest.raises( + TimeoutError, + match=f"{job_metadata['jobName']}: Polling for job completion " + f"timed out after {poll_timeout_seconds} seconds.", + ): + quantum_job.download_result( + poll_timeout_seconds=poll_timeout_seconds, poll_interval_seconds=poll_interval_seconds + ) + + +def test_download_result_when_extract_path_not_provided( + quantum_job, generate_get_job_response, aws_session, result_setup +): + state = "COMPLETED" + expected_saved_data = {"converged": True, "energy": -0.2} + get_job_response_completed = generate_get_job_response(status=state) + quantum_job._aws_session.get_job.return_value = get_job_response_completed + job_metadata = quantum_job.metadata(True) + job_name = job_metadata["jobName"] + quantum_job.download_result() + + with open(f"{job_name}/results.json", "r") as file: + actual_data = json.loads(file.read())["dataDictionary"] + assert expected_saved_data == actual_data + + +def test_download_result_when_extract_path_provided( + quantum_job, generate_get_job_response, aws_session, result_setup +): + expected_saved_data = {"converged": True, "energy": -0.2} + state = "COMPLETED" + get_job_response_completed = generate_get_job_response(status=state) + aws_session.get_job.return_value = get_job_response_completed + job_metadata = quantum_job.metadata(True) + job_name = job_metadata["jobName"] + + with tempfile.TemporaryDirectory() as temp_dir: + quantum_job.download_result(temp_dir) + + with open(f"{temp_dir}/{job_name}/results.json", "r") as file: + actual_data = json.loads(file.read())["dataDictionary"] + assert expected_saved_data == actual_data + + +def test_empty_dict_returned_when_result_not_saved( + quantum_job, generate_get_job_response, aws_session +): + state = "COMPLETED" + get_job_response_completed = generate_get_job_response(status=state) + aws_session.get_job.return_value = get_job_response_completed + + exception_response = { + "Error": { + "Code": "404", + "Message": "Not Found", + } + } + quantum_job._aws_session.download_from_s3 = Mock( + side_effect=ClientError(exception_response, "HeadObject") + ) + assert quantum_job.result() == {} + + +def test_results_not_in_s3_for_download(quantum_job, generate_get_job_response, aws_session): + state = "COMPLETED" + get_job_response_completed = generate_get_job_response(status=state) + aws_session.get_job.return_value = get_job_response_completed + job_metadata = quantum_job.metadata(True) + output_s3_path = job_metadata["outputDataConfig"]["s3Path"] + + error_message = f"Error retrieving results, could not find results at '{output_s3_path}" + + exception_response = { + "Error": { + "Code": "404", + "Message": "Not Found", + } + } + quantum_job._aws_session.download_from_s3 = Mock( + side_effect=ClientError(exception_response, "HeadObject") + ) + with pytest.raises(ClientError, match=error_message): + quantum_job.download_result() + + +def test_results_raises_error_for_non_404_errors( + quantum_job, generate_get_job_response, aws_session +): + state = "COMPLETED" + get_job_response_completed = generate_get_job_response(status=state) + aws_session.get_job.return_value = get_job_response_completed + + error = "An error occurred \\(402\\) when calling the SomeObject operation: Something" + + exception_response = { + "Error": { + "Code": "402", + "Message": "Something", + } + } + quantum_job._aws_session.download_from_s3 = Mock( + side_effect=ClientError(exception_response, "SomeObject") + ) + with pytest.raises(ClientError, match=error): + quantum_job.result() + + +@patch("braket.aws.aws_quantum_job.AwsQuantumJob.download_result") +def test_results_json_file_not_in_tar( + result_download, quantum_job, aws_session, generate_get_job_response +): + state = "COMPLETED" + get_job_response_completed = generate_get_job_response(status=state) + quantum_job._aws_session.get_job.return_value = get_job_response_completed + assert quantum_job.result() == {} + + +@pytest.fixture +def entry_point(): + return "test-source-module.entry_point:func" + + +@pytest.fixture +def bucket(): + return "braket-region-id" + + +@pytest.fixture( + params=[ + None, + "aws.location/custom-jobs:tag.1.2.3", + "other.uri/custom-name:tag", + "other-custom-format.com", + ] +) +def image_uri(request): + return request.param + + +@pytest.fixture(params=["given_job_name", "default_job_name"]) +def job_name(request): + if request.param == "given_job_name": + return "test-job-name" + + +@pytest.fixture +def s3_prefix(job_name): + return f"{job_name}/non-default" + + +@pytest.fixture(params=["local_source", "s3_source"]) +def source_module(request, bucket, s3_prefix): + if request.param == "local_source": + return "test-source-module" + elif request.param == "s3_source": + return AwsSession.construct_s3_uri(bucket, "test-source-prefix", "source.tar.gz") + + +@pytest.fixture +def role_arn(): + return "arn:aws:iam::0000000000:role/AmazonBraketInternalSLR" + + +@pytest.fixture +def device_arn(): + return "arn:aws:braket:::device/qpu/test/device-name" + + +@pytest.fixture +def prepare_job_args(aws_session): + return { + "device": Mock(), + "source_module": Mock(), + "entry_point": Mock(), + "image_uri": Mock(), + "job_name": Mock(), + "code_location": Mock(), + "role_arn": Mock(), + "hyperparameters": Mock(), + "input_data": Mock(), + "instance_config": Mock(), + "stopping_condition": Mock(), + "output_data_config": Mock(), + "copy_checkpoints_from_job": Mock(), + "checkpoint_config": Mock(), + "aws_session": aws_session, + "tags": Mock(), + } + + +def test_str(quantum_job): + expected = f"AwsQuantumJob('arn':'{quantum_job.arn}')" + assert str(quantum_job) == expected + + +def test_arn(quantum_job_arn, aws_session): + quantum_job = AwsQuantumJob(quantum_job_arn, aws_session) + assert quantum_job.arn == quantum_job_arn + + +def test_name(quantum_job_arn, quantum_job_name, aws_session): + quantum_job = AwsQuantumJob(quantum_job_arn, aws_session) + assert quantum_job.name == quantum_job_name + + +@pytest.mark.xfail(raises=AttributeError) +def test_no_arn_setter(quantum_job): + quantum_job.arn = 123 + + +@pytest.mark.parametrize("wait_until_complete", [True, False]) +@patch("braket.aws.aws_quantum_job.AwsQuantumJob.logs") +@patch("braket.aws.aws_quantum_job.prepare_quantum_job") +def test_create_job( + mock_prepare_quantum_job, + mock_logs, + aws_session, + prepare_job_args, + quantum_job_arn, + wait_until_complete, +): + test_response_args = {"testArgs": "MyTestArg"} + mock_prepare_quantum_job.return_value = test_response_args + job = AwsQuantumJob.create(wait_until_complete=wait_until_complete, **prepare_job_args) + mock_prepare_quantum_job.assert_called_with(**prepare_job_args) + aws_session.create_job.assert_called_with(**test_response_args) + if wait_until_complete: + mock_logs.assert_called_once() + else: + mock_logs.assert_not_called() + assert job.arn == quantum_job_arn + + +def test_create_fake_arg(): + unexpected_kwarg = "create\\(\\) got an unexpected keyword argument 'fake_arg'" + with pytest.raises(TypeError, match=unexpected_kwarg): + AwsQuantumJob.create( + device="device", + source_module="source", + fake_arg="fake_value", + ) + + +def test_cancel_job(quantum_job_arn, aws_session, generate_cancel_job_response): + cancellation_status = "CANCELLING" + aws_session.cancel_job.return_value = generate_cancel_job_response( + cancellationStatus=cancellation_status + ) + quantum_job = AwsQuantumJob(quantum_job_arn, aws_session) + status = quantum_job.cancel() + aws_session.cancel_job.assert_called_with(quantum_job_arn) + assert status == cancellation_status + + +@pytest.mark.xfail(raises=ClientError) +def test_cancel_job_surfaces_exception(quantum_job, aws_session): + exception_response = { + "Error": { + "Code": "ValidationException", + "Message": "unit-test-error", + } + } + aws_session.cancel_job.side_effect = ClientError(exception_response, "cancel_job") + quantum_job.cancel() + + +@pytest.mark.parametrize( + "generate_get_job_response_kwargs", + [ + { + "status": "RUNNING", + }, + { + "status": "COMPLETED", + }, + { + "status": "COMPLETED", + "startedAt": datetime.datetime(2021, 1, 1, 1, 0, 0, 0), + }, + {"status": "COMPLETED", "endedAt": datetime.datetime(2021, 1, 1, 1, 0, 0, 0)}, + { + "status": "COMPLETED", + "startedAt": datetime.datetime(2021, 1, 1, 1, 0, 0, 0), + "endedAt": datetime.datetime(2021, 1, 1, 1, 0, 0, 0), + }, + ], +) +@patch( + "braket.jobs.metrics_data.cwl_insights_metrics_fetcher." + "CwlInsightsMetricsFetcher.get_metrics_for_job" +) +def test_metrics( + metrics_fetcher_mock, + quantum_job, + aws_session, + generate_get_job_response, + generate_get_job_response_kwargs, +): + get_job_response_running = generate_get_job_response(**generate_get_job_response_kwargs) + aws_session.get_job.return_value = get_job_response_running + + expected_metrics = {"Test": [1]} + metrics_fetcher_mock.return_value = expected_metrics + metrics = quantum_job.metrics() + assert metrics == expected_metrics + + +@pytest.fixture +def log_stream_responses(): + return ( + ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + "Message": "This shouldn't get raised...", + } + }, + "DescribeLogStreams", + ), + {"logStreams": []}, + {"logStreams": [{"logStreamName": "stream-1"}]}, + ) + + +@pytest.fixture +def log_events_responses(): + return ( + {"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]}, + {"nextForwardToken": None, "events": []}, + { + "nextForwardToken": None, + "events": [ + {"timestamp": 1, "message": "hi there #1"}, + {"timestamp": 2, "message": "hi there #2"}, + ], + }, + {"nextForwardToken": None, "events": []}, + { + "nextForwardToken": None, + "events": [ + {"timestamp": 2, "message": "hi there #2"}, + {"timestamp": 2, "message": "hi there #2a"}, + {"timestamp": 3, "message": "hi there #3"}, + ], + }, + {"nextForwardToken": None, "events": []}, + ) + + +def test_logs( + quantum_job, + generate_get_job_response, + log_events_responses, + log_stream_responses, + capsys, +): + quantum_job._aws_session.get_job.side_effect = ( + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + ) + quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses + quantum_job._aws_session.get_log_events.side_effect = log_events_responses + + quantum_job.logs(wait=True, poll_interval_seconds=0) + + captured = capsys.readouterr() + assert captured.out == "\n".join( + ( + "..", + "hi there #1", + "hi there #2", + "hi there #2a", + "hi there #3", + "", + ) + ) + + +@patch.dict("os.environ", {"JPY_PARENT_PID": "True"}) +def test_logs_multiple_instances( + quantum_job, + generate_get_job_response, + log_events_responses, + log_stream_responses, + capsys, +): + quantum_job._aws_session.get_job.side_effect = ( + generate_get_job_response(status="RUNNING", instanceConfig={"instanceCount": 2}), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + ) + log_stream_responses[-1]["logStreams"].append({"logStreamName": "stream-2"}) + quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses + + event_counts = { + "stream-1": 0, + "stream-2": 0, + } + + def get_log_events(log_group, log_stream, start_time, start_from_head, next_token): + log_events_dict = { + "stream-1": log_events_responses, + "stream-2": log_events_responses, + } + log_events_dict["stream-1"] += ( + { + "nextForwardToken": None, + "events": [], + }, + { + "nextForwardToken": None, + "events": [], + }, + ) + log_events_dict["stream-2"] += ( + { + "nextForwardToken": None, + "events": [ + {"timestamp": 3, "message": "hi there #3"}, + {"timestamp": 4, "message": "hi there #4"}, + ], + }, + { + "nextForwardToken": None, + "events": [], + }, + ) + event_counts[log_stream] += 1 + return log_events_dict[log_stream][event_counts[log_stream]] + + quantum_job._aws_session.get_log_events.side_effect = get_log_events + + quantum_job.logs(wait=True, poll_interval_seconds=0) + + captured = capsys.readouterr() + assert captured.out == "\n".join( + ( + "..", + "\x1b[34mhi there #1\x1b[0m", + "\x1b[35mhi there #1\x1b[0m", + "\x1b[34mhi there #2\x1b[0m", + "\x1b[35mhi there #2\x1b[0m", + "\x1b[34mhi there #2a\x1b[0m", + "\x1b[35mhi there #2a\x1b[0m", + "\x1b[34mhi there #3\x1b[0m", + "\x1b[35mhi there #3\x1b[0m", + "\x1b[35mhi there #4\x1b[0m", + "", + ) + ) + + +def test_logs_error(quantum_job, generate_get_job_response, capsys): + quantum_job._aws_session.get_job.side_effect = ( + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + ) + quantum_job._aws_session.describe_log_streams.side_effect = ( + ClientError( + { + "Error": { + "Code": "UnknownCode", + "Message": "Some error message", + } + }, + "DescribeLogStreams", + ), + ) + + with pytest.raises(ClientError, match="Some error message"): + quantum_job.logs(wait=True, poll_interval_seconds=0) + + +def test_initialize_session_for_valid_device(device_arn, aws_session, caplog): + first_region = aws_session.region + logger = logging.getLogger(__name__) + + aws_session.get_device.side_effect = [ + ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ), + ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ), + device_arn, + ] + + caplog.set_level(logging.INFO) + AwsQuantumJob._initialize_session(aws_session, device_arn, logger) + + assert f"Changed session region from '{first_region}' to '{aws_session.region}'" in caplog.text + + +def test_initialize_session_for_invalid_device(aws_session, device_arn): + logger = logging.getLogger(__name__) + aws_session.get_device.side_effect = ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ) + + device_not_found = "QPU 'arn:aws:braket:::device/qpu/test/device-name' not found." + with pytest.raises(ValueError, match=device_not_found): + AwsQuantumJob._initialize_session(aws_session, device_arn, logger) + + +def test_no_region_routing_simulator(aws_session): + logger = logging.getLogger(__name__) + + aws_session.get_device.side_effect = ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ) + + device_arn = "arn:aws:braket:::device/simulator/test/device-name" + device_not_found = f"Simulator '{device_arn}' not found in 'us-test-1'" + with pytest.raises(ValueError, match=device_not_found): + AwsQuantumJob._initialize_session(aws_session, device_arn, logger) + + +def test_exception_in_credentials_session_region(device_arn, aws_session): + logger = logging.getLogger(__name__) + + aws_session.get_device.side_effect = ClientError( + { + "Error": { + "Code": "SomeOtherErrorMessage", + } + }, + "getDevice", + ) + + error_message = ( + "An error occurred \\(SomeOtherErrorMessage\\) " + "when calling the getDevice operation: Unknown" + ) + with pytest.raises(ClientError, match=error_message): + AwsQuantumJob._initialize_session(aws_session, device_arn, logger) + + +def test_exceptions_in_all_device_regions(device_arn, aws_session): + logger = logging.getLogger(__name__) + + aws_session.get_device.side_effect = [ + ClientError( + { + "Error": { + "Code": "ResourceNotFoundException", + } + }, + "getDevice", + ), + ClientError( + { + "Error": { + "Code": "SomeOtherErrorMessage", + } + }, + "getDevice", + ), + ] + + error_message = ( + "An error occurred \\(SomeOtherErrorMessage\\) " + "when calling the getDevice operation: Unknown" + ) + with pytest.raises(ClientError, match=error_message): + AwsQuantumJob._initialize_session(aws_session, device_arn, logger) diff --git a/test/unit_tests/braket/aws/test_aws_session.py b/test/unit_tests/braket/aws/test_aws_session.py index 8735690c8..5032eed2b 100644 --- a/test/unit_tests/braket/aws/test_aws_session.py +++ b/test/unit_tests/braket/aws/test_aws_session.py @@ -12,10 +12,16 @@ # language governing permissions and limitations under the License. import json +import os +import tempfile +import time +from pathlib import Path from unittest.mock import MagicMock, Mock, patch +import boto3 import pytest from botocore.exceptions import ClientError +from botocore.stub import Stubber import braket._schemas as braket_schemas import braket._sdk as braket_sdk @@ -36,19 +42,118 @@ def boto_session(): @pytest.fixture -def aws_session(boto_session): - return AwsSession(boto_session=boto_session, braket_client=Mock()) +def braket_client(): + _braket_client = Mock() + _braket_client.meta.region_name = "us-west-2" + return _braket_client + + +@pytest.fixture +def aws_session(boto_session, braket_client, account_id): + _aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) + + _aws_session._sts = Mock() + _aws_session._sts.get_caller_identity.return_value = { + "Account": account_id, + } + + _aws_session._s3 = Mock() + return _aws_session + + +@pytest.fixture +def aws_explicit_session(): + _boto_session = Mock() + _boto_session.region_name = "us-test-1" + + creds = Mock() + creds.access_key = "access key" + creds.secret_key = "secret key" + creds.token = "token" + creds.method = "explicit" + _boto_session.get_credentials.return_value = creds + + _aws_session = Mock() + _aws_session.boto_session = _boto_session + _aws_session._default_bucket = "amazon-braket-us-test-1-00000000" + _aws_session.default_bucket.return_value = _aws_session._default_bucket + _aws_session._custom_default_bucket = False + _aws_session.account_id = "00000000" + _aws_session.region = "us-test-1" + return _aws_session + + +@pytest.fixture +def account_id(): + return "000000000" + + +@pytest.fixture +def job_role_name(): + return "AmazonBraketJobsExecutionRole-134534514345" + + +@pytest.fixture +def job_role_arn(job_role_name): + return f"arn:aws:iam::0000000000:role/{job_role_name}" + + +@pytest.fixture +def get_job_response(): + return { + "algorithmSpecification": { + "scriptModeConfig": { + "entryPoint": "my_file:start_here", + "s3Uri": "s3://amazon-braket-jobs/job-path/my_file.py", + } + }, + "checkpointConfig": { + "localPath": "/opt/omega/checkpoints", + "s3Uri": "s3://amazon-braket-jobs/job-path/checkpoints", + }, + "instanceConfig": { + "instanceCount": 1, + "instanceType": "ml.m5.large", + "volumeSizeInGb": 1, + }, + "jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446", + "jobName": "job-test-20210628140446", + "outputDataConfig": {"s3Path": "s3://amazon-braket-jobs/job-path/output"}, + "roleArn": "arn:aws:iam::875981177017:role/AmazonBraketJobRole", + "status": "RUNNING", + } + + +@pytest.fixture +def resource_not_found_response(): + return { + "Error": { + "Code": "ResourceNotFoundException", + "Message": "unit-test-error", + } + } + + +@pytest.fixture +def throttling_response(): + return { + "Error": { + "Code": "ThrottlingException", + "Message": "unit-test-error", + } + } def test_initializes_boto_client_if_required(boto_session): AwsSession(boto_session=boto_session) - boto_session.client.assert_called_with("braket", config=None) + boto_session.client.assert_any_call("braket", config=None) -def test_uses_supplied_braket_client(): +def test_user_supplied_braket_client(): boto_session = Mock() boto_session.region_name = "foobar" braket_client = Mock() + braket_client.meta.region_name = "foobar" aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) assert aws_session.braket_client == braket_client @@ -56,7 +161,85 @@ def test_uses_supplied_braket_client(): def test_config(boto_session): config = Mock() AwsSession(boto_session=boto_session, config=config) - boto_session.client.assert_called_with("braket", config=config) + boto_session.client.assert_any_call("braket", config=config) + + +def test_region(): + boto_region = "boto-region" + braket_region = "braket-region" + + boto_session = Mock() + boto_session.region_name = boto_region + braket_client = Mock() + braket_client.meta.region_name = braket_region + + assert ( + AwsSession( + boto_session=boto_session, + ).region + == boto_region + ) + + assert ( + AwsSession( + braket_client=braket_client, + ).region + == braket_region + ) + + regions_must_match = ( + "Boto Session region and Braket Client region must match and currently " + "they do not: Boto Session region is 'boto-region', but " + "Braket Client region is 'braket-region'." + ) + with pytest.raises(ValueError, match=regions_must_match): + AwsSession( + boto_session=boto_session, + braket_client=braket_client, + ) + + +def test_iam(aws_session): + aws_session._iam = Mock() + assert aws_session.iam_client + aws_session.boto_session.client.assert_not_called() + aws_session._iam = None + assert aws_session.iam_client + aws_session.boto_session.client.assert_called_with("iam", region_name="us-west-2") + + +def test_s3(aws_session): + assert aws_session.s3_client + aws_session.boto_session.client.assert_not_called() + aws_session._s3 = None + assert aws_session.s3_client + aws_session.boto_session.client.assert_called_with("s3", region_name="us-west-2") + + +def test_sts(aws_session): + assert aws_session.sts_client + aws_session.boto_session.client.assert_not_called() + aws_session._sts = None + assert aws_session.sts_client + aws_session.boto_session.client.assert_called_with("sts", region_name="us-west-2") + + +def test_logs(aws_session): + aws_session._logs = Mock() + assert aws_session.logs_client + aws_session.boto_session.client.assert_not_called() + aws_session._logs = None + assert aws_session.logs_client + aws_session.boto_session.client.assert_called_with("logs", region_name="us-west-2") + + +def test_ecr(aws_session): + aws_session._ecr = Mock() + assert aws_session.ecr_client + aws_session.boto_session.client.assert_not_called() + aws_session._ecr = None + assert aws_session.ecr_client + aws_session.boto_session.client.assert_called_with("ecr", region_name="us-west-2") @patch("os.path.exists") @@ -75,6 +258,7 @@ def test_populates_user_agent(os_path_exists_mock, metadata_file_exists, initial boto_session = Mock() boto_session.region_name = "foobar" braket_client = Mock() + braket_client.meta.region_name = "foobar" braket_client._client_config.user_agent = initial_user_agent nbi_metadata_path = "/opt/ml/metadata/resource-metadata.json" os_path_exists_mock.return_value = metadata_file_exists @@ -131,8 +315,7 @@ def test_retrieve_s3_object_body_client_error(boto_session): aws_session.retrieve_s3_object_body(bucket_name, filename) -def test_get_device(boto_session): - braket_client = Mock() +def test_get_device(boto_session, braket_client): return_val = {"deviceArn": "arn1", "deviceName": "name1"} braket_client.get_device.return_value = return_val aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) @@ -155,13 +338,30 @@ def test_create_quantum_task(aws_session): kwargs = { "backendArn": "arn:aws:us-west-2:abc:xyz:abc", "cwLogGroupArn": "arn:aws:us-west-2:abc:xyz:abc", - "destinationUrl": "http://s3-us-west-2.amazonaws.com/task-output-derebolt-1/output.json", + "destinationUrl": "http://s3-us-west-2.amazonaws.com/task-output-bar-1/output.json", "program": {"ir": '{"instructions":[]}', "qubitCount": 4}, } assert aws_session.create_quantum_task(**kwargs) == arn aws_session.braket_client.create_quantum_task.assert_called_with(**kwargs) +def test_create_quantum_task_with_job_token(aws_session): + arn = "arn:aws:braket:us-west-2:1234567890:task/task-name" + job_token = "arn:aws:braket:us-west-2:1234567890:job/job-name" + aws_session.braket_client.create_quantum_task.return_value = {"quantumTaskArn": arn} + + kwargs = { + "backendArn": "arn:aws:us-west-2:abc:xyz:abc", + "cwLogGroupArn": "arn:aws:us-west-2:abc:xyz:abc", + "destinationUrl": "http://s3-us-west-2.amazonaws.com/task-output-foo-1/output.json", + "program": {"ir": '{"instructions":[]}', "qubitCount": 4}, + } + with patch.dict(os.environ, {"AMZN_BRAKET_JOB_TOKEN": job_token}): + assert aws_session.create_quantum_task(**kwargs) == arn + kwargs.update({"jobToken": job_token}) + aws_session.braket_client.create_quantum_task.assert_called_with(**kwargs) + + def test_get_quantum_task(aws_session): arn = "foo:bar:arn" return_value = {"quantumTaskArn": arn} @@ -171,23 +371,10 @@ def test_get_quantum_task(aws_session): aws_session.braket_client.get_quantum_task.assert_called_with(quantumTaskArn=arn) -def test_get_quantum_task_retry(aws_session): +def test_get_quantum_task_retry(aws_session, throttling_response, resource_not_found_response): arn = "foo:bar:arn" return_value = {"quantumTaskArn": arn} - resource_not_found_response = { - "Error": { - "Code": "ResourceNotFoundException", - "Message": "unit-test-error", - } - } - throttling_response = { - "Error": { - "Code": "ThrottlingException", - "Message": "unit-test-error", - } - } - aws_session.braket_client.get_quantum_task.side_effect = [ ClientError(resource_not_found_response, "unit-test"), ClientError(throttling_response, "unit-test"), @@ -196,35 +383,81 @@ def test_get_quantum_task_retry(aws_session): assert aws_session.get_quantum_task(arn) == return_value aws_session.braket_client.get_quantum_task.assert_called_with(quantumTaskArn=arn) - aws_session.braket_client.get_quantum_task.call_count == 3 + assert aws_session.braket_client.get_quantum_task.call_count == 3 -def test_get_quantum_task_fail_after_retries(aws_session): - resource_not_found_response = { - "Error": { - "Code": "ResourceNotFoundException", - "Message": "unit-test-error", - } - } - throttling_response = { +def test_get_quantum_task_fail_after_retries( + aws_session, throttling_response, resource_not_found_response +): + aws_session.braket_client.get_quantum_task.side_effect = [ + ClientError(resource_not_found_response, "unit-test"), + ClientError(throttling_response, "unit-test"), + ClientError(throttling_response, "unit-test"), + ] + + with pytest.raises(ClientError): + aws_session.get_quantum_task("some-arn") + assert aws_session.braket_client.get_quantum_task.call_count == 3 + + +def test_get_quantum_task_does_not_retry_other_exceptions(aws_session): + exception_response = { "Error": { - "Code": "ThrottlingException", + "Code": "SomeOtherException", "Message": "unit-test-error", } } aws_session.braket_client.get_quantum_task.side_effect = [ + ClientError(exception_response, "unit-test"), + ] + + with pytest.raises(ClientError): + aws_session.get_quantum_task("some-arn") + assert aws_session.braket_client.get_quantum_task.call_count == 1 + + +def test_get_job(aws_session, get_job_response): + arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + aws_session.braket_client.get_job.return_value = get_job_response + + assert aws_session.get_job(arn) == get_job_response + aws_session.braket_client.get_job.assert_called_with(jobArn=arn) + + +def test_get_job_retry( + aws_session, get_job_response, throttling_response, resource_not_found_response +): + arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + + aws_session.braket_client.get_job.side_effect = [ + ClientError(resource_not_found_response, "unit-test"), + ClientError(throttling_response, "unit-test"), + get_job_response, + ] + + assert aws_session.get_job(arn) == get_job_response + aws_session.braket_client.get_job.assert_called_with(jobArn=arn) + assert aws_session.braket_client.get_job.call_count == 3 + + +def test_get_job_fail_after_retries(aws_session, throttling_response, resource_not_found_response): + arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + + aws_session.braket_client.get_job.side_effect = [ ClientError(resource_not_found_response, "unit-test"), ClientError(throttling_response, "unit-test"), ClientError(throttling_response, "unit-test"), ] with pytest.raises(ClientError): - aws_session.get_quantum_task("some-arn") - aws_session.braket_client.get_quantum_task.call_count == 3 + aws_session.get_job(arn) + aws_session.braket_client.get_job.assert_called_with(jobArn=arn) + assert aws_session.braket_client.get_job.call_count == 3 -def test_get_quantum_task_does_not_retry_other_exceptions(aws_session): +def test_get_job_does_not_retry_other_exceptions(aws_session): + arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" exception_response = { "Error": { "Code": "SomeOtherException", @@ -232,13 +465,60 @@ def test_get_quantum_task_does_not_retry_other_exceptions(aws_session): } } - aws_session.braket_client.get_quantum_task.side_effect = [ + aws_session.braket_client.get_job.side_effect = [ ClientError(exception_response, "unit-test"), ] with pytest.raises(ClientError): - aws_session.get_quantum_task("some-arn") - aws_session.braket_client.get_quantum_task.call_count == 1 + aws_session.get_job(arn) + aws_session.braket_client.get_job.assert_called_with(jobArn=arn) + assert aws_session.braket_client.get_job.call_count == 1 + + +def test_cancel_job(aws_session): + arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + cancel_job_response = { + "ResponseMetadata": { + "RequestId": "857b0893-2073-4ad6-b828-744af8400dfe", + "HTTPStatusCode": 200, + }, + "cancellationStatus": "CANCELLING", + "jobArn": "arn:aws:braket:us-west-2:1234567890:job/job-name", + } + aws_session.braket_client.cancel_job.return_value = cancel_job_response + + assert aws_session.cancel_job(arn) == cancel_job_response + aws_session.braket_client.cancel_job.assert_called_with(jobArn=arn) + + +@pytest.mark.parametrize( + "exception_type", + [ + "ResourceNotFoundException", + "ValidationException", + "AccessDeniedException", + "ThrottlingException", + "InternalServiceException", + "ConflictException", + ], +) +def test_cancel_job_surfaces_errors(exception_type, aws_session): + arn = "arn:aws:braket:us-west-2:1234567890:job/job-name" + exception_response = { + "Error": { + "Code": "SomeOtherException", + "Message": "unit-test-error", + } + } + + aws_session.braket_client.cancel_job.side_effect = [ + ClientError(exception_response, "unit-test"), + ] + + with pytest.raises(ClientError): + aws_session.cancel_job(arn) + aws_session.braket_client.cancel_job.assert_called_with(jobArn=arn) + assert aws_session.braket_client.cancel_job.call_count == 1 @pytest.mark.parametrize( @@ -431,3 +711,541 @@ def test_search_devices_arns(aws_session): ], PaginationConfig={"MaxItems": 100}, ) + + +def test_create_job(aws_session): + arn = "foo:bar:arn" + aws_session.braket_client.create_job.return_value = {"jobArn": arn} + + kwargs = { + "jobName": "job-name", + "roleArn": "role-arn", + "algorithmSpecification": { + "scriptModeConfig": { + "entryPoint": "entry-point", + "s3Uri": "s3-uri", + "compressionType": "GZIP", + } + }, + } + assert aws_session.create_job(**kwargs) == arn + aws_session.braket_client.create_job.assert_called_with(**kwargs) + + +@pytest.mark.parametrize( + "string, valid", + ( + ("s3://bucket/key", True), + ("S3://bucket/key", True), + ("https://bucket-name-123.s3.us-west-2.amazonaws.com/key/with/dirs", True), + ("https://bucket-name-123.S3.us-west-2.amazonaws.com/key/with/dirs", True), + ("https://bucket-name-123.S3.us-west-2.amazonaws.com/", False), + ("https://bucket-name-123.S3.us-west-2.amazonaws.com", False), + ("https://S3.us-west-2.amazonaws.com", False), + ("s3://bucket/", False), + ("s3://bucket", False), + ("s3://////", False), + ("http://bucket/key", False), + ("bucket/key", False), + ), +) +def test_is_s3_uri(string, valid): + assert AwsSession.is_s3_uri(string) == valid + + +@pytest.mark.parametrize( + "uri, bucket, key", + ( + ( + "s3://bucket-name-123/key/with/multiple/dirs", + "bucket-name-123", + "key/with/multiple/dirs", + ), + ( + "s3://bucket-name-123/key-with_one.dirs", + "bucket-name-123", + "key-with_one.dirs", + ), + ( + "https://bucket-name-123.s3.us-west-2.amazonaws.com/key/with/dirs", + "bucket-name-123", + "key/with/dirs", + ), + ( + "https://bucket-name-123.S3.us-west-2.amazonaws.com/key/with/dirs", + "bucket-name-123", + "key/with/dirs", + ), + ), +) +def test_parse_s3_uri(uri, bucket, key): + assert bucket, key == AwsSession.parse_s3_uri(uri) + + +@pytest.mark.parametrize( + "uri", + ( + "s3://bucket.name-123/key-with_one.dirs", + "http://bucket-name-123/key/with/multiple/dirs", + "bucket-name-123/key/with/multiple/dirs", + "s3://bucket-name-123/", + "s3://bucket-name-123", + ), +) +def test_parse_s3_uri_invalid(uri): + with pytest.raises(ValueError, match=f"Not a valid S3 uri: {uri}"): + AwsSession.parse_s3_uri(uri) + + +@pytest.mark.parametrize( + "bucket, dirs", + [ + ("bucket", ("d1", "d2", "d3")), + ("bucket-123-braket", ("dir",)), + pytest.param( + "braket", + (), + marks=pytest.mark.xfail(raises=ValueError, strict=True), + ), + ], +) +def test_construct_s3_uri(bucket, dirs): + parsed_bucket, parsed_key = AwsSession.parse_s3_uri(AwsSession.construct_s3_uri(bucket, *dirs)) + assert parsed_bucket == bucket + assert parsed_key == "/".join(dirs) + + +def test_get_default_jobs_role(aws_session, job_role_arn, job_role_name): + iam_client = boto3.client("iam") + with Stubber(iam_client) as stub: + stub.add_response( + "list_roles", + { + "Roles": [ + { + "Arn": "arn:aws:iam::0000000000:role/nonJobsRole", + "RoleName": "nonJobsRole", + "Path": "/", + "RoleId": "nonJobsRole-213453451345-431513", + "CreateDate": time.time(), + } + ] + * 100, + "IsTruncated": True, + "Marker": "resp-marker", + }, + ) + stub.add_response( + "list_roles", + { + "Roles": [ + { + "Arn": job_role_arn, + "RoleName": job_role_name, + "Path": "/", + "RoleId": f"{job_role_name}-213453451345-431513", + "CreateDate": time.time(), + } + ], + "IsTruncated": False, + }, + {"Marker": "resp-marker"}, + ) + aws_session._iam = iam_client + assert aws_session.get_default_jobs_role() == job_role_arn + + +def test_get_default_jobs_role_not_found(aws_session, job_role_arn, job_role_name): + iam_client = boto3.client("iam") + with Stubber(iam_client) as stub: + stub.add_response( + "list_roles", + { + "Roles": [ + { + "Arn": "arn:aws:iam::0000000000:role/nonJobsRole", + "RoleName": "nonJobsRole", + "Path": "/", + "RoleId": "nonJobsRole-213453451345-431513", + "CreateDate": time.time(), + } + ] + * 100, + "IsTruncated": True, + "Marker": "resp-marker", + }, + ) + stub.add_response( + "list_roles", + { + "Roles": [ + { + "Arn": "arn:aws:iam::0000000000:role/nonJobsRole2", + "RoleName": "nonJobsRole2", + "Path": "/", + "RoleId": "nonJobsRole2-213453451345-431513", + "CreateDate": time.time(), + } + ], + "IsTruncated": False, + }, + {"Marker": "resp-marker"}, + ) + aws_session._iam = iam_client + with pytest.raises(RuntimeError): + aws_session.get_default_jobs_role() + + +def test_upload_to_s3(aws_session): + filename = "file.txt" + s3_uri = "s3://bucket-123/key" + bucket, key = "bucket-123", "key" + aws_session.upload_to_s3(filename, s3_uri) + aws_session._s3.upload_file.assert_called_with(filename, bucket, key) + + +def test_upload_local_data(aws_session): + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + + Path("input-dir", "pref-dir", "sub-pref-dir").mkdir(parents=True) + Path("input-dir", "not-pref-dir").mkdir() + + # these should all get uploaded + Path("input-dir", "pref-dir", "sub-pref-dir", "very-nested.txt").touch() + Path("input-dir", "pref-dir", "nested.txt").touch() + Path("input-dir", "pref.txt").touch() + Path("input-dir", "pref-and-more.txt").touch() + + # these should not + Path("input-dir", "false-pref.txt").touch() + Path("input-dir", "not-pref-dir", "pref-fake.txt").touch() + + aws_session.upload_to_s3 = Mock() + aws_session.upload_local_data("input-dir/pref", "s3://bucket/pref") + call_args = {args for args, kwargs in aws_session.upload_to_s3.call_args_list} + assert call_args == { + ( + str(Path("input-dir", "pref-dir", "sub-pref-dir", "very-nested.txt")), + "s3://bucket/pref-dir/sub-pref-dir/very-nested.txt", + ), + (str(Path("input-dir", "pref-dir", "nested.txt")), "s3://bucket/pref-dir/nested.txt"), + (str(Path("input-dir", "pref.txt")), "s3://bucket/pref.txt"), + (str(Path("input-dir", "pref-and-more.txt")), "s3://bucket/pref-and-more.txt"), + } + os.chdir("..") + + +def test_upload_local_data_absolute(aws_session): + with tempfile.TemporaryDirectory() as temp_dir: + Path(temp_dir, "input-dir", "pref-dir", "sub-pref-dir").mkdir(parents=True) + Path(temp_dir, "input-dir", "not-pref-dir").mkdir() + + # these should all get uploaded + Path(temp_dir, "input-dir", "pref-dir", "sub-pref-dir", "very-nested.txt").touch() + Path(temp_dir, "input-dir", "pref-dir", "nested.txt").touch() + Path(temp_dir, "input-dir", "pref.txt").touch() + Path(temp_dir, "input-dir", "pref-and-more.txt").touch() + + # these should not + Path(temp_dir, "input-dir", "false-pref.txt").touch() + Path(temp_dir, "input-dir", "not-pref-dir", "pref-fake.txt").touch() + + aws_session.upload_to_s3 = Mock() + aws_session.upload_local_data(str(Path(temp_dir, "input-dir", "pref")), "s3://bucket/pref") + call_args = {args for args, kwargs in aws_session.upload_to_s3.call_args_list} + assert call_args == { + ( + str(Path(temp_dir, "input-dir", "pref-dir", "sub-pref-dir", "very-nested.txt")), + "s3://bucket/pref-dir/sub-pref-dir/very-nested.txt", + ), + ( + str(Path(temp_dir, "input-dir", "pref-dir", "nested.txt")), + "s3://bucket/pref-dir/nested.txt", + ), + (str(Path(temp_dir, "input-dir", "pref.txt")), "s3://bucket/pref.txt"), + ( + str(Path(temp_dir, "input-dir", "pref-and-more.txt")), + "s3://bucket/pref-and-more.txt", + ), + } + + +def test_download_from_s3(aws_session): + filename = "model.tar.gz" + s3_uri = ( + "s3://amazon-braket-jobs/job-path/output/" + "BraketJob-875981177017-job-test-20210628140446/output/model.tar.gz" + ) + bucket, key = ( + "amazon-braket-jobs", + "job-path/output/BraketJob-875981177017-job-test-20210628140446/output/model.tar.gz", + ) + aws_session.download_from_s3(s3_uri, filename) + aws_session._s3.download_file.assert_called_with(bucket, key, filename) + + +def test_copy_identical_s3(aws_session): + s3_uri = "s3://bucket/key" + aws_session.copy_s3_object(s3_uri, s3_uri) + aws_session.boto_session.client.return_value.copy.assert_not_called() + + +def test_copy_s3(aws_session): + source_s3_uri = "s3://here/now" + dest_s3_uri = "s3://there/then" + source_bucket, source_key = AwsSession.parse_s3_uri(source_s3_uri) + dest_bucket, dest_key = AwsSession.parse_s3_uri(dest_s3_uri) + aws_session.copy_s3_object(source_s3_uri, dest_s3_uri) + aws_session._s3.copy.assert_called_with( + { + "Bucket": source_bucket, + "Key": source_key, + }, + dest_bucket, + dest_key, + ) + + +def test_copy_identical_s3_directory(aws_session): + s3_uri = "s3://bucket/prefix/" + aws_session.copy_s3_directory(s3_uri, s3_uri) + aws_session.boto_session.client.return_value.copy.assert_not_called() + + +def test_copy_s3_directory(aws_session): + aws_session.list_keys = Mock(return_value=[f"now/key-{i}" for i in range(5)]) + source_s3_uri = "s3://here/now" + dest_s3_uri = "s3://there/then" + aws_session.copy_s3_directory(source_s3_uri, dest_s3_uri) + for i in range(5): + aws_session.s3_client.copy.assert_any_call( + { + "Bucket": "here", + "Key": f"now/key-{i}", + }, + "there", + f"then/key-{i}", + ) + + +def test_list_keys(aws_session): + bucket, prefix = "bucket", "prefix" + aws_session.s3_client.list_objects_v2.side_effect = [ + { + "IsTruncated": True, + "Contents": [ + {"Key": "copy-test/copy.txt"}, + {"Key": "copy-test/copy2.txt"}, + ], + "NextContinuationToken": "next-continuation-token", + }, + { + "IsTruncated": False, + "Contents": [ + {"Key": "copy-test/nested/double-nested/double-nested.txt"}, + {"Key": "copy-test/nested/nested.txt"}, + ], + }, + ] + keys = aws_session.list_keys(bucket, prefix) + assert keys == [ + "copy-test/copy.txt", + "copy-test/copy2.txt", + "copy-test/nested/double-nested/double-nested.txt", + "copy-test/nested/nested.txt", + ] + + +def test_default_bucket(aws_session, account_id): + region = "test-region-0" + aws_session.boto_session.region_name = region + assert aws_session.default_bucket() == f"amazon-braket-{region}-{account_id}" + + +def test_default_bucket_given(aws_session): + default_bucket = "default_bucket" + aws_session._default_bucket = default_bucket + assert aws_session.default_bucket() == default_bucket + aws_session._s3.create_bucket.assert_not_called() + + +@patch.dict("os.environ", {"AMZN_BRAKET_OUT_S3_BUCKET": "default_bucket_env"}) +def test_default_bucket_env_variable(boto_session, braket_client): + aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) + assert aws_session.default_bucket() == "default_bucket_env" + + +@pytest.mark.parametrize( + "region", + ( + "test-region-0", + "us-east-1", + ), +) +def test_create_s3_bucket_if_it_does_not_exist(aws_session, region, account_id): + bucket = f"amazon-braket-{region}-{account_id}" + aws_session._create_s3_bucket_if_it_does_not_exist(bucket, region) + kwargs = { + "Bucket": bucket, + "CreateBucketConfiguration": { + "LocationConstraint": region, + }, + } + if region == "us-east-1": + del kwargs["CreateBucketConfiguration"] + aws_session._s3.create_bucket.assert_called_with(**kwargs) + aws_session._s3.put_public_access_block.assert_called_with( + Bucket=bucket, + PublicAccessBlockConfiguration={ + "BlockPublicAcls": True, + "IgnorePublicAcls": True, + "BlockPublicPolicy": True, + "RestrictPublicBuckets": True, + }, + ) + + +@pytest.mark.parametrize( + "error", + ( + ClientError( + { + "Error": { + "Code": "BucketAlreadyOwnedByYou", + "Message": "Your previous request to create the named bucket succeeded " + "and you already own it.", + } + }, + "CreateBucket", + ), + ClientError( + { + "Error": { + "Code": "OperationAborted", + "Message": "A conflicting conditional operation is currently in progress " + "against this resource. Please try again.", + } + }, + "CreateBucket", + ), + pytest.param( + ClientError( + { + "Error": { + "Code": "OtherCode", + "Message": "This should fail properly.", + } + }, + "CreateBucket", + ), + marks=pytest.mark.xfail(raises=ClientError, strict=True), + ), + ), +) +def test_create_s3_bucket_if_it_does_not_exist_error(aws_session, error, account_id): + region = "test-region-0" + bucket = f"amazon-braket-{region}-{account_id}" + aws_session._s3.create_bucket.side_effect = error + aws_session._create_s3_bucket_if_it_does_not_exist(bucket, region) + + +@pytest.mark.xfail(raises=ValueError) +def test_bucket_already_exists_for_another_account(aws_session): + exception_response = { + "Error": { + "Code": "BucketAlreadyExists", + "Message": "This should fail properly.", + } + } + bucket_name, region = "some-bucket-123", "test-region" + aws_session._s3.create_bucket.side_effect = ClientError(exception_response, "CreateBucket") + aws_session._create_s3_bucket_if_it_does_not_exist(bucket_name, region) + + +@pytest.mark.parametrize( + "limit, next_token", + ( + (None, None), + (10, None), + (None, "next-token"), + (10, "next-token"), + ), +) +def test_describe_log_streams(aws_session, limit, next_token): + aws_session._logs = Mock() + + log_group = "log_group" + log_stream_prefix = "log_stream_prefix" + + describe_log_stream_args = { + "logGroupName": log_group, + "logStreamNamePrefix": log_stream_prefix, + "orderBy": "LogStreamName", + } + + if limit: + describe_log_stream_args.update({"limit": limit}) + + if next_token: + describe_log_stream_args.update({"nextToken": next_token}) + + aws_session.describe_log_streams(log_group, log_stream_prefix, limit, next_token) + + aws_session._logs.describe_log_streams.assert_called_with(**describe_log_stream_args) + + +@pytest.mark.parametrize( + "next_token", + (None, "next-token"), +) +def test_get_log_events(aws_session, next_token): + aws_session._logs = Mock() + + log_group = "log_group" + log_stream_name = "log_stream_name" + start_time = "timestamp" + start_from_head = True + + log_events_args = { + "logGroupName": log_group, + "logStreamName": log_stream_name, + "startTime": start_time, + "startFromHead": start_from_head, + } + + if next_token: + log_events_args.update({"nextToken": next_token}) + + aws_session.get_log_events(log_group, log_stream_name, start_time, start_from_head, next_token) + + aws_session._logs.get_log_events.assert_called_with(**log_events_args) + + +@patch("boto3.Session") +def test_copy_session(boto_session_init, aws_session): + boto_session_init.return_value = Mock() + copied_session = AwsSession.copy_session(aws_session, "us-west-2") + boto_session_init.assert_called_with(region_name="us-west-2") + assert copied_session._default_bucket is None + + +@patch("boto3.Session") +def test_copy_explicit_session(boto_session_init, aws_explicit_session): + boto_session_init.return_value = Mock() + AwsSession.copy_session(aws_explicit_session, "us-west-2") + boto_session_init.assert_called_with( + aws_access_key_id="access key", + aws_secret_access_key="secret key", + aws_session_token="token", + region_name="us-west-2", + ) + + +@patch("boto3.Session") +def test_copy_session_custom_default_bucket(mock_boto, aws_session): + mock_boto.return_value.region_name = "us-test-1" + aws_session._default_bucket = "my-own-default" + aws_session._custom_default_bucket = True + copied_session = AwsSession.copy_session(aws_session) + assert copied_session._default_bucket == "my-own-default" diff --git a/test/unit_tests/braket/jobs/local/test_local_job.py b/test/unit_tests/braket/jobs/local/test_local_job.py new file mode 100644 index 000000000..185b059e3 --- /dev/null +++ b/test/unit_tests/braket/jobs/local/test_local_job.py @@ -0,0 +1,215 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +from unittest.mock import Mock, mock_open, patch + +import pytest + +from braket.jobs.local.local_job import LocalQuantumJob + + +@pytest.fixture +def aws_session(): + _aws_session = Mock() + return _aws_session + + +@pytest.fixture +def job_results(): + return {"dataFormat": "plaintext", "dataDictionary": {"some_results": {"excellent": "here"}}} + + +@pytest.fixture +def run_log(): + test_log = ( + "This is a multi-line log.\n" + "This is the next line.\n" + "Metrics - timestamp=1633027264.5406773; Cost=-4.034; iteration_number=0;\n" + "Metrics - timestamp=1633027288.6284382; Cost=-3.957; iteration_number=1;\n" + ) + return test_log + + +@pytest.fixture +def test_envs(): + return {"Test": "Env"} + + +@pytest.mark.parametrize( + "creation_kwargs", + [ + ( + { + "jobName": "Test-Job-Name", + "algorithmSpecification": {"containerImage": {"uri": "file://test-URI"}}, + "checkpointConfig": {"localPath": "test/local/path/"}, + } + ), + ( + { + "jobName": "Test-Job-Name", + "algorithmSpecification": {"containerImage": {"uri": "file://test-URI"}}, + "checkpointConfig": {}, + } + ), + ( + { + "jobName": "Test-Job-Name", + "algorithmSpecification": {"containerImage": {"uri": "file://test-URI"}}, + } + ), + ( + { + "jobName": "Test-Job-Name", + "algorithmSpecification": {}, + } + ), + ], +) +@patch("braket.jobs.local.local_job.prepare_quantum_job") +@patch("braket.jobs.local.local_job.retrieve_image") +@patch("braket.jobs.local.local_job.setup_container") +@patch("braket.jobs.local.local_job._LocalJobContainer") +@patch("os.path.isdir") +def test_create( + mock_dir, + mock_container, + mock_setup, + mock_retrieve_image, + mock_prepare_job, + aws_session, + creation_kwargs, + job_results, + run_log, + test_envs, +): + with patch("builtins.open", mock_open()) as file_open: + mock_dir.return_value = False + mock_prepare_job.return_value = creation_kwargs + + mock_container_open = mock_container.return_value.__enter__.return_value + mock_container_open.run_log = run_log + file_read = file_open() + file_read.read.return_value = json.dumps(job_results) + mock_setup.return_value = test_envs + + job = LocalQuantumJob.create( + device=Mock(), + source_module=Mock(), + entry_point=Mock(), + image_uri=Mock(), + job_name=Mock(), + code_location=Mock(), + role_arn=Mock(), + hyperparameters=Mock(), + input_data=Mock(), + output_data_config=Mock(), + checkpoint_config=Mock(), + aws_session=aws_session, + ) + assert job.name == "Test-Job-Name" + assert job.arn == "local:job/Test-Job-Name" + assert job.state() == "COMPLETED" + assert job.run_log == run_log + assert job.metadata() is None + assert job.cancel() is None + assert job.download_result() is None + assert job.logs() is None + assert job.result() == job_results["dataDictionary"] + assert job.metrics() == { + "Cost": [-4.034, -3.957], + "iteration_number": [0.0, 1.0], + "timestamp": [1633027264.5406773, 1633027288.6284382], + } + mock_setup.assert_called_with(mock_container_open, aws_session, **creation_kwargs) + mock_container_open.run_local_job.assert_called_with(test_envs) + + +def test_create_invalid_arg(): + unexpected_kwarg = "create\\(\\) got an unexpected keyword argument 'wait_until_complete'" + with pytest.raises(TypeError, match=unexpected_kwarg): + LocalQuantumJob.create( + device="device", + source_module="source", + wait_until_complete=True, + ) + + +@patch("os.path.isdir") +def test_read_runlog_file(mock_dir): + mock_dir.return_value = True + with patch("builtins.open", mock_open()) as file_open: + file_read = file_open() + file_read.read.return_value = "Test Log" + job = LocalQuantumJob("local:job/Fake-Job") + assert job.run_log == "Test Log" + + +@patch("braket.jobs.local.local_job.prepare_quantum_job") +@patch("os.path.isdir") +def test_create_existing_job(mock_dir, mock_prepare_job, aws_session): + mock_dir.return_value = True + mock_prepare_job.return_value = { + "jobName": "Test-Job-Name", + "algorithmSpecification": {"containerImage": {"uri": "file://test-URI"}}, + "checkpointConfig": {"localPath": "test/local/path/"}, + } + dir_already_exists = ( + "A local directory called Test-Job-Name already exists. Please use a different job name." + ) + with pytest.raises(ValueError, match=dir_already_exists): + LocalQuantumJob.create( + device=Mock(), + source_module=Mock(), + entry_point=Mock(), + image_uri=Mock(), + job_name=Mock(), + code_location=Mock(), + role_arn=Mock(), + hyperparameters=Mock(), + input_data=Mock(), + output_data_config=Mock(), + checkpoint_config=Mock(), + aws_session=aws_session, + ) + + +def test_invalid_arn(): + invalid_arn = "Arn Invalid-Arn is not a valid local job arn" + with pytest.raises(ValueError, match=invalid_arn): + LocalQuantumJob("Invalid-Arn") + + +def test_missing_job_dir(): + missing_dir = "Unable to find local job results for Missing-Dir" + with pytest.raises(ValueError, match=missing_dir): + LocalQuantumJob("local:job/Missing-Dir") + + +@patch("os.path.isdir") +def test_missing_runlog_file(mock_dir): + mock_dir.return_value = True + job = LocalQuantumJob("local:job/Fake-Dir") + no_file = "Unable to find logs in the local job directory Fake-Dir." + with pytest.raises(ValueError, match=no_file): + job.run_log + + +@patch("os.path.isdir") +def test_missing_results_file(mock_dir): + mock_dir.return_value = True + job = LocalQuantumJob("local:job/Fake-Dir") + no_results = "Unable to find results in the local job directory Fake-Dir." + with pytest.raises(ValueError, match=no_results): + job.result() diff --git a/test/unit_tests/braket/jobs/local/test_local_job_container.py b/test/unit_tests/braket/jobs/local/test_local_job_container.py new file mode 100644 index 000000000..34b7a3855 --- /dev/null +++ b/test/unit_tests/braket/jobs/local/test_local_job_container.py @@ -0,0 +1,364 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import base64 +import subprocess +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from braket.jobs.local.local_job_container import _LocalJobContainer + + +@pytest.fixture +def repo_uri(): + return "012345678901.dkr.ecr.us-west-2.amazonaws.com" + + +@pytest.fixture +def image_uri(repo_uri): + return f"{repo_uri}/my-repo:my-tag" + + +@pytest.fixture +def aws_session(): + _aws_session = Mock() + return _aws_session + + +@patch("subprocess.check_output") +@patch("subprocess.run") +def test_start_and_stop(mock_run, mock_check_output, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + ] + with _LocalJobContainer(image_uri, aws_session): + pass + mock_check_output.assert_any_call(["docker", "images", "-q", image_uri]) + mock_check_output.assert_any_call( + ["docker", "run", "-d", "--rm", local_image_name, "tail", "-f", "/dev/null"] + ) + assert mock_check_output.call_count == 2 + mock_run.assert_any_call(["docker", "stop", running_container_name]) + assert mock_run.call_count == 1 + + +@patch("subprocess.check_output") +@patch("subprocess.run") +def test_pull_container(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + test_token = "Test Token" + mock_check_output.side_effect = [ + str.encode(""), + str.encode(local_image_name), + str.encode(running_container_name), + ] + aws_session.ecr_client.get_authorization_token.return_value = { + "authorizationData": [{"authorizationToken": base64.b64encode(str.encode(test_token))}] + } + with _LocalJobContainer(image_uri, aws_session): + pass + mock_check_output.assert_any_call(["docker", "images", "-q", image_uri]) + mock_check_output.assert_any_call( + ["docker", "run", "-d", "--rm", local_image_name, "tail", "-f", "/dev/null"] + ) + assert mock_check_output.call_count == 3 + mock_run.assert_any_call(["docker", "login", "-u", "AWS", "-p", test_token, repo_uri]) + mock_run.assert_any_call(["docker", "pull", image_uri]) + mock_run.assert_any_call(["docker", "stop", running_container_name]) + assert mock_run.call_count == 3 + + +@patch("subprocess.check_output") +@patch("subprocess.run") +def test_run_job_success(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + env_variables = { + "ENV0": "VALUE0", + "ENV1": "VALUE1", + } + run_program_name = "Run Program Name" + expected_run_output = "Expected Run Output" + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + str.encode(run_program_name), + str.encode(expected_run_output), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.run_local_job(env_variables) + run_output = container.run_log + assert run_output == expected_run_output + mock_check_output.assert_any_call(["docker", "images", "-q", image_uri]) + mock_check_output.assert_any_call( + ["docker", "run", "-d", "--rm", local_image_name, "tail", "-f", "/dev/null"] + ) + mock_check_output.assert_any_call( + ["docker", "exec", running_container_name, "printenv", "SAGEMAKER_PROGRAM"] + ) + mock_check_output.assert_any_call( + [ + "docker", + "exec", + "-w", + "/opt/ml/code/", + "-e", + "ENV0=VALUE0", + "-e", + "ENV1=VALUE1", + running_container_name, + "python", + run_program_name, + ] + ) + assert mock_check_output.call_count == 4 + mock_run.assert_any_call(["docker", "stop", running_container_name]) + assert mock_run.call_count == 1 + + +@patch("subprocess.check_output") +@patch("subprocess.run") +def test_customer_script_fails(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + env_variables = { + "ENV0": "VALUE0", + "ENV1": "VALUE1", + } + run_program_name = "Run Program Name" + expected_error_output = "Expected Error Output" + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + str.encode(run_program_name), + subprocess.CalledProcessError("Test Error", "test", str.encode(expected_error_output)), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.run_local_job(env_variables) + run_output = container.run_log + assert run_output == expected_error_output + assert mock_check_output.call_count == 4 + mock_run.assert_any_call(["docker", "stop", running_container_name]) + assert mock_run.call_count == 1 + + +@patch("subprocess.check_output") +@patch("subprocess.run") +def test_make_dir(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + test_dir_path = "/test/dir/path" + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + str.encode(""), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.makedir(test_dir_path) + mock_check_output.assert_any_call(["docker", "images", "-q", image_uri]) + mock_check_output.assert_any_call( + ["docker", "run", "-d", "--rm", local_image_name, "tail", "-f", "/dev/null"] + ) + mock_check_output.assert_any_call( + ["docker", "exec", running_container_name, "mkdir", "-p", test_dir_path] + ) + assert mock_check_output.call_count == 3 + mock_run.assert_any_call(["docker", "stop", running_container_name]) + assert mock_run.call_count == 1 + + +@patch("subprocess.check_output") +@patch("subprocess.run") +def test_copy_to(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + source_path = str(Path("test", "source", "dir", "path", "srcfile.txt")) + dest_path = str(Path("test", "dest", "dir", "path", "dstfile.txt")) + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + str.encode(""), + str.encode(""), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.copy_to(source_path, dest_path) + mock_check_output.assert_any_call(["docker", "images", "-q", image_uri]) + mock_check_output.assert_any_call( + ["docker", "run", "-d", "--rm", local_image_name, "tail", "-f", "/dev/null"] + ) + mock_check_output.assert_any_call( + [ + "docker", + "exec", + running_container_name, + "mkdir", + "-p", + str(Path("test", "dest", "dir", "path")), + ] + ) + mock_check_output.assert_any_call( + ["docker", "cp", source_path, f"{running_container_name}:{dest_path}"] + ) + assert mock_check_output.call_count == 4 + mock_run.assert_any_call(["docker", "stop", running_container_name]) + assert mock_run.call_count == 1 + + +@patch("subprocess.check_output") +@patch("subprocess.run") +def test_copy_from(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + source_path = "/test/source/dir/path/srcfile.txt" + dest_path = "/test/dest/dir/path/dstfile.txt" + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + str.encode(""), + str.encode(""), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.copy_from(source_path, dest_path) + mock_check_output.assert_any_call(["docker", "images", "-q", image_uri]) + mock_check_output.assert_any_call( + ["docker", "run", "-d", "--rm", local_image_name, "tail", "-f", "/dev/null"] + ) + mock_check_output.assert_any_call( + ["docker", "cp", f"{running_container_name}:{source_path}", dest_path] + ) + assert mock_check_output.call_count == 3 + mock_run.assert_any_call(["docker", "stop", running_container_name]) + assert mock_run.call_count == 1 + + +@patch("subprocess.check_output") +@patch("subprocess.run") +@pytest.mark.xfail(raises=ValueError) +def test_run_fails_no_program(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + env_variables = { + "ENV0": "VALUE0", + "ENV1": "VALUE1", + } + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + str.encode(""), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.run_local_job(env_variables) + + +@patch("subprocess.check_output") +@patch("subprocess.run") +@pytest.mark.xfail(raises=subprocess.CalledProcessError) +def test_make_dir_fails(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + test_dir_path = "/test/dir/path" + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + subprocess.CalledProcessError("Test Error", "test", str.encode("test output")), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.makedir(test_dir_path) + + +@patch("subprocess.check_output") +@patch("subprocess.run") +@pytest.mark.xfail(raises=subprocess.CalledProcessError) +def test_copy_to_fails(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + source_path = "/test/source/dir/path/srcfile.txt" + dest_path = "/test/dest/dir/path/dstfile.txt" + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + subprocess.CalledProcessError("Test Error", "test", str.encode("test output")), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.copy_to(source_path, dest_path) + + +@patch("subprocess.check_output") +@patch("subprocess.run") +@pytest.mark.xfail(raises=subprocess.CalledProcessError) +def test_copy_from_fails(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + source_path = "/test/source/dir/path/srcfile.txt" + dest_path = "/test/dest/dir/path/dstfile.txt" + mock_check_output.side_effect = [ + str.encode(local_image_name), + str.encode(running_container_name), + subprocess.CalledProcessError("Test Error", "test", str.encode("test output")), + ] + with _LocalJobContainer(image_uri, aws_session) as container: + container.copy_from(source_path, dest_path) + + +@patch("subprocess.check_output") +@patch("subprocess.run") +@pytest.mark.xfail(raises=ValueError) +def test_pull_fails_no_auth(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + mock_check_output.side_effect = [ + str.encode(""), + str.encode(local_image_name), + str.encode(running_container_name), + ] + aws_session.ecr_client.get_authorization_token.return_value = {} + with _LocalJobContainer(image_uri, aws_session): + pass + + +@patch("subprocess.check_output") +@patch("subprocess.run") +@pytest.mark.xfail(raises=ValueError) +def test_pull_fails_invalid_uri(mock_run, mock_check_output, aws_session): + local_image_name = "LocalImageName" + running_container_name = "RunningContainer" + mock_check_output.side_effect = [ + str.encode(""), + str.encode(local_image_name), + str.encode(running_container_name), + ] + aws_session.ecr_client.get_authorization_token.return_value = {} + with _LocalJobContainer("TestURI", aws_session): + pass + + +@patch("subprocess.check_output") +@patch("subprocess.run") +@pytest.mark.xfail(raises=ValueError) +def test_pull_fails_unknown_reason(mock_run, mock_check_output, repo_uri, image_uri, aws_session): + test_token = "Test Token" + mock_check_output.side_effect = [ + str.encode(""), + str.encode(""), + ] + aws_session.ecr_client.get_authorization_token.return_value = { + "authorizationData": [{"authorizationToken": base64.b64encode(str.encode(test_token))}] + } + with _LocalJobContainer(image_uri, aws_session): + pass diff --git a/test/unit_tests/braket/jobs/local/test_local_job_container_setup.py b/test/unit_tests/braket/jobs/local/test_local_job_container_setup.py new file mode 100644 index 000000000..93e114ef6 --- /dev/null +++ b/test/unit_tests/braket/jobs/local/test_local_job_container_setup.py @@ -0,0 +1,228 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import os +from pathlib import Path +from unittest.mock import Mock, mock_open, patch + +import pytest + +from braket.jobs.local.local_job_container_setup import setup_container + + +@pytest.fixture +def aws_session(): + _aws_session = Mock() + _aws_session.boto_session.get_credentials.return_value.access_key = "Test Access Key" + _aws_session.boto_session.get_credentials.return_value.secret_key = "Test Secret Key" + _aws_session.boto_session.get_credentials.return_value.token = None + _aws_session.region = "Test Region" + _aws_session.list_keys.side_effect = lambda bucket, prefix: [ + key + for key in [ + "input-dir/", + "input-dir/file-1.txt", + "input-dir/file-2.txt", + ] + if key.startswith(prefix) + ] + return _aws_session + + +@pytest.fixture +def container(): + _container = Mock() + return _container + + +@pytest.fixture +def creation_kwargs(): + return { + "algorithmSpecification": { + "scriptModeConfig": { + "entryPoint": "my_file:start_here", + "s3Uri": "s3://amazon-braket-jobs/job-path/my_file.py", + } + }, + "checkpointConfig": { + "localPath": "/opt/omega/checkpoints", + "s3Uri": "s3://amazon-braket-jobs/job-path/checkpoints", + }, + "outputDataConfig": {"s3Path": "s3://test_bucket/test_location/"}, + "deviceConfig": {"device": "test device ARN"}, + "jobName": "Test-Job-Name", + "roleArn": "arn:aws:iam::875981177017:role/AmazonBraketJobRole", + } + + +@pytest.fixture +def compressed_script_mode_config(): + return { + "scriptModeConfig": { + "entryPoint": "my_file:start_here", + "s3Uri": "s3://amazon-braket-jobs/job-path/my_archive.gzip", + "compressionType": "gzip", + } + } + + +@pytest.fixture +def expected_envs(): + return { + "AMZN_BRAKET_CHECKPOINT_DIR": "/opt/omega/checkpoints", + "AMZN_BRAKET_DEVICE_ARN": "test device ARN", + "AMZN_BRAKET_IMAGE_SETUP_SCRIPT": "s3://amazon-braket-external-assets-preview-us-west-2/" + "HybridJobsAccess/scripts/setup-container.sh", + "AMZN_BRAKET_JOB_NAME": "Test-Job-Name", + "AMZN_BRAKET_JOB_RESULTS_DIR": "/opt/braket/model", + "AMZN_BRAKET_JOB_RESULTS_S3_PATH": "test_location/Test-Job-Name/output", + "AMZN_BRAKET_OUT_S3_BUCKET": "test_bucket", + "AMZN_BRAKET_SCRIPT_ENTRY_POINT": "my_file:start_here", + "AMZN_BRAKET_SCRIPT_S3_URI": "s3://amazon-braket-jobs/job-path/my_file.py", + "AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://test_bucket/jobs/Test-Job-Name/tasks", + "AWS_ACCESS_KEY_ID": "Test Access Key", + "AWS_DEFAULT_REGION": "Test Region", + "AWS_SECRET_ACCESS_KEY": "Test Secret Key", + } + + +@pytest.fixture +def input_data_config(): + return [ + # s3 prefix is a single file + { + "channelName": "single-file", + "dataSource": {"s3DataSource": {"s3Uri": "s3://input_bucket/input-dir/file-1.txt"}}, + }, + # s3 prefix is a directory no slash + { + "channelName": "directory-no-slash", + "dataSource": {"s3DataSource": {"s3Uri": "s3://input_bucket/input-dir"}}, + }, + # s3 prefix is a directory with slash + { + "channelName": "directory-slash", + "dataSource": {"s3DataSource": {"s3Uri": "s3://input_bucket/input-dir/"}}, + }, + # s3 prefix is a prefix for a directory + { + "channelName": "directory-prefix", + "dataSource": {"s3DataSource": {"s3Uri": "s3://input_bucket/input"}}, + }, + # s3 prefix is a prefix for multiple files + { + "channelName": "files-prefix", + "dataSource": {"s3DataSource": {"s3Uri": "s3://input_bucket/input-dir/file"}}, + }, + ] + + +def test_basic_setup(container, aws_session, creation_kwargs, expected_envs): + aws_session.parse_s3_uri.return_value = ["test_bucket", "test_location"] + envs = setup_container(container, aws_session, **creation_kwargs) + assert envs == expected_envs + container.makedir.assert_any_call("/opt/ml/model") + container.makedir.assert_any_call(expected_envs["AMZN_BRAKET_CHECKPOINT_DIR"]) + assert container.makedir.call_count == 2 + + +def test_compressed_script_mode( + container, aws_session, creation_kwargs, expected_envs, compressed_script_mode_config +): + creation_kwargs["algorithmSpecification"] = compressed_script_mode_config + expected_envs["AMZN_BRAKET_SCRIPT_S3_URI"] = "s3://amazon-braket-jobs/job-path/my_archive.gzip" + expected_envs["AMZN_BRAKET_SCRIPT_COMPRESSION_TYPE"] = "gzip" + aws_session.parse_s3_uri.return_value = ["test_bucket", "test_location"] + envs = setup_container(container, aws_session, **creation_kwargs) + assert envs == expected_envs + container.makedir.assert_any_call("/opt/ml/model") + container.makedir.assert_any_call(expected_envs["AMZN_BRAKET_CHECKPOINT_DIR"]) + assert container.makedir.call_count == 2 + + +@patch("json.dump") +@patch("tempfile.TemporaryDirectory") +def test_hyperparameters(tempfile, json, container, aws_session, creation_kwargs, expected_envs): + with patch("builtins.open", mock_open()): + tempfile.return_value.__enter__.return_value = "temporaryDir" + creation_kwargs["hyperParameters"] = {"test": "hyper"} + expected_envs["AMZN_BRAKET_HP_FILE"] = "/opt/braket/input/config/hyperparameters.json" + aws_session.parse_s3_uri.return_value = ["test_bucket", "test_location"] + envs = setup_container(container, aws_session, **creation_kwargs) + assert envs == expected_envs + container.makedir.assert_any_call("/opt/ml/model") + container.makedir.assert_any_call(expected_envs["AMZN_BRAKET_CHECKPOINT_DIR"]) + assert container.makedir.call_count == 2 + container.copy_to.assert_called_with( + os.path.join("temporaryDir", "hyperparameters.json"), + "/opt/ml/input/config/hyperparameters.json", + ) + + +def test_input(container, aws_session, creation_kwargs, input_data_config): + creation_kwargs.update({"inputDataConfig": input_data_config}) + setup_container(container, aws_session, **creation_kwargs) + download_locations = [call[0][1] for call in aws_session.download_from_s3.call_args_list] + expected_downloads = [ + Path("single-file", "file-1.txt"), + Path("directory-no-slash", "file-1.txt"), + Path("directory-no-slash", "file-2.txt"), + Path("directory-slash", "file-1.txt"), + Path("directory-slash", "file-2.txt"), + Path("directory-prefix", "input-dir", "file-1.txt"), + Path("directory-prefix", "input-dir", "file-2.txt"), + Path("files-prefix", "file-1.txt"), + Path("files-prefix", "file-2.txt"), + ] + + for download, expected_download in zip(download_locations, expected_downloads): + assert download.endswith(str(expected_download)) + + +def test_duplicate_input(container, aws_session, creation_kwargs, input_data_config): + input_data_config.append( + { + # this is a duplicate channel + "channelName": "single-file", + "dataSource": {"s3DataSource": {"s3Uri": "s3://input_bucket/irrelevant"}}, + } + ) + creation_kwargs.update({"inputDataConfig": input_data_config}) + dupes_not_allowed = "Duplicate channel names not allowed for input data: single-file" + with pytest.raises(ValueError, match=dupes_not_allowed): + setup_container(container, aws_session, **creation_kwargs) + + +def test_no_data_input(container, aws_session, creation_kwargs, input_data_config): + input_data_config.append( + { + # this channel won't match any data + "channelName": "no-data", + "dataSource": {"s3DataSource": {"s3Uri": "s3://input_bucket/irrelevant"}}, + } + ) + creation_kwargs.update({"inputDataConfig": input_data_config}) + no_data_found = "No data found for channel 'no-data'" + with pytest.raises(RuntimeError, match=no_data_found): + setup_container(container, aws_session, **creation_kwargs) + + +def test_temporary_credentials(container, aws_session, creation_kwargs, expected_envs): + aws_session.boto_session.get_credentials.return_value.token = "Test Token" + expected_envs["AWS_SESSION_TOKEN"] = "Test Token" + aws_session.parse_s3_uri.return_value = ["test_bucket", "test_location"] + envs = setup_container(container, aws_session, **creation_kwargs) + assert envs == expected_envs + container.makedir.assert_any_call("/opt/ml/model") + container.makedir.assert_any_call(expected_envs["AMZN_BRAKET_CHECKPOINT_DIR"]) + assert container.makedir.call_count == 2 diff --git a/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py new file mode 100644 index 000000000..f0d42f523 --- /dev/null +++ b/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from unittest.mock import Mock, call, patch + +import pytest + +from braket.jobs.metrics_data import MetricsRetrievalError +from braket.jobs.metrics_data.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher + + +@pytest.fixture +def aws_session(): + _aws_session = Mock() + return _aws_session + + +EXAMPLE_METRICS_LOG_LINES = [ + [ + {"field": "@timestamp", "value": "Test timestamp 0"}, + {"field": "@message", "value": "Test value 0"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 1"}, + {"field": "@message", "value": "Test value 1"}, + ], + [ + {"field": "@timestamp", "value": "Test timestamp 2"}, + ], + [ + {"field": "@message", "value": "Test value 3"}, + ], + [], +] + +EXPECTED_CALL_LIST = [ + call("Test timestamp 0", "Test value 0"), + call("Test timestamp 1", "Test value 1"), + call(None, "Test value 3"), +] + + +@patch("braket.jobs.metrics_data.cwl_insights_metrics_fetcher.LogMetricsParser.get_parsed_metrics") +@patch("braket.jobs.metrics_data.cwl_insights_metrics_fetcher.LogMetricsParser.parse_log_message") +def test_get_all_metrics_complete_results(mock_add_metrics, mock_get_metrics, aws_session): + logs_client_mock = Mock() + aws_session.logs_client = logs_client_mock + + logs_client_mock.start_query.return_value = {"queryId": "test"} + logs_client_mock.get_query_results.return_value = { + "status": "Complete", + "results": EXAMPLE_METRICS_LOG_LINES, + } + expected_result = {"Test": [0]} + mock_get_metrics.return_value = expected_result + + fetcher = CwlInsightsMetricsFetcher(aws_session) + + result = fetcher.get_metrics_for_job("test_job", job_start_time=1, job_end_time=2) + logs_client_mock.get_query_results.assert_called_with(queryId="test") + logs_client_mock.start_query.assert_called_with( + logGroupName="/aws/braket/jobs", + startTime=1, + endTime=2, + queryString="fields @timestamp, @message | filter @logStream like /^test_job\\//" + " | filter @message like /^Metrics - /", + limit=10000, + ) + assert mock_add_metrics.call_args_list == EXPECTED_CALL_LIST + assert result == expected_result + + +def test_get_all_metrics_timeout(aws_session): + logs_client_mock = Mock() + aws_session.logs_client = logs_client_mock + + logs_client_mock.start_query.return_value = {"queryId": "test"} + logs_client_mock.get_query_results.return_value = {"status": "Queued"} + + fetcher = CwlInsightsMetricsFetcher(aws_session, 0.1, 0.2) + result = fetcher.get_metrics_for_job("test_job") + logs_client_mock.get_query_results.assert_called() + assert result == {} + + +@pytest.mark.xfail(raises=MetricsRetrievalError) +def test_get_all_metrics_failed(aws_session): + logs_client_mock = Mock() + aws_session.logs_client = logs_client_mock + + logs_client_mock.start_query.return_value = {"queryId": "test"} + logs_client_mock.get_query_results.return_value = {"status": "Failed"} + + fetcher = CwlInsightsMetricsFetcher(aws_session) + fetcher.get_metrics_for_job("test_job") diff --git a/test/unit_tests/braket/jobs/metrics_data/test_cwl_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics_data/test_cwl_metrics_fetcher.py new file mode 100644 index 000000000..fdaff840b --- /dev/null +++ b/test/unit_tests/braket/jobs/metrics_data/test_cwl_metrics_fetcher.py @@ -0,0 +1,135 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from unittest.mock import Mock, call, patch + +import pytest + +from braket.jobs.metrics_data.cwl_metrics_fetcher import CwlMetricsFetcher + + +@pytest.fixture +def aws_session(): + _aws_session = Mock() + return _aws_session + + +EXAMPLE_METRICS_LOG_LINES = [ + { + "timestamp": "Test timestamp 0", + "message": "Metrics - Test value 0", + }, + { + "timestamp": "Test timestamp 1", + "message": "Metrics - Test value 1", + }, + { + "timestamp": "Test timestamp 2", + }, + { + "message": "Metrics - Test value 3", + }, + { + # This metrics fetcher will filter out log line that don't have a "Metrics -" tag. + "message": "No prefix, Test value 4", + }, +] + +EXPECTED_CALL_LIST = [ + call("Test timestamp 0", "Metrics - Test value 0"), + call("Test timestamp 1", "Metrics - Test value 1"), + call(None, "Metrics - Test value 3"), +] + + +@patch("braket.jobs.metrics_data.cwl_metrics_fetcher.LogMetricsParser.get_parsed_metrics") +@patch("braket.jobs.metrics_data.cwl_metrics_fetcher.LogMetricsParser.parse_log_message") +def test_get_all_metrics_complete_results(mock_add_metrics, mock_get_metrics, aws_session): + logs_client_mock = Mock() + aws_session.logs_client = logs_client_mock + + logs_client_mock.describe_log_streams.return_value = { + "logStreams": [{"logStreamName": "stream name"}, {}] + } + logs_client_mock.get_log_events.return_value = { + "events": EXAMPLE_METRICS_LOG_LINES, + "nextForwardToken": None, + } + expected_result = {"Test": [0]} + mock_get_metrics.return_value = expected_result + + fetcher = CwlMetricsFetcher(aws_session) + result = fetcher.get_metrics_for_job("test_job") + assert mock_add_metrics.call_args_list == EXPECTED_CALL_LIST + assert result == expected_result + + +@patch("braket.jobs.metrics_data.cwl_metrics_fetcher.LogMetricsParser.parse_log_message") +def test_get_log_streams_timeout(mock_add_metrics, aws_session): + logs_client_mock = Mock() + aws_session.logs_client = logs_client_mock + + logs_client_mock.describe_log_streams.return_value = { + "logStreams": [{"logStreamName": "stream name"}], + "nextToken": "forever", + } + logs_client_mock.get_log_events.return_value = { + "events": EXAMPLE_METRICS_LOG_LINES, + } + + fetcher = CwlMetricsFetcher(aws_session, 0.1) + result = fetcher.get_metrics_for_job("test_job") + mock_add_metrics.assert_not_called() + assert result == {} + + +@patch("braket.jobs.metrics_data.cwl_metrics_fetcher.LogMetricsParser.parse_log_message") +def test_get_no_streams_returned(mock_add_metrics, aws_session): + logs_client_mock = Mock() + aws_session.logs_client = logs_client_mock + + logs_client_mock.describe_log_streams.return_value = {} + + fetcher = CwlMetricsFetcher(aws_session) + result = fetcher.get_metrics_for_job("test_job") + logs_client_mock.describe_log_streams.assert_called() + mock_add_metrics.assert_not_called() + assert result == {} + + +@patch("braket.jobs.metrics_data.cwl_metrics_fetcher.LogMetricsParser.get_parsed_metrics") +@patch("braket.jobs.metrics_data.cwl_metrics_fetcher.LogMetricsParser.parse_log_message") +def test_get_metrics_timeout(mock_add_metrics, mock_get_metrics, aws_session): + logs_client_mock = Mock() + aws_session.logs_client = logs_client_mock + + logs_client_mock.describe_log_streams.return_value = { + "logStreams": [{"logStreamName": "stream name"}] + } + logs_client_mock.get_log_events.side_effect = get_log_events_forever + expected_result = {"Test": [0]} + mock_get_metrics.return_value = expected_result + + fetcher = CwlMetricsFetcher(aws_session, 0.1) + result = fetcher.get_metrics_for_job("test_job") + logs_client_mock.get_log_events.assert_called() + mock_add_metrics.assert_called() + assert result == expected_result + + +def get_log_events_forever(*args, **kwargs): + next_token = "1" + token = kwargs.get("nextToken") + if token and token == "1": + next_token = "2" + return {"events": EXAMPLE_METRICS_LOG_LINES, "nextForwardToken": next_token} diff --git a/test/unit_tests/braket/jobs/metrics_data/test_log_metrics_parser.py b/test/unit_tests/braket/jobs/metrics_data/test_log_metrics_parser.py new file mode 100644 index 000000000..f3edadb02 --- /dev/null +++ b/test/unit_tests/braket/jobs/metrics_data/test_log_metrics_parser.py @@ -0,0 +1,164 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import pytest + +from braket.jobs.metrics_data import LogMetricsParser +from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType + +MALFORMED_METRICS_LOG_LINES = [ + {"timestamp": "Test timestamp 0", "message": ""}, + {"timestamp": "Test timestamp 1", "message": "No semicolon metric0=2.0"}, + {"timestamp": "Test timestamp 2", "message": "metric0=not_a_number;"}, + {"timestamp": "Test timestamp 3", "message": "also not a number metric0=2 . 0;"}, + {"timestamp": "Test timestamp 3", "message": "metric0=;"}, + {"timestamp": "Test timestamp 3", "message": "metric0= ;"}, + {"timestamp": "Test timestamp 4"}, + {"unknown": "Unknown"}, +] + +SIMPLE_METRICS_LOG_LINES = [ + # This is a standard line of what our metrics may look like + { + "timestamp": "Test timestamp 0", + "message": "Metrics - metric0=0.0; metric1=1.0; metric2=2.0 ;", + }, + # This line overwrites the timestamp by having it output in the metrics. + { + "timestamp": "Test timestamp 1", + "message": "Metrics - timestamp=1628019160; metric0=0.1; metric2= 2.1;", + }, + # This line adds metric3 that won't have values for any other timestamp + { + "timestamp": "Test timestamp 2", + "message": "Metrics - metric0=0.2; metric1=1.2; metric2= 2.2 ; metric3=0.2;", + }, + # This line adds metrics expressed as exponents + { + "timestamp": "Test timestamp 3", + "message": "Metrics - metric0=-0.4; metric1=3.14e-22; metric2=3.14E22;", + }, +] + +SIMPLE_METRICS_RESULT = { + "timestamp": [ + "Test timestamp 0", + 1628019160, + "Test timestamp 2", + "Test timestamp 3", + ], + "metric0": [0.0, 0.1, 0.2, -0.4], + "metric1": [1.0, None, 1.2, 3.14e-22], + "metric2": [2.0, 2.1, 2.2, 3.14e22], + "metric3": [None, None, 0.2, None], +} + +# This will test how metrics are combined when the multiple metrics have the same timestamp +SINGLE_TIMESTAMP_METRICS_LOG_LINES = [ + {"timestamp": "Test timestamp 0", "message": "Metrics - metric0=0.0;"}, + {"timestamp": "Test timestamp 0", "message": "Metrics - metric0=0.1; metric1=1.1;"}, + {"timestamp": "Test timestamp 0", "message": "Metrics - metric0=0.2; metric2=2.8;"}, + {"timestamp": "Test timestamp 0", "message": "Metrics - metric0=0.3; metric1=1.3;"}, + {"timestamp": "Test timestamp 0", "message": "Metrics - metric1=1.4; metric2=2.4;"}, + { + "timestamp": "Test timestamp 0", + "message": "Metrics - metric0=0.5; metric1=1.5; metric2=2.5;", + }, + {"timestamp": "Test timestamp 0", "message": "Metrics - metric1=0.6; metric0=0.6;"}, +] + + +ITERATION_AND_TIMESTAMPS_LOG_LINES = [ + {"timestamp": "Test timestamp 0", "message": "Metrics - iteration_number=0; metric0=0.0;"}, + { + "timestamp": "Test timestamp 1", + "message": "Metrics - metric0=0.1; metric1=1.1; iteration_number=0;", + }, + {"timestamp": "Test timestamp 2", "message": "Metrics - metric0=0.2; metric2=2.8;"}, + {"timestamp": "Test timestamp 3", "message": "Metrics - metric0=0.3; metric1=1.3;"}, + { + "timestamp": "Test timestamp 4", + "message": "Metrics - metric1=1.4; metric2=2.4; iteration_number=0;", + }, + { + "timestamp": "Test timestamp 5", + "message": "Metrics - metric0=0.5; metric1=1.5; metric2=2.5;", + }, + { + "timestamp": "Test timestamp 6", + "message": "Metrics - metric1=0.6; iteration_number=0; metric0=0.6;", + }, +] + + +SINGLE_TIMESTAMP_MAX_RESULTS = { + "timestamp": ["Test timestamp 0"], + "metric0": [0.6], + "metric1": [1.5], + "metric2": [2.8], +} + +SINGLE_TIMESTAMP_MIN_RESULTS = { + "timestamp": ["Test timestamp 0"], + "metric0": [0.0], + "metric1": [0.6], + "metric2": [2.4], +} + +ITERATION_NUMBER_MAX_RESULTS = { + "iteration_number": [0], + "timestamp": ["Test timestamp 6"], + "metric0": [0.6], + "metric1": [1.4], + "metric2": [2.4], +} + + +@pytest.mark.parametrize( + "log_events, metric_type, metric_stat, metrics_results", + [ + ([], MetricType.TIMESTAMP, MetricStatistic.MAX, {}), + (MALFORMED_METRICS_LOG_LINES, MetricType.TIMESTAMP, MetricStatistic.MAX, {}), + ( + SIMPLE_METRICS_LOG_LINES, + MetricType.TIMESTAMP, + MetricStatistic.MAX, + SIMPLE_METRICS_RESULT, + ), + ( + SINGLE_TIMESTAMP_METRICS_LOG_LINES, + MetricType.TIMESTAMP, + MetricStatistic.MAX, + SINGLE_TIMESTAMP_MAX_RESULTS, + ), + ( + SINGLE_TIMESTAMP_METRICS_LOG_LINES, + MetricType.TIMESTAMP, + MetricStatistic.MIN, + SINGLE_TIMESTAMP_MIN_RESULTS, + ), + ( + ITERATION_AND_TIMESTAMPS_LOG_LINES, + MetricType.ITERATION_NUMBER, + MetricStatistic.MAX, + ITERATION_NUMBER_MAX_RESULTS, + ), + # TODO: https://app.asana.com/0/1199668788990775/1200502190825620 + # We should also test some real-world data, once we have it. + ], +) +def test_get_all_metrics_complete_results(log_events, metric_type, metric_stat, metrics_results): + parser = LogMetricsParser() + for log_event in log_events: + parser.parse_log_message(log_event.get("timestamp"), log_event.get("message")) + assert parser.get_parsed_metrics(metric_type, metric_stat) == metrics_results diff --git a/test/unit_tests/braket/jobs/test_data_persistence.py b/test/unit_tests/braket/jobs/test_data_persistence.py new file mode 100644 index 000000000..d40d2c203 --- /dev/null +++ b/test/unit_tests/braket/jobs/test_data_persistence.py @@ -0,0 +1,274 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +import os +import tempfile +from dataclasses import dataclass +from unittest.mock import patch + +import numpy as np +import pytest + +from braket.jobs.data_persistence import load_job_checkpoint, save_job_checkpoint, save_job_result +from braket.jobs_data import PersistedJobDataFormat + + +@pytest.mark.parametrize( + "job_name, file_suffix, data_format, checkpoint_data, expected_saved_data", + [ + ( + "job_plaintext_simple_dict", + "", + PersistedJobDataFormat.PLAINTEXT, + {"converged": True, "energy": -0.2}, + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": {"converged": True, "energy": -0.2}, + "dataFormat": "plaintext", + } + ), + ), + ( + "job_pickled_simple_dict", + "suffix1", + PersistedJobDataFormat.PICKLED_V4, + {"converged": True, "energy": -0.2}, + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": { + "converged": "gASILg==\n", + "energy": "gASVCgAAAAAAAABHv8mZmZmZmZou\n", + }, + "dataFormat": "pickled_v4", + } + ), + ), + ], +) +def test_save_job_checkpoint( + job_name, file_suffix, data_format, checkpoint_data, expected_saved_data +): + with tempfile.TemporaryDirectory() as tmp_dir: + with patch.dict( + os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name} + ): + save_job_checkpoint(checkpoint_data, file_suffix, data_format) + + expected_file_location = ( + f"{tmp_dir}/{job_name}_{file_suffix}.json" + if file_suffix + else f"{tmp_dir}/{job_name}.json" + ) + with open(expected_file_location, "r") as expected_file: + assert expected_file.read() == expected_saved_data + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize("checkpoint_data", [{}, None]) +def test_save_job_checkpoint_raises_error_empty_data(checkpoint_data): + job_name = "foo" + with tempfile.TemporaryDirectory() as tmp_dir: + with patch.dict( + os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name} + ): + save_job_checkpoint(checkpoint_data) + + +@pytest.mark.parametrize( + "job_name, file_suffix, data_format, saved_data, expected_checkpoint_data", + [ + ( + "job_plaintext_simple_dict", + "", + PersistedJobDataFormat.PLAINTEXT, + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": {"converged": True, "energy": -0.2}, + "dataFormat": "plaintext", + } + ), + {"converged": True, "energy": -0.2}, + ), + ( + "job_pickled_simple_dict", + "", + PersistedJobDataFormat.PICKLED_V4, + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": { + "converged": "gASILg==\n", + "energy": "gASVCgAAAAAAAABHv8mZmZmZmZou\n", + }, + "dataFormat": "pickled_v4", + } + ), + {"converged": True, "energy": -0.2}, + ), + ], +) +def test_load_job_checkpoint( + job_name, file_suffix, data_format, saved_data, expected_checkpoint_data +): + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = ( + f"{tmp_dir}/{job_name}_{file_suffix}.json" + if file_suffix + else f"{tmp_dir}/{job_name}.json" + ) + with open(file_path, "w") as f: + f.write(saved_data) + + with patch.dict( + os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name} + ): + loaded_data = load_job_checkpoint(job_name, file_suffix) + assert loaded_data == expected_checkpoint_data + + +@pytest.mark.xfail(raises=FileNotFoundError) +def test_load_job_checkpoint_raises_error_file_not_exists(): + job_name = "old_job" + file_suffix = "correct_suffix" + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = f"{tmp_dir}/{job_name}_{file_suffix}.json" + with open(file_path, "w") as _: + pass + + with patch.dict( + os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name} + ): + load_job_checkpoint(job_name, "wrong_suffix") + + +@pytest.mark.xfail(raises=ValueError) +def test_load_job_checkpoint_raises_error_corrupted_data(): + job_name = "old_job_corrupted_data" + file_suffix = "foo" + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = f"{tmp_dir}/{job_name}_{file_suffix}.json" + with open(file_path, "w") as corrupted_file: + corrupted_file.write( + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": { + "converged": "gASILg==\n", + "energy": "gASVCgBHv--corrupted---\n", + }, + "dataFormat": "pickled_v4", + } + ) + ) + + with patch.dict( + os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name} + ): + load_job_checkpoint(job_name, file_suffix) + + +@dataclass +class CustomClassToPersist: + float_val: float + str_val: str + bool_val: bool + + +def test_save_and_load_job_checkpoint(): + with tempfile.TemporaryDirectory() as tmp_dir: + job_name = "job_name_1" + data = { + "np_array": np.array([1]), + "custom_class": CustomClassToPersist(3.4, "str", True), + "none_value": None, + "nested_dict": {"a": {"b": False}}, + } + with patch.dict( + os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir, "AMZN_BRAKET_JOB_NAME": job_name} + ): + save_job_checkpoint(data, data_format=PersistedJobDataFormat.PICKLED_V4) + retrieved = load_job_checkpoint(job_name) + assert retrieved == data + + +@pytest.mark.parametrize( + "data_format, result_data, expected_saved_data", + [ + ( + PersistedJobDataFormat.PLAINTEXT, + {"converged": True, "energy": -0.2}, + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": {"converged": True, "energy": -0.2}, + "dataFormat": "plaintext", + } + ), + ), + ( + PersistedJobDataFormat.PICKLED_V4, + {"converged": True, "energy": -0.2}, + json.dumps( + { + "braketSchemaHeader": { + "name": "braket.jobs_data.persisted_job_data", + "version": "1", + }, + "dataDictionary": { + "converged": "gASILg==\n", + "energy": "gASVCgAAAAAAAABHv8mZmZmZmZou\n", + }, + "dataFormat": "pickled_v4", + } + ), + ), + ], +) +def test_save_job_result(data_format, result_data, expected_saved_data): + with tempfile.TemporaryDirectory() as tmp_dir: + with patch.dict(os.environ, {"AMZN_BRAKET_JOB_RESULTS_DIR": tmp_dir}): + save_job_result(result_data, data_format) + + expected_file_location = f"{tmp_dir}/results.json" + with open(expected_file_location, "r") as expected_file: + assert expected_file.read() == expected_saved_data + + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize("result_data", [{}, None]) +def test_save_job_result_raises_error_empty_data(result_data): + with tempfile.TemporaryDirectory() as tmp_dir: + with patch.dict(os.environ, {"AMZN_BRAKET_CHECKPOINT_DIR": tmp_dir}): + save_job_result(result_data) diff --git a/test/unit_tests/braket/jobs/test_image_uris.py b/test/unit_tests/braket/jobs/test_image_uris.py new file mode 100644 index 000000000..ab608c5a8 --- /dev/null +++ b/test/unit_tests/braket/jobs/test_image_uris.py @@ -0,0 +1,57 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import pytest + +from braket.jobs.image_uris import Framework, retrieve_image + + +@pytest.mark.parametrize( + "region, framework, expected_uri", + [ + ( + "us-west-1", + Framework.BASE, + "292282985366.dkr.ecr.us-west-1.amazonaws.com/" + "amazon-braket-base-jobs:1.0-cpu-py37-ubuntu18.04", + ), + ( + "us-east-1", + Framework.PL_TENSORFLOW, + "292282985366.dkr.ecr.us-east-1.amazonaws.com/amazon-braket-tensorflow-jobs:" + "2.4.1-cpu-py37-ubuntu18.04", + ), + ( + "us-west-2", + Framework.PL_PYTORCH, + "292282985366.dkr.ecr.us-west-2.amazonaws.com/" + "amazon-braket-pytorch-jobs:1.8.1-cpu-py37-ubuntu18.04", + ), + ], +) +def test_retrieve_image_default_version(region, framework, expected_uri): + assert retrieve_image(framework, region) == expected_uri + + +@pytest.mark.parametrize( + "region, framework", + [ + ("eu-west-1", Framework.BASE), + (None, Framework.BASE), + ("us-west-1", None), + ("us-west-1", "foo"), + ], +) +@pytest.mark.xfail(raises=ValueError) +def test_retrieve_image_incorrect_input(region, framework): + retrieve_image(framework, region) diff --git a/test/unit_tests/braket/jobs/test_metrics.py b/test/unit_tests/braket/jobs/test_metrics.py new file mode 100644 index 000000000..1670cee3e --- /dev/null +++ b/test/unit_tests/braket/jobs/test_metrics.py @@ -0,0 +1,35 @@ +from unittest.mock import patch + +import pytest + +from braket.jobs.metrics import log_metric + + +@pytest.mark.parametrize( + "test_value, test_timestamp, test_iteration, result_string", + [ + # Happy case + (0.1, 1, 2, "Metrics - timestamp=1; TestName=0.1; iteration_number=2;"), + # We handle exponent values + (3.14e-22, 1, 2, "Metrics - timestamp=1; TestName=3.14e-22; iteration_number=2;"), + # When iteration number is not provided, we don't print it + (5, 1, None, "Metrics - timestamp=1; TestName=5;"), + # When iteration number is 0, we do print it + (5, 1, 0, "Metrics - timestamp=1; TestName=5; iteration_number=0;"), + # When timestamp is not provided, we use time.time() + (-3.14, None, 2, "Metrics - timestamp=time_mocked; TestName=-3.14; iteration_number=2;"), + ], +) +@patch("time.time") +@patch("builtins.print") +def test_log_metric( + print_mock, time_mock, test_value, test_timestamp, test_iteration, result_string +): + time_mock.return_value = "time_mocked" + log_metric( + metric_name="TestName", + value=test_value, + timestamp=test_timestamp, + iteration_number=test_iteration, + ) + print_mock.assert_called_with(result_string) diff --git a/test/unit_tests/braket/jobs/test_quantum_job_creation.py b/test/unit_tests/braket/jobs/test_quantum_job_creation.py new file mode 100644 index 000000000..824cc33cf --- /dev/null +++ b/test/unit_tests/braket/jobs/test_quantum_job_creation.py @@ -0,0 +1,636 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import datetime +import tempfile +import time +from collections import defaultdict +from dataclasses import asdict +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from braket.aws import AwsSession +from braket.jobs.config import ( + CheckpointConfig, + InstanceConfig, + OutputDataConfig, + S3DataSourceConfig, + StoppingCondition, +) +from braket.jobs.quantum_job_creation import ( + _generate_default_job_name, + _process_input_data, + _process_local_source_module, + _process_s3_source_module, + _tar_and_upload_to_code_location, + _validate_entry_point, + prepare_quantum_job, +) + + +@pytest.fixture +def aws_session(): + _aws_session = Mock(spec=AwsSession) + _aws_session.default_bucket.return_value = "default-bucket-name" + _aws_session.get_default_jobs_role.return_value = "default-role-arn" + return _aws_session + + +@pytest.fixture +def entry_point(): + return "test-source-dir.entry_point:func" + + +@pytest.fixture +def bucket(): + return "braket-region-id" + + +@pytest.fixture +def tags(): + return {"tag-key": "tag-value"} + + +@pytest.fixture( + params=[ + None, + "aws.location/amazon-braket-custom-jobs:tag.1.2.3", + "other.uri/amazon-braket-custom-name:tag", + "other.uri/custom-non-managed:tag", + "other-custom-format.com", + ] +) +def image_uri(request): + return request.param + + +@pytest.fixture(params=["given_job_name", "default_job_name"]) +def job_name(request): + if request.param == "given_job_name": + return "test-job-name" + + +@pytest.fixture +def s3_prefix(job_name): + return f"{job_name}/non-default" + + +@pytest.fixture(params=["local_source", "s3_source"]) +def source_module(request, bucket): + if request.param == "local_source": + return "test-source-module" + elif request.param == "s3_source": + return AwsSession.construct_s3_uri(bucket, "test-source-prefix", "source.tar.gz") + + +@pytest.fixture +def code_location(bucket, s3_prefix): + return AwsSession.construct_s3_uri(bucket, s3_prefix, "script") + + +@pytest.fixture +def role_arn(): + return "arn:aws:iam::0000000000:role/AmazonBraketInternalSLR" + + +@pytest.fixture +def device(): + return "arn:aws:braket:::device/qpu/test/device-name" + + +@pytest.fixture +def hyperparameters(): + return { + "param": "value", + "other-param": 100, + } + + +@pytest.fixture(params=["dict", "local"]) +def input_data(request, bucket): + if request.param == "dict": + return { + "s3_input": f"s3://{bucket}/data/prefix", + "local_input": "local/prefix", + "config_input": S3DataSourceConfig(f"s3://{bucket}/config/prefix"), + } + elif request.param == "local": + return "local/prefix" + + +@pytest.fixture +def instance_config(): + return InstanceConfig( + instanceType="ml.m5.large", + volumeSizeInGb=1, + ) + + +@pytest.fixture +def stopping_condition(): + return StoppingCondition( + maxRuntimeInSeconds=1200, + ) + + +@pytest.fixture +def output_data_config(bucket, s3_prefix): + return OutputDataConfig( + s3Path=AwsSession.construct_s3_uri(bucket, s3_prefix, "output"), + ) + + +@pytest.fixture +def checkpoint_config(bucket, s3_prefix): + return CheckpointConfig( + localPath="/opt/omega/checkpoints", + s3Uri=AwsSession.construct_s3_uri(bucket, s3_prefix, "checkpoints"), + ) + + +@pytest.fixture +def generate_get_job_response(): + def _get_job_response(**kwargs): + response = { + "ResponseMetadata": { + "RequestId": "d223b1a0-ee5c-4c75-afa7-3c29d5338b62", + "HTTPStatusCode": 200, + }, + "algorithmSpecification": { + "scriptModeConfig": { + "entryPoint": "my_file:start_here", + "s3Uri": "s3://amazon-braket-jobs/job-path/my_file.py", + } + }, + "checkpointConfig": { + "localPath": "/opt/omega/checkpoints", + "s3Uri": "s3://amazon-braket-jobs/job-path/checkpoints", + }, + "createdAt": datetime.datetime(2021, 6, 28, 21, 4, 51), + "deviceConfig": { + "device": "arn:aws:braket:::device/qpu/rigetti/Aspen-10", + }, + "hyperParameters": { + "foo": "bar", + }, + "inputDataConfig": [ + { + "channelName": "training_input", + "dataSource": { + "s3DataSource": { + "s3Uri": "s3://amazon-braket-jobs/job-path/input", + } + }, + } + ], + "instanceConfig": { + "instanceType": "ml.m5.large", + "volumeSizeInGb": 1, + }, + "jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446", + "jobName": "job-test-20210628140446", + "outputDataConfig": {"s3Path": "s3://amazon-braket-jobs/job-path/data"}, + "roleArn": "arn:aws:iam::875981177017:role/AmazonBraketJobRole", + "status": "RUNNING", + "stoppingCondition": {"maxRuntimeInSeconds": 1200}, + } + response.update(kwargs) + + return response + + return _get_job_response + + +@pytest.fixture(params=["fixtures", "defaults", "nones"]) +def create_job_args( + request, + aws_session, + entry_point, + image_uri, + source_module, + job_name, + code_location, + role_arn, + device, + hyperparameters, + input_data, + instance_config, + stopping_condition, + output_data_config, + checkpoint_config, + tags, +): + if request.param == "fixtures": + return dict( + (key, value) + for key, value in { + "device": device, + "source_module": source_module, + "entry_point": entry_point, + "image_uri": image_uri, + "job_name": job_name, + "code_location": code_location, + "role_arn": role_arn, + "hyperparameters": hyperparameters, + "input_data": input_data, + "instance_config": instance_config, + "stopping_condition": stopping_condition, + "output_data_config": output_data_config, + "checkpoint_config": checkpoint_config, + "aws_session": aws_session, + "tags": tags, + }.items() + if value is not None + ) + elif request.param == "defaults": + return { + "device": device, + "source_module": source_module, + "entry_point": entry_point, + "aws_session": aws_session, + } + elif request.param == "nones": + return defaultdict( + lambda: None, + device=device, + source_module=source_module, + entry_point=entry_point, + aws_session=aws_session, + ) + + +@patch("tarfile.TarFile.add") +@patch("importlib.util.find_spec") +@patch("braket.jobs.quantum_job_creation.Path") +@patch("time.time") +def test_create_job( + mock_time, + mock_path, + mock_findspec, + mock_tarfile, + aws_session, + source_module, + create_job_args, +): + mock_path.return_value.resolve.return_value.parent = "parent_dir" + mock_path.return_value.resolve.return_value.stem = source_module + mock_path.return_value.name = "file_name" + mock_time.return_value = datetime.datetime.now().timestamp() + expected_kwargs = _translate_creation_args(create_job_args) + result_kwargs = prepare_quantum_job(**create_job_args) + assert expected_kwargs == result_kwargs + + +def _translate_creation_args(create_job_args): + aws_session = create_job_args["aws_session"] + create_job_args = defaultdict(lambda: None, **create_job_args) + image_uri = create_job_args["image_uri"] + job_name = create_job_args["job_name"] or _generate_default_job_name(image_uri) + default_bucket = aws_session.default_bucket() + code_location = create_job_args["code_location"] or AwsSession.construct_s3_uri( + default_bucket, "jobs", job_name, "script" + ) + role_arn = create_job_args["role_arn"] or aws_session.get_default_jobs_role() + device = create_job_args["device"] + hyperparameters = create_job_args["hyperparameters"] or {} + input_data = create_job_args["input_data"] or {} + instance_config = create_job_args["instance_config"] or InstanceConfig() + output_data_config = create_job_args["output_data_config"] or OutputDataConfig( + s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "data") + ) + stopping_condition = create_job_args["stopping_condition"] or StoppingCondition() + checkpoint_config = create_job_args["checkpoint_config"] or CheckpointConfig( + s3Uri=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "checkpoints") + ) + entry_point = create_job_args["entry_point"] + source_module = create_job_args["source_module"] + if not AwsSession.is_s3_uri(source_module): + entry_point = entry_point or Path(source_module).stem + algorithm_specification = { + "scriptModeConfig": { + "entryPoint": entry_point, + "s3Uri": f"{code_location}/source.tar.gz", + "compressionType": "GZIP", + } + } + if image_uri: + algorithm_specification["containerImage"] = {"uri": image_uri} + tags = create_job_args.get("tags", {}) + + test_kwargs = { + "jobName": job_name, + "roleArn": role_arn, + "algorithmSpecification": algorithm_specification, + "inputDataConfig": _process_input_data(input_data, job_name, aws_session), + "instanceConfig": asdict(instance_config), + "outputDataConfig": asdict(output_data_config), + "checkpointConfig": asdict(checkpoint_config), + "deviceConfig": {"device": device}, + "hyperParameters": hyperparameters, + "stoppingCondition": asdict(stopping_condition), + "tags": tags, + } + + return test_kwargs + + +@patch("time.time") +def test_generate_default_job_name(mock_time, image_uri): + job_type_mapping = { + None: "-default", + "aws.location/amazon-braket-custom-jobs:tag.1.2.3": "-custom", + "other.uri/amazon-braket-custom-name:tag": "-custom-name", + "other.uri/custom-non-managed:tag": "", + "other-custom-format.com": "", + } + job_type = job_type_mapping[image_uri] + mock_time.return_value = datetime.datetime.now().timestamp() + assert _generate_default_job_name(image_uri) == f"braket-job{job_type}-{time.time() * 1000:.0f}" + + +@pytest.mark.parametrize( + "source_module", + ( + "s3://bucket/source_module.tar.gz", + "s3://bucket/SOURCE_MODULE.TAR.GZ", + ), +) +def test_process_s3_source_module(source_module, aws_session): + _process_s3_source_module(source_module, "entry_point", aws_session, "code_location") + aws_session.copy_s3_object.assert_called_with(source_module, "code_location/source.tar.gz") + + +def test_process_s3_source_module_not_tar_gz(aws_session): + must_be_tar_gz = ( + "If source_module is an S3 URI, it must point to a tar.gz file. " + "Not a valid S3 URI for parameter `source_module`: s3://bucket/source_module" + ) + with pytest.raises(ValueError, match=must_be_tar_gz): + _process_s3_source_module( + "s3://bucket/source_module", "entry_point", aws_session, "code_location" + ) + + +def test_process_s3_source_module_no_entry_point(aws_session): + entry_point_required = "If source_module is an S3 URI, entry_point must be provided." + with pytest.raises(ValueError, match=entry_point_required): + _process_s3_source_module("s3://bucket/source_module", None, aws_session, "code_location") + + +@patch("braket.jobs.quantum_job_creation._tar_and_upload_to_code_location") +@patch("braket.jobs.quantum_job_creation._validate_entry_point") +def test_process_local_source_module(validate_mock, tar_and_upload_mock, aws_session): + with tempfile.TemporaryDirectory() as temp_dir: + source_module = Path(temp_dir, "source_module") + source_module.touch() + + _process_local_source_module( + str(source_module), "entry_point", aws_session, "code_location" + ) + + source_module_abs_path = Path(temp_dir, "source_module").resolve() + validate_mock.assert_called_with(source_module_abs_path, "entry_point") + tar_and_upload_mock.assert_called_with(source_module_abs_path, aws_session, "code_location") + + +def test_process_local_source_module_not_found(aws_session): + with tempfile.TemporaryDirectory() as temp_dir: + source_module = str(Path(temp_dir, "source_module").as_posix()) + source_module_not_found = f"Source module not found: {source_module}" + with pytest.raises(ValueError, match=source_module_not_found): + _process_local_source_module(source_module, "entry_point", aws_session, "code_location") + + +def test_validate_entry_point_default_file(): + with tempfile.TemporaryDirectory() as temp_dir: + source_module_path = Path(temp_dir, "source_module.py") + source_module_path.touch() + # import source_module + _validate_entry_point(source_module_path, "source_module") + # from source_module import func + _validate_entry_point(source_module_path, "source_module:func") + # import . + _validate_entry_point(source_module_path, ".") + # from . import func + _validate_entry_point(source_module_path, ".:func") + + +def test_validate_entry_point_default_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + source_module_path = Path(temp_dir, "source_module") + source_module_path.mkdir() + # import source_module + _validate_entry_point(source_module_path, "source_module") + # from source_module import func + _validate_entry_point(source_module_path, "source_module:func") + # import . + _validate_entry_point(source_module_path, ".") + # from . import func + _validate_entry_point(source_module_path, ".:func") + + +def test_validate_entry_point_submodule_file(): + with tempfile.TemporaryDirectory() as temp_dir: + source_module_path = Path(temp_dir, "source_module") + source_module_path.mkdir() + Path(source_module_path, "submodule.py").touch() + # from source_module import submodule + _validate_entry_point(source_module_path, "source_module.submodule") + # from source_module.submodule import func + _validate_entry_point(source_module_path, "source_module.submodule:func") + # from . import submodule + _validate_entry_point(source_module_path, ".submodule") + # from .submodule import func + _validate_entry_point(source_module_path, ".submodule:func") + + +def test_validate_entry_point_submodule_init(): + with tempfile.TemporaryDirectory() as temp_dir: + source_module_path = Path(temp_dir, "source_module") + source_module_path.mkdir() + Path(source_module_path, "submodule.py").touch() + with open(str(Path(source_module_path, "__init__.py")), "w") as f: + f.write("from . import submodule as renamed") + # from source_module import renamed + _validate_entry_point(source_module_path, "source_module:renamed") + # from . import renamed + _validate_entry_point(source_module_path, ".:renamed") + + +def test_validate_entry_point_source_module_not_found(): + with tempfile.TemporaryDirectory() as temp_dir: + source_module_path = Path(temp_dir, "source_module") + source_module_path.mkdir() + Path(source_module_path, "submodule.py").touch() + + # catches ModuleNotFoundError + module_not_found = "Entry point module was not found: fake_source_module.submodule" + with pytest.raises(ValueError, match=module_not_found): + _validate_entry_point(source_module_path, "fake_source_module.submodule") + + # catches AssertionError for module is not None + submodule_not_found = "Entry point module was not found: source_module.fake_submodule" + with pytest.raises(ValueError, match=submodule_not_found): + _validate_entry_point(source_module_path, "source_module.fake_submodule") + + +@patch("tarfile.TarFile.add") +def test_tar_and_upload_to_code_location(mock_tar_add, aws_session): + with tempfile.TemporaryDirectory() as temp_dir: + source_module_path = Path(temp_dir, "source_module") + source_module_path.mkdir() + _tar_and_upload_to_code_location(source_module_path, aws_session, "code_location") + mock_tar_add.assert_called_with(source_module_path, arcname="source_module") + local, s3 = aws_session.upload_to_s3.call_args_list[0][0] + assert local.endswith("source.tar.gz") + assert s3 == "code_location/source.tar.gz" + + +@patch("braket.jobs.quantum_job_creation._process_local_source_module") +@patch("braket.jobs.quantum_job_creation._validate_entry_point") +@patch("braket.jobs.quantum_job_creation._validate_params") +def test_copy_checkpoints( + mock_validate_input, + mock_validate_entry_point, + mock_process_local_source, + aws_session, + entry_point, + device, + checkpoint_config, + generate_get_job_response, +): + other_checkpoint_uri = "s3://amazon-braket-jobs/job-path/checkpoints" + aws_session.get_job.return_value = generate_get_job_response( + checkpointConfig={ + "s3Uri": other_checkpoint_uri, + } + ) + prepare_quantum_job( + device=device, + source_module="source_module", + entry_point=entry_point, + copy_checkpoints_from_job="other-job-arn", + checkpoint_config=checkpoint_config, + aws_session=aws_session, + ) + aws_session.copy_s3_directory.assert_called_with(other_checkpoint_uri, checkpoint_config.s3Uri) + + +def test_invalid_input_parameters(entry_point, aws_session): + error_message = ( + "'instance_config' should be of '' " + "but user provided ." + ) + with pytest.raises(ValueError, match=error_message): + prepare_quantum_job( + aws_session=aws_session, + entry_point=entry_point, + device="arn:aws:braket:::device/quantum-simulator/amazon/sv1", + source_module="alpha_test_job", + hyperparameters={ + "param-1": "first parameter", + "param-2": "second param", + }, + instance_config=2, + ) + + +@pytest.mark.parametrize( + "input_data, input_data_configs", + ( + ( + "local/prefix", + [ + { + "channelName": "input", + "dataSource": { + "s3DataSource": { + "s3Uri": "s3://default-bucket-name/jobs/job-name/data/input/prefix", + }, + }, + } + ], + ), + ( + "s3://my-bucket/my/prefix-", + [ + { + "channelName": "input", + "dataSource": { + "s3DataSource": { + "s3Uri": "s3://my-bucket/my/prefix-", + }, + }, + } + ], + ), + ( + S3DataSourceConfig( + "s3://my-bucket/my/manifest.json", + content_type="text/csv", + ), + [ + { + "channelName": "input", + "dataSource": { + "s3DataSource": { + "s3Uri": "s3://my-bucket/my/manifest.json", + }, + }, + "contentType": "text/csv", + } + ], + ), + ( + { + "local-input": "local/prefix", + "s3-input": "s3://my-bucket/my/prefix-", + "config-input": S3DataSourceConfig( + "s3://my-bucket/my/manifest.json", + ), + }, + [ + { + "channelName": "local-input", + "dataSource": { + "s3DataSource": { + "s3Uri": "s3://default-bucket-name/jobs/job-name/" + "data/local-input/prefix", + }, + }, + }, + { + "channelName": "s3-input", + "dataSource": { + "s3DataSource": { + "s3Uri": "s3://my-bucket/my/prefix-", + }, + }, + }, + { + "channelName": "config-input", + "dataSource": { + "s3DataSource": { + "s3Uri": "s3://my-bucket/my/manifest.json", + }, + }, + }, + ], + ), + ), +) +def test_process_input_data(aws_session, input_data, input_data_configs): + job_name = "job-name" + assert _process_input_data(input_data, job_name, aws_session) == input_data_configs diff --git a/test/unit_tests/braket/jobs/test_serialization.py b/test/unit_tests/braket/jobs/test_serialization.py new file mode 100644 index 000000000..bbd1c6238 --- /dev/null +++ b/test/unit_tests/braket/jobs/test_serialization.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + + +import pytest + +from braket.jobs.serialization import deserialize_values, serialize_values +from braket.jobs_data import PersistedJobDataFormat + + +@pytest.mark.parametrize( + "data_format, submitted_data, expected_serialized_data", + [ + ( + PersistedJobDataFormat.PLAINTEXT, + {"converged": True, "energy": -0.2}, + {"converged": True, "energy": -0.2}, + ), + ( + PersistedJobDataFormat.PICKLED_V4, + {"converged": True, "energy": -0.2}, + {"converged": "gASILg==\n", "energy": "gASVCgAAAAAAAABHv8mZmZmZmZou\n"}, + ), + ], +) +def test_job_serialize_data(data_format, submitted_data, expected_serialized_data): + serialized_data = serialize_values(submitted_data, data_format) + assert serialized_data == expected_serialized_data + + +@pytest.mark.parametrize( + "data_format, submitted_data, expected_deserialized_data", + [ + ( + PersistedJobDataFormat.PLAINTEXT, + {"converged": True, "energy": -0.2}, + {"converged": True, "energy": -0.2}, + ), + ( + PersistedJobDataFormat.PICKLED_V4, + {"converged": "gASILg==\n", "energy": "gASVCgAAAAAAAABHv8mZmZmZmZou\n"}, + {"converged": True, "energy": -0.2}, + ), + ], +) +def test_job_deserialize_data(data_format, submitted_data, expected_deserialized_data): + deserialized_data = deserialize_values(submitted_data, data_format) + assert deserialized_data == expected_deserialized_data