Skip to content

Commit

Permalink
Adjoint support in Braket.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws committed Dec 8, 2022
1 parent 3d68d4f commit 174cb5d
Show file tree
Hide file tree
Showing 36 changed files with 1,399 additions and 308 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Mocking = "78c3b35d-d492-501b-9361-3d52fe80e533"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 1 addition & 1 deletion PyBraket/CondaPkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ scipy = ""
numpy = ""

[pip.deps]
amazon-braket-sdk = ">=1.33.0.dev0"
amazon-braket-sdk = ">=1.35.0"
2 changes: 2 additions & 0 deletions PyBraket/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MicroMamba = "0b3b1443-0f03-428d-bdfb-f27f9c1191ea"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"

[compat]
Expand All @@ -30,4 +31,5 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
20 changes: 11 additions & 9 deletions PyBraket/src/pycircuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,20 @@ end
PyCircuit(c::Circuit) = convert(PyCircuit, c)

Py(o::Braket.Observables.HermitianObservable) = braketobs.Hermitian(Py(o.matrix))
Py(o::Braket.Observables.TensorProduct) = braketobs.TensorProduct(pylist(map(Py, o.factors)))
Py(o::Expectation) = circuit.result_types.Expectation(Py(o.observable), pylist(o.targets))
Py(o::Variance) = circuit.result_types.Variance(Py(o.observable), pylist(o.targets))
Py(o::Sample) = circuit.result_types.Sample(Py(o.observable), pylist(o.targets))
Py(o::Probability) = circuit.result_types.Probability(pylist(o.targets))
Py(o::DensityMatrix) = circuit.result_types.DensityMatrix(pylist(o.targets))
Py(o::Amplitude) = circuit.result_types.Amplitude(pylist(map(s->Py(s), o.states)))
Py(o::StateVector) = circuit.result_types.StateVector()
Py(o::Braket.Observables.TensorProduct) = o.coefficient * braketobs.TensorProduct(pylist(map(Py, o.factors)))
Py(o::Braket.Observables.Sum) = braketobs.Sum(pylist(map(Py, o.summands)))
Py(o::AdjointGradient) = circuit.result_types.AdjointGradient(Py(o.observable), pylist(o.target), pylist(o.parameters))
Py(o::Expectation) = circuit.result_types.Expectation(Py(o.observable), pylist(o.targets))
Py(o::Variance) = circuit.result_types.Variance(Py(o.observable), pylist(o.targets))
Py(o::Sample) = circuit.result_types.Sample(Py(o.observable), pylist(o.targets))
Py(o::Probability) = circuit.result_types.Probability(pylist(o.targets))
Py(o::DensityMatrix) = circuit.result_types.DensityMatrix(pylist(o.targets))
Py(o::Amplitude) = circuit.result_types.Amplitude(pylist(map(s->Py(s), o.states)))
Py(o::StateVector) = circuit.result_types.StateVector()

for (typ, py_typ, label) in ((:(Braket.Observables.H), :H, "h"), (:(Braket.Observables.X), :X, "x"), (:(Braket.Observables.Y), :Y, "y"), (:(Braket.Observables.Z), :Z, "z"), (:(Braket.Observables.I), :I, "i"))
@eval begin
Py(o::$typ) = braketobs.$py_typ()
Py(o::$typ) = o.coefficient * braketobs.$py_typ()
end
end

Expand Down
40 changes: 29 additions & 11 deletions PyBraket/src/pyschema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module PySchema

using PythonCall, Braket, Braket.IR
import Braket: DwaveTiming, GateModelQpuParadigmProperties, ResultTypeValue, OqcDeviceCapabilities, GateFidelity2Q, OneQubitProperties, AdditionalMetadata, DwaveAdvantageDeviceLevelParameters, IonqDeviceCapabilities, TwoQubitProperties, IonqDeviceParameters, NativeQuilMetadata, QueraDeviceCapabilities, ExecutionDay, StandardizedGateModelQpuDeviceProperties, XanaduDeviceCapabilities, DeviceExecutionWindow, DwaveDeviceParameters, CoherenceTime, DwaveMetadata, Frame, DeviceCost, Dwave2000QDeviceLevelParameters, FidelityType, QueraMetadata, OqcProviderProperties, DeviceServiceProperties, XanaduDeviceParameters, BlackbirdProgram, Geometry, Fidelity1Q, AnnealingTaskResult, AnalogHamiltonianSimulationShotResult, RigettiDeviceCapabilities, DwaveProviderProperties, DeviceActionProperties, DwaveProviderLevelParameters, JaqcdDeviceActionProperties, Direction, PulseDeviceActionProperties, BlackbirdDeviceActionProperties, PerformanceLattice, PerformanceRydberg, DeviceActionType, IonqProviderProperties, PersistedJobDataFormat, PhotonicModelTaskResult, DeviceDocumentation, XanaduProviderProperties, QubitDirection, ContinuousVariableQpuParadigmProperties, PulseFunctionArgument, DwaveAdvantageDeviceParameters, RigettiDeviceParameters, AnalogHamiltonianSimulationTaskResult, TaskMetadata, GateModelSimulatorParadigmProperties, GateModelTaskResult, RigettiProviderProperties, OqcDeviceParameters, DeviceConnectivity, PulseFunction, Rydberg, PerformanceRydbergGlobal, RydbergGlobal, GateModelSimulatorDeviceCapabilities, XanaduMetadata, GateModelParameters, Performance, GateModelSimulatorDeviceParameters, AnalogHamiltonianSimulationShotMeasurement, OqcMetadata, QueraAhsParadigmProperties, Lattice, PersistedJobData, DwaveDeviceCapabilities, ProblemType, RigettiMetadata, SimulatorMetadata, Area, Problem, ResultType, PostProcessingType, ResultFormat, Dwave2000QDeviceParameters, AnalogHamiltonianSimulationShotMetadata, braketSchemaHeader, OpenQasmProgram, Port, OpenQASMDeviceActionProperties, AbstractProgram
import Braket.IR: Z, Sample, CPhaseShift01, PhaseDamping, Rz, GeneralizedAmplitudeDamping, XX, ZZ, PhaseFlip, Vi, Depolarizing, Variance, TwoQubitDepolarizing, DensityMatrix, CPhaseShift00, ECR, CompilerDirective, CCNot, Unitary, BitFlip, Y, Swap, CZ, EndVerbatimBox, Program, CNot, CSwap, Ry, I, Si, AmplitudeDamping, StateVector, ISwap, H, XY, YY, T, TwoQubitDephasing, X, Ti, CV, StartVerbatimBox, PauliChannel, PSwap, Expectation, Probability, PhaseShift, V, CPhaseShift, S, Rx, Kraus, Amplitude, CPhaseShift10, MultiQubitPauliChannel, CY, Setup, Hamiltonian, ShiftingField, AtomArrangement, TimeSeries, PhysicalField, AHSProgram, DrivingField, AbstractProgramResult
import Braket.IR: Z, Sample, CPhaseShift01, PhaseDamping, Rz, GeneralizedAmplitudeDamping, XX, ZZ, PhaseFlip, Vi, Depolarizing, Variance, TwoQubitDepolarizing, DensityMatrix, CPhaseShift00, ECR, CompilerDirective, CCNot, Unitary, BitFlip, Y, Swap, CZ, EndVerbatimBox, Program, CNot, AdjointGradient, CSwap, Ry, I, Si, AmplitudeDamping, StateVector, ISwap, H, XY, YY, T, TwoQubitDephasing, X, Ti, CV, StartVerbatimBox, PauliChannel, PSwap, Expectation, Probability, PhaseShift, V, CPhaseShift, S, Rx, Kraus, Amplitude, CPhaseShift10, MultiQubitPauliChannel, CY, Setup, Hamiltonian, ShiftingField, AtomArrangement, TimeSeries, PhysicalField, AHSProgram, DrivingField, AbstractProgramResult, IRObservable

const instructions = PythonCall.pynew()
const dwave_metadata_v1 = PythonCall.pynew()
Expand Down Expand Up @@ -71,6 +71,12 @@ const shared_models = PythonCall.pynew()
const schema_header = PythonCall.pynew()
const openqasm_device_action_properties = PythonCall.pynew()

function union_convert(::Type{IRObservable}, x)
PythonCall.pyisinstance(x, PythonCall.pybuiltins.str) && return pyconvert(String, x)
x_vec = Union{String, Vector{Vector{Vector{Float64}}}}[PythonCall.pyisinstance(x_, PythonCall.pybuiltins.str) ? pyconvert(String, x_) : pyconvert(Vector{Vector{Vector{Float64}}}, x_) for x_ in x]
return x_vec
end

function union_convert(union_type, x)
union_ts = union_type isa Union ? Type[] : [union_type]
union_t = union_type
Expand Down Expand Up @@ -117,11 +123,8 @@ function jl_convert_attr(n, t, attr)
return pyconvert(t, attr)
end
else
if !PythonCall.pyisnone(attr)
return union_convert(t, attr)
else
return nothing
end
PythonCall.pyisnone(attr) && return nothing
return union_convert(t, attr)
end
end

Expand All @@ -148,16 +151,28 @@ function jl_convert(::Type{T}, x::Py) where {T<:AbstractIR}
end
PythonCall.pyconvert_return(T(args..., pyconvert(String, pygetattr(x, "type"))))
end

function jl_convert(::Type{AbstractProgram}, x::Py)
T = Braket.lookup_type(pyconvert(braketSchemaHeader, pygetattr(x, "braketSchemaHeader")))
return PythonCall.pyconvert_return(pyconvert(T, x))
end
function jl_convert(::Type{AbstractProgramResult}, x::Py)
T_ = pyconvert(String, x.type)
T = Dict("expectation"=>Expectation, "variance"=>Variance, "statevector"=>StateVector, "densitymatrix"=>DensityMatrix, "sample"=>Sample, "amplitude"=>Amplitude, "probability"=>Probability)
T = Dict("adjoint_gradient"=>AdjointGradient, "expectation"=>Expectation, "variance"=>Variance, "statevector"=>StateVector, "densitymatrix"=>DensityMatrix, "sample"=>Sample, "amplitude"=>Amplitude, "probability"=>Probability)
return PythonCall.pyconvert_return(pyconvert(T[T_], x))
end
#=for (irT, pyT) in ((:(Braket.IR.Expectation), :(pyjaqcd.Expectation)),
(:(Braket.IR.Variance), :(pyjaqcd.Variance)),
(:(Braket.IR.Sample), :(pyjaqcd.Sample)),
(:(Braket.IR.Amplitude), :(pyjaqcd.Amplitude)),
(:(Braket.IR.StateVector), :(pyjaqcd.StateVector)),
(:(Braket.IR.Probability), :(pyjaqcd.Probability)),
(:(Braket.IR.DensityMatrix), :(pyjaqcd.DensityMatrix)),
(:(Braket.IR.AdjointGradient), :(pyjaqcd.AdjointGradient)))
@eval begin
Py(o::$irT) = $pyT(;arg_gen(o, fieldnames($irT))...)
end
end
=#

function __init__()
PythonCall.pycopy!(instructions, pyimport("braket.ir.jaqcd.instructions"))
Expand Down Expand Up @@ -363,6 +378,7 @@ function __init__()
PythonCall.pyconvert_add_rule("braket.ir.annealing.problem_v1:ProblemType", ProblemType, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.instructions:ECR", ECR, jl_convert)
PythonCall.pyconvert_add_rule("braket.task_result.rigetti_metadata_v1:RigettiMetadata", RigettiMetadata, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.results:AdjointGradient", AdjointGradient, jl_convert)
PythonCall.pyconvert_add_rule("braket.task_result.simulator_metadata_v1:SimulatorMetadata", SimulatorMetadata, jl_convert)
PythonCall.pyconvert_add_rule("braket.device_schema.quera.quera_ahs_paradigm_properties_v1:Area", Area, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.annealing.problem_v1:Problem", Problem, jl_convert)
Expand All @@ -386,11 +402,12 @@ function __init__()
PythonCall.pyconvert_add_rule("braket.device_schema.openqasm_device_action_properties:OpenQASMDeviceActionProperties", OpenQASMDeviceActionProperties, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.program_v1:Program", Program, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.program_v1:Program", AbstractProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.openqasm.program_v1:Program", OpenQasmProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.program_v1:Program", AbstractProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.openqasm.program_v1:OpenQasmProgram", AbstractProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.openqasm.program_v1:Program", AbstractProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.blackbird.program_v1:Program", BlackbirdProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.blackbird.program_v1:BlackbirdProgram", AbstractProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.blackbird.program_v1:Program", AbstractProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.ahs.program_v1:Program", AHSProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.ahs.program_v1:AHSProgram", AbstractProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.ahs.program_v1:Program", AbstractProgram, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.results:Amplitude", AbstractProgramResult, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.results:Expectation", AbstractProgramResult, jl_convert)
Expand All @@ -399,6 +416,7 @@ function __init__()
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.results:StateVector", AbstractProgramResult, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.results:DensityMatrix", AbstractProgramResult, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.results:Variance", AbstractProgramResult, jl_convert)
PythonCall.pyconvert_add_rule("braket.ir.jaqcd.results:AdjointGradient", AbstractProgramResult, jl_convert)

end

Expand Down
2 changes: 1 addition & 1 deletion PyBraket/test/CondaPkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ scipy = ""
numpy = ""

[pip.deps]
amazon-braket-sdk = ">=1.33.0.dev0"
amazon-braket-sdk = ">=1.35.0"
1 change: 1 addition & 0 deletions PyBraket/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PyBraket = "e85266a6-1825-490b-a80e-9b9469c53660"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
Loading

0 comments on commit 174cb5d

Please sign in to comment.