Skip to content

Commit

Permalink
fix: Aqua persistent tasks test and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws committed Nov 15, 2023
1 parent c8e2616 commit 81509fd
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 54 deletions.
6 changes: 2 additions & 4 deletions PyBraket/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
using Test, Aqua, Braket, Braket.AWS, PyBraket

withenv("JULIA_CONDAPKG_VERBOSITY"=>"-1") do
Aqua.test_all(PyBraket, ambiguities=false, unbound_args=false, piracies=false)
Aqua.test_ambiguities(PyBraket)
end
Aqua.test_all(PyBraket, ambiguities=false, unbound_args=false, piracies=false, persistent_tasks=false)
Aqua.test_ambiguities(PyBraket)

function set_aws_creds(test_type)
if test_type == "unit"
Expand Down
1 change: 1 addition & 0 deletions docs/src/device.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
```@docs
Device
AwsDevice
Braket.BraketDevice
isavailable
search_devices
get_devices
Expand Down
35 changes: 18 additions & 17 deletions src/aws_jobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,28 @@ function prepare_quantum_job(device::String, source_module::String, j_opts::Jobs
return (algo_spec=algo_spec, token=token, dev_conf=dev_conf, inst_conf=inst_conf, job_name=j_opts.job_name, out_conf=out_conf, role_arn=j_opts.role_arn, params=params)
end

function AwsQuantumJob(device::String, source_module::String, job_opts::JobsOptions)
args = prepare_quantum_job(device, source_module, job_opts)
algo_spec = args[:algo_spec]
token = args[:token]
dev_conf = args[:dev_conf]
inst_conf = args[:inst_conf]
job_name = args[:job_name]
out_conf = args[:out_conf]
role_arn = args[:role_arn]
params = args[:params]
response = BRAKET.create_job(algo_spec, token, dev_conf, inst_conf, job_name, out_conf, role_arn, params)
job = AwsQuantumJob(response["jobArn"])
job_opts.wait_until_complete && logs(job, wait=true)
return job
end

"""
AwsQuantumJob(device::String, source_module::String; kwargs...)
AwsQuantumJob(device::Union{String, BraketDevice}, source_module::String; kwargs...)
Create and launch an `AwsQuantumJob` which will use device `device` (a managed simulator, a QPU, or an [embedded simulator](https://docs.aws.amazon.com/braket/latest/developerguide/pennylane-embedded-simulators.html))
and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. The keyword arguments
`kwargs` control the launch configuration of the job.
`kwargs` control the launch configuration of the job. `device` can be either the device's ARN as a `String`, or a [`BraketDevice`](@ref).
# Keyword Arguments
- `entry_point::String` - the function to run in `source_module` if `source_module` is a Python module/Julia package. Defaults to an empty string, in which case
Expand Down Expand Up @@ -517,20 +533,5 @@ and will run the code (either a single file, or a Julia package, or a Python mod
The default is `CheckpointConfig("/opt/jobs/checkpoints", "s3://{default_bucket_name}/jobs/{job_name}/checkpoints")`.
- `tags::Dict{String, String}` - specifies the key-value pairs for tagging this job.
"""
function AwsQuantumJob(device::String, source_module::String, job_opts::JobsOptions)
args = prepare_quantum_job(device, source_module, job_opts)
algo_spec = args[:algo_spec]
token = args[:token]
dev_conf = args[:dev_conf]
inst_conf = args[:inst_conf]
job_name = args[:job_name]
out_conf = args[:out_conf]
role_arn = args[:role_arn]
params = args[:params]
response = BRAKET.create_job(algo_spec, token, dev_conf, inst_conf, job_name, out_conf, role_arn, params)
job = AwsQuantumJob(response["jobArn"])
job_opts.wait_until_complete && logs(job, wait=true)
return job
end
AwsQuantumJob(device::String, source_module::String; kwargs...) = AwsQuantumJob(device, source_module, JobsOptions(; kwargs...))
AwsQuantumJob(device::BraketDevice, source_module::String; kwargs...) = AwsQuantumJob(convert(String, device), source_module, JobsOptions(; kwargs...))
15 changes: 15 additions & 0 deletions src/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,27 @@ const _GET_DEVICES_ORDER_BY_KEYS = Set(("arn", "name", "type", "provider_name",
@enum AwsDeviceType SIMULATOR QPU
const AwsDeviceTypeDict = Dict("SIMULATOR"=>SIMULATOR, "QPU"=>QPU)

"""
BraketDevice
An abstract type representing one of the devices available on Amazon Braket, which will automatically
generate its ARN when passed to the appropriate function.
# Examples
```jldoctest
julia> d = Braket.SV1()
julia> arn(d)
"arn:aws:braket:::device/quantum-simulator/amazon/sv1"
```
"""
abstract type BraketDevice end
for provider in (:AmazonDevice, :_XanaduDevice, :_DWaveDevice, :OQCDevice, :QuEraDevice, :IonQDevice, :RigettiDevice)
@eval begin
abstract type $provider <: BraketDevice end
end
end
arn(d::BraketDevice) = convert(String, d)

for (d, d_arn) in zip((:SV1, :DM1, :TN1), ("sv1", "dm1", "tn1"))
@eval begin
Expand Down
66 changes: 34 additions & 32 deletions src/local_jobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,43 @@ mutable struct LocalQuantumJob <: Job
end
end

function LocalQuantumJob(
device::String,
source_module::String,
j_opts::JobsOptions;
force_update::Bool=false,
config::AWSConfig=global_aws_config()
)
image_uri = isempty(j_opts.image_uri) ? retrieve_image(BASE, config) : j_opts.image_uri
args = prepare_quantum_job(device, source_module, j_opts)
algo_spec = args[:algo_spec]
job_name = args[:job_name]
ispath(job_name) && throw(ErrorException("a local directory called $job_name already exists. Please use a different job name."))
image_uri = haskey(algo_spec, "containerImage") ? algo_spec["containerImage"]["uri"] : retrieve_image(BASE, config)

run_log = ""
let local_job_container=LocalJobContainer(image_uri, args, force_update=force_update)
local_job_container = run_local_job!(local_job_container)
# copy results out
copy_from_container!(local_job_container, "/opt/ml/model", job_name)
!ispath(job_name) && mkdir(job_name)
write(joinpath(job_name, "log.txt"), local_job_container.run_log)
if haskey(args, :params) && haskey(args[:params], "checkpointConfig") && haskey(args[:params]["checkpointConfig"], "localPath")
checkpoint_path = args[:params]["checkpointConfig"]["localPath"]
copy_from_container!(local_job_container, checkpoint_path, joinpath(job_name, "checkpoints"))
end
run_log = local_job_container.run_log
stop_container!(local_job_container)
end
return LocalQuantumJob("local:job/$job_name", run_log=run_log)
end

"""
LocalQuantumJob(device::String, source_module::String; kwargs...)
LocalQuantumJob(device::Union{String, BraketDevice}, source_module::String; kwargs...)
Create and launch a `LocalQuantumJob` which will use device `device` (a managed simulator, a QPU, or an [embedded simulator](https://docs.aws.amazon.com/braket/latest/developerguide/pennylane-embedded-simulators.html))
and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. A *local* job
and will run the code (either a single file, or a Julia package, or a Python module) located at `source_module`. `device` can be either the device's ARN as a `String`, or a [`BraketDevice`](@ref).
A *local* job
runs *locally* on your computational resource by launching the Job container locally using `docker`. The job will block
until it completes, replicating the `wait_until_complete` behavior of [`AwsQuantumJob`](@ref).
Expand Down Expand Up @@ -325,36 +357,6 @@ The keyword arguments `kwargs` control the launch configuration of the job.
The default is `CheckpointConfig("/opt/jobs/checkpoints", "s3://{default_bucket_name}/jobs/{job_name}/checkpoints")`.
- `tags::Dict{String, String}` - specifies the key-value pairs for tagging this job.
"""
function LocalQuantumJob(
device::String,
source_module::String,
j_opts::JobsOptions;
force_update::Bool=false,
config::AWSConfig=global_aws_config()
)
image_uri = isempty(j_opts.image_uri) ? retrieve_image(BASE, config) : j_opts.image_uri
args = prepare_quantum_job(device, source_module, j_opts)
algo_spec = args[:algo_spec]
job_name = args[:job_name]
ispath(job_name) && throw(ErrorException("a local directory called $job_name already exists. Please use a different job name."))
image_uri = haskey(algo_spec, "containerImage") ? algo_spec["containerImage"]["uri"] : retrieve_image(BASE, config)

run_log = ""
let local_job_container=LocalJobContainer(image_uri, args, force_update=force_update)
local_job_container = run_local_job!(local_job_container)
# copy results out
copy_from_container!(local_job_container, "/opt/ml/model", job_name)
!ispath(job_name) && mkdir(job_name)
write(joinpath(job_name, "log.txt"), local_job_container.run_log)
if haskey(args, :params) && haskey(args[:params], "checkpointConfig") && haskey(args[:params]["checkpointConfig"], "localPath")
checkpoint_path = args[:params]["checkpointConfig"]["localPath"]
copy_from_container!(local_job_container, checkpoint_path, joinpath(job_name, "checkpoints"))
end
run_log = local_job_container.run_log
stop_container!(local_job_container)
end
return LocalQuantumJob("local:job/$job_name", run_log=run_log)
end
LocalQuantumJob(device::String, source_module::String; force_update::Bool=false, config::AWSConfig=global_aws_config(), kwargs...) = LocalQuantumJob(device, source_module, JobsOptions(; kwargs...); force_update=force_update, config=config)
LocalQuantumJob(device::BraketDevice, source_module::String; kwargs...) = LocalQuantumJob(convert(String, device), source_module; kwargs...)

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Pkg, Test, Aqua, Braket

in_ci = tryparse(Bool, get(ENV, "BRAKET_CI", "false"))
Aqua.test_all(Braket, ambiguities=false, unbound_args=false, piracies=false, stale_deps=!in_ci, deps_compat=!in_ci)
Aqua.test_all(Braket, ambiguities=false, unbound_args=false, piracies=false, stale_deps=!in_ci, deps_compat=!in_ci, persistent_tasks=false)
Aqua.test_ambiguities(Braket)
Aqua.test_piracies(Braket, treat_as_own=[Braket.DecFP.Dec128])

Expand Down

0 comments on commit 81509fd

Please sign in to comment.