Skip to content

Commit

Permalink
fix: doc fixes (#748)
Browse files Browse the repository at this point in the history
  • Loading branch information
krneta authored Oct 18, 2023
1 parent 4cc0eba commit f4957ec
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/braket/jobs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class S3DataSourceConfig:
def __init__(
self,
s3_data: str,
content_type: str | None = None,
content_type: str = None,
):
"""Create a definition for input data used by a Braket Hybrid job.
Expand Down
6 changes: 2 additions & 4 deletions src/braket/jobs/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def save_job_checkpoint(
f.write(persisted_data.json())


def load_job_checkpoint(
job_name: str | None = None, checkpoint_file_suffix: str = ""
) -> dict[str, Any]:
def load_job_checkpoint(job_name: str = None, checkpoint_file_suffix: str = "") -> dict[str, Any]:
"""
Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint
file that ends with the `checkpoint_file_suffix`. The `job_name` can refer to any job whose
Expand All @@ -80,7 +78,7 @@ def load_job_checkpoint(
Args:
job_name (str | None): str that specifies the name of the job whose checkpoints
job_name (str): str that specifies the name of the job whose checkpoints
are to be loaded. Default: current job name.
checkpoint_file_suffix (str): str specifying the file suffix that is used to
Expand Down
78 changes: 43 additions & 35 deletions src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from logging import Logger, getLogger
from pathlib import Path
from types import ModuleType
from typing import Any
from typing import Any, Dict, List

import cloudpickle

Expand All @@ -45,22 +45,22 @@

def hybrid_job(
*,
device: str | None,
device: str,
include_modules: str | ModuleType | Iterable[str | ModuleType] = None,
dependencies: str | Path | None = None,
dependencies: str | Path = None,
local: bool = False,
job_name: str | None = None,
image_uri: str | None = None,
job_name: str = None,
image_uri: str = None,
input_data: str | dict | S3DataSourceConfig = None,
wait_until_complete: bool = False,
instance_config: InstanceConfig | None = None,
distribution: str | None = None,
copy_checkpoints_from_job: str | None = None,
checkpoint_config: CheckpointConfig | None = None,
role_arn: str | None = None,
stopping_condition: StoppingCondition | None = None,
output_data_config: OutputDataConfig | None = None,
aws_session: AwsSession | None = None,
instance_config: InstanceConfig = None,
distribution: str = None,
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
role_arn: str = None,
stopping_condition: StoppingCondition = None,
output_data_config: OutputDataConfig = None,
aws_session: AwsSession = None,
tags: dict[str, str] = None,
logger: Logger = getLogger(__name__),
) -> Callable:
Expand All @@ -73,7 +73,7 @@ def hybrid_job(
`copy_checkpoints_from_job`, `stopping_condition`, `tags`, and `logger`.
Args:
device (str | None): Device ARN of the QPU device that receives priority quantum
device (str): Device ARN of the QPU device that receives priority quantum
task queueing once the hybrid job begins running. Each QPU has a separate hybrid jobs
queue so that only one hybrid job is running at a time. The device string is accessible
in the hybrid job instance as the environment variable "AMZN_BRAKET_DEVICE_ARN".
Expand All @@ -85,20 +85,20 @@ def hybrid_job(
modules to be included. Any references to members of these modules in the hybrid job
algorithm code will be serialized as part of the algorithm code. Default value `[]`
dependencies (str | Path | None): Path (absolute or relative) to a requirements.txt
dependencies (str | Path): Path (absolute or relative) to a requirements.txt
file to be used for the hybrid job.
local (bool): Whether to use local mode for the hybrid job. Default `False`
job_name (str | None): A string that specifies the name with which the job is created.
job_name (str): A string that specifies the name with which the job is created.
Allowed pattern for job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`. Defaults to
f'{decorated-function-name}-{timestamp}'.
image_uri (str | None): A str that specifies the ECR image to use for executing the job.
image_uri (str): A str that specifies the ECR image to use for executing the job.
`retrieve_image()` function may be used for retrieving the ECR image URIs
for the containers supported by Braket. Default = `<Braket base image_uri>`.
input_data (str | Dict | S3DataSourceConfig): Information about the training
input_data (str | dict | S3DataSourceConfig): Information about the training
data. Dictionary maps channel names to local paths or S3 URIs. Contents found
at any local paths will be uploaded to S3 at
f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local
Expand All @@ -110,51 +110,59 @@ def hybrid_job(
This would tail the job logs as it waits. Otherwise `False`. Ignored if using
local mode. Default: `False`.
instance_config (InstanceConfig | None): Configuration of the instance(s) for running the
instance_config (InstanceConfig): Configuration of the instance(s) for running the
classical code for the hybrid job. Defaults to
`InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`.
distribution (str | None): A str that specifies how the job should be distributed.
distribution (str): A str that specifies how the job should be distributed.
If set to "data_parallel", the hyperparameters for the job will be set to use data
parallelism features for PyTorch or TensorFlow. Default: None.
copy_checkpoints_from_job (str | None): A str that specifies the job ARN whose
copy_checkpoints_from_job (str): A str that specifies the job ARN whose
checkpoint you want to use in the current job. Specifying this value will copy
over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config
s3Uri to the current job's checkpoint_config s3Uri, making it available at
checkpoint_config.localPath during the job execution. Default: None
checkpoint_config (CheckpointConfig | None): Configuration that specifies the
checkpoint_config (CheckpointConfig): Configuration that specifies the
location where checkpoint data is stored.
Default: CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints').
role_arn (str | None): A str providing the IAM role ARN used to execute the
role_arn (str): A str providing the IAM role ARN used to execute the
script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`.
stopping_condition (StoppingCondition | None): The maximum length of time, in seconds,
stopping_condition (StoppingCondition): The maximum length of time, in seconds,
and the maximum number of tasks that a job can run before being forcefully stopped.
Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60).
output_data_config (OutputDataConfig | None): Specifies the location for the output of
output_data_config (OutputDataConfig): Specifies the location for the output of
the job.
Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data',
kmsKeyId=None).
aws_session (AwsSession | None): AwsSession for connecting to AWS Services.
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()
tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this job.
tags (dict[str, str]): Dict specifying the key-value pairs for tagging this job.
Default: {}.
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default is `getLogger(__name__)`
Returns:
Callable: the callable for creating a Hybrid Job.
"""
_validate_python_version(image_uri, aws_session)

def _hybrid_job(entry_point):
def _hybrid_job(entry_point: Callable) -> Callable:
@functools.wraps(entry_point)
def job_wrapper(*args, **kwargs):
def job_wrapper(*args, **kwargs) -> Callable:
"""
The job wrapper.
Returns:
Callable: the callable for creating a Hybrid Job.
"""
with _IncludeModules(include_modules), tempfile.TemporaryDirectory(
dir="", prefix="decorator_job_"
) as temp_dir:
Expand Down Expand Up @@ -208,7 +216,7 @@ def job_wrapper(*args, **kwargs):
return _hybrid_job


def _validate_python_version(image_uri: str | None, aws_session: AwsSession | None = None):
def _validate_python_version(image_uri: str | None, aws_session: AwsSession | None = None) -> None:
"""Validate python version at job definition time"""
aws_session = aws_session or AwsSession()
# user provides a custom image_uri
Expand Down Expand Up @@ -257,7 +265,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) -> str:
"""Create an entry point from a function"""

def wrapped_entry_point():
def wrapped_entry_point() -> Any:
"""Partial function wrapping entry point with given parameters"""
return entry_point(*args, **kwargs)

Expand All @@ -277,7 +285,7 @@ def wrapped_entry_point():
)


def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict):
def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> Dict:
"""Capture function arguments as hyperparameters"""
signature = inspect.signature(entry_point)
bound_args = signature.bind(*args, **kwargs)
Expand Down Expand Up @@ -322,7 +330,7 @@ def _sanitize(hyperparameter: Any) -> str:
return sanitized


def _process_input_data(input_data):
def _process_input_data(input_data: Dict) -> List[str]:
"""
Create symlinks to data
Expand All @@ -336,12 +344,12 @@ def _process_input_data(input_data):
if not isinstance(input_data, dict):
input_data = {"input": input_data}

def matches(prefix):
def matches(prefix: str) -> List[str]:
return [
str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(str(prefix))
]

def is_prefix(path):
def is_prefix(path: str) -> bool:
return len(matches(path)) > 1 or not Path(path).exists()

prefix_channels = set()
Expand Down
4 changes: 2 additions & 2 deletions src/braket/jobs/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from enum import Enum
from functools import cache
from typing import Dict
from typing import Dict, Set


class Framework(str, Enum):
Expand All @@ -26,7 +26,7 @@ class Framework(str, Enum):
PL_PYTORCH = "PL_PYTORCH"


def built_in_images(region):
def built_in_images(region: str) -> Set[str]:
return {retrieve_image(framework, region) for framework in Framework}


Expand Down
8 changes: 4 additions & 4 deletions src/braket/jobs/quantum_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ def prepare_quantum_job(
return create_job_kwargs


def _generate_default_job_name(image_uri: str | None = None, func: Callable | None = None) -> str:
def _generate_default_job_name(image_uri: str = None, func: Callable = None) -> str:
"""
Generate default job name using the image uri and entrypoint function.
Args:
image_uri (str | None): URI for the image container.
func (Callable | None): The entry point function.
image_uri (str): URI for the image container.
func (Callable): The entry point function.
Returns:
str: Hybrid job name.
Expand Down Expand Up @@ -361,7 +361,7 @@ def _validate_params(dict_arr: dict[str, tuple[any, any]]) -> None:
Validate that config parameters are of the right type.
Args:
dict_arr (Dict[str, Tuple[any, any]]): dict mapping parameter names to
dict_arr (dict[str, tuple[any, any]]): dict mapping parameter names to
a tuple containing the provided value and expected type.
"""
for parameter_name, value_tuple in dict_arr.items():
Expand Down

0 comments on commit f4957ec

Please sign in to comment.