Skip to content

Commit

Permalink
Use a ProcessPoolExecutor to talk to Julia (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws authored Aug 21, 2024
1 parent 1d6791f commit 1ac470f
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 115 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,28 @@ on:
jobs:
build:
runs-on: ${{ matrix.os }}
timeout-minutes: 10
timeout-minutes: 30
strategy:
# don't run all the jobs at once to avoid wasting CI
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v4
- name: Install juliaup
uses: julia-actions/install-juliaup@v2
with:
channel: '1'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install tox
pip install tox juliacall
pip install -e . # to put juliapkg.json in sys.path
python -c 'import juliacall' # force install of all deps
- name: Run Tests
run: |
tox -e unit-tests
Expand Down
2 changes: 1 addition & 1 deletion src/braket/juliapkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"packages": {
"BraketSimulator": {
"uuid": "76d27892-9a0b-406c-98e4-7c178e9b3dff",
"version": "0.0.2"
"rev": "ksh/precompile"
},
"JSON3": {
"uuid": "0f8b85d8-7281-11e9-16c2-39a750bddbf1",
Expand Down
15 changes: 1 addition & 14 deletions src/braket/simulator_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
from braket.ir.openqasm import Program as OpenQASMProgram

from braket.simulator_v2.density_matrix_simulator_v2 import ( # noqa: F401
DensityMatrixSimulatorV2,
)
from braket.simulator_v2.julia_import import jl, jlBraketSimulator # noqa: F401
from braket.simulator_v2.julia_import import setup_julia # noqa: F401
from braket.simulator_v2.state_vector_simulator_v2 import ( # noqa: F401
StateVectorSimulatorV2,
)

from ._version import __version__ # noqa: F401

payload = OpenQASMProgram(
source="""
OPENQASM 3.0;
qubit[1] q;
h q[0];
#pragma braket result state_vector
"""
)
StateVectorSimulatorV2().run_openqasm(payload)
StateVectorSimulatorV2().run_multiple([payload, payload])
180 changes: 123 additions & 57 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,121 @@
import sys
from collections.abc import Sequence
from typing import Optional, Union
from concurrent.futures import ProcessPoolExecutor, wait
from typing import List, Optional, Union

import numpy as np
from braket.default_simulator.simulator import BaseLocalSimulator
from braket.ir.jaqcd import DensityMatrix, Probability, StateVector
from braket.ir.openqasm import Program as OpenQASMProgram
from braket.task_result import GateModelTaskResult
from juliacall import JuliaError

from braket.simulator_v2.julia_import import jl
from braket.simulator_v2.julia_import import setup_julia


def _handle_julia_error(error):
# we don't import `JuliaError` explicitly here to avoid
# having to import juliacall on the main thread. we need
# to call *this* function on that thread in case getting
# the result from the submitted Future raises an exception
if type(error).__name__ == "JuliaError":
python_exception = getattr(error.exception, "alternate_type", None)
if python_exception is None:
py_error = error
else:
class_val = getattr(sys.modules["builtins"], str(python_exception))
py_error = class_val(str(error.exception.message))
raise py_error
else:
raise error


def translate_and_run(
device_id: str, openqasm_ir: OpenQASMProgram, shots: int = 0
) -> str:
jl = setup_julia()
jl_shots = shots
jl_inputs = (
jl.Dict[jl.String, jl.Any](
jl.Pair(jl.convert(jl.String, k), jl.convert(jl.Any, v))
for (k, v) in openqasm_ir.inputs.items()
)
if openqasm_ir.inputs
else jl.Dict[jl.String, jl.Any]()
)
if device_id == "braket_sv_v2":
device = jl.BraketSimulator.StateVectorSimulator(0, 0)
elif device_id == "braket_dm_v2":
device = jl.BraketSimulator.DensityMatrixSimulator(0, 0)

try:
result = jl.BraketSimulator.simulate._jl_call_nogil(
device,
openqasm_ir.source,
jl_inputs,
jl_shots,
)
py_result = str(result)
return py_result
except Exception as e:
_handle_julia_error(e)


def translate_and_run_multiple(
device_id: str,
programs: Sequence[OpenQASMProgram],
shots: Optional[int] = 0,
inputs: Optional[Union[dict, Sequence[dict]]] = {},
) -> List[str]:
jl = setup_julia()
irs = jl.Vector[jl.String]()
is_single_input = isinstance(inputs, dict) or len(inputs) == 1
py_inputs = {}
if (is_single_input and isinstance(inputs, dict)) or not is_single_input:
py_inputs = [inputs.copy() for _ in range(len(programs))]
elif is_single_input and not isinstance(inputs, dict):
py_inputs = [inputs[0].copy() for _ in range(len(programs))]
else:
py_inputs = inputs
jl_inputs = jl.Vector[jl.Dict[jl.String, jl.Any]]()
for p_ix, program in enumerate(programs):
irs.append(program.source)
if program.inputs:
jl_inputs.append(program.inputs | py_inputs[p_ix])
else:
jl_inputs.append(py_inputs[p_ix])

if device_id == "braket_sv_v2":
device = jl.BraketSimulator.StateVectorSimulator(0, 0)
elif device_id == "braket_dm_v2":
device = jl.BraketSimulator.DensityMatrixSimulator(0, 0)

try:
results = jl.BraketSimulator.simulate._jl_call_nogil(
device,
irs,
jl_inputs,
shots,
)
py_results = [str(result) for result in results]
except Exception as e:
_handle_julia_error(e)
return py_results


class BaseLocalSimulatorV2(BaseLocalSimulator):
def __init__(self, device):
def __init__(self, device: str):
self._device = device
executor = ProcessPoolExecutor(max_workers=1, initializer=setup_julia)
def no_op():
pass
# trigger worker creation here, because workers are created
# on an as-needed basis, *not* when the executor is created
f = executor.submit(no_op)
wait([f])
self._executor = executor

def __del__(self):
self._executor.shutdown(wait=False)

def initialize_simulation(self, **kwargs):
return
Expand All @@ -40,20 +141,15 @@ def run_openqasm(
as a result type when shots=0. Or, if StateVector and Amplitude result types
are requested when shots>0.
"""
f = self._executor.submit(
translate_and_run,
self._device,
openqasm_ir,
shots,
)
try:
jl_shots = shots
jl_inputs = (
jl.Dict[jl.String, jl.Any](
jl.Pair(jl.convert(jl.String, k), jl.convert(jl.Any, v))
for (k, v) in openqasm_ir.inputs.items()
)
if openqasm_ir.inputs
else jl.Dict[jl.String, jl.Any]()
)
jl_result = jl.BraketSimulator.simulate._jl_call_nogil(
self._device, openqasm_ir.source, jl_inputs, jl_shots
)
except JuliaError as e:
jl_result = f.result()
except Exception as e:
_handle_julia_error(e)

result = GateModelTaskResult.parse_raw_schema(jl_result)
Expand Down Expand Up @@ -85,34 +181,18 @@ def run_multiple(
list[GateModelTaskResult]: A list of result objects, with the ith object being
the result of the ith program.
"""
f = self._executor.submit(
translate_and_run_multiple,
self._device,
programs,
shots,
inputs,
)
try:
irs = jl.Vector[jl.String]()
is_single_input = isinstance(inputs, dict) or len(inputs) == 1
py_inputs = {}
if (is_single_input and isinstance(inputs, dict)) or not is_single_input:
py_inputs = [inputs.copy() for _ in range(len(programs))]
elif is_single_input and not isinstance(inputs, dict):
py_inputs = [inputs[0].copy() for _ in range(len(programs))]
else:
py_inputs = inputs
jl_inputs = jl.Vector[jl.Dict[jl.String, jl.Any]]()
for p_ix, program in enumerate(programs):
irs.append(program.source)
if program.inputs:
jl_inputs.append(program.inputs | py_inputs[p_ix])
else:
jl_inputs.append(py_inputs[p_ix])

jl_results = jl.BraketSimulator.simulate._jl_call_nogil(
self._device,
irs,
jl_inputs,
shots,
max_parallel=jl.convert(jl.Int, max_parallel),
)

except JuliaError as e:
jl_results = f.result()
except Exception as e:
_handle_julia_error(e)

results = [
GateModelTaskResult.parse_raw_schema(jl_result) for jl_result in jl_results
]
Expand Down Expand Up @@ -166,17 +246,3 @@ def reconstruct_complex(v):
task_result.resultTypes[result_ind].value = np.asarray(val)

return task_result


def _handle_julia_error(julia_error: JuliaError):
try:
print(julia_error)
python_exception = getattr(julia_error.exception, "alternate_type", None)
if python_exception is None:
error = julia_error
else:
class_val = getattr(sys.modules["builtins"], str(python_exception))
error = class_val(julia_error.exception.message)
except Exception:
raise julia_error
raise error
3 changes: 1 addition & 2 deletions src/braket/simulator_v2/density_matrix_simulator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
)

from braket.simulator_v2.base_simulator_v2 import BaseLocalSimulatorV2
from braket.simulator_v2.julia_import import jlBraketSimulator


class DensityMatrixSimulatorV2(BaseLocalSimulatorV2):
Expand All @@ -19,7 +18,7 @@ class DensityMatrixSimulatorV2(BaseLocalSimulatorV2):
DEVICE_ID = "braket_dm_v2"

def __init__(self):
super().__init__(jlBraketSimulator.DensityMatrixSimulator(0, 0))
super().__init__(self.DEVICE_ID)

@property
def properties(self) -> GateModelSimulatorDeviceCapabilities:
Expand Down
Loading

0 comments on commit 1ac470f

Please sign in to comment.