Skip to content

Commit

Permalink
Pivot to stack based postfix/rpn deserialization
Browse files Browse the repository at this point in the history
This commit changes how the deserialization works to use a postfix
stack based approach. Operands are push on the stack and then popped off
based on the operation being run. The result of the operation is then
pushed on the stack. This handles nested objects much more cleanly than
the recursion based approach because we just keep pushing on the stack
instead of recursing, making the accounting much simpler. After the
expression payload is finished being processed there will be a single
value on the stack and that is returned as the final expression.
  • Loading branch information
mtreinish committed Nov 5, 2024
1 parent 60a2172 commit a903387
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 144 deletions.
2 changes: 2 additions & 0 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
self._parameter_symbols = {self: symbol}
self._name_map = None
self._qpy_replay = []
self._standalone_param = True

def assign(self, parameter, value):
if parameter != self:
Expand Down Expand Up @@ -174,3 +175,4 @@ def __setstate__(self, state):
self._parameter_symbols = {self: self._symbol_expr}
self._name_map = None
self._qpy_replay = []
self._standalone_param = True
33 changes: 27 additions & 6 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def op_code_to_method(op_code: _OPCode):
@dataclass
class _INSTRUCTION:
op: _OPCode
lhs: ParameterValueType
lhs: ParameterValueType | None
rhs: ParameterValueType | None = None


Expand All @@ -103,6 +103,7 @@ class ParameterExpression:
"_symbol_expr",
"_name_map",
"_qpy_replay",
"_standalone_param",
]

def __init__(self, symbol_map: dict, expr, *, _qpy_replay=None):
Expand All @@ -124,6 +125,7 @@ def __init__(self, symbol_map: dict, expr, *, _qpy_replay=None):
self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols)
self._symbol_expr = expr
self._name_map: dict | None = None
self._standalone_param = False
if _qpy_replay is not None:
self._qpy_replay = _qpy_replay
else:
Expand All @@ -143,7 +145,10 @@ def _names(self) -> dict:

def conjugate(self) -> "ParameterExpression":
"""Return the conjugate."""
new_op = _INSTRUCTION(_OPCode.CONJ, self)
if self._standalone_param:
new_op = _INSTRUCTION(_OPCode.CONJ, self)
else:
new_op = _INSTRUCTION(_OPCode.CONJ, None)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
conjugated = ParameterExpression(
Expand Down Expand Up @@ -357,10 +362,16 @@ def _apply_operation(

if reflected:
expr = operation(other_expr, self_expr)
new_op = _INSTRUCTION(op_code, other, self)
if self._standalone_param:
new_op = _INSTRUCTION(op_code, other, self)
else:
new_op = _INSTRUCTION(op_code, other, None)
else:
expr = operation(self_expr, other_expr)
new_op = _INSTRUCTION(op_code, self, other)
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self, other)
else:
new_op = _INSTRUCTION(op_code, None, other)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

Expand All @@ -386,6 +397,13 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
# If it is not contained then return 0
return 0.0

if self._standalone_param:
new_op = _INSTRUCTION(_OPCode.DERIV, self, param)
else:
new_op = _INSTRUCTION(_OPCode.DERIV, None, param)
qpy_replay = self._qpy_replay.copy()
qpy_replay.append(new_op)

# Compute the gradient of the parameter expression w.r.t. param
key = self._parameter_symbols[param]
expr_grad = symengine.Derivative(self._symbol_expr, key)
Expand All @@ -399,7 +417,7 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
parameter_symbols[parameter] = symbol
# If the gradient corresponds to a parameter expression then return the new expression.
if len(parameter_symbols) > 0:
return ParameterExpression(parameter_symbols, expr=expr_grad)
return ParameterExpression(parameter_symbols, expr=expr_grad, _qpy_replay=qpy_replay)
# If no free symbols left, return a complex or float gradient
expr_grad_cplx = complex(expr_grad)
if expr_grad_cplx.imag != 0:
Expand Down Expand Up @@ -446,7 +464,10 @@ def __rpow__(self, other):
return self._apply_operation(pow, other, reflected=True, op_code=_OPCode.POW)

def _call(self, ufunc, op_code):
new_op = _INSTRUCTION(op_code, self)
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self)
else:
new_op = _INSTRUCTION(op_code, None)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
return ParameterExpression(
Expand Down
243 changes: 105 additions & 138 deletions qiskit/qpy/binary_io/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _write_parameter_vec(file_obj, obj):
file_obj.write(name_bytes)


def _encode_replay_entry(inst, expression_tracking, file_obj, version, side=False):
def _encode_replay_entry(inst, file_obj, version, side=False):
inst_type = None
inst_data = None
if inst is None:
Expand All @@ -75,52 +75,47 @@ def _encode_replay_entry(inst, expression_tracking, file_obj, version, side=Fals
inst_type = "i"
inst_data = struct.pack("!Qq", 0, inst)
elif isinstance(inst, ParameterExpression):
if inst not in expression_tracking:
if not side:
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
255,
"s".encode("utf8"),
b"\x00",
"n".encode("utf8"),
b"\x00",
)
else:
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
255,
"n".encode("utf8"),
b"\x00",
"s".encode("utf8"),
b"\x00",
)

file_obj.write(entry)
_write_parameter_expression_v13(file_obj, inst, version)
if not side:
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
255,
"e".encode("utf8"),
b"\x00",
"n".encode("utf8"),
b"\x00",
)
else:
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
255,
"n".encode("utf8"),
b"\x00",
"e".encode("utf8"),
b"\x00",
)
file_obj.write(entry)
inst_type = "n"
inst_data = b"\x00"
if not side:
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
255,
"s".encode("utf8"),
b"\x00",
"n".encode("utf8"),
b"\x00",
)
else:
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
255,
"n".encode("utf8"),
b"\x00",
"s".encode("utf8"),
b"\x00",
)
file_obj.write(entry)
_write_parameter_expression_v13(file_obj, inst, version)
if not side:
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
255,
"e".encode("utf8"),
b"\x00",
"n".encode("utf8"),
b"\x00",
)
else:
inst_type = "n"
inst_data = b"\x00"
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
255,
"n".encode("utf8"),
b"\x00",
"e".encode("utf8"),
b"\x00",
)
file_obj.write(entry)
inst_type = "n"
inst_data = b"\x00"
else:
raise exceptions.QpyError("Invalid parameter expression type")
return inst_type, inst_data
Expand All @@ -147,16 +142,13 @@ def _encode_replay_subs(subs, file_obj, version):


def _write_parameter_expression_v13(file_obj, obj, version):
expression_tracking = {
obj,
}
symbol_map = {}
for inst in obj._qpy_replay:
if isinstance(inst, _SUBS):
symbol_map.update(_encode_replay_subs(inst, file_obj, version))
continue
lhs_type, lhs = _encode_replay_entry(inst.lhs, expression_tracking, file_obj, version)
rhs_type, rhs = _encode_replay_entry(inst.rhs, expression_tracking, file_obj, version, True)
lhs_type, lhs = _encode_replay_entry(inst.lhs, file_obj, version)
rhs_type, rhs = _encode_replay_entry(inst.rhs, file_obj, version, True)
entry = struct.pack(
formats.PARAM_EXPR_ELEM_V4_PACK,
inst.op,
Expand Down Expand Up @@ -553,86 +545,70 @@ def _read_parameter_expression_v4(file_obj, vectors, version):
def _read_parameter_expr_v13(buf, symbol_map, version, vectors):
param_uuid_map = {symbol.uuid: symbol for symbol in symbol_map if isinstance(symbol, Parameter)}
name_map = {str(v): k for k, v in symbol_map.items()}
expression = None
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
rhs = None
lhs = None
stack = []
while data:
expression_data = formats.PARAM_EXPR_ELEM_V4._make(
struct.unpack(formats.PARAM_EXPR_ELEM_V4_PACK, data)
)
if lhs is None:
if expression_data.LHS_TYPE == b"p":
lhs = param_uuid_map[uuid.UUID(bytes=expression_data.LHS)]
elif expression_data.LHS_TYPE == b"f":
lhs = struct.unpack("!Qd", expression_data.LHS)[1]
elif expression_data.LHS_TYPE == b"n":
lhs = None
elif expression_data.LHS_TYPE == b"c":
lhs = complex(*struct.unpack("!dd", expression_data.LHS))
elif expression_data.LHS_TYPE == b"i":
lhs = struct.unpack("!Qq", expression_data.LHS)[1]
elif expression_data.LHS_TYPE == b"s":
lhs = _read_parameter_expr_v13(buf, symbol_map, version, vectors)
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
continue
elif expression_data.LHS_TYPE == b"e":
return expression
elif expression_data.LHS_TYPE == b"u":
size = struct.unpack_from("!QQ", expression_data.LHS)[0]
subs_map_data = buf.read(size)
with io.BytesIO(subs_map_data) as mapping_buf:
mapping = common.read_mapping(
mapping_buf, deserializer=loads_value, version=version, vectors=vectors
)
expression = expression.subs(
{name_map[k]: v for k, v in mapping.items()}, allow_unknown_parameters=True
)
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
continue
else:
raise exceptions.QpyError(
"Unknown ParameterExpression operation type {expression_data.LHS_TYPE}"
)
if rhs is None:
if expression_data.RHS_TYPE == b"p":
rhs = param_uuid_map[uuid.UUID(bytes=expression_data.RHS)]
elif expression_data.RHS_TYPE == b"f":
rhs = struct.unpack("!Qd", expression_data.RHS)[1]
elif expression_data.RHS_TYPE == b"n":
rhs = None
elif expression_data.RHS_TYPE == b"c":
rhs = complex(*struct.unpack("!dd", expression_data.LHS))
elif expression_data.RHS_TYPE == b"i":
rhs = struct.unpack("!Qq", expression_data.RHS)[1]
elif expression_data.RHS_TYPE == b"s":
rhs = _read_parameter_expr_v13(buf, symbol_map, version, vectors)
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
continue
elif expression_data.RHS_TYPE == b"e":
return expression
else:
raise exceptions.QpyError(
f"Unknown ParameterExpression operation type {expression_data.RHS_TYPE}"
# LHS
if expression_data.LHS_TYPE == b"p":
stack.append(param_uuid_map[uuid.UUID(bytes=expression_data.LHS)])
elif expression_data.LHS_TYPE == b"f":
stack.append(struct.unpack("!Qd", expression_data.LHS)[1])
elif expression_data.LHS_TYPE == b"n":
pass
elif expression_data.LHS_TYPE == b"c":
stack.append(complex(*struct.unpack("!dd", expression_data.LHS)))
elif expression_data.LHS_TYPE == b"i":
stack.append(struct.unpack("!Qq", expression_data.LHS)[1])
elif expression_data.LHS_TYPE == b"s":
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
continue
elif expression_data.LHS_TYPE == b"e":
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
continue
elif expression_data.LHS_TYPE == b"u":
size = struct.unpack_from("!QQ", expression_data.LHS)[0]
subs_map_data = buf.read(size)
with io.BytesIO(subs_map_data) as mapping_buf:
mapping = common.read_mapping(
mapping_buf, deserializer=loads_value, version=version, vectors=vectors
)
reverse_op = False
if expression is None:
if isinstance(lhs, ParameterExpression):
expression = lhs
elif isinstance(rhs, ParameterExpression):
expression = rhs
reverse_op = True
rhs = lhs
else:
raise exceptions.QpyError("Invalid ParameterExpression payload construction")
stack.append({name_map[k]: v for k, v in mapping.items()})
else:
raise exceptions.QpyError(
"Unknown ParameterExpression operation type {expression_data.LHS_TYPE}"
)
# RHS
if expression_data.RHS_TYPE == b"p":
stack.append(param_uuid_map[uuid.UUID(bytes=expression_data.RHS)])
elif expression_data.RHS_TYPE == b"f":
stack.append(struct.unpack("!Qd", expression_data.RHS)[1])
elif expression_data.RHS_TYPE == b"n":
pass
elif expression_data.RHS_TYPE == b"c":
stack.append(complex(*struct.unpack("!dd", expression_data.RHS)))
elif expression_data.RHS_TYPE == b"i":
stack.append(struct.unpack("!Qq", expression_data.RHS)[1])
elif expression_data.RHS_TYPE == b"s":
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
continue
elif expression_data.RHS_TYPE == b"e":
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
continue
else:
raise exceptions.QpyError(
f"Unknown ParameterExpression operation type {expression_data.RHS_TYPE}"
)
if expression_data.OP_CODE == 255:
continue
method_str = op_code_to_method(_OPCode(expression_data.OP_CODE))
# Handle reverse operators
if rhs is None and expression is not None:
reverse_op = True
rhs = lhs
if expression_data.OP_CODE in {0, 1, 2, 3, 4, 13, 15}:
if reverse_op:
# Map arithmetic operators to reverse methods
rhs = stack.pop()
lhs = stack.pop()
# Reverse ops
if not isinstance(lhs, ParameterExpression) and isinstance(rhs, ParameterExpression):
if expression_data.OP_CODE == 0:
method_str = "__radd__"
elif expression_data.OP_CODE == 1:
Expand All @@ -641,23 +617,14 @@ def _read_parameter_expr_v13(buf, symbol_map, version, vectors):
method_str = "__rmul__"
elif expression_data.OP_CODE == 3:
method_str = "__rtruediv__"

expression = getattr(expression, method_str)(rhs)
stack.append(getattr(rhs, method_str)(lhs))
else:
stack.append(getattr(lhs, method_str)(rhs))
else:
expression = getattr(expression, method_str)()
lhs = None
rhs = None
lhs = stack.pop()
stack.append(getattr(lhs, method_str)())
data = buf.read(formats.PARAM_EXPR_ELEM_V4_SIZE)
if expression is None:
if isinstance(lhs, ParameterExpression):
expression = lhs
elif isinstance(rhs, ParameterExpression):
expression = rhs
reverse_op = True
rhs = lhs
else:
raise exceptions.QpyError("Invalid ParameterExpression payload construction")
return expression
return stack.pop()


def _read_expr(
Expand Down

0 comments on commit a903387

Please sign in to comment.