diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 778521d00..b306004dd 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -14,6 +14,7 @@ from subprocess import Popen from types import FrameType from typing import Literal +from typing import MutableMapping from typing import overload from typing import TYPE_CHECKING @@ -233,7 +234,7 @@ def submit( request_data: bool = False, logging_level: int | None = None, max_logging_depth: int = -1, - data: dict[str, Any] | None = None, + data: MutableMapping[str, Any] | None = None, ) -> uuid.UUID: """ Submit a compilation job to the Compiler. @@ -271,7 +272,7 @@ def submit( task.logging_level = logging_level or self._discover_lowest_log_level() task.max_logging_depth = max_logging_depth if data is not None: - task.data = data + task.data.update(data) # Submit task to runtime self._send(RuntimeMessage.SUBMIT, task) @@ -306,7 +307,7 @@ def compile( request_data: Literal[False] = ..., logging_level: int | None = ..., max_logging_depth: int = ..., - data: dict[str, Any] | None = ..., + data: MutableMapping[str, Any] | None = ..., ) -> Circuit: ... @@ -318,7 +319,7 @@ def compile( request_data: Literal[True], logging_level: int | None = ..., max_logging_depth: int = ..., - data: dict[str, Any] | None = ..., + data: MutableMapping[str, Any] | None = ..., ) -> tuple[Circuit, PassData]: ... @@ -330,7 +331,7 @@ def compile( request_data: bool, logging_level: int | None = ..., max_logging_depth: int = ..., - data: dict[str, Any] | None = ..., + data: MutableMapping[str, Any] | None = ..., ) -> Circuit | tuple[Circuit, PassData]: ... @@ -341,7 +342,7 @@ def compile( request_data: bool = False, logging_level: int | None = None, max_logging_depth: int = -1, - data: dict[str, Any] | None = None, + data: MutableMapping[str, Any] | None = None, ) -> Circuit | tuple[Circuit, PassData]: """Submit a task, wait for its results; see :func:`submit` for more.""" task_id = self.submit( diff --git a/bqskit/compiler/workflow.py b/bqskit/compiler/workflow.py index 65399ff87..9a0ad2677 100644 --- a/bqskit/compiler/workflow.py +++ b/bqskit/compiler/workflow.py @@ -3,7 +3,6 @@ import copy import logging -import dill from typing import Iterable from typing import Iterator from typing import overload @@ -11,6 +10,8 @@ from typing import TYPE_CHECKING from typing import Union +import dill + from bqskit.compiler.basepass import BasePass from bqskit.utils.random import seed_random_sources from bqskit.utils.typing import is_iterable @@ -40,7 +41,7 @@ def __init__(self, passes: WorkflowLike, name: str = '') -> None: """ if isinstance(passes, Workflow): self._passes: list[BasePass] = copy.deepcopy(passes._passes) - self._name = copy.deepcopy(passes._name) if name == '' else name + self._name: str = name if name else copy.deepcopy(passes._name) return if isinstance(passes, BasePass): diff --git a/bqskit/ir/circuit.py b/bqskit/ir/circuit.py index 2fd69f846..1c628125c 100644 --- a/bqskit/ir/circuit.py +++ b/bqskit/ir/circuit.py @@ -3,10 +3,10 @@ import copy import logging -import warnings import pickle -import dill +import warnings from typing import Any +from typing import Callable from typing import cast from typing import Collection from typing import Dict @@ -20,6 +20,7 @@ from typing import Tuple from typing import TYPE_CHECKING +import dill import numpy as np import numpy.typing as npt @@ -3243,9 +3244,15 @@ def from_operation(op: Operation) -> Circuit: circuit.append_gate(op.gate, list(range(circuit.num_qudits)), op.params) return circuit - def __reduce__(self): + def __reduce__(self) -> tuple[ + Callable[ + [int, tuple[int, ...], list[tuple[bool, bytes]], bytes], + Circuit, + ], + tuple[int, tuple[int, ...], list[tuple[bool, bytes]], bytes], + ]: """Return the pickle state of the circuit.""" - serialized_gates = [] + serialized_gates: list[tuple[bool, bytes]] = [] gate_table = {} for gate in self.gate_set: gate_table[gate] = len(serialized_gates) @@ -3254,7 +3261,7 @@ def __reduce__(self): else: serialized_gates.append((True, dill.dumps(gate, recurse=True))) - cycles = [] + cycles: list[list[tuple[int, tuple[int, ...], list[float]]]] = [] last_cycle = -1 for cycle, op in self.operations_with_cycles(): @@ -3265,7 +3272,7 @@ def __reduce__(self): marshalled_op = ( gate_table[op.gate], op.location._location, - op.params + op.params, ) cycles[-1].append(marshalled_op) @@ -3280,7 +3287,12 @@ def __reduce__(self): # endregion -def rebuild_circuit(num_qudits, radixes, serialized_gates, serialized_cycles) -> Circuit: +def rebuild_circuit( + num_qudits: int, + radixes: tuple[int, ...], + serialized_gates: list[tuple[bool, bytes]], + serialized_cycles: bytes, +) -> Circuit: """Rebuild a circuit from a pickle state.""" circuit = Circuit(num_qudits, radixes) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 3a7072273..f4897b5ec 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -556,10 +556,13 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: """Schedule tasks between this node's employees.""" if len(tasks) == 0: return - assignments = self.assign_tasks(tasks) - - # for e, assignment in sorted(zip(self.employees, assignments), key=lambda x: x[0].num_idle_workers, reverse=True): - for e, assignment in zip(self.employees, assignments): + assignments = zip(self.employees, self.assign_tasks(tasks)) + sorted_assignments = sorted( + assignments, + key=lambda x: x[0].num_idle_workers, + reverse=True, + ) + for e, assignment in sorted_assignments: num_tasks = len(assignment) if num_tasks == 0: diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 02d2bf25b..1c0a9fdf6 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -30,6 +30,7 @@ from bqskit.runtime.result import RuntimeResult from bqskit.runtime.task import RuntimeTask + def listen(server: DetachedServer, port: int) -> None: """Listening thread listens for client connections.""" listener = Listener(('0.0.0.0', port)) @@ -137,8 +138,9 @@ def handle_message( if path not in sys.path: sys.path.append(path) for employee in self.employees: - employee.conn.send((RuntimeMessage.IMPORTPATH, path)) - + employee.conn.send( + (RuntimeMessage.IMPORTPATH, path), + ) elif msg == RuntimeMessage.DISCONNECT: self.handle_disconnect(conn) diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index f36d540dc..ccffa3b9a 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -3,10 +3,11 @@ import inspect import logging -import dill from typing import Any from typing import Coroutine +import dill + from bqskit.runtime.address import RuntimeAddress @@ -41,7 +42,7 @@ def __init__( self.task_id = RuntimeTask.task_counter self.serialized_fnargs = dill.dumps(fnargs) - self._fnargs = None + self._fnargs: tuple[Any, Any, Any] | None = None self._name = fnargs[0].__name__ """Tuple of function pointer, arguments, and keyword arguments.""" @@ -84,6 +85,7 @@ def fnargs(self) -> tuple[Any, Any, Any]: """Return the function pointer, arguments, and keyword arguments.""" if self._fnargs is None: self._fnargs = dill.loads(self.serialized_fnargs) + assert self._fnargs is not None # for type checker return self._fnargs def step(self, send_val: Any = None) -> Any: diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 575d10864..7c4c8434d 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -358,8 +358,11 @@ def _handle_result(self, result: RuntimeResult) -> None: task = self._tasks[box.dest_addr] if task.wake_on_next or box.ready: + # print(f'Worker {self._id} is waking task + # {task.return_address}, with {task.wake_on_next=}, + # {box.ready=}') self._ready_task_ids.put(box.dest_addr) # Wake it - box.dest_addr = None # Prevent double wake + box.dest_addr = None # Prevent double wake def _handle_cancel(self, addr: RuntimeAddress) -> None: """ @@ -394,7 +397,7 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: if not t.is_descendant_of(addr) ] - def _get_next_ready_task(self) -> RuntimeTask: + def _get_next_ready_task(self) -> RuntimeTask | None: """Return the next ready task if one exists, otherwise block.""" while True: if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: @@ -487,6 +490,8 @@ def _process_await(self, task: RuntimeTask, future: RuntimeFuture) -> None: # # raise RuntimeError(m) # task.wake_on_next = True task.wake_on_next = future._next_flag + # print(f'Worker {self._id} is waiting on task + # {task.return_address}, with {task.wake_on_next=}') if box.ready: self._ready_task_ids.put(task.return_address) @@ -497,7 +502,8 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: packaged_result = RuntimeResult(task.return_address, result, self._id) if task.return_address not in self._tasks: - print(f'Task was cancelled: {task.return_address}, {task.fnargs[0].__name__}') + # print(f'Task was cancelled: {task.return_address}, + # {task.fnargs[0].__name__}') return if task.return_address.worker_id == self._id: diff --git a/setup.py b/setup.py index 298874919..a546d6267 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ 'numpy>=1.22.0', 'scipy>=1.8.0', 'typing-extensions>=4.0.0', - 'dill>=0.3.8' + 'dill>=0.3.8', ], python_requires='>=3.8, <4', entry_points={