diff --git a/src/braket/jobs/config.py b/src/braket/jobs/config.py index b846a8bf5..432e740e4 100644 --- a/src/braket/jobs/config.py +++ b/src/braket/jobs/config.py @@ -63,7 +63,7 @@ class S3DataSourceConfig: def __init__( self, s3_data: str, - content_type: str | None = None, + content_type: str = None, ): """Create a definition for input data used by a Braket Hybrid job. diff --git a/src/braket/jobs/data_persistence.py b/src/braket/jobs/data_persistence.py index ccd9a47dd..6ef5a6d18 100644 --- a/src/braket/jobs/data_persistence.py +++ b/src/braket/jobs/data_persistence.py @@ -65,9 +65,7 @@ def save_job_checkpoint( f.write(persisted_data.json()) -def load_job_checkpoint( - job_name: str | None = None, checkpoint_file_suffix: str = "" -) -> dict[str, Any]: +def load_job_checkpoint(job_name: str = None, checkpoint_file_suffix: str = "") -> dict[str, Any]: """ Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint file that ends with the `checkpoint_file_suffix`. The `job_name` can refer to any job whose @@ -80,7 +78,7 @@ def load_job_checkpoint( Args: - job_name (str | None): str that specifies the name of the job whose checkpoints + job_name (str): str that specifies the name of the job whose checkpoints are to be loaded. Default: current job name. checkpoint_file_suffix (str): str specifying the file suffix that is used to diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index 521da7fe8..8de496f5e 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -25,7 +25,7 @@ from logging import Logger, getLogger from pathlib import Path from types import ModuleType -from typing import Any +from typing import Any, Dict, List import cloudpickle @@ -45,22 +45,22 @@ def hybrid_job( *, - device: str | None, + device: str, include_modules: str | ModuleType | Iterable[str | ModuleType] = None, - dependencies: str | Path | None = None, + dependencies: str | Path = None, local: bool = False, - job_name: str | None = None, - image_uri: str | None = None, + job_name: str = None, + image_uri: str = None, input_data: str | dict | S3DataSourceConfig = None, wait_until_complete: bool = False, - instance_config: InstanceConfig | None = None, - distribution: str | None = None, - copy_checkpoints_from_job: str | None = None, - checkpoint_config: CheckpointConfig | None = None, - role_arn: str | None = None, - stopping_condition: StoppingCondition | None = None, - output_data_config: OutputDataConfig | None = None, - aws_session: AwsSession | None = None, + instance_config: InstanceConfig = None, + distribution: str = None, + copy_checkpoints_from_job: str = None, + checkpoint_config: CheckpointConfig = None, + role_arn: str = None, + stopping_condition: StoppingCondition = None, + output_data_config: OutputDataConfig = None, + aws_session: AwsSession = None, tags: dict[str, str] = None, logger: Logger = getLogger(__name__), ) -> Callable: @@ -73,7 +73,7 @@ def hybrid_job( `copy_checkpoints_from_job`, `stopping_condition`, `tags`, and `logger`. Args: - device (str | None): Device ARN of the QPU device that receives priority quantum + device (str): Device ARN of the QPU device that receives priority quantum task queueing once the hybrid job begins running. Each QPU has a separate hybrid jobs queue so that only one hybrid job is running at a time. The device string is accessible in the hybrid job instance as the environment variable "AMZN_BRAKET_DEVICE_ARN". @@ -85,20 +85,20 @@ def hybrid_job( modules to be included. Any references to members of these modules in the hybrid job algorithm code will be serialized as part of the algorithm code. Default value `[]` - dependencies (str | Path | None): Path (absolute or relative) to a requirements.txt + dependencies (str | Path): Path (absolute or relative) to a requirements.txt file to be used for the hybrid job. local (bool): Whether to use local mode for the hybrid job. Default `False` - job_name (str | None): A string that specifies the name with which the job is created. + job_name (str): A string that specifies the name with which the job is created. Allowed pattern for job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`. Defaults to f'{decorated-function-name}-{timestamp}'. - image_uri (str | None): A str that specifies the ECR image to use for executing the job. + image_uri (str): A str that specifies the ECR image to use for executing the job. `retrieve_image()` function may be used for retrieving the ECR image URIs for the containers supported by Braket. Default = ``. - input_data (str | Dict | S3DataSourceConfig): Information about the training + input_data (str | dict | S3DataSourceConfig): Information about the training data. Dictionary maps channel names to local paths or S3 URIs. Contents found at any local paths will be uploaded to S3 at f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local @@ -110,51 +110,59 @@ def hybrid_job( This would tail the job logs as it waits. Otherwise `False`. Ignored if using local mode. Default: `False`. - instance_config (InstanceConfig | None): Configuration of the instance(s) for running the + instance_config (InstanceConfig): Configuration of the instance(s) for running the classical code for the hybrid job. Defaults to `InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`. - distribution (str | None): A str that specifies how the job should be distributed. + distribution (str): A str that specifies how the job should be distributed. If set to "data_parallel", the hyperparameters for the job will be set to use data parallelism features for PyTorch or TensorFlow. Default: None. - copy_checkpoints_from_job (str | None): A str that specifies the job ARN whose + copy_checkpoints_from_job (str): A str that specifies the job ARN whose checkpoint you want to use in the current job. Specifying this value will copy over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config s3Uri to the current job's checkpoint_config s3Uri, making it available at checkpoint_config.localPath during the job execution. Default: None - checkpoint_config (CheckpointConfig | None): Configuration that specifies the + checkpoint_config (CheckpointConfig): Configuration that specifies the location where checkpoint data is stored. Default: CheckpointConfig(localPath='/opt/jobs/checkpoints', s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints'). - role_arn (str | None): A str providing the IAM role ARN used to execute the + role_arn (str): A str providing the IAM role ARN used to execute the script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`. - stopping_condition (StoppingCondition | None): The maximum length of time, in seconds, + stopping_condition (StoppingCondition): The maximum length of time, in seconds, and the maximum number of tasks that a job can run before being forcefully stopped. Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). - output_data_config (OutputDataConfig | None): Specifies the location for the output of + output_data_config (OutputDataConfig): Specifies the location for the output of the job. Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', kmsKeyId=None). - aws_session (AwsSession | None): AwsSession for connecting to AWS Services. + aws_session (AwsSession): AwsSession for connecting to AWS Services. Default: AwsSession() - tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this job. + tags (dict[str, str]): Dict specifying the key-value pairs for tagging this job. Default: {}. logger (Logger): Logger object with which to write logs, such as task statuses while waiting for task to be in a terminal state. Default is `getLogger(__name__)` + + Returns: + Callable: the callable for creating a Hybrid Job. """ _validate_python_version(image_uri, aws_session) - def _hybrid_job(entry_point): + def _hybrid_job(entry_point: Callable) -> Callable: @functools.wraps(entry_point) - def job_wrapper(*args, **kwargs): + def job_wrapper(*args, **kwargs) -> Callable: + """ + The job wrapper. + Returns: + Callable: the callable for creating a Hybrid Job. + """ with _IncludeModules(include_modules), tempfile.TemporaryDirectory( dir="", prefix="decorator_job_" ) as temp_dir: @@ -208,7 +216,7 @@ def job_wrapper(*args, **kwargs): return _hybrid_job -def _validate_python_version(image_uri: str | None, aws_session: AwsSession | None = None): +def _validate_python_version(image_uri: str | None, aws_session: AwsSession | None = None) -> None: """Validate python version at job definition time""" aws_session = aws_session or AwsSession() # user provides a custom image_uri @@ -257,7 +265,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _serialize_entry_point(entry_point: Callable, args: tuple, kwargs: dict) -> str: """Create an entry point from a function""" - def wrapped_entry_point(): + def wrapped_entry_point() -> Any: """Partial function wrapping entry point with given parameters""" return entry_point(*args, **kwargs) @@ -277,7 +285,7 @@ def wrapped_entry_point(): ) -def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict): +def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> Dict: """Capture function arguments as hyperparameters""" signature = inspect.signature(entry_point) bound_args = signature.bind(*args, **kwargs) @@ -322,7 +330,7 @@ def _sanitize(hyperparameter: Any) -> str: return sanitized -def _process_input_data(input_data): +def _process_input_data(input_data: Dict) -> List[str]: """ Create symlinks to data @@ -336,12 +344,12 @@ def _process_input_data(input_data): if not isinstance(input_data, dict): input_data = {"input": input_data} - def matches(prefix): + def matches(prefix: str) -> List[str]: return [ str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(str(prefix)) ] - def is_prefix(path): + def is_prefix(path: str) -> bool: return len(matches(path)) > 1 or not Path(path).exists() prefix_channels = set() diff --git a/src/braket/jobs/image_uris.py b/src/braket/jobs/image_uris.py index eedc5e795..3a3346abe 100644 --- a/src/braket/jobs/image_uris.py +++ b/src/braket/jobs/image_uris.py @@ -15,7 +15,7 @@ import os from enum import Enum from functools import cache -from typing import Dict +from typing import Dict, Set class Framework(str, Enum): @@ -26,7 +26,7 @@ class Framework(str, Enum): PL_PYTORCH = "PL_PYTORCH" -def built_in_images(region): +def built_in_images(region: str) -> Set[str]: return {retrieve_image(framework, region) for framework in Framework} diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py index 65d674c28..e34df0508 100644 --- a/src/braket/jobs/quantum_job_creation.py +++ b/src/braket/jobs/quantum_job_creation.py @@ -232,13 +232,13 @@ def prepare_quantum_job( return create_job_kwargs -def _generate_default_job_name(image_uri: str | None = None, func: Callable | None = None) -> str: +def _generate_default_job_name(image_uri: str = None, func: Callable = None) -> str: """ Generate default job name using the image uri and entrypoint function. Args: - image_uri (str | None): URI for the image container. - func (Callable | None): The entry point function. + image_uri (str): URI for the image container. + func (Callable): The entry point function. Returns: str: Hybrid job name. @@ -361,7 +361,7 @@ def _validate_params(dict_arr: dict[str, tuple[any, any]]) -> None: Validate that config parameters are of the right type. Args: - dict_arr (Dict[str, Tuple[any, any]]): dict mapping parameter names to + dict_arr (dict[str, tuple[any, any]]): dict mapping parameter names to a tuple containing the provided value and expected type. """ for parameter_name, value_tuple in dict_arr.items():