Skip to content

Commit

Permalink
change: use nospecialize for some calls for compilation and make sure… (
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws authored Jul 3, 2024
1 parent a389d80 commit 06f2594
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion PyBraket/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Braket = "=0.9.1"
CondaPkg = "=0.2.22"
DataStructures = "=0.18.20"
LinearAlgebra = "1.6"
PythonCall = "=0.9.19"
PythonCall = "=0.9.20"
Statistics = "1"
StructTypes = "=1.10.0"
Test = "1.6"
Expand Down
2 changes: 1 addition & 1 deletion PyBraket/src/local_simulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
Braket.simulate(d::PyLocalSimulator, task_spec::Braket.AnalogHamiltonianSimulation; shots::Int=0, kwargs...) = simulate(d, ir(task_spec); shots=shots, kwargs...)

function Braket._run_internal(simulator::PyLocalSimulator, task_spec::AnalogHamiltonianSimulation, args...; kwargs...)
raw_py_result = simulator._run_internal(Py(ir(task_spec)), args...; kwargs...)
raw_py_result = simulator.run(Py(ir(task_spec)), args...; kwargs...).result()
jl_task_metadata = pyconvert(Braket.TaskMetadata, raw_py_result.task_metadata)
jl_measurements = map(raw_py_result.measurements) do m
jl_status = pyconvert(String, pystr(m.status))
Expand Down
7 changes: 4 additions & 3 deletions src/circuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,9 @@ QubitSet with 2 elements:
qubits(c::Circuit) = (qs = union!(copy(c.moments._qubits), c.qubit_observable_set); QubitSet(qs))
function qubits(p::Program)
inst_qubits = mapreduce(ix->ix.target, union, p.instructions, init=Set{Int}())
bri_qubits = mapreduce(ix->ix.target, union, p.basis_rotation_instructions, init=Set{Int}())
res_qubits = mapreduce(ix->(hasproperty(ix, :targets) && !isnothing(ix.targets)) ? reduce(vcat, ix.targets) : Set{Int}(), union, p.results, init=Set{Int}())
return union(inst_qubits, res_qubits)
return union(inst_qubits, bri_qubits, res_qubits)
end
"""
qubit_count(c::Circuit) -> Int
Expand Down Expand Up @@ -584,7 +585,7 @@ function add_instruction!(c::Circuit, ix::Instruction{O}) where {O<:Operator}
return c
end

function add_instruction!(c::Circuit, ix::Instruction{O}, target) where {O<:Operator}
function add_instruction!(c::Circuit, @nospecialize(ix::Instruction{O}), target) where {O<:Operator}
to_add = Instruction[]
if qubit_count(ix.operator) == 1
to_add = [remap(ix, q) for q in target]
Expand All @@ -595,7 +596,7 @@ function add_instruction!(c::Circuit, ix::Instruction{O}, target) where {O<:Oper
return c
end

function add_instruction!(c::Circuit, ix::Instruction{O}, target_mapping::Dict{<:Integer, <:Integer}) where {O<:Operator}
function add_instruction!(c::Circuit, @nospecialize(ix::Instruction{O}), target_mapping::Dict{<:Integer, <:Integer}) where {O<:Operator}
to_add = [remap(ix, target_mapping)]
foreach(ix->add_instruction!(c, ix), to_add)
return c
Expand Down
2 changes: 1 addition & 1 deletion src/local_simulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function simulate(d::LocalSimulator, task_specs::Vector{T}, args...; shots::Int=
end
end
@debug "batch size is $(length(task_specs)). Time to run internally: $(stats.time). GC time: $(stats.gctime)."
return LocalQuantumTaskBatch([local_result.task_metadata.id for local_result in results], results)
return LocalQuantumTaskBatch([local_result.result.task_metadata.id for local_result in results], results)
end
(d::LocalSimulator)(args...; kwargs...) = simulate(d, args...; kwargs...)

Expand Down
6 changes: 3 additions & 3 deletions src/schemas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ qubit_count(ixs::Vector{Instruction}) = length(qubits(ixs))

bind_value!(ix::Instruction{O}, param_values::Dict{Symbol, Number}) where {O<:Operator} = Instruction{O}(bind_value!(ix.operator, param_values), ix.target)

remap(ix::Instruction{O}, mapping::Dict{<:Integer, <:Integer}) where {O<:Operator} = Instruction{O}(copy(ix.operator), [mapping[q] for q in ix.target])
remap(ix::Instruction{O}, target::VecOrQubitSet) where {O<:Operator} = Instruction{O}(copy(ix.operator), target[1:length(ix.target)])
remap(ix::Instruction{O}, target::IntOrQubit) where {O<:Operator} = Instruction{O}(copy(ix.operator), target)
remap(@nospecialize(ix::Instruction{O}), mapping::Dict{<:Integer, <:Integer}) where {O} = Instruction{O}(copy(ix.operator), [mapping[q] for q in ix.target])
remap(@nospecialize(ix::Instruction{O}), target::VecOrQubitSet) where {O} = Instruction{O}(copy(ix.operator), target[1:length(ix.target)])
remap(@nospecialize(ix::Instruction{O}), target::IntOrQubit) where {O} = Instruction{O}(copy(ix.operator), target)

function StructTypes.constructfrom(::Type{Program}, obj)
new_obj = copy(obj)
Expand Down

0 comments on commit 06f2594

Please sign in to comment.