Skip to content
This repository has been archived by the owner on Jul 13, 2022. It is now read-only.

Commit

Permalink
FigureData wrapper for figures (qiskit-community#814)
Browse files Browse the repository at this point in the history
* Added FigureData wrapper for figures
* FigureData is now kept when adding the figure into a composite analysis result
* Update qiskit_experiments/database_service/db_experiment_data.py

Co-authored-by: Yael Ben-Haim <[email protected]>
Co-authored-by: Christopher J. Wood <[email protected]>
  • Loading branch information
3 people authored Jun 24, 2022
1 parent 494da6f commit e106b3b
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 14 deletions.
86 changes: 78 additions & 8 deletions qiskit_experiments/database_service/db_experiment_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time
from typing import Optional, List, Any, Union, Callable, Dict, Tuple
import copy
import io
from concurrent import futures
from threading import Event
from functools import wraps
Expand All @@ -29,12 +30,12 @@
import numpy as np

from matplotlib import pyplot
from matplotlib.figure import Figure as MatplotlibFigure
from qiskit import QiskitError
from qiskit.providers import Job, Backend, Provider
from qiskit.result import Result
from qiskit.providers.jobstatus import JobStatus, JOB_FINAL_STATES
from qiskit_experiments.framework.json import ExperimentEncoder, ExperimentDecoder

from .database_service import DatabaseServiceV1
from .exceptions import DbExperimentDataError, DbExperimentEntryNotFound, DbExperimentEntryExists
from .db_analysis_result import DbAnalysisResultV1 as DbAnalysisResult
Expand Down Expand Up @@ -130,6 +131,64 @@ def __json_encode__(self):
return self.__getstate__()


class FigureData:
"""Wrapper for figures and figure metadata"""

def __init__(self, figure, name=None, metadata=None):
"""Creates a new figure data object"""
self.figure = figure
self._name = name
self.metadata = metadata or {}

# name is read only
@property
def name(self) -> str:
"""The name of the figure"""
return self._name

@property
def metadata(self) -> dict:
"""The metadata dictionary stored with the figure"""
return self._metadata

@metadata.setter
def metadata(self, new_metadata: dict):
"""Set the metadata to new value; must be a dictionary"""
if not isinstance(new_metadata, dict):
raise ValueError("figure metadata must be a dictionary")
self._metadata = new_metadata

def copy(self, new_name: Optional[str] = None):
"""Creates a copy of the figure data"""
name = new_name or self.name
return FigureData(figure=self.figure, name=name, metadata=copy.deepcopy(self.metadata))

def __json_encode__(self) -> Dict[str, Any]:
"""Return the json representation of the figure data"""
return {"figure": self.figure, "name": self.name, "metadata": self.metadata}

@classmethod
def __json_decode__(cls, args: Dict[str, Any]) -> "FigureData":
"""Initialize a figure data from the json representation"""
return cls(**args)

def _repr_png_(self):
if isinstance(self.figure, MatplotlibFigure):
b = io.BytesIO()
self.figure.savefig(b, format="png", bbox_inches="tight")
png = b.getvalue()
return png
else:
return None

def _repr_svg_(self):
if isinstance(self.figure, str):
return self.figure
if isinstance(self.figure, bytes):
return str(self.figure)
return None


class DbExperimentData:
"""Base common type for all versioned DbExperimentData classes.
Expand Down Expand Up @@ -703,7 +762,15 @@ def add_figures(
with open(figure, "rb") as file:
figure = file.read()

self._figures[fig_name] = figure
# check whether the figure is already wrapped, meaning it came from a sub-experiment
if isinstance(figure, FigureData):
figure_data = figure.copy(new_name=fig_name)

else:
figure_metadata = {"qubits": self.metadata.get("physical_qubits")}
figure_data = FigureData(figure=figure, name=fig_name, metadata=figure_metadata)

self._figures[fig_name] = figure_data

save = save_figure if save_figure is not None else self.auto_save
if save and self._service:
Expand Down Expand Up @@ -769,8 +836,8 @@ def figure(
the content of the figure is returned instead.
Returns:
The size of the figure if `file_name` is specified. Otherwise the
content of the figure in bytes.
The size of the figure if `file_name` is specified.
Otherwise the :class:`.FigureData`.
Raises:
DbExperimentEntryNotFound: If the figure cannot be found.
Expand All @@ -780,17 +847,16 @@ def figure(

figure_data = self._figures.get(figure_key, None)
if figure_data is None and self.service:
figure_data = self.service.figure(
experiment_id=self.experiment_id, figure_name=figure_key
)
figure = self.service.figure(experiment_id=self.experiment_id, figure_name=figure_key)
figure_data = FigureData(figure=figure, name=figure_key)
self._figures[figure_key] = figure_data

if figure_data is None:
raise DbExperimentEntryNotFound(f"Figure {figure_key} not found.")

if file_name:
with open(file_name, "wb") as output:
num_bytes = output.write(figure_data)
num_bytes = output.write(figure_data.figure)
return num_bytes
return figure_data

Expand Down Expand Up @@ -1037,6 +1103,10 @@ def save(self) -> None:
for name, figure in self._figures.items():
if figure is None:
continue
# currently only the figure and its name are stored in the database
if isinstance(figure, FigureData):
figure = figure.figure
LOG.debug("Figure metadata is currently not saved to the database")
if isinstance(figure, pyplot.Figure):
figure = plot_to_svg_bytes(figure)
data = {"experiment_id": self.experiment_id, "figure": figure, "figure_name": name}
Expand Down
2 changes: 2 additions & 0 deletions qiskit_experiments/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
AnalysisConfig
ExperimentEncoder
ExperimentDecoder
FigureData
.. _composite-experiment:
Expand Down Expand Up @@ -253,6 +254,7 @@
ExperimentStatus,
JobStatus,
AnalysisStatus,
FigureData,
)
from .base_analysis import BaseAnalysis
from .base_experiment import BaseExperiment
Expand Down
9 changes: 9 additions & 0 deletions releasenotes/notes/figure_data-ecf5a82c95844b6a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
upgrade:
- |
Adds a :class:`.FigureData` class for adding metadata to analysis result figures. Figures added to
:class:`.ExperimentData` are now stored using this class. The raw image object (SVG or matplotlib.Figure)
can be accessed using the :attr:`.FigureData.figure` attribute.
Note that currently metadata is only stored locally and will be discarded when saved to the cloud
experiment service database.
51 changes: 45 additions & 6 deletions test/database_service/test_db_experiment_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_add_figure(self):
with self.subTest(name=name):
exp_data = DbExperimentData(backend=self.backend, experiment_type="qiskit_test")
fn = exp_data.add_figures(figure, figure_name)
self.assertEqual(hello_bytes, exp_data.figure(fn))
self.assertEqual(hello_bytes, exp_data.figure(fn).figure)

def test_add_figure_plot(self):
"""Test adding a matplotlib figure."""
Expand All @@ -280,7 +280,7 @@ def test_add_figure_plot(self):
service = self._set_mock_service()
exp_data = DbExperimentData(backend=self.backend, experiment_type="qiskit_test")
exp_data.add_figures(figure, save_figure=True)
self.assertEqual(figure, exp_data.figure(0))
self.assertEqual(figure, exp_data.figure(0).figure)
service.create_figure.assert_called_once()
_, kwargs = service.create_figure.call_args
self.assertIsInstance(kwargs["figure"], bytes)
Expand All @@ -305,7 +305,7 @@ def test_add_figures(self):
exp_data = DbExperimentData(backend=self.backend, experiment_type="qiskit_test")
added_names = exp_data.add_figures(figures, figure_names)
for idx, added_fn in enumerate(added_names):
self.assertEqual(hello_bytes[idx], exp_data.figure(added_fn))
self.assertEqual(hello_bytes[idx], exp_data.figure(added_fn).figure)

def test_add_figure_overwrite(self):
"""Test updating an existing figure."""
Expand All @@ -318,7 +318,7 @@ def test_add_figure_overwrite(self):
exp_data.add_figures(friend_bytes, fn)

exp_data.add_figures(friend_bytes, fn, overwrite=True)
self.assertEqual(friend_bytes, exp_data.figure(fn))
self.assertEqual(friend_bytes, exp_data.figure(fn).figure)

def test_add_figure_save(self):
"""Test saving a figure in the database."""
Expand All @@ -331,6 +331,45 @@ def test_add_figure_save(self):
self.assertEqual(kwargs["figure"], hello_bytes)
self.assertEqual(kwargs["experiment_id"], exp_data.experiment_id)

def test_add_figure_metadata(self):
hello_bytes = str.encode("hello world")
qubits = [0, 1, 2]
exp_data = DbExperimentData(
backend=self.backend,
experiment_type="qiskit_test",
metadata={"physical_qubits": qubits},
)
exp_data.add_figures(hello_bytes)
exp_data.figure(0).metadata["foo"] = "bar"
figure_data = exp_data.figure(0)

self.assertEqual(figure_data.metadata["qubits"], qubits)
self.assertEqual(figure_data.metadata["foo"], "bar")
expected_name_prefix = "qiskit_test_Fig-0_Exp-"
self.assertEqual(figure_data.name[: len(expected_name_prefix)], expected_name_prefix)

exp_data2 = DbExperimentData(
backend=self.backend,
experiment_type="qiskit_test",
metadata={"physical_qubits": [1, 2, 3, 4]},
)
exp_data2.add_figures(figure_data, "new_name.svg")
figure_data = exp_data2.figure("new_name.svg")

# metadata should not change when adding to new ExperimentData
self.assertEqual(figure_data.metadata["qubits"], qubits)
self.assertEqual(figure_data.metadata["foo"], "bar")
# name should change
self.assertEqual(figure_data.name, "new_name.svg")

# can set the metadata to new dictionary
figure_data.metadata = {"bar": "foo"}
self.assertEqual(figure_data.metadata["bar"], "foo")

# cannot set the metadata to something other than dictionary
with self.assertRaises(ValueError):
figure_data.metadata = ["foo", "bar"]

def test_add_figure_bad_input(self):
"""Test adding figures with bad input."""
exp_data = DbExperimentData(backend=self.backend, experiment_type="qiskit_test")
Expand All @@ -347,8 +386,8 @@ def test_get_figure(self):
)
idx = randrange(3)
expected_figure = str.encode(figure_template.format(idx))
self.assertEqual(expected_figure, exp_data.figure(name_template.format(idx)))
self.assertEqual(expected_figure, exp_data.figure(idx))
self.assertEqual(expected_figure, exp_data.figure(name_template.format(idx)).figure)
self.assertEqual(expected_figure, exp_data.figure(idx).figure)

file_name = uuid.uuid4().hex
self.addCleanup(os.remove, file_name)
Expand Down

0 comments on commit e106b3b

Please sign in to comment.