From e106b3b7e20b940fb2845a4d849c16ca33f37caf Mon Sep 17 00:00:00 2001 From: gadial Date: Fri, 24 Jun 2022 17:00:51 +0300 Subject: [PATCH] FigureData wrapper for figures (#814) * Added FigureData wrapper for figures * FigureData is now kept when adding the figure into a composite analysis result * Update qiskit_experiments/database_service/db_experiment_data.py Co-authored-by: Yael Ben-Haim Co-authored-by: Christopher J. Wood --- .../database_service/db_experiment_data.py | 86 +++++++++++++++++-- qiskit_experiments/framework/__init__.py | 2 + .../notes/figure_data-ecf5a82c95844b6a.yaml | 9 ++ .../test_db_experiment_data.py | 51 +++++++++-- 4 files changed, 134 insertions(+), 14 deletions(-) create mode 100644 releasenotes/notes/figure_data-ecf5a82c95844b6a.yaml diff --git a/qiskit_experiments/database_service/db_experiment_data.py b/qiskit_experiments/database_service/db_experiment_data.py index ba472420bd..73bd8d6501 100644 --- a/qiskit_experiments/database_service/db_experiment_data.py +++ b/qiskit_experiments/database_service/db_experiment_data.py @@ -20,6 +20,7 @@ import time from typing import Optional, List, Any, Union, Callable, Dict, Tuple import copy +import io from concurrent import futures from threading import Event from functools import wraps @@ -29,12 +30,12 @@ import numpy as np from matplotlib import pyplot +from matplotlib.figure import Figure as MatplotlibFigure from qiskit import QiskitError from qiskit.providers import Job, Backend, Provider from qiskit.result import Result from qiskit.providers.jobstatus import JobStatus, JOB_FINAL_STATES from qiskit_experiments.framework.json import ExperimentEncoder, ExperimentDecoder - from .database_service import DatabaseServiceV1 from .exceptions import DbExperimentDataError, DbExperimentEntryNotFound, DbExperimentEntryExists from .db_analysis_result import DbAnalysisResultV1 as DbAnalysisResult @@ -130,6 +131,64 @@ def __json_encode__(self): return self.__getstate__() +class FigureData: + """Wrapper for figures and figure metadata""" + + def __init__(self, figure, name=None, metadata=None): + """Creates a new figure data object""" + self.figure = figure + self._name = name + self.metadata = metadata or {} + + # name is read only + @property + def name(self) -> str: + """The name of the figure""" + return self._name + + @property + def metadata(self) -> dict: + """The metadata dictionary stored with the figure""" + return self._metadata + + @metadata.setter + def metadata(self, new_metadata: dict): + """Set the metadata to new value; must be a dictionary""" + if not isinstance(new_metadata, dict): + raise ValueError("figure metadata must be a dictionary") + self._metadata = new_metadata + + def copy(self, new_name: Optional[str] = None): + """Creates a copy of the figure data""" + name = new_name or self.name + return FigureData(figure=self.figure, name=name, metadata=copy.deepcopy(self.metadata)) + + def __json_encode__(self) -> Dict[str, Any]: + """Return the json representation of the figure data""" + return {"figure": self.figure, "name": self.name, "metadata": self.metadata} + + @classmethod + def __json_decode__(cls, args: Dict[str, Any]) -> "FigureData": + """Initialize a figure data from the json representation""" + return cls(**args) + + def _repr_png_(self): + if isinstance(self.figure, MatplotlibFigure): + b = io.BytesIO() + self.figure.savefig(b, format="png", bbox_inches="tight") + png = b.getvalue() + return png + else: + return None + + def _repr_svg_(self): + if isinstance(self.figure, str): + return self.figure + if isinstance(self.figure, bytes): + return str(self.figure) + return None + + class DbExperimentData: """Base common type for all versioned DbExperimentData classes. @@ -703,7 +762,15 @@ def add_figures( with open(figure, "rb") as file: figure = file.read() - self._figures[fig_name] = figure + # check whether the figure is already wrapped, meaning it came from a sub-experiment + if isinstance(figure, FigureData): + figure_data = figure.copy(new_name=fig_name) + + else: + figure_metadata = {"qubits": self.metadata.get("physical_qubits")} + figure_data = FigureData(figure=figure, name=fig_name, metadata=figure_metadata) + + self._figures[fig_name] = figure_data save = save_figure if save_figure is not None else self.auto_save if save and self._service: @@ -769,8 +836,8 @@ def figure( the content of the figure is returned instead. Returns: - The size of the figure if `file_name` is specified. Otherwise the - content of the figure in bytes. + The size of the figure if `file_name` is specified. + Otherwise the :class:`.FigureData`. Raises: DbExperimentEntryNotFound: If the figure cannot be found. @@ -780,9 +847,8 @@ def figure( figure_data = self._figures.get(figure_key, None) if figure_data is None and self.service: - figure_data = self.service.figure( - experiment_id=self.experiment_id, figure_name=figure_key - ) + figure = self.service.figure(experiment_id=self.experiment_id, figure_name=figure_key) + figure_data = FigureData(figure=figure, name=figure_key) self._figures[figure_key] = figure_data if figure_data is None: @@ -790,7 +856,7 @@ def figure( if file_name: with open(file_name, "wb") as output: - num_bytes = output.write(figure_data) + num_bytes = output.write(figure_data.figure) return num_bytes return figure_data @@ -1037,6 +1103,10 @@ def save(self) -> None: for name, figure in self._figures.items(): if figure is None: continue + # currently only the figure and its name are stored in the database + if isinstance(figure, FigureData): + figure = figure.figure + LOG.debug("Figure metadata is currently not saved to the database") if isinstance(figure, pyplot.Figure): figure = plot_to_svg_bytes(figure) data = {"experiment_id": self.experiment_id, "figure": figure, "figure_name": name} diff --git a/qiskit_experiments/framework/__init__.py b/qiskit_experiments/framework/__init__.py index 45bd35c22a..1a52e77d41 100644 --- a/qiskit_experiments/framework/__init__.py +++ b/qiskit_experiments/framework/__init__.py @@ -216,6 +216,7 @@ AnalysisConfig ExperimentEncoder ExperimentDecoder + FigureData .. _composite-experiment: @@ -253,6 +254,7 @@ ExperimentStatus, JobStatus, AnalysisStatus, + FigureData, ) from .base_analysis import BaseAnalysis from .base_experiment import BaseExperiment diff --git a/releasenotes/notes/figure_data-ecf5a82c95844b6a.yaml b/releasenotes/notes/figure_data-ecf5a82c95844b6a.yaml new file mode 100644 index 0000000000..b363d8ccfd --- /dev/null +++ b/releasenotes/notes/figure_data-ecf5a82c95844b6a.yaml @@ -0,0 +1,9 @@ +--- +upgrade: + - | + Adds a :class:`.FigureData` class for adding metadata to analysis result figures. Figures added to + :class:`.ExperimentData` are now stored using this class. The raw image object (SVG or matplotlib.Figure) + can be accessed using the :attr:`.FigureData.figure` attribute. + + Note that currently metadata is only stored locally and will be discarded when saved to the cloud + experiment service database. diff --git a/test/database_service/test_db_experiment_data.py b/test/database_service/test_db_experiment_data.py index c38bacb8a0..07d02ccb0e 100644 --- a/test/database_service/test_db_experiment_data.py +++ b/test/database_service/test_db_experiment_data.py @@ -270,7 +270,7 @@ def test_add_figure(self): with self.subTest(name=name): exp_data = DbExperimentData(backend=self.backend, experiment_type="qiskit_test") fn = exp_data.add_figures(figure, figure_name) - self.assertEqual(hello_bytes, exp_data.figure(fn)) + self.assertEqual(hello_bytes, exp_data.figure(fn).figure) def test_add_figure_plot(self): """Test adding a matplotlib figure.""" @@ -280,7 +280,7 @@ def test_add_figure_plot(self): service = self._set_mock_service() exp_data = DbExperimentData(backend=self.backend, experiment_type="qiskit_test") exp_data.add_figures(figure, save_figure=True) - self.assertEqual(figure, exp_data.figure(0)) + self.assertEqual(figure, exp_data.figure(0).figure) service.create_figure.assert_called_once() _, kwargs = service.create_figure.call_args self.assertIsInstance(kwargs["figure"], bytes) @@ -305,7 +305,7 @@ def test_add_figures(self): exp_data = DbExperimentData(backend=self.backend, experiment_type="qiskit_test") added_names = exp_data.add_figures(figures, figure_names) for idx, added_fn in enumerate(added_names): - self.assertEqual(hello_bytes[idx], exp_data.figure(added_fn)) + self.assertEqual(hello_bytes[idx], exp_data.figure(added_fn).figure) def test_add_figure_overwrite(self): """Test updating an existing figure.""" @@ -318,7 +318,7 @@ def test_add_figure_overwrite(self): exp_data.add_figures(friend_bytes, fn) exp_data.add_figures(friend_bytes, fn, overwrite=True) - self.assertEqual(friend_bytes, exp_data.figure(fn)) + self.assertEqual(friend_bytes, exp_data.figure(fn).figure) def test_add_figure_save(self): """Test saving a figure in the database.""" @@ -331,6 +331,45 @@ def test_add_figure_save(self): self.assertEqual(kwargs["figure"], hello_bytes) self.assertEqual(kwargs["experiment_id"], exp_data.experiment_id) + def test_add_figure_metadata(self): + hello_bytes = str.encode("hello world") + qubits = [0, 1, 2] + exp_data = DbExperimentData( + backend=self.backend, + experiment_type="qiskit_test", + metadata={"physical_qubits": qubits}, + ) + exp_data.add_figures(hello_bytes) + exp_data.figure(0).metadata["foo"] = "bar" + figure_data = exp_data.figure(0) + + self.assertEqual(figure_data.metadata["qubits"], qubits) + self.assertEqual(figure_data.metadata["foo"], "bar") + expected_name_prefix = "qiskit_test_Fig-0_Exp-" + self.assertEqual(figure_data.name[: len(expected_name_prefix)], expected_name_prefix) + + exp_data2 = DbExperimentData( + backend=self.backend, + experiment_type="qiskit_test", + metadata={"physical_qubits": [1, 2, 3, 4]}, + ) + exp_data2.add_figures(figure_data, "new_name.svg") + figure_data = exp_data2.figure("new_name.svg") + + # metadata should not change when adding to new ExperimentData + self.assertEqual(figure_data.metadata["qubits"], qubits) + self.assertEqual(figure_data.metadata["foo"], "bar") + # name should change + self.assertEqual(figure_data.name, "new_name.svg") + + # can set the metadata to new dictionary + figure_data.metadata = {"bar": "foo"} + self.assertEqual(figure_data.metadata["bar"], "foo") + + # cannot set the metadata to something other than dictionary + with self.assertRaises(ValueError): + figure_data.metadata = ["foo", "bar"] + def test_add_figure_bad_input(self): """Test adding figures with bad input.""" exp_data = DbExperimentData(backend=self.backend, experiment_type="qiskit_test") @@ -347,8 +386,8 @@ def test_get_figure(self): ) idx = randrange(3) expected_figure = str.encode(figure_template.format(idx)) - self.assertEqual(expected_figure, exp_data.figure(name_template.format(idx))) - self.assertEqual(expected_figure, exp_data.figure(idx)) + self.assertEqual(expected_figure, exp_data.figure(name_template.format(idx)).figure) + self.assertEqual(expected_figure, exp_data.figure(idx).figure) file_name = uuid.uuid4().hex self.addCleanup(os.remove, file_name)