Skip to content

Commit

Permalink
Switched dag_drawer from pydot to retworkx (#8162)
Browse files Browse the repository at this point in the history
* switched `dag_drawer` from pydot to `retworkx`

* added optional annotation to method

* added graphviz annotations to tests

* minor fixes

* added mapping to `retworkx`

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
prakharb10 and mergify[bot] authored Jun 15, 2022
1 parent 1e872b7 commit 5f77531
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 106 deletions.
4 changes: 4 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"sphinx.ext.mathjax",
"sphinx.ext.viewcode",
"sphinx.ext.extlinks",
"sphinx.ext.intersphinx",
"jupyter_sphinx",
"sphinx_autodoc_typehints",
"reno.sphinxext",
Expand Down Expand Up @@ -59,6 +60,9 @@
# (e.g., if this is set to ['foo.'], then foo.bar is shown under B, not F).
modindex_common_prefix = ["qiskit."]

intersphinx_mapping = {
'retworkx': ('https://qiskit.org/documentation/retworkx/', None),
}

# -- Options for HTML output -------------------------------------------------

Expand Down
119 changes: 14 additions & 105 deletions qiskit/visualization/dag_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,89 +15,22 @@
"""
Visualization function for DAG circuit representation.
"""

import os
import sys
import tempfile
from retworkx.visualization import graphviz_draw

from qiskit.dagcircuit.dagnode import DAGOpNode, DAGInNode, DAGOutNode
from qiskit.circuit import Qubit
from qiskit.utils import optionals as _optionals
from qiskit.exceptions import InvalidFileError
from .exceptions import VisualizationError

FILENAME_EXTENSIONS = {
"bmp",
"canon",
"cgimage",
"cmap",
"cmapx",
"cmapx_np",
"dot",
"dot_json",
"eps",
"exr",
"fig",
"gd",
"gd2",
"gif",
"gv",
"icns",
"ico",
"imap",
"imap_np",
"ismap",
"jp2",
"jpe",
"jpeg",
"jpg",
"json",
"json0",
"mp",
"pct",
"pdf",
"pic",
"pict",
"plain",
"plain-ext",
"png",
"pov",
"ps",
"ps2",
"psd",
"sgi",
"svg",
"svgz",
"tga",
"tif",
"tiff",
"tk",
"vdx",
"vml",
"vmlz",
"vrml",
"wbmp",
"webp",
"xdot",
"xdot1.2",
"xdot1.4",
"xdot_json",
}


@_optionals.HAS_PYDOT.require_in_call
@_optionals.HAS_GRAPHVIZ.require_in_call
def dag_drawer(dag, scale=0.7, filename=None, style="color"):
"""Plot the directed acyclic graph (dag) to represent operation dependencies
in a quantum circuit.
Note this function leverages
`pydot <https://github.com/erocarrera/pydot>`_ to generate the graph, which
means that having `Graphviz <https://www.graphviz.org/>`_ installed on your
system is required for this to work.
The current release of Graphviz can be downloaded here: <https://graphviz.gitlab.io/download/>.
Download the version of the software that matches your environment and follow the instructions
to install Graph Visualization Software (Graphviz) on your operating system.
This function calls the :func:`~retworkx.visualization.graphviz_draw` function from the ``retworkx``
package to draw the DAG.
Args:
dag (DAGCircuit): The dag to draw.
Expand All @@ -112,7 +45,6 @@ def dag_drawer(dag, scale=0.7, filename=None, style="color"):
Raises:
VisualizationError: when style is not recognized.
MissingOptionalLibraryError: when pydot or pillow are not installed.
InvalidFileError: when filename provided is not valid
Example:
Expand All @@ -135,7 +67,6 @@ def dag_drawer(dag, scale=0.7, filename=None, style="color"):
dag = circuit_to_dag(circ)
dag_drawer(dag)
"""
import pydot

# NOTE: use type str checking to avoid potential cyclical import
# the two tradeoffs ere that it will not handle subclasses and it is
Expand Down Expand Up @@ -225,38 +156,16 @@ def edge_attr_func(edge):
e["label"] = label
return e

dot_str = dag._multi_graph.to_dot(node_attr_func, edge_attr_func, graph_attrs)
dot = pydot.graph_from_dot_data(dot_str)[0]

image_type = None
if filename:
if "." not in filename:
raise InvalidFileError("Parameter 'filename' must be in format 'name.extension'")
extension = filename.split(".")[-1]
if extension not in FILENAME_EXTENSIONS:
raise InvalidFileError(
"Filename extension must be one of: " + " ".join(FILENAME_EXTENSIONS)
)
dot.write(filename, format=extension)
return None
elif ("ipykernel" in sys.modules) and ("spyder" not in sys.modules):
_optionals.HAS_PIL.require_now("dag_drawer")
from PIL import Image

with tempfile.TemporaryDirectory() as tmpdirname:
tmp_path = os.path.join(tmpdirname, "dag.png")
dot.write_png(tmp_path)
with Image.open(tmp_path) as test_image:
image = test_image.copy()
os.remove(tmp_path)
return image
else:
_optionals.HAS_PIL.require_now("dag_drawer")
from PIL import Image

with tempfile.TemporaryDirectory() as tmpdirname:
tmp_path = os.path.join(tmpdirname, "dag.png")
dot.write_png(tmp_path)
image = Image.open(tmp_path)
image.show()
os.remove(tmp_path)
return None
image_type = filename.split(".")[-1]
return graphviz_draw(
dag._multi_graph,
node_attr_func,
edge_attr_func,
graph_attrs,
filename,
image_type,
)
9 changes: 8 additions & 1 deletion test/python/visualization/test_dag_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,27 @@ def setUp(self):
circuit.cx(qr[0], qr[1])
self.dag = circuit_to_dag(circuit)

@unittest.skipUnless(_optionals.HAS_GRAPHVIZ, "Graphviz not installed")
def test_dag_drawer_invalid_style(self):
"""Test dag draw with invalid style."""
self.assertRaises(VisualizationError, dag_drawer, self.dag, style="multicolor")

@unittest.skipUnless(_optionals.HAS_GRAPHVIZ, "Graphviz not installed")
def test_dag_drawer_checks_filename_correct_format(self):
"""filename must contain name and extension"""
with self.assertRaisesRegex(
InvalidFileError, "Parameter 'filename' must be in format 'name.extension'"
):
dag_drawer(self.dag, filename="aaabc")

@unittest.skipUnless(_optionals.HAS_GRAPHVIZ, "Graphviz not installed")
def test_dag_drawer_checks_filename_extension(self):
"""filename must have a valid extension"""
with self.assertRaisesRegex(InvalidFileError, "Filename extension must be one of: .*"):
with self.assertRaisesRegex(
ValueError,
"The specified value for the image_type argument, 'abc' is not a "
"valid choice. It must be one of: .*",
):
dag_drawer(self.dag, filename="aa.abc")

@unittest.skipUnless(_optionals.HAS_GRAPHVIZ, "Graphviz not installed")
Expand Down

0 comments on commit 5f77531

Please sign in to comment.