From ea69a4590ba058514cf7615e26c5113261c3e2ee Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 1 Oct 2024 14:33:13 -0400 Subject: [PATCH] Handle schedules too This commit updates the schedule serialization path too, as it was also directly loading symengine expressions. The code handling the workaround is extracted to a standalone function which is used in both spots now instead of calling symengine directly. --- qiskit/qpy/binary_io/schedules.py | 2 +- qiskit/qpy/binary_io/value.py | 25 +--------------------- qiskit/qpy/common.py | 35 ++++++++++++++++++++++++++++++- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/qiskit/qpy/binary_io/schedules.py b/qiskit/qpy/binary_io/schedules.py index eae5e6f57ad9..75b25f3b20a3 100644 --- a/qiskit/qpy/binary_io/schedules.py +++ b/qiskit/qpy/binary_io/schedules.py @@ -106,7 +106,7 @@ def _loads_symbolic_expr(expr_bytes, use_symengine=False): return None expr_bytes = zlib.decompress(expr_bytes) if use_symengine: - return load_basic(expr_bytes) + return common.load_symengine_payload(expr_bytes) else: from sympy import parse_expr diff --git a/qiskit/qpy/binary_io/value.py b/qiskit/qpy/binary_io/value.py index bb550c771ebc..63f208790402 100644 --- a/qiskit/qpy/binary_io/value.py +++ b/qiskit/qpy/binary_io/value.py @@ -290,30 +290,7 @@ def _read_parameter_expression_v3(file_obj, vectors, use_symengine): payload = file_obj.read(data.expr_size) if use_symengine: - # This is a horrible hack to workaround the symengine version checking - # it's deserialization does. There were no changes to the serialization - # format between 0.11 and 0.13 but the deserializer checks that it can't - # load across a major or minor version boundary. This works around it - # by just lying about the generating version. - symengine_version = symengine.__version__.split(".") - major = payload[2] - minor = payload[3] - if int(symengine_version[1]) != minor: - if minor not in (11, 13): - raise exceptions.QpyError( - f"Incompatible symengine version {major}.{minor} used to generate the QPY " - "payload" - ) - minor_version = int(symengine_version[1]) - if minor_version not in (11, 13): - raise exceptions.QpyError( - f"Incompatible installed symengine version {symengine.__version__} to load " - "this QPY payload" - ) - payload = bytearray(payload) - payload[3] = minor_version - payload = bytes(payload) - expr_ = load_basic(payload) + expr_ = common.load_symengine_payload(payload) else: from sympy.parsing.sympy_parser import parse_expr diff --git a/qiskit/qpy/common.py b/qiskit/qpy/common.py index 048320d5cad6..84acebcf335d 100644 --- a/qiskit/qpy/common.py +++ b/qiskit/qpy/common.py @@ -18,7 +18,12 @@ import io import struct -from qiskit.qpy import formats +import symengine +from symengine.lib.symengine_wrapper import ( # pylint: disable = no-name-in-module + load_basic, +) + +from qiskit.qpy import formats, exceptions QPY_VERSION = 12 QPY_COMPATIBILITY_VERSION = 10 @@ -304,3 +309,31 @@ def mapping_from_binary(binary_data, deserializer, **kwargs): mapping = read_mapping(container, deserializer, **kwargs) return mapping + + +def load_symengine_payload(payload: bytes) -> symengine.Expr: + """Load a symengine expression from it's serialized cereal payload.""" + # This is a horrible hack to workaround the symengine version checking + # it's deserialization does. There were no changes to the serialization + # format between 0.11 and 0.13 but the deserializer checks that it can't + # load across a major or minor version boundary. This works around it + # by just lying about the generating version. + symengine_version = symengine.__version__.split(".") + major = payload[2] + minor = payload[3] + if int(symengine_version[1]) != minor: + if minor not in (11, 13): + raise exceptions.QpyError( + f"Incompatible symengine version {major}.{minor} used to generate the QPY " + "payload" + ) + minor_version = int(symengine_version[1]) + if minor_version not in (11, 13): + raise exceptions.QpyError( + f"Incompatible installed symengine version {symengine.__version__} to load " + "this QPY payload" + ) + payload = bytearray(payload) + payload[3] = minor_version + payload = bytes(payload) + return load_basic(payload)