From 81509fd9bf576110f1e5fd779c0f027d2928cf31 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 15 Nov 2023 22:50:39 +0000 Subject: [PATCH] fix: Aqua persistent tasks test and docs --- PyBraket/test/runtests.jl | 6 ++-- docs/src/device.md | 1 + src/aws_jobs.jl | 35 +++++++++++---------- src/device.jl | 15 +++++++++ src/local_jobs.jl | 66 ++++++++++++++++++++------------------- test/runtests.jl | 2 +- 6 files changed, 71 insertions(+), 54 deletions(-) diff --git a/PyBraket/test/runtests.jl b/PyBraket/test/runtests.jl index 604fe04f..ef624576 100644 --- a/PyBraket/test/runtests.jl +++ b/PyBraket/test/runtests.jl @@ -1,9 +1,7 @@ using Test, Aqua, Braket, Braket.AWS, PyBraket -withenv("JULIA_CONDAPKG_VERBOSITY"=>"-1") do - Aqua.test_all(PyBraket, ambiguities=false, unbound_args=false, piracies=false) - Aqua.test_ambiguities(PyBraket) -end +Aqua.test_all(PyBraket, ambiguities=false, unbound_args=false, piracies=false, persistent_tasks=false) +Aqua.test_ambiguities(PyBraket) function set_aws_creds(test_type) if test_type == "unit" diff --git a/docs/src/device.md b/docs/src/device.md index ff3a8617..f1bab3a6 100644 --- a/docs/src/device.md +++ b/docs/src/device.md @@ -3,6 +3,7 @@ ```@docs Device AwsDevice +Braket.BraketDevice isavailable search_devices get_devices diff --git a/src/aws_jobs.jl b/src/aws_jobs.jl index 13eea178..2e3bc9e1 100644 --- a/src/aws_jobs.jl +++ b/src/aws_jobs.jl @@ -474,12 +474,28 @@ function prepare_quantum_job(device::String, source_module::String, j_opts::Jobs return (algo_spec=algo_spec, token=token, dev_conf=dev_conf, inst_conf=inst_conf, job_name=j_opts.job_name, out_conf=out_conf, role_arn=j_opts.role_arn, params=params) end +function AwsQuantumJob(device::String, source_module::String, job_opts::JobsOptions) + args = prepare_quantum_job(device, source_module, job_opts) + algo_spec = args[:algo_spec] + token = args[:token] + dev_conf = args[:dev_conf] + inst_conf = args[:inst_conf] + job_name = args[:job_name] + out_conf = args[:out_conf] + role_arn = args[:role_arn] + params = args[:params] + response = BRAKET.create_job(algo_spec, token, dev_conf, inst_conf, job_name, out_conf, role_arn, params) + job = AwsQuantumJob(response["jobArn"]) + job_opts.wait_until_complete && logs(job, wait=true) + return job +end + """ - AwsQuantumJob(device::String, source_module::String; kwargs...) + AwsQuantumJob(device::Union{String, BraketDevice}, source_module::String; kwargs...) Create and launch an `AwsQuantumJob` which will use device `device` (a managed simulator, a QPU, or an [embedded simulator](https://docs.aws.amazon.com/braket/latest/developerguide/pennylane-embedded-simulators.html)) and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. The keyword arguments -`kwargs` control the launch configuration of the job. +`kwargs` control the launch configuration of the job. `device` can be either the device's ARN as a `String`, or a [`BraketDevice`](@ref). # Keyword Arguments - `entry_point::String` - the function to run in `source_module` if `source_module` is a Python module/Julia package. Defaults to an empty string, in which case @@ -517,20 +533,5 @@ and will run the code (either a single file, or a Julia package, or a Python mod The default is `CheckpointConfig("/opt/jobs/checkpoints", "s3://{default_bucket_name}/jobs/{job_name}/checkpoints")`. - `tags::Dict{String, String}` - specifies the key-value pairs for tagging this job. """ -function AwsQuantumJob(device::String, source_module::String, job_opts::JobsOptions) - args = prepare_quantum_job(device, source_module, job_opts) - algo_spec = args[:algo_spec] - token = args[:token] - dev_conf = args[:dev_conf] - inst_conf = args[:inst_conf] - job_name = args[:job_name] - out_conf = args[:out_conf] - role_arn = args[:role_arn] - params = args[:params] - response = BRAKET.create_job(algo_spec, token, dev_conf, inst_conf, job_name, out_conf, role_arn, params) - job = AwsQuantumJob(response["jobArn"]) - job_opts.wait_until_complete && logs(job, wait=true) - return job -end AwsQuantumJob(device::String, source_module::String; kwargs...) = AwsQuantumJob(device, source_module, JobsOptions(; kwargs...)) AwsQuantumJob(device::BraketDevice, source_module::String; kwargs...) = AwsQuantumJob(convert(String, device), source_module, JobsOptions(; kwargs...)) diff --git a/src/device.jl b/src/device.jl index 0b87687f..3b402bce 100644 --- a/src/device.jl +++ b/src/device.jl @@ -8,12 +8,27 @@ const _GET_DEVICES_ORDER_BY_KEYS = Set(("arn", "name", "type", "provider_name", @enum AwsDeviceType SIMULATOR QPU const AwsDeviceTypeDict = Dict("SIMULATOR"=>SIMULATOR, "QPU"=>QPU) +""" + BraketDevice + +An abstract type representing one of the devices available on Amazon Braket, which will automatically +generate its ARN when passed to the appropriate function. + +# Examples +```jldoctest +julia> d = Braket.SV1() + +julia> arn(d) +"arn:aws:braket:::device/quantum-simulator/amazon/sv1" +``` +""" abstract type BraketDevice end for provider in (:AmazonDevice, :_XanaduDevice, :_DWaveDevice, :OQCDevice, :QuEraDevice, :IonQDevice, :RigettiDevice) @eval begin abstract type $provider <: BraketDevice end end end +arn(d::BraketDevice) = convert(String, d) for (d, d_arn) in zip((:SV1, :DM1, :TN1), ("sv1", "dm1", "tn1")) @eval begin diff --git a/src/local_jobs.jl b/src/local_jobs.jl index 586739a6..e1208573 100644 --- a/src/local_jobs.jl +++ b/src/local_jobs.jl @@ -283,11 +283,43 @@ mutable struct LocalQuantumJob <: Job end end +function LocalQuantumJob( + device::String, + source_module::String, + j_opts::JobsOptions; + force_update::Bool=false, + config::AWSConfig=global_aws_config() + ) + image_uri = isempty(j_opts.image_uri) ? retrieve_image(BASE, config) : j_opts.image_uri + args = prepare_quantum_job(device, source_module, j_opts) + algo_spec = args[:algo_spec] + job_name = args[:job_name] + ispath(job_name) && throw(ErrorException("a local directory called $job_name already exists. Please use a different job name.")) + image_uri = haskey(algo_spec, "containerImage") ? algo_spec["containerImage"]["uri"] : retrieve_image(BASE, config) + + run_log = "" + let local_job_container=LocalJobContainer(image_uri, args, force_update=force_update) + local_job_container = run_local_job!(local_job_container) + # copy results out + copy_from_container!(local_job_container, "/opt/ml/model", job_name) + !ispath(job_name) && mkdir(job_name) + write(joinpath(job_name, "log.txt"), local_job_container.run_log) + if haskey(args, :params) && haskey(args[:params], "checkpointConfig") && haskey(args[:params]["checkpointConfig"], "localPath") + checkpoint_path = args[:params]["checkpointConfig"]["localPath"] + copy_from_container!(local_job_container, checkpoint_path, joinpath(job_name, "checkpoints")) + end + run_log = local_job_container.run_log + stop_container!(local_job_container) + end + return LocalQuantumJob("local:job/$job_name", run_log=run_log) +end + """ - LocalQuantumJob(device::String, source_module::String; kwargs...) + LocalQuantumJob(device::Union{String, BraketDevice}, source_module::String; kwargs...) Create and launch a `LocalQuantumJob` which will use device `device` (a managed simulator, a QPU, or an [embedded simulator](https://docs.aws.amazon.com/braket/latest/developerguide/pennylane-embedded-simulators.html)) -and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. A *local* job +and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. `device` can be either the device's ARN as a `String`, or a [`BraketDevice`](@ref). +A *local* job runs *locally* on your computational resource by launching the Job container locally using `docker`. The job will block until it completes, replicating the `wait_until_complete` behavior of [`AwsQuantumJob`](@ref). @@ -325,36 +357,6 @@ The keyword arguments `kwargs` control the launch configuration of the job. The default is `CheckpointConfig("/opt/jobs/checkpoints", "s3://{default_bucket_name}/jobs/{job_name}/checkpoints")`. - `tags::Dict{String, String}` - specifies the key-value pairs for tagging this job. """ -function LocalQuantumJob( - device::String, - source_module::String, - j_opts::JobsOptions; - force_update::Bool=false, - config::AWSConfig=global_aws_config() - ) - image_uri = isempty(j_opts.image_uri) ? retrieve_image(BASE, config) : j_opts.image_uri - args = prepare_quantum_job(device, source_module, j_opts) - algo_spec = args[:algo_spec] - job_name = args[:job_name] - ispath(job_name) && throw(ErrorException("a local directory called $job_name already exists. Please use a different job name.")) - image_uri = haskey(algo_spec, "containerImage") ? algo_spec["containerImage"]["uri"] : retrieve_image(BASE, config) - - run_log = "" - let local_job_container=LocalJobContainer(image_uri, args, force_update=force_update) - local_job_container = run_local_job!(local_job_container) - # copy results out - copy_from_container!(local_job_container, "/opt/ml/model", job_name) - !ispath(job_name) && mkdir(job_name) - write(joinpath(job_name, "log.txt"), local_job_container.run_log) - if haskey(args, :params) && haskey(args[:params], "checkpointConfig") && haskey(args[:params]["checkpointConfig"], "localPath") - checkpoint_path = args[:params]["checkpointConfig"]["localPath"] - copy_from_container!(local_job_container, checkpoint_path, joinpath(job_name, "checkpoints")) - end - run_log = local_job_container.run_log - stop_container!(local_job_container) - end - return LocalQuantumJob("local:job/$job_name", run_log=run_log) -end LocalQuantumJob(device::String, source_module::String; force_update::Bool=false, config::AWSConfig=global_aws_config(), kwargs...) = LocalQuantumJob(device, source_module, JobsOptions(; kwargs...); force_update=force_update, config=config) LocalQuantumJob(device::BraketDevice, source_module::String; kwargs...) = LocalQuantumJob(convert(String, device), source_module; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index d7b3d5ce..d558e546 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ using Pkg, Test, Aqua, Braket in_ci = tryparse(Bool, get(ENV, "BRAKET_CI", "false")) -Aqua.test_all(Braket, ambiguities=false, unbound_args=false, piracies=false, stale_deps=!in_ci, deps_compat=!in_ci) +Aqua.test_all(Braket, ambiguities=false, unbound_args=false, piracies=false, stale_deps=!in_ci, deps_compat=!in_ci, persistent_tasks=false) Aqua.test_ambiguities(Braket) Aqua.test_piracies(Braket, treat_as_own=[Braket.DecFP.Dec128])