diff --git a/qiskit_experiments/framework/experiment_data.py b/qiskit_experiments/framework/experiment_data.py index 20fab44cc8..3428875086 100644 --- a/qiskit_experiments/framework/experiment_data.py +++ b/qiskit_experiments/framework/experiment_data.py @@ -1311,6 +1311,12 @@ def figure( raise ExperimentEntryNotFound(f"Figure {figure_key} not found.") figure_key = self._figures.keys()[figure_key] + # All figures must have '.svg' in their names when added, as the extension is added to the key + # name in the `add_figures()` method of this class. + if isinstance(figure_key, str): + if not figure_key.endswith(".svg"): + figure_key += ".svg" + figure_data = self._figures.get(figure_key, None) if figure_data is None and self.service: figure = self.service.figure(experiment_id=self.experiment_id, figure_name=figure_key) diff --git a/releasenotes/notes/access_figure_without_extension-5b7438c19e223d6b.yaml b/releasenotes/notes/access_figure_without_extension-5b7438c19e223d6b.yaml new file mode 100644 index 0000000000..bb7d34a29d --- /dev/null +++ b/releasenotes/notes/access_figure_without_extension-5b7438c19e223d6b.yaml @@ -0,0 +1,4 @@ +--- +upgrade: + - | + Figures in `ExperimentData` objects can now be accessed without '.svg' extension. \ No newline at end of file diff --git a/test/database_service/test_db_experiment_data.py b/test/database_service/test_db_experiment_data.py index 3cbfa52928..3d83fa8512 100644 --- a/test/database_service/test_db_experiment_data.py +++ b/test/database_service/test_db_experiment_data.py @@ -409,6 +409,8 @@ def test_get_figure(self): exp_data = ExperimentData(experiment_type="qiskit_test") figure_template = "hello world {}" name_template = "figure_{}.svg" + name_template_wo_ext = "figure_{}" + for idx in range(3): exp_data.add_figures( str.encode(figure_template.format(idx)), figure_names=name_template.format(idx) @@ -418,6 +420,11 @@ def test_get_figure(self): self.assertEqual(expected_figure, exp_data.figure(name_template.format(idx)).figure) self.assertEqual(expected_figure, exp_data.figure(idx).figure) + # Check that figure will be returned without file extension in name + expected_figure = str.encode(figure_template.format(idx)) + self.assertEqual(expected_figure, exp_data.figure(name_template_wo_ext.format(idx)).figure) + self.assertEqual(expected_figure, exp_data.figure(idx).figure) + file_name = uuid.uuid4().hex self.addCleanup(os.remove, file_name) exp_data.figure(idx, file_name)