Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change: Update to new BraketSimulator interface #90

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
# don't run on draft PRs
if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }}
# allow windows to fail
continue-on-error: ${{ matrix.os == 'windows-latest' }}
continue-on-error: ${{ matrix.os == 'windows-latest' || matrix.group == 'PyBraket-unit' }}
strategy:
fail-fast: true
max-parallel: 2
Expand Down Expand Up @@ -61,7 +61,7 @@ jobs:
# don't run on draft PRs
if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }}
# allow failures on nightly or beta Julia
continue-on-error: ${{ matrix.version == 'nightly'}}
continue-on-error: ${{ matrix.version == 'nightly' || matrix.group == 'PyBraket-unit' }}
strategy:
fail-fast: true
max-parallel: 2
Expand Down
6 changes: 4 additions & 2 deletions PyBraket/CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
python = ">=3"
python = ">=3.9,<=3.11"
pydantic = ""
scipy = ""
numpy = ""

[pip.deps]
amazon-braket-sdk = ">=1.70.0"
amazon-braket-sdk = ">=1.81.0"
amazon-braket-default-simulator = ">=1.10.0"
urllib3 = "<2"
botocore = ">=1.34"
35 changes: 9 additions & 26 deletions src/local_simulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@
"""
function simulate(d::LocalSimulator, task_spec::Union{Circuit, AnalogHamiltonianSimulation, AbstractProgram}, args...; shots::Int=0, inputs::Dict{String, Float64} = Dict{String, Float64}(), kwargs...)
sim = d._delegate
@debug "Single task. Starting run..."
stats = @timed _run_internal(sim, task_spec, args...; inputs=inputs, shots=shots, kwargs...)
@debug "Single task. Time to run internally: $(stats.time). GC time: $(stats.gctime)."
local_result = stats.value
local_result = _run_internal(sim, task_spec, args...; inputs=inputs, shots=shots, kwargs...)
return LocalQuantumTask(local_result.task_metadata.id, local_result)
end

Expand Down Expand Up @@ -152,45 +149,31 @@
program = ir(circuit, Val(:OpenQASM))
full_inputs = isnothing(program.inputs) ? inputs : merge(program.inputs, inputs)
full_program = OpenQasmProgram(program.braketSchemaHeader, program.source, full_inputs)
r = simulate(simulator, full_program, args...; shots=shots, kwargs...)
r = simulate(simulator, full_program, shots; kwargs...)

Check warning on line 152 in src/local_simulator.jl

View check run for this annotation

Codecov / codecov/patch

src/local_simulator.jl#L152

Added line #L152 was not covered by tests
return format_result(r)
elseif haskey(properties(simulator).action, "braket.ir.jaqcd.program")
validate_circuit_and_shots(circuit, shots)
validate_circuit_and_shots(circuit, shots)

Check warning on line 155 in src/local_simulator.jl

View check run for this annotation

Codecov / codecov/patch

src/local_simulator.jl#L155

Added line #L155 was not covered by tests
program = ir(circuit, Val(:JAQCD))
qubits = qubit_count(circuit)
r = simulate(simulator, program, qubits, args...; shots=shots, inputs=inputs, kwargs...)
r = simulate(simulator, program, qubits, shots; inputs=inputs, kwargs...)

Check warning on line 158 in src/local_simulator.jl

View check run for this annotation

Codecov / codecov/patch

src/local_simulator.jl#L158

Added line #L158 was not covered by tests
return format_result(r)
else
throw(ErrorException("$(typeof(simulator)) does not support qubit gate-based programs."))
end
end
function _run_internal(simulator, program::OpenQasmProgram, args...; shots::Int=0, inputs::Dict{String, Float64}=Dict{String, Float64}(), kwargs...)
if haskey(properties(simulator).action, "braket.ir.openqasm.program")
stats = @timed begin
simulate(simulator, program; shots=shots, inputs=inputs, kwargs...)
end
@debug "Time to invoke simulator: $(stats.time)"
r = stats.value
stats = @timed format_result(r)
@debug "Time to format results: $(stats.time)"
return stats.value
r = simulate(simulator, program, shots; inputs=inputs, kwargs...)
return format_result(r)

Check warning on line 167 in src/local_simulator.jl

View check run for this annotation

Codecov / codecov/patch

src/local_simulator.jl#L166-L167

Added lines #L166 - L167 were not covered by tests
else
throw(ErrorException("$(typeof(simulator)) does not support qubit gate-based programs."))
end
end
function _run_internal(simulator, program::Program, args...; shots::Int=0, inputs::Dict{String, Float64}=Dict{String, Float64}(), kwargs...)
if haskey(properties(simulator).action, "braket.ir.jaqcd.program")
stats = @timed qubit_count(program)
@debug "Time to get qubit count: $(stats.time)"
qubits = stats.value
stats = @timed begin
simulate(simulator, program, qubits, args...; shots=shots, inputs=inputs, kwargs...)
end
@debug "Time to invoke simulator: $(stats.time)"
r = stats.value
stats = @timed format_result(r)
@debug "Time to format results: $(stats.time)"
return stats.value
qubits = qubit_count(program)
r = simulate(simulator, program, qubits, shots; inputs=inputs, kwargs...)
return format_result(r)

Check warning on line 176 in src/local_simulator.jl

View check run for this annotation

Codecov / codecov/patch

src/local_simulator.jl#L174-L176

Added lines #L174 - L176 were not covered by tests
else
throw(ErrorException("$(typeof(simulator)) does not support qubit gate-based programs."))
end
Expand Down
2 changes: 1 addition & 1 deletion test/task_batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ XANADU_ARN = "arn:aws:braket:us-east-1::device/qpu/xanadu/Borealis"
n_tasks = 10
raw_tasks = [Braket.AwsQuantumTask("arn:fake:$i") for i in 1:n_tasks]
specs = [c for ix in 1:n_tasks]
t = Braket.AwsQuantumTaskBatch(deepcopy(raw_tasks), nothing, Set{String}(), "fake_device", specs, ("fake_bucket", "fake_prefix"), 100, 1, 10)
t = Braket.AwsQuantumTaskBatch(copy(raw_tasks), nothing, Set{String}(), "fake_device", specs, ("fake_bucket", "fake_prefix"), 100, 1, 10)
t._results = fill("fake_result", n_tasks)
@test isnothing(Braket.retry_unsuccessful_tasks(t))
t._results = convert(Vector{Any}, fill(nothing, n_tasks))
Expand Down
Loading