Skip to content

Commit

Permalink
feat: Add support for tags in AwsQuantumJob.create() (#138)
Browse files Browse the repository at this point in the history
* feat: Add support for tags in AwsQuantumJob.create()

* reformat
  • Loading branch information
kshitijc authored Nov 9, 2021
1 parent da7b589 commit 90a9886
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def create(
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
aws_session: AwsSession = None,
tags: Dict[str, str] = None,
) -> AwsQuantumJob:
"""Creates a job by invoking the Braket CreateJob API.
Expand Down Expand Up @@ -149,6 +150,9 @@ def create(
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()
tags (Dict[str, str]): Dict specifying the key-value pairs for tagging this job.
Default: None.
Returns:
AwsQuantumJob: Job tracking the execution on Amazon Braket.
Expand All @@ -173,6 +177,7 @@ def create(
copy_checkpoints_from_job=copy_checkpoints_from_job,
checkpoint_config=checkpoint_config,
aws_session=aws_session,
tags=tags,
)

job_arn = aws_session.create_job(**create_job_kwargs)
Expand Down
5 changes: 1 addition & 4 deletions src/braket/jobs/metrics_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
# language governing permissions and limitations under the License.

from braket.jobs.metrics_data.cwl_metrics_fetcher import CwlMetricsFetcher # noqa: F401
from braket.jobs.metrics_data.definitions import ( # noqa: F401
MetricPeriod,
MetricStatistic,
)
from braket.jobs.metrics_data.definitions import MetricPeriod, MetricStatistic # noqa: F401
from braket.jobs.metrics_data.exceptions import MetricsRetrievalError # noqa: F401
from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser # noqa: F401
5 changes: 5 additions & 0 deletions src/braket/jobs/quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def prepare_quantum_job(
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
aws_session: AwsSession = None,
tags: Dict[str, str] = None,
):
"""Creates a job by invoking the Braket CreateJob API.
Expand Down Expand Up @@ -119,6 +120,9 @@ def prepare_quantum_job(
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()
tags (Dict[str, str]): Dict specifying the key-value pairs for tagging this job.
Default: None.
Returns:
AwsQuantumJob: Job tracking the execution on Amazon Braket.
Expand Down Expand Up @@ -198,6 +202,7 @@ def prepare_quantum_job(
"deviceConfig": asdict(device_config),
"hyperParameters": hyperparameters,
"stoppingCondition": asdict(stopping_condition),
"tags": tags,
}

return create_job_kwargs
Expand Down
1 change: 1 addition & 0 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def prepare_job_args(aws_session):
"copy_checkpoints_from_job": Mock(),
"checkpoint_config": Mock(),
"aws_session": aws_session,
"tags": Mock(),
}


Expand Down
9 changes: 9 additions & 0 deletions test/unit_tests/braket/jobs/test_quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def bucket():
return "braket-region-id"


@pytest.fixture
def tags():
return {"tag-key": "tag-value"}


@pytest.fixture(
params=[
None,
Expand Down Expand Up @@ -224,6 +229,7 @@ def create_job_args(
stopping_condition,
output_data_config,
checkpoint_config,
tags,
):
if request.param == "fixtures":
return dict(
Expand All @@ -243,6 +249,7 @@ def create_job_args(
"output_data_config": output_data_config,
"checkpoint_config": checkpoint_config,
"aws_session": aws_session,
"tags": tags,
}.items()
if value is not None
)
Expand Down Expand Up @@ -319,6 +326,7 @@ def _translate_creation_args(create_job_args):
}
if image_uri:
algorithm_specification["containerImage"] = {"uri": image_uri}
tags = create_job_args.get("tags")

test_kwargs = {
"jobName": job_name,
Expand All @@ -331,6 +339,7 @@ def _translate_creation_args(create_job_args):
"deviceConfig": {"device": device},
"hyperParameters": hyperparameters,
"stoppingCondition": asdict(stopping_condition),
"tags": tags,
}

return test_kwargs
Expand Down

0 comments on commit 90a9886

Please sign in to comment.