Skip to content

Commit

Permalink
feature: Implement .cancel() for AwsQuantumJob (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijc authored Jul 20, 2021
1 parent 72b16ee commit cdc53e2
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,12 @@ def cancel(self) -> str:
Returns:
str: Representing the status of the job.
Raises:
ClientError: If there are errors invoking the CancelJob API.
"""
cancellation_response = self._aws_session.cancel_job(self._arn)
return cancellation_response["cancellationStatus"]

def result(self) -> Dict[str, Any]:
"""Retrieves the job result persisted using save_job_result() function.
Expand Down
12 changes: 12 additions & 0 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ def get_job(self, arn: str) -> Dict[str, Any]:
"""
return self.braket_client.get_job(jobArn=arn)

def cancel_job(self, arn: str) -> Dict[str, Any]:
"""
Cancel the quantum job.
Args:
arn (str): The ARN of the quantum job to cancel.
Returns:
Dict[str, Any]: The response from the Amazon Braket `CancelJob` operation.
"""
return self.braket_client.cancel_job(jobArn=arn)

def get_execution_role(self, aws_session):
"""Return the role ARN whose credentials are used to call the API.
Throws an exception if role doesn't exist.
Expand Down
41 changes: 41 additions & 0 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unittest.mock import Mock, patch

import pytest
from botocore.exceptions import ClientError

from braket.aws import AwsQuantumJob

Expand Down Expand Up @@ -83,6 +84,23 @@ def _get_job_response(**kwargs):
return _get_job_response


@pytest.fixture
def generate_cancel_job_response():
def _cancel_job_response(**kwargs):
response = {
"ResponseMetadata": {
"RequestId": "857b0893-2073-4ad6-b828-744af8400dfe",
"HTTPStatusCode": 200,
},
"cancellationStatus": "CANCELLING",
"jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446",
}
response.update(kwargs)
return response

return _cancel_job_response


@pytest.fixture
def quantum_job_name():
return "job-test-20210625121626"
Expand Down Expand Up @@ -215,3 +233,26 @@ def test_state_caching(quantum_job, aws_session, generate_get_job_response, quan
assert quantum_job.state(True) == state_1
aws_session.get_job.assert_called_with(quantum_job_arn)
assert aws_session.get_job.call_count == 1


def test_cancel_job(quantum_job_arn, aws_session, generate_cancel_job_response):
cancellation_status = "CANCELLING"
aws_session.cancel_job.return_value = generate_cancel_job_response(
cancellationStatus=cancellation_status
)
quantum_job = AwsQuantumJob(quantum_job_arn, aws_session)
status = quantum_job.cancel()
aws_session.cancel_job.assert_called_with(quantum_job_arn)
assert status == cancellation_status


@pytest.mark.xfail(raises=ClientError)
def test_cancel_job_surfaces_exception(quantum_job, aws_session):
exception_response = {
"Error": {
"Code": "ValidationException",
"Message": "unit-test-error",
}
}
aws_session.cancel_job.side_effect = ClientError(exception_response, "cancel_job")
quantum_job.cancel()
46 changes: 46 additions & 0 deletions test/unit_tests/braket/aws/test_aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,52 @@ def test_get_job_does_not_retry_other_exceptions(aws_session):
assert aws_session.braket_client.get_job.call_count == 1


def test_cancel_job(aws_session):
arn = "arn:aws:braket:us-west-2:1234567890:job/job-name"
cancel_job_response = {
"ResponseMetadata": {
"RequestId": "857b0893-2073-4ad6-b828-744af8400dfe",
"HTTPStatusCode": 200,
},
"cancellationStatus": "CANCELLING",
"jobArn": "arn:aws:braket:us-west-2:1234567890:job/job-name",
}
aws_session.braket_client.cancel_job.return_value = cancel_job_response

assert aws_session.cancel_job(arn) == cancel_job_response
aws_session.braket_client.cancel_job.assert_called_with(jobArn=arn)


@pytest.mark.parametrize(
"exception_type",
[
"ResourceNotFoundException",
"ValidationException",
"AccessDeniedException",
"ThrottlingException",
"InternalServiceException",
"ConflictException",
],
)
def test_cancel_job_surfaces_errors(exception_type, aws_session):
arn = "arn:aws:braket:us-west-2:1234567890:job/job-name"
exception_response = {
"Error": {
"Code": "SomeOtherException",
"Message": "unit-test-error",
}
}

aws_session.braket_client.cancel_job.side_effect = [
ClientError(exception_response, "unit-test"),
]

with pytest.raises(ClientError):
aws_session.cancel_job(arn)
aws_session.braket_client.cancel_job.assert_called_with(jobArn=arn)
assert aws_session.braket_client.cancel_job.call_count == 1


@pytest.mark.parametrize(
"input,output",
[
Expand Down

0 comments on commit cdc53e2

Please sign in to comment.