Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointing Large Runs #239

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions bqskit/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -227,6 +227,7 @@
from bqskit.passes.group import PassGroup
from bqskit.passes.io.checkpoint import LoadCheckpointPass
from bqskit.passes.io.checkpoint import SaveCheckpointPass
from bqskit.passes.io.intermediate import CheckpointRestartPass
from bqskit.passes.io.intermediate import RestoreIntermediatePass
from bqskit.passes.io.intermediate import SaveIntermediatePass
from bqskit.passes.mapping.apply import ApplyPlacement
@@ -344,6 +345,7 @@
'ParallelDo',
'LoadCheckpointPass',
'SaveCheckpointPass',
'CheckpointRestartPass',
'SaveIntermediatePass',
'RestoreIntermediatePass',
'GroupSingleQuditGatePass',
30 changes: 26 additions & 4 deletions bqskit/passes/control/foreach.py
Original file line number Diff line number Diff line change
@@ -63,6 +63,7 @@ def __init__(
collection_filter: Callable[[Operation], bool] | None = None,
replace_filter: ReplaceFilterFn | str = 'always',
batch_size: int | None = None,
blocks_to_run: list[int] = [],
) -> None:
"""
Construct a ForEachBlockPass.
@@ -127,6 +128,11 @@ def __init__(
Defaults to 'always'. #TODO: address importability
batch_size (int): (Deprecated).
blocks_to_run (List[int]):
A list of blocks to run the ForEachBlockPass body on. By default
you run on all blocks. This is mainly used with checkpointing,
where some blocks have already finished while others have not.
"""
if batch_size is not None:
import warnings
@@ -140,7 +146,7 @@ def __init__(
self.collection_filter = collection_filter or default_collection_filter
self.replace_filter = replace_filter or default_replace_filter
self.workflow = Workflow(loop_body)

self.blocks_to_run = sorted(blocks_to_run)
if not callable(self.collection_filter):
raise TypeError(
'Expected callable method that maps Operations to booleans for'
@@ -171,9 +177,20 @@ async def run(self, circuit: Circuit, data: PassData) -> None:

# Collect blocks
blocks: list[tuple[int, Operation]] = []
for cycle, op in circuit.operations_with_cycles():
if self.collection_filter(op):
blocks.append((cycle, op))
if (len(self.blocks_to_run) == 0):
self.blocks_to_run = list(range(circuit.num_operations))

block_ids = self.blocks_to_run.copy()
next_id = block_ids.pop(0)
for i, (cycle, op) in enumerate(circuit.operations_with_cycles()):
if i == next_id:
if self.collection_filter(op):
blocks.append((cycle, op))
if len(block_ids) > 0:
next_id = block_ids.pop(0)
else:
# No more blocks to run on
break

# No blocks, no work
if len(blocks) == 0:
@@ -212,6 +229,11 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
block_data['model'] = submodel
block_data['point'] = CircuitPoint(cycle, op.location[0])
block_data['calculate_error_bound'] = self.calculate_error_bound
# Need to zero pad block ids for consistency
num_digits = len(str(circuit.num_operations))
block_data['block_num'] = str(
self.blocks_to_run[i],
).zfill(num_digits)
for key in data:
if key.startswith(self.pass_down_key_prefix):
block_data[key] = data[key]
2 changes: 2 additions & 0 deletions bqskit/passes/io/__init__.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,12 @@

from bqskit.passes.io.checkpoint import LoadCheckpointPass
from bqskit.passes.io.checkpoint import SaveCheckpointPass
from bqskit.passes.io.intermediate import CheckpointRestartPass
from bqskit.passes.io.intermediate import RestoreIntermediatePass
from bqskit.passes.io.intermediate import SaveIntermediatePass

__all__ = [
'CheckpointRestartPass',
'LoadCheckpointPass',
'SaveCheckpointPass',
'SaveIntermediatePass',
170 changes: 124 additions & 46 deletions bqskit/passes/io/intermediate.py
Original file line number Diff line number Diff line change
@@ -3,18 +3,24 @@

import logging
import pickle
import shutil
from os import listdir
from os import mkdir
from os.path import exists
from os.path import join
from re import findall
from typing import cast
from typing import Sequence

from bqskit.compiler.basepass import BasePass
from bqskit.compiler.passdata import PassData
from bqskit.ir.circuit import Circuit
from bqskit.ir.gates.circuitgate import CircuitGate
from bqskit.ir.lang.qasm2.qasm2 import OPENQASM2Language
from bqskit.ir.operation import Operation
from bqskit.passes.alias import PassAlias
from bqskit.passes.util.converttou3 import ToU3Pass
from bqskit.utils.typing import is_sequence

_logger = logging.getLogger(__name__)

@@ -32,6 +38,7 @@ def __init__(
path_to_save_dir: str,
project_name: str | None = None,
save_as_qasm: bool = True,
overwrite: bool = False,
) -> None:
"""
Constructor for the SaveIntermediatePass.
@@ -57,15 +64,18 @@ def __init__(
else 'unnamed_project'

enum = 1
if exists(self.pathdir + self.projname):
while exists(self.pathdir + self.projname + f'_{enum}'):
enum += 1
self.projname += f'_{enum}'
_logger.warning(
f'Path {path_to_save_dir} already exists, '
f'saving to {self.pathdir + self.projname} '
'instead.',
)
if exists(join(self.pathdir, self.projname)):
if overwrite:
shutil.rmtree(join(self.pathdir, self.projname))
else:
while exists(join(self.pathdir, self.projname + f'_{enum}')):
enum += 1
self.projname += f'_{enum}'
_logger.warning(
f'Path {path_to_save_dir} already exists, '
f'saving to {self.pathdir + self.projname} '
'instead.',
)

mkdir(self.pathdir + self.projname)

@@ -102,8 +112,8 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
block.params,
)
subcircuit.unfold((0, 0))
await ToU3Pass().run(subcircuit, PassData(subcircuit))
if self.as_qasm:
await ToU3Pass().run(subcircuit, PassData(subcircuit))
with open(block_skeleton + f'{enum}.qasm', 'w') as f:
f.write(OPENQASM2Language().encode(subcircuit))
else:
@@ -117,7 +127,10 @@ async def run(self, circuit: Circuit, data: PassData) -> None:


class RestoreIntermediatePass(BasePass):
def __init__(self, project_directory: str, load_blocks: bool = True):
def __init__(
self, project_directory: str, load_blocks: bool = True,
as_circuit_gate: bool = False,
):
"""
Constructor for the RestoreIntermediatePass.
@@ -130,30 +143,20 @@ def __init__(self, project_directory: str, load_blocks: bool = True):
the user must explicitly call load_blocks() themselves. Defaults
to True.
as_circuit_gate (bool): If True, blocks are reloaded as a circuit
gate rather than a circuit.
Raises:
ValueError: If `project_directory` does not exist or if
`structure.pickle` is invalid.
"""
self.proj_dir = project_directory
if not exists(self.proj_dir):
raise TypeError(
f"Project directory '{self.proj_dir}' does not exist.",
)
if not exists(self.proj_dir + '/structure.pickle'):
raise TypeError(
f'Project directory `{self.proj_dir}` does not '
'contain `structure.pickle`.',
)

with open(self.proj_dir + '/structure.pickle', 'rb') as f:
self.structure = pickle.load(f)

if not isinstance(self.structure, list):
raise TypeError('The provided `structure.pickle` is not a list.')

self.block_list: list[str] = []
if load_blocks:
self.reload_blocks()
self.as_circuit_gate = as_circuit_gate
# We will detect automatically if blocks are saved as qasm or pickle
self.saved_as_qasm = False

self.load_blocks = load_blocks

def reload_blocks(self) -> None:
"""
@@ -164,11 +167,18 @@ def reload_blocks(self) -> None:
ValueError: if there are more block files than indices in the
`structure.pickle`.
"""
files = listdir(self.proj_dir)
files = sorted(listdir(self.proj_dir))
# Files are of the form block_*.pickle or block_*.qasm
self.block_list = [f for f in files if 'block_' in f]
pickle_list = [f for f in self.block_list if '.pickle' in f]
if len(pickle_list) == 0:
self.saved_as_qasm = True
self.block_list = [f for f in self.block_list if '.qasm' in f]
else:
self.block_list = pickle_list
if len(self.block_list) > len(self.structure):
raise ValueError(
'More block files than indicies in `structure.pickle`',
'More block files than indices in `structure.pickle`',
)

async def run(self, circuit: Circuit, data: PassData) -> None:
@@ -179,21 +189,89 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
ValueError: if a block file and the corresponding index in
`structure.pickle` are differnt lengths.
"""
# If the circuit is empty, just append blocks in order
if circuit.depth == 0:
for block in self.block_list:
# Get block
block_num = int(findall(r'\d+', block)[0])
with open(self.proj_dir + '/' + block) as f:

if not exists(self.proj_dir):
raise TypeError(
f"Project directory '{self.proj_dir}' does not exist.",
)
if not exists(self.proj_dir + '/structure.pickle'):
raise TypeError(
f'Project directory `{self.proj_dir}` does not '
'contain `structure.pickle`.',
)

with open(self.proj_dir + '/structure.pickle', 'rb') as f:
self.structure = pickle.load(f)

if not isinstance(self.structure, list):
raise TypeError('The provided `structure.pickle` is not a list.')

if self.load_blocks:
self.reload_blocks()

# Get circuit from checkpoint, ignore previous circuit
new_circuit = Circuit(circuit.num_qudits, circuit.radixes)
for block in self.block_list:
# Get block
block_num = int(findall(r'\d+', block)[0])
if self.saved_as_qasm:
with open(join(self.proj_dir, block)) as f:
block_circ = OPENQASM2Language().decode(f.read())
# Get location
block_location = self.structure[block_num]
if block_circ.num_qudits != len(block_location):
raise ValueError(
f'{block} and `structure.pickle` locations are '
'different sizes.',
)
# Append to circuit
circuit.append_circuit(block_circ, block_location)
else:
with open(join(self.proj_dir, block), 'rb') as f:
block_circ = pickle.load(f)
# Get location
block_location = self.structure[block_num]
if block_circ.num_qudits != len(block_location):
raise ValueError(
f'{block} and `structure.pickle` locations are '
'different sizes.',
)
# Append to circuit
new_circuit.append_circuit(
block_circ, block_location,
as_circuit_gate=self.as_circuit_gate,
)

circuit.become(new_circuit)
# Check if the circuit has been partitioned, if so, try to replace
# blocks


class CheckpointRestartPass(PassAlias):
def __init__(
self, base_checkpoint_dir: str,
project_name: str,
default_passes: BasePass | Sequence[BasePass],
save_as_qasm: bool = True,
) -> None:
"""Group together one or more `passes`."""
if not is_sequence(default_passes):
default_passes = [cast(BasePass, default_passes)]

if not isinstance(default_passes, list):
default_passes = list(default_passes)

full_checkpoint_dir = join(base_checkpoint_dir, project_name)

# Check if checkpoint files exist
if not exists(join(full_checkpoint_dir, 'structure.pickle')):
_logger.info('Checkpoint does not exist!')
save_pass = SaveIntermediatePass(
base_checkpoint_dir, project_name,
save_as_qasm=save_as_qasm, overwrite=True,
)
default_passes.append(save_pass)
self.passes = default_passes
else:
# Already checkpointed, restore
_logger.info('Restoring from Checkpoint!')
self.passes = [
RestoreIntermediatePass(
full_checkpoint_dir, as_circuit_gate=True,
),
]

def get_passes(self) -> list[BasePass]:
"""Return the passes to be run, see :class:`PassAlias` for more."""
return self.passes
58 changes: 57 additions & 1 deletion bqskit/passes/processing/scan.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,10 @@
from __future__ import annotations

import logging
import pickle
from os import mkdir
from os.path import exists
from os.path import join
from typing import Any
from typing import Callable

@@ -12,6 +16,7 @@
from bqskit.ir.opt.cost.functions import HilbertSchmidtResidualsGenerator
from bqskit.ir.opt.cost.generator import CostFunctionGenerator
from bqskit.utils.typing import is_real_number

_logger = logging.getLogger(__name__)


@@ -29,6 +34,7 @@ def __init__(
cost: CostFunctionGenerator = HilbertSchmidtResidualsGenerator(),
instantiate_options: dict[str, Any] = {},
collection_filter: Callable[[Operation], bool] | None = None,
checkpoint_proj: str | None = None,
) -> None:
"""
Construct a ScanningGateRemovalPass.
@@ -94,6 +100,9 @@ def __init__(
'cost_fn_gen': self.cost,
}
self.instantiate_options.update(instantiate_options)
self.checkpoint_proj = checkpoint_proj
if (self.checkpoint_proj and not exists(self.checkpoint_proj)):
mkdir(self.checkpoint_proj)

async def run(self, circuit: Circuit, data: PassData) -> None:
"""Perform the pass's operation, see :class:`BasePass` for more."""
@@ -108,7 +117,45 @@ async def run(self, circuit: Circuit, data: PassData) -> None:

circuit_copy = circuit.copy()
reverse_iter = not self.start_from_left
for cycle, op in circuit.operations_with_cycles(reverse=reverse_iter):

start_ind = 0
iterator = circuit.operations_with_cycles(reverse=reverse_iter)
all_ops = [x for x in iterator]

# Things needed for saving data
if self.checkpoint_proj:
block_num: str = data.get('block_num', '0')
save_data_file = join(
self.checkpoint_proj,
f'block_{block_num}.data',
)
save_circuit_file = join(
self.checkpoint_proj, f'block_{block_num}.pickle',
)
if exists(save_data_file):
_logger.debug(f'Reloading block {block_num}!')
# Reload ind from previous stop
with open(save_data_file, 'rb') as df:
new_data = pickle.load(df)
data.update(new_data)
with open(save_circuit_file, 'rb') as cf:
circuit_copy = pickle.load(cf)
start_ind = data.get('ind', 0)
if start_ind >= len(all_ops):
all_ops = []
_logger.debug('Block is already finished!')
else:
all_ops = all_ops[start_ind:]
_logger.debug('starting at ', start_ind)
else:
# Initial checkpoint
with open(save_data_file, 'wb') as df:
data['ind'] = 0
pickle.dump(data, df)
with open(save_circuit_file, 'wb') as cf:
pickle.dump(circuit_copy, cf)

for i, (cycle, op) in enumerate(all_ops):

if not self.collection_filter(op):
_logger.debug(f'Skipping operation {op} at cycle {cycle}.')
@@ -131,6 +178,15 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
if self.cost(working_copy, target) < self.success_threshold:
_logger.debug('Successfully removed operation.')
circuit_copy = working_copy
# Create checkpoint
if self.checkpoint_proj:
with open(save_circuit_file, 'wb') as cf:
pickle.dump(circuit_copy, cf)

if self.checkpoint_proj:
with open(save_data_file, 'wb') as df:
data['ind'] = i + start_ind + 1
pickle.dump(data, df)

circuit.become(circuit_copy)