Skip to content

Commit

Permalink
fix: only call aws_session.bucket() when s3_destination_folder is not…
Browse files Browse the repository at this point in the history
… prov… (#37)

fix: only call aws_session.bucket() when s3_destination_folder is not provided
  • Loading branch information
ajberdy authored Aug 27, 2021
1 parent 0fb288c commit 5b06e24
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/braket/aws/aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ def run(
See Also:
`braket.aws.aws_quantum_task.AwsQuantumTask.create()`
"""
default_s3_location = (
os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET") or self._aws_session.default_bucket(),
os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_PATH") or "tasks",
)
return AwsQuantumTask.create(
self._aws_session,
self._arn,
task_specification,
s3_destination_folder or default_s3_location,
s3_destination_folder
or (
os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET") or self._aws_session.default_bucket(),
os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_PATH") or "tasks",
),
shots if shots is not None else self._default_shots,
poll_timeout_seconds=poll_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
Expand Down Expand Up @@ -200,15 +200,15 @@ def run_batch(
See Also:
`braket.aws.aws_quantum_task_batch.AwsQuantumTaskBatch`
"""
default_s3_location = (
os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET") or self._aws_session.default_bucket(),
os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_PATH") or "tasks",
)
return AwsQuantumTaskBatch(
AwsDevice._copy_aws_session(self._aws_session, max_connections=max_connections),
self._arn,
task_specifications,
s3_destination_folder or default_s3_location,
s3_destination_folder
or (
os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET") or self._aws_session.default_bucket(),
os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_PATH") or "tasks",
),
shots if shots is not None else self._default_shots,
max_parallel=max_parallel if max_parallel is not None else self._default_max_parallel,
max_workers=max_connections,
Expand Down
21 changes: 21 additions & 0 deletions test/unit_tests/braket/aws/test_aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,27 @@ def test_run_with_qpu_no_shots(aws_quantum_task_mock, device, circuit, s3_destin
)


@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create")
def test_default_bucket_not_called(aws_quantum_task_mock, device, circuit, s3_destination_folder):
device = device(RIGETTI_ARN)
run_and_assert(
aws_quantum_task_mock,
device,
MOCK_DEFAULT_S3_DESTINATION_FOLDER,
AwsDevice.DEFAULT_SHOTS_QPU,
AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
circuit,
s3_destination_folder,
None,
None,
None,
None,
None,
)
device._aws_session.default_bucket.assert_not_called()


@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create")
def test_run_with_shots_poll_timeout_kwargs(
aws_quantum_task_mock, device, circuit, s3_destination_folder
Expand Down

0 comments on commit 5b06e24

Please sign in to comment.