Skip to content

Commit

Permalink
pre-commit (ish)
Browse files Browse the repository at this point in the history
  • Loading branch information
edyounis committed Apr 2, 2024
1 parent 72c073d commit 6787100
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 27 deletions.
13 changes: 7 additions & 6 deletions bqskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from subprocess import Popen
from types import FrameType
from typing import Literal
from typing import MutableMapping
from typing import overload
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -233,7 +234,7 @@ def submit(
request_data: bool = False,
logging_level: int | None = None,
max_logging_depth: int = -1,
data: dict[str, Any] | None = None,
data: MutableMapping[str, Any] | None = None,
) -> uuid.UUID:
"""
Submit a compilation job to the Compiler.
Expand Down Expand Up @@ -271,7 +272,7 @@ def submit(
task.logging_level = logging_level or self._discover_lowest_log_level()
task.max_logging_depth = max_logging_depth
if data is not None:
task.data = data
task.data.update(data)

# Submit task to runtime
self._send(RuntimeMessage.SUBMIT, task)
Expand Down Expand Up @@ -306,7 +307,7 @@ def compile(
request_data: Literal[False] = ...,
logging_level: int | None = ...,
max_logging_depth: int = ...,
data: dict[str, Any] | None = ...,
data: MutableMapping[str, Any] | None = ...,
) -> Circuit:
...

Expand All @@ -318,7 +319,7 @@ def compile(
request_data: Literal[True],
logging_level: int | None = ...,
max_logging_depth: int = ...,
data: dict[str, Any] | None = ...,
data: MutableMapping[str, Any] | None = ...,
) -> tuple[Circuit, PassData]:
...

Expand All @@ -330,7 +331,7 @@ def compile(
request_data: bool,
logging_level: int | None = ...,
max_logging_depth: int = ...,
data: dict[str, Any] | None = ...,
data: MutableMapping[str, Any] | None = ...,
) -> Circuit | tuple[Circuit, PassData]:
...

Expand All @@ -341,7 +342,7 @@ def compile(
request_data: bool = False,
logging_level: int | None = None,
max_logging_depth: int = -1,
data: dict[str, Any] | None = None,
data: MutableMapping[str, Any] | None = None,
) -> Circuit | tuple[Circuit, PassData]:
"""Submit a task, wait for its results; see :func:`submit` for more."""
task_id = self.submit(
Expand Down
5 changes: 3 additions & 2 deletions bqskit/compiler/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import copy
import logging
import dill
from typing import Iterable
from typing import Iterator
from typing import overload
from typing import Sequence
from typing import TYPE_CHECKING
from typing import Union

import dill

from bqskit.compiler.basepass import BasePass
from bqskit.utils.random import seed_random_sources
from bqskit.utils.typing import is_iterable
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, passes: WorkflowLike, name: str = '') -> None:
"""
if isinstance(passes, Workflow):
self._passes: list[BasePass] = copy.deepcopy(passes._passes)
self._name = copy.deepcopy(passes._name) if name == '' else name
self._name: str = name if name else copy.deepcopy(passes._name)
return

if isinstance(passes, BasePass):
Expand Down
26 changes: 19 additions & 7 deletions bqskit/ir/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import copy
import logging
import warnings
import pickle
import dill
import warnings
from typing import Any
from typing import Callable
from typing import cast
from typing import Collection
from typing import Dict
Expand All @@ -20,6 +20,7 @@
from typing import Tuple
from typing import TYPE_CHECKING

import dill
import numpy as np
import numpy.typing as npt

Expand Down Expand Up @@ -3243,9 +3244,15 @@ def from_operation(op: Operation) -> Circuit:
circuit.append_gate(op.gate, list(range(circuit.num_qudits)), op.params)
return circuit

def __reduce__(self):
def __reduce__(self) -> tuple[
Callable[
[int, tuple[int, ...], list[tuple[bool, bytes]], bytes],
Circuit,
],
tuple[int, tuple[int, ...], list[tuple[bool, bytes]], bytes],
]:
"""Return the pickle state of the circuit."""
serialized_gates = []
serialized_gates: list[tuple[bool, bytes]] = []
gate_table = {}
for gate in self.gate_set:
gate_table[gate] = len(serialized_gates)
Expand All @@ -3254,7 +3261,7 @@ def __reduce__(self):
else:
serialized_gates.append((True, dill.dumps(gate, recurse=True)))

cycles = []
cycles: list[list[tuple[int, tuple[int, ...], list[float]]]] = []
last_cycle = -1
for cycle, op in self.operations_with_cycles():

Expand All @@ -3265,7 +3272,7 @@ def __reduce__(self):
marshalled_op = (
gate_table[op.gate],
op.location._location,
op.params
op.params,
)
cycles[-1].append(marshalled_op)

Expand All @@ -3280,7 +3287,12 @@ def __reduce__(self):
# endregion


def rebuild_circuit(num_qudits, radixes, serialized_gates, serialized_cycles) -> Circuit:
def rebuild_circuit(
num_qudits: int,
radixes: tuple[int, ...],
serialized_gates: list[tuple[bool, bytes]],
serialized_cycles: bytes,
) -> Circuit:
"""Rebuild a circuit from a pickle state."""
circuit = Circuit(num_qudits, radixes)

Expand Down
11 changes: 7 additions & 4 deletions bqskit/runtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,13 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None:
"""Schedule tasks between this node's employees."""
if len(tasks) == 0:
return
assignments = self.assign_tasks(tasks)

# for e, assignment in sorted(zip(self.employees, assignments), key=lambda x: x[0].num_idle_workers, reverse=True):
for e, assignment in zip(self.employees, assignments):
assignments = zip(self.employees, self.assign_tasks(tasks))
sorted_assignments = sorted(
assignments,
key=lambda x: x[0].num_idle_workers,
reverse=True,
)
for e, assignment in sorted_assignments:
num_tasks = len(assignment)

if num_tasks == 0:
Expand Down
6 changes: 4 additions & 2 deletions bqskit/runtime/detached.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from bqskit.runtime.result import RuntimeResult
from bqskit.runtime.task import RuntimeTask


def listen(server: DetachedServer, port: int) -> None:
"""Listening thread listens for client connections."""
listener = Listener(('0.0.0.0', port))
Expand Down Expand Up @@ -137,8 +138,9 @@ def handle_message(
if path not in sys.path:
sys.path.append(path)
for employee in self.employees:
employee.conn.send((RuntimeMessage.IMPORTPATH, path))

employee.conn.send(
(RuntimeMessage.IMPORTPATH, path),
)

elif msg == RuntimeMessage.DISCONNECT:
self.handle_disconnect(conn)
Expand Down
6 changes: 4 additions & 2 deletions bqskit/runtime/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import inspect
import logging
import dill
from typing import Any
from typing import Coroutine

import dill

from bqskit.runtime.address import RuntimeAddress


Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(
self.task_id = RuntimeTask.task_counter

self.serialized_fnargs = dill.dumps(fnargs)
self._fnargs = None
self._fnargs: tuple[Any, Any, Any] | None = None
self._name = fnargs[0].__name__
"""Tuple of function pointer, arguments, and keyword arguments."""

Expand Down Expand Up @@ -84,6 +85,7 @@ def fnargs(self) -> tuple[Any, Any, Any]:
"""Return the function pointer, arguments, and keyword arguments."""
if self._fnargs is None:
self._fnargs = dill.loads(self.serialized_fnargs)
assert self._fnargs is not None # for type checker
return self._fnargs

def step(self, send_val: Any = None) -> Any:
Expand Down
12 changes: 9 additions & 3 deletions bqskit/runtime/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,11 @@ def _handle_result(self, result: RuntimeResult) -> None:
task = self._tasks[box.dest_addr]

if task.wake_on_next or box.ready:
# print(f'Worker {self._id} is waking task
# {task.return_address}, with {task.wake_on_next=},
# {box.ready=}')
self._ready_task_ids.put(box.dest_addr) # Wake it
box.dest_addr = None # Prevent double wake
box.dest_addr = None # Prevent double wake

def _handle_cancel(self, addr: RuntimeAddress) -> None:
"""
Expand Down Expand Up @@ -394,7 +397,7 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None:
if not t.is_descendant_of(addr)
]

def _get_next_ready_task(self) -> RuntimeTask:
def _get_next_ready_task(self) -> RuntimeTask | None:
"""Return the next ready task if one exists, otherwise block."""
while True:
if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0:
Expand Down Expand Up @@ -487,6 +490,8 @@ def _process_await(self, task: RuntimeTask, future: RuntimeFuture) -> None:
# # raise RuntimeError(m)
# task.wake_on_next = True
task.wake_on_next = future._next_flag
# print(f'Worker {self._id} is waiting on task
# {task.return_address}, with {task.wake_on_next=}')

if box.ready:
self._ready_task_ids.put(task.return_address)
Expand All @@ -497,7 +502,8 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None:
packaged_result = RuntimeResult(task.return_address, result, self._id)

if task.return_address not in self._tasks:
print(f'Task was cancelled: {task.return_address}, {task.fnargs[0].__name__}')
# print(f'Task was cancelled: {task.return_address},
# {task.fnargs[0].__name__}')
return

if task.return_address.worker_id == self._id:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
'numpy>=1.22.0',
'scipy>=1.8.0',
'typing-extensions>=4.0.0',
'dill>=0.3.8'
'dill>=0.3.8',
],
python_requires='>=3.8, <4',
entry_points={
Expand Down

0 comments on commit 6787100

Please sign in to comment.