Skip to content

Commit

Permalink
remove api fields (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajberdy authored Nov 5, 2021
1 parent 4dd1b8a commit d53b858
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 105 deletions.
10 changes: 2 additions & 8 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
OutputDataConfig,
S3DataSourceConfig,
StoppingCondition,
VpcConfig,
)
from braket.jobs.metrics_data.cwl_insights_metrics_fetcher import CwlInsightsMetricsFetcher

Expand Down Expand Up @@ -77,7 +76,6 @@ def create(
output_data_config: OutputDataConfig = None,
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
vpc_config: VpcConfig = None,
aws_session: AwsSession = None,
) -> AwsQuantumJob:
"""Creates a job by invoking the Braket CreateJob API.
Expand Down Expand Up @@ -150,10 +148,7 @@ def create(
checkpoint_config (CheckpointConfig): Configuration that specifies the location where
checkpoint data is stored.
Default: CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=None).
vpc_config (VpcConfig): Configuration that specifies the security groups and subnets
to use for running the job. Default: None.
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints').
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()
Expand Down Expand Up @@ -181,7 +176,6 @@ def create(
output_data_config=output_data_config,
copy_checkpoints_from_job=copy_checkpoints_from_job,
checkpoint_config=checkpoint_config,
vpc_config=vpc_config,
aws_session=aws_session,
)

Expand Down Expand Up @@ -314,7 +308,7 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
stream_prefix = f"{self.name}/"
stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position
instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"]
instance_count = 1 # currently only support a single instance
has_streams = False
color_wrap = logs.ColorWrap()

Expand Down
1 change: 0 additions & 1 deletion src/braket/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
OutputDataConfig,
S3DataSourceConfig,
StoppingCondition,
VpcConfig,
)
from braket.jobs.data_persistence import ( # noqa: F401
load_job_checkpoint,
Expand Down
38 changes: 3 additions & 35 deletions src/braket/jobs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from dataclasses import dataclass, field
from typing import List, Optional
from dataclasses import dataclass
from typing import Optional


@dataclass
Expand All @@ -28,9 +28,7 @@ class InstanceConfig:
"""Configuration of the instances used to execute the job."""

instanceType: str = "ml.m5.large"
instanceCount: int = 1
volumeSizeInGb: int = 30
volumeKmsKeyId = None


@dataclass
Expand All @@ -48,20 +46,9 @@ class StoppingCondition:
maxRuntimeInSeconds: int = 5 * 24 * 60 * 60


@dataclass
class VpcConfig:
# TODO: Add securityGroupIds and subnets default values here.
# TODO: Ensure that length of the list for securityGroupIds is between 1 and 5
# and for subnets between 1 and 16.
"""Configuration that specifies the security groups and subnets to use for running the job."""

securityGroupIds: List[str]
subnets: List[str]


@dataclass
class DeviceConfig:
devices: List[str] = field(default_factory=list)
device: str


class S3DataSourceConfig:
Expand All @@ -71,40 +58,21 @@ class S3DataSourceConfig:
config (dict[str, dict]): config passed to the Braket API
"""

class DistributionType:
FULLY_REPLICATED = "FULLY_REPLICATED"
SHARDED_BY_S3_KEY = "SHARDED_BY_S3_KEY"

class S3DataType:
S3_PREFIX = "S3_PREFIX"
MANIFEST_FILE = "MANIFEST_FILE"

def __init__(
self,
s3_data,
distribution=DistributionType.FULLY_REPLICATED,
content_type=None,
s3_data_type=S3DataType.S3_PREFIX,
):
"""Create a definition for input data used by a Braket job.
Args:
s3_data (str): Defines the location of s3 data to train on.
distribution (str): Valid values: 'FullyReplicated', 'ShardedByS3Key'
(default: 'FullyReplicated').
content_type (str): MIME type of the input data (default: None).
s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'.
If 'S3Prefix', ``s3_data`` defines a prefix of s3 objects to train on.
All objects with s3 keys beginning with ``s3_data`` will be used to train.
If 'ManifestFile', then ``s3_data`` defines a single S3 manifest file,
listing the S3 data to train on.
"""
self.config = {
"dataSource": {
"s3DataSource": {
"s3DataType": s3_data_type,
"s3Uri": s3_data,
"s3DataDistributionType": distribution,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/braket/jobs/local/local_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def create(
checkpoint_config (CheckpointConfig): Configuration that specifies the location where
checkpoint data is stored.
Default: CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=None).
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints').
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()
Expand Down
19 changes: 5 additions & 14 deletions src/braket/jobs/quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from braket.aws.aws_session import AwsSession
from braket.jobs.config import (
Expand All @@ -30,7 +30,6 @@
OutputDataConfig,
S3DataSourceConfig,
StoppingCondition,
VpcConfig,
)


Expand All @@ -49,7 +48,6 @@ def prepare_quantum_job(
output_data_config: OutputDataConfig = None,
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
vpc_config: VpcConfig = None,
aws_session: AwsSession = None,
):
"""Creates a job by invoking the Braket CreateJob API.
Expand Down Expand Up @@ -116,10 +114,7 @@ def prepare_quantum_job(
checkpoint_config (CheckpointConfig): Configuration that specifies the location where
checkpoint data is stored.
Default: CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=None).
vpc_config (VpcConfig): Configuration that specifies the security groups and subnets
to use for running the job. Default: None.
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints').
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()
Expand All @@ -135,12 +130,11 @@ def prepare_quantum_job(
"stopping_condition": (stopping_condition, StoppingCondition),
"output_data_config": (output_data_config, OutputDataConfig),
"checkpoint_config": (checkpoint_config, CheckpointConfig),
"vpc_config": (vpc_config, VpcConfig),
}

_validate_params(param_datatype_map)
aws_session = aws_session or AwsSession()
device_config = DeviceConfig(devices=[device_arn])
device_config = DeviceConfig(device_arn)
job_name = job_name or _generate_default_job_name(image_uri)
role_arn = role_arn or aws_session.get_execution_role()
hyperparameters = hyperparameters or {}
Expand Down Expand Up @@ -206,17 +200,14 @@ def prepare_quantum_job(
"stoppingCondition": asdict(stopping_condition),
}

if vpc_config:
create_job_kwargs["vpcConfig"] = asdict(vpc_config)

return create_job_kwargs


def _generate_default_job_name(image_uri: str) -> str:
def _generate_default_job_name(image_uri: Optional[str]) -> str:
"""
Generate default job name using the image uri and a timestamp
Args:
image_uri (str): URI for the image container.
image_uri (str, optional): URI for the image container.
Returns:
str: Job name.
Expand Down
3 changes: 0 additions & 3 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,8 @@ def _get_job_response(**kwargs):
"inputDataConfig": [
{
"channelName": "training_input",
"compressionType": "NONE",
"dataSource": {
"s3DataSource": {
"s3DataType": "S3_PREFIX",
"s3Uri": "s3://amazon-braket-jobs/job-path/input",
}
},
Expand Down Expand Up @@ -502,7 +500,6 @@ def prepare_job_args(aws_session):
"output_data_config": Mock(),
"copy_checkpoints_from_job": Mock(),
"checkpoint_config": Mock(),
"vpc_config": Mock(),
"aws_session": aws_session,
}

Expand Down
Loading

0 comments on commit d53b858

Please sign in to comment.