Skip to content

Commit

Permalink
Fix single clbit display in condition for all 3 circuit drawers (#7285)
Browse files Browse the repository at this point in the history
* Get mpl and utils conditions working

* Almost on text drawer

* Finish text drawer

* Cleanup

* Finish latex and fix tests

* Working on measure with condition in utils

* Lint

* More lint

* Finish 7248 and 7284 plus measure issues

* Add mpl tests

* Fix mpl tests

* Add latex and text tests and bug fixes

* Lint and update image

* Reno and fix image

* Image

* Fix measure with registerless bit

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
enavarro51 and mergify[bot] authored Nov 26, 2021
1 parent 413ec5e commit c07c2cc
Show file tree
Hide file tree
Showing 17 changed files with 434 additions and 213 deletions.
200 changes: 104 additions & 96 deletions qiskit/visualization/latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
from qiskit.circuit.measure import Measure
from qiskit.visualization.qcstyle import load_style
from qiskit.circuit.tools.pi_check import pi_check
from .utils import get_gate_ctrl_text, get_param_str, get_bit_label, generate_latex_label
from .utils import (
get_gate_ctrl_text,
get_param_str,
get_bit_label,
generate_latex_label,
get_condition_label,
)


class QCircuitImage:
Expand Down Expand Up @@ -78,15 +84,12 @@ def __init__(
# image scaling
self.scale = 1.0 if scale is None else scale

# Map of qregs to sizes
self.qregs = {}

# Map of cregs to sizes
self.cregs = {}

# List of qubits and cbits in order of appearance in code and image
# May also include ClassicalRegisters if cregbundle=True
self.ordered_bits = []
self._ordered_bits = []

# Map from registers to the list they appear in the image
self.img_regs = {}
Expand Down Expand Up @@ -121,19 +124,19 @@ def __init__(
self.plot_barriers = plot_barriers

#################################
self.qubit_list = qubits
self.clbit_list = clbits
self.ordered_bits = qubits + clbits
self._qubits = qubits
self._clbits = clbits
self._ordered_bits = qubits + clbits
self.cregs = {reg: reg.size for reg in cregs}

self.bit_locations = {
self._bit_locations = {
bit: {"register": register, "index": index}
for register in cregs + qregs
for index, bit in enumerate(register)
}
for index, bit in list(enumerate(qubits)) + list(enumerate(clbits)):
if bit not in self.bit_locations:
self.bit_locations[bit] = {"register": None, "index": index}
if bit not in self._bit_locations:
self._bit_locations[bit] = {"register": None, "index": index}

self.cregbundle = cregbundle
# If there is any custom instruction that uses clasiscal bits
Expand All @@ -143,8 +146,8 @@ def __init__(
if node.op.name not in {"measure"} and node.cargs:
self.cregbundle = False

self.cregs_bits = [self.bit_locations[bit]["register"] for bit in clbits]
self.img_regs = {bit: ind for ind, bit in enumerate(self.ordered_bits)}
self.cregs_bits = [self._bit_locations[bit]["register"] for bit in clbits]
self.img_regs = {bit: ind for ind, bit in enumerate(self._ordered_bits)}

num_reg_bits = sum(reg.size for reg in self.cregs)
if self.cregbundle:
Expand Down Expand Up @@ -210,17 +213,17 @@ def _initialize_latex_array(self):
self.wire_separation = 1.0
self._latex = [
[
"\\cw" if isinstance(self.ordered_bits[j], Clbit) else "\\qw"
"\\cw" if isinstance(self._ordered_bits[j], Clbit) else "\\qw"
for _ in range(self.img_depth + 1)
]
for j in range(self.img_width)
]
self._latex.append([" "] * (self.img_depth + 1))

# quantum register
for ii, reg in enumerate(self.qubit_list):
register = self.bit_locations[reg]["register"]
index = self.bit_locations[reg]["index"]
for ii, reg in enumerate(self._qubits):
register = self._bit_locations[reg]["register"]
index = self._bit_locations[reg]["index"]
qubit_label = get_bit_label("latex", register, index, qubit=True, layout=self.layout)
qubit_label += " : "
if self.initial_state:
Expand All @@ -230,10 +233,10 @@ def _initialize_latex_array(self):

# classical register
offset = 0
if self.clbit_list:
for ii in range(len(self.qubit_list), self.img_width):
register = self.bit_locations[self.ordered_bits[ii + offset]]["register"]
index = self.bit_locations[self.ordered_bits[ii + offset]]["index"]
if self._clbits:
for ii in range(len(self._qubits), self.img_width):
register = self._bit_locations[self._ordered_bits[ii + offset]]["register"]
index = self._bit_locations[self._ordered_bits[ii + offset]]["index"]
clbit_label = get_bit_label(
"latex", register, index, qubit=False, cregbundle=self.cregbundle
)
Expand Down Expand Up @@ -328,9 +331,9 @@ def _get_image_depth(self):
sum_column_widths = sum(1 + v / 3 for v in max_column_widths)

max_reg_name = 3
for reg in self.ordered_bits:
if self.bit_locations[reg]["register"] is not None:
max_reg_name = max(max_reg_name, len(self.bit_locations[reg]["register"].name))
for reg in self._ordered_bits:
if self._bit_locations[reg]["register"] is not None:
max_reg_name = max(max_reg_name, len(self._bit_locations[reg]["register"].name))
sum_column_widths += 5 + max_reg_name / 3

# could be a fraction so ceil
Expand Down Expand Up @@ -420,7 +423,7 @@ def _build_latex_array(self):

def _build_multi_gate(self, op, gate_text, wire_list, cwire_list, col):
"""Add a multiple wire gate to the _latex list"""
cwire_start = len(self.qubit_list)
cwire_start = len(self._qubits)
num_cols_op = 1
if isinstance(op, (SwapGate, RZZGate)):
num_cols_op = self._build_symmetric_gate(op, gate_text, wire_list, col)
Expand Down Expand Up @@ -524,19 +527,31 @@ def _build_measure(self, node, col):
"""Build a meter and the lines to the creg"""
wire1 = self.img_regs[node.qargs[0]]
self._latex[wire1][col] = "\\meter"

if self.cregbundle:
wire2 = len(self.qubit_list)
cregindex = self.img_regs[node.cargs[0]] - wire2
for creg_size in self.cregs.values():
if cregindex >= creg_size:
cregindex -= creg_size
wire2 = len(self._qubits)
prev_reg = None
idx_str = ""
cond_offset = 1.5 if node.op.condition else 0.0
for i, reg in enumerate(self.cregs_bits):
# if it's a registerless bit
if reg is None:
if self._clbits[i] == node.cargs[0]:
break
wire2 += 1
else:
continue
# if it's a whole register or a bit in a register
if reg == self._bit_locations[node.cargs[0]]["register"]:
idx_str = str(self._bit_locations[node.cargs[0]]["index"])
break
cond_offset = 1.5 if node.op.condition else 0.0
if self.cregbundle and prev_reg and prev_reg == reg:
continue
wire2 += 1
prev_reg = reg

self._latex[wire2][col] = "\\dstick{_{_{\\hspace{%sem}%s}}} \\cw \\ar @{<=} [-%s,0]" % (
cond_offset,
str(cregindex),
idx_str,
str(wire2 - wire1),
)
else:
Expand All @@ -553,11 +568,11 @@ def _build_barrier(self, node, col):
if index - 1 == last:
last = index
else:
pos = self.img_regs[self.qubit_list[first]]
pos = self.img_regs[self._qubits[first]]
self._latex[pos][col - 1] += " \\barrier[0em]{" + str(last - first) + "}"
self._latex[pos][col] = "\\qw"
first = last = index
pos = self.img_regs[self.qubit_list[first]]
pos = self.img_regs[self._qubits[first]]
self._latex[pos][col - 1] += " \\barrier[0em]{" + str(last - first) + "}"
self._latex[pos][col] = "\\qw"

Expand All @@ -581,77 +596,70 @@ def _add_controls(self, wire_list, ctrlqargs, ctrl_state, col):

def _add_condition(self, op, wire_list, col):
"""Add a condition to the _latex list"""
# if_value - a bit string for the condition
# cwire - the wire number for the first wire for the condition register
# or if cregbundle, wire number of the condition register itself
# gap - the number of wires from cwire to the bottom gate qubit

label, clbit_mask, val_list = get_condition_label(
op.condition, self._clbits, self._bit_locations, self.cregbundle
)
if not self.reverse_bits:
val_list = val_list[::-1]
cond_is_bit = isinstance(op.condition[0], Clbit)
if cond_is_bit:
cond_reg = self.bit_locations[op.condition[0]]["register"]
if_value = op.condition[1]
cond_reg = (
op.condition[0] if not cond_is_bit else self._bit_locations[op.condition[0]]["register"]
)
# if cregbundle, add 1 to cwire for each register and each registerless bit, until
# the condition bit/register is found. If not cregbundle, add 1 to cwire for every
# bit until condition found.
cwire = len(self._qubits)
if self.cregbundle:
prev_reg = None
for i, reg in enumerate(self.cregs_bits):
# if it's a registerless bit
if reg is None:
if self._clbits[i] == op.condition[0]:
break
cwire += 1
continue
# if it's a whole register or a bit in a register
if reg == cond_reg:
break
if self.cregbundle and prev_reg and prev_reg == reg:
continue
cwire += 1
prev_reg = reg
else:
cond_reg = op.condition[0]
creg_size = self.cregs[cond_reg]
if_value = format(op.condition[1], "b").zfill(creg_size)
if not self.reverse_bits:
if_value = if_value[::-1]

cwire = len(self.qubit_list)
iter_cregs = iter(list(self.cregs)) if self.cregbundle else iter(self.cregs_bits)
for creg in iter_cregs:
if creg == cond_reg:
break
cwire += 1
for bit in clbit_mask:
if bit == "1":
break
cwire += 1

gap = cwire - max(wire_list)
meas_offset = -0.3 if isinstance(op, Measure) else 0.0
if self.cregbundle:
# Print the condition value at the bottom and put bullet on creg line
if cond_is_bit:
ctrl_bit = (
str(cond_reg.name) + "_" + str(self.bit_locations[op.condition[0]]["index"])
)
label = "T" if if_value is True else "F"
self._latex[cwire][col] = "\\control \\cw^(%s){^{\\mathtt{%s=%s}}} \\cwx[-%s]" % (
meas_offset,
ctrl_bit,
label,
str(gap),
)
else:
self._latex[cwire][col] = "\\control \\cw^(%s){^{\\mathtt{%s}}} \\cwx[-%s]" % (
meas_offset,
str(hex(op.condition[1])),
str(gap),
)
# Print the condition value at the bottom and put bullet on creg line
if cond_is_bit or self.cregbundle:
control = "\\control" if op.condition[1] else "\\controlo"
self._latex[cwire][col] = f"{control}" + " \\cw^(%s){^{\\mathtt{%s}}} \\cwx[-%s]" % (
meas_offset,
label,
str(gap),
)
else:
# Add the open and closed buttons to indicate the condition value
if cond_is_bit:
extra_gap = list(cond_reg).index(op.condition[0])
gap += extra_gap
control = "\\control" if if_value is True else "\\controlo"
self._latex[cwire + extra_gap][col] = (
f"{control}" + " \\cw^(%s){^{\\mathtt{%s}}} \\cwx[-%s]"
) % (
meas_offset,
str(hex(op.condition[1])),
str(gap),
)
else:
for i in range(creg_size - 1):
control = "\\control" if if_value[i] == "1" else "\\controlo"
self._latex[cwire + i][col] = f"{control} \\cw \\cwx[-" + str(gap) + "]"
gap = 1
# Add (hex condition value) below the last cwire
control = "\\control" if if_value[creg_size - 1] == "1" else "\\controlo"
self._latex[creg_size + cwire - 1][col] = (
f"{control}" + " \\cw^(%s){^{\\mathtt{%s}}} \\cwx[-%s]"
) % (
meas_offset,
str(hex(op.condition[1])),
str(gap),
)
creg_size = op.condition[0].size
for i in range(creg_size - 1):
control = "\\control" if val_list[i] == "1" else "\\controlo"
self._latex[cwire + i][col] = f"{control} \\cw \\cwx[-" + str(gap) + "]"
gap = 1
# Add (hex condition value) below the last cwire
control = "\\control" if val_list[creg_size - 1] == "1" else "\\controlo"
self._latex[creg_size + cwire - 1][col] = (
f"{control}" + " \\cw^(%s){^{\\mathtt{%s}}} \\cwx[-%s]"
) % (
meas_offset,
label,
str(gap),
)

def _truncate_float(self, matchobj, ndigits=4):
"""Truncate long floats."""
Expand Down
Loading

0 comments on commit c07c2cc

Please sign in to comment.