Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

doc fixes #748

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/braket/jobs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class S3DataSourceConfig:
def __init__(
self,
s3_data: str,
content_type: str | None = None,
content_type: str = None,
):
"""Create a definition for input data used by a Braket Hybrid job.

Expand Down
6 changes: 2 additions & 4 deletions src/braket/jobs/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def save_job_checkpoint(
f.write(persisted_data.json())


def load_job_checkpoint(
job_name: str | None = None, checkpoint_file_suffix: str = ""
) -> dict[str, Any]:
def load_job_checkpoint(job_name: str = None, checkpoint_file_suffix: str = "") -> dict[str, Any]:
"""
Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint
file that ends with the `checkpoint_file_suffix`. The `job_name` can refer to any job whose
Expand All @@ -80,7 +78,7 @@ def load_job_checkpoint(


Args:
job_name (str | None): str that specifies the name of the job whose checkpoints
job_name (str): str that specifies the name of the job whose checkpoints
are to be loaded. Default: current job name.

checkpoint_file_suffix (str): str specifying the file suffix that is used to
Expand Down
78 changes: 43 additions & 35 deletions src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from logging import Logger, getLogger
from pathlib import Path
from types import ModuleType
from typing import Any
from typing import Any, Dict, List

import cloudpickle

Expand All @@ -45,22 +45,22 @@

def hybrid_job(
*,
device: str | None,
device: str,
include_modules: str | ModuleType | Iterable[str | ModuleType] = None,
dependencies: str | Path | None = None,
dependencies: str | Path = None,
local: bool = False,
job_name: str | None = None,
image_uri: str | None = None,
job_name: str = None,
image_uri: str = None,
input_data: str | dict | S3DataSourceConfig = None,
wait_until_complete: bool = False,
instance_config: InstanceConfig | None = None,
distribution: str | None = None,
copy_checkpoints_from_job: str | None = None,
checkpoint_config: CheckpointConfig | None = None,
role_arn: str | None = None,
stopping_condition: StoppingCondition | None = None,
output_data_config: OutputDataConfig | None = None,
aws_session: AwsSession | None = None,
instance_config: InstanceConfig = None,
distribution: str = None,
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
role_arn: str = None,
stopping_condition: StoppingCondition = None,
output_data_config: OutputDataConfig = None,
aws_session: AwsSession = None,
tags: dict[str, str] = None,
logger: Logger = getLogger(__name__),
) -> Callable:
Expand All @@ -73,7 +73,7 @@ def hybrid_job(
`copy_checkpoints_from_job`, `stopping_condition`, `tags`, and `logger`.

Args:
device (str | None): Device ARN of the QPU device that receives priority quantum
device (str): Device ARN of the QPU device that receives priority quantum
task queueing once the hybrid job begins running. Each QPU has a separate hybrid jobs
queue so that only one hybrid job is running at a time. The device string is accessible
in the hybrid job instance as the environment variable "AMZN_BRAKET_DEVICE_ARN".
Expand All @@ -85,20 +85,20 @@ def hybrid_job(
modules to be included. Any references to members of these modules in the hybrid job
algorithm code will be serialized as part of the algorithm code. Default value `[]`

dependencies (str | Path | None): Path (absolute or relative) to a requirements.txt
dependencies (str | Path): Path (absolute or relative) to a requirements.txt
file to be used for the hybrid job.

local (bool): Whether to use local mode for the hybrid job. Default `False`

job_name (str | None): A string that specifies the name with which the job is created.
job_name (str): A string that specifies the name with which the job is created.
Allowed pattern for job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`. Defaults to
f'{decorated-function-name}-{timestamp}'.

image_uri (str | None): A str that specifies the ECR image to use for executing the job.
image_uri (str): A str that specifies the ECR image to use for executing the job.
`retrieve_image()` function may be used for retrieving the ECR image URIs
for the containers supported by Braket. Default = `<Braket base image_uri>`.

input_data (str | Dict | S3DataSourceConfig): Information about the training
input_data (str | dict | S3DataSourceConfig): Information about the training
data. Dictionary maps channel names to local paths or S3 URIs. Contents found
at any local paths will be uploaded to S3 at
f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local
Expand All @@ -110,51 +110,59 @@ def hybrid_job(
This would tail the job logs as it waits. Otherwise `False`. Ignored if using
local mode. Default: `False`.

instance_config (InstanceConfig | None): Configuration of the instance(s) for running the
instance_config (InstanceConfig): Configuration of the instance(s) for running the
classical code for the hybrid job. Defaults to
`InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`.

distribution (str | None): A str that specifies how the job should be distributed.
distribution (str): A str that specifies how the job should be distributed.
If set to "data_parallel", the hyperparameters for the job will be set to use data
parallelism features for PyTorch or TensorFlow. Default: None.

copy_checkpoints_from_job (str | None): A str that specifies the job ARN whose
copy_checkpoints_from_job (str): A str that specifies the job ARN whose
checkpoint you want to use in the current job. Specifying this value will copy
over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config
s3Uri to the current job's checkpoint_config s3Uri, making it available at
checkpoint_config.localPath during the job execution. Default: None

checkpoint_config (CheckpointConfig | None): Configuration that specifies the
checkpoint_config (CheckpointConfig): Configuration that specifies the
location where checkpoint data is stored.
Default: CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints').

role_arn (str | None): A str providing the IAM role ARN used to execute the
role_arn (str): A str providing the IAM role ARN used to execute the
script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`.

stopping_condition (StoppingCondition | None): The maximum length of time, in seconds,
stopping_condition (StoppingCondition): The maximum length of time, in seconds,
and the maximum number of tasks that a job can run before being forcefully stopped.
Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60).

output_data_config (OutputDataConfig | None): Specifies the location for the output of
output_data_config (OutputDataConfig): Specifies the location for the output of
the job.
Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data',
kmsKeyId=None).

aws_session (AwsSession | None): AwsSession for connecting to AWS Services.
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()

tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this job.
tags (dict[str, str]): Dict specifying the key-value pairs for tagging this job.
Default: {}.

logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default is `getLogger(__name__)`

Returns:
Callable: the callable for creating a Hybrid Job.
"""
_validate_python_version(image_uri, aws_session)

def _hybrid_job(entry_point):
def _hybrid_job(entry_point: Callable) -> Callable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to know more about what these changes are

@functools.wraps(entry_point)
def job_wrapper(*args, **kwargs):
def job_wrapper(*args, **kwargs) -> Callable:
"""
The job wrapper.
Returns:
Callable: the callable for creating a Hybrid Job.
"""
with _IncludeModules(include_modules), tempfile.TemporaryDirectory(
dir="", prefix="decorator_job_"
) as temp_dir:
Expand Down Expand Up @@ -208,7 +216,7 @@ def job_wrapper(*args, **kwargs):
return _hybrid_job


def _validate_python_version(image_uri: str | None, aws_session: AwsSession | None = None):
def _validate_python_version(image_uri: str | None, aws_session: AwsSession | None = None) -> None:
"""Validate python version at job definition time"""
aws_session = aws_session or AwsSession()
# user provides a custom image_uri
Expand Down Expand Up @@ -257,7 +265,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) -> str:
"""Create an entry point from a function"""

def wrapped_entry_point():
def wrapped_entry_point() -> Any:
"""Partial function wrapping entry point with given parameters"""
return entry_point(*args, **kwargs)

Expand All @@ -277,7 +285,7 @@ def wrapped_entry_point():
)


def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict):
def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> Dict:
"""Capture function arguments as hyperparameters"""
signature = inspect.signature(entry_point)
bound_args = signature.bind(*args, **kwargs)
Expand Down Expand Up @@ -322,7 +330,7 @@ def _sanitize(hyperparameter: Any) -> str:
return sanitized


def _process_input_data(input_data):
def _process_input_data(input_data: Dict) -> List[str]:
"""
Create symlinks to data

Expand All @@ -336,12 +344,12 @@ def _process_input_data(input_data):
if not isinstance(input_data, dict):
input_data = {"input": input_data}

def matches(prefix):
def matches(prefix: str) -> List[str]:
return [
str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(str(prefix))
]

def is_prefix(path):
def is_prefix(path: str) -> bool:
return len(matches(path)) > 1 or not Path(path).exists()

prefix_channels = set()
Expand Down
4 changes: 2 additions & 2 deletions src/braket/jobs/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from enum import Enum
from functools import cache
from typing import Dict
from typing import Dict, Set


class Framework(str, Enum):
Expand All @@ -26,7 +26,7 @@ class Framework(str, Enum):
PL_PYTORCH = "PL_PYTORCH"


def built_in_images(region):
def built_in_images(region: str) -> Set[str]:
return {retrieve_image(framework, region) for framework in Framework}


Expand Down
8 changes: 4 additions & 4 deletions src/braket/jobs/quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ def prepare_quantum_job(
return create_job_kwargs


def _generate_default_job_name(image_uri: str | None = None, func: Callable | None = None) -> str:
def _generate_default_job_name(image_uri: str = None, func: Callable = None) -> str:
"""
Generate default job name using the image uri and entrypoint function.

Args:
image_uri (str | None): URI for the image container.
func (Callable | None): The entry point function.
image_uri (str): URI for the image container.
func (Callable): The entry point function.

Returns:
str: Hybrid job name.
Expand Down Expand Up @@ -361,7 +361,7 @@ def _validate_params(dict_arr: dict[str, tuple[any, any]]) -> None:
Validate that config parameters are of the right type.

Args:
dict_arr (Dict[str, Tuple[any, any]]): dict mapping parameter names to
dict_arr (dict[str, tuple[any, any]]): dict mapping parameter names to
a tuple containing the provided value and expected type.
"""
for parameter_name, value_tuple in dict_arr.items():
Expand Down