Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move matplotlib imports to runtime #5485

Merged
merged 9 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions qiskit/visualization/counts_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
from .matplotlib import HAS_MATPLOTLIB
from .exceptions import VisualizationError

if HAS_MATPLOTLIB:
from matplotlib import get_backend
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator


def hamming_distance(str1, str2):
"""Calculate the Hamming distance between two bit strings
Expand Down Expand Up @@ -100,6 +95,9 @@ def plot_histogram(data, figsize=(7, 5), color=None, number_to_keep=None,
"""
if not HAS_MATPLOTLIB:
raise ImportError('Must have Matplotlib installed.')
from matplotlib import get_backend
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
if sort not in VALID_SORTS:
raise VisualizationError("Value of sort option, %s, isn't a "
"valid choice. Must be 'asc', "
Expand Down
19 changes: 11 additions & 8 deletions qiskit/visualization/gate_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@
from .matplotlib import HAS_MATPLOTLIB
from .exceptions import VisualizationError

if HAS_MATPLOTLIB:
import matplotlib
from matplotlib import get_backend
import matplotlib.pyplot as plt # pylint: disable=import-error
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
from matplotlib import ticker


class _GraphDist():
"""Transform the circles properly for non-square axes.
Expand Down Expand Up @@ -120,6 +112,9 @@ def plot_gate_map(backend, figsize=None,
if not HAS_MATPLOTLIB:
raise ImportError('Must have Matplotlib installed. To install, '
'run "pip install matplotlib".')
from matplotlib import get_backend
import matplotlib.pyplot as plt # pylint: disable=import-error
import matplotlib.patches as mpatches

if backend.configuration().simulator:
raise QiskitError('Requires a device backend, not simulator.')
Expand Down Expand Up @@ -436,6 +431,14 @@ def plot_error_map(backend, figsize=(12, 9), show_title=True):
except ImportError:
raise ImportError('Must have seaborn installed to use plot_error_map. '
'To install, run "pip install seaborn".')
if not HAS_MATPLOTLIB:
raise ImportError('Must have Matplotlib installed. To install, '
'run "pip install matplotlib".')
import matplotlib
from matplotlib import get_backend
import matplotlib.pyplot as plt # pylint: disable=import-error
import matplotlib.gridspec as gridspec
from matplotlib import ticker

color_map = sns.cubehelix_palette(reverse=True, as_cmap=True)

Expand Down
84 changes: 51 additions & 33 deletions qiskit/visualization/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@

import numpy as np

try:
from matplotlib import get_backend
from matplotlib import patches
from matplotlib import pyplot as plt

HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False

try:
from pylatexenc.latex2text import LatexNodes2Text
Expand Down Expand Up @@ -130,6 +122,10 @@ def __init__(self, qregs, cregs, ops,
if not HAS_MATPLOTLIB:
raise ImportError('The class MatplotlibDrawer needs matplotlib. '
'To install, run "pip install matplotlib".')
from matplotlib import patches
self.patches_mod = patches
from matplotlib import pyplot as plt
self.plt_mod = plt
if not HAS_PYLATEX:
raise ImportError('The class MatplotlibDrawer needs pylatexenc. '
'to install, run "pip install pylatexenc".')
Expand Down Expand Up @@ -300,7 +296,7 @@ def _get_text_width(self, text, fontsize, param=False):
return 0.0

if self._renderer:
t = plt.text(0.5, 0.5, text, fontsize=fontsize)
t = self.plt_mod.text(0.5, 0.5, text, fontsize=fontsize)
return t.get_window_extent(renderer=self._renderer).width / 60.0
else:
math_mode_match = self._mathmode_regex.search(text)
Expand Down Expand Up @@ -423,7 +419,7 @@ def _multiqubit_gate(self, xy, fc=None, ec=None, gt=None, sc=None, text='', subt

qubit_span = abs(ypos) - abs(ypos_max) + 1
height = HIG + (qubit_span - 1)
box = patches.Rectangle(
box = self.patches_mod.Rectangle(
xy=(xpos - 0.5 * wid, ypos - .5 * HIG), width=wid, height=height,
fc=fc, ec=ec, linewidth=self._lwidth15, zorder=PORDER_GATE)
self._ax.add_patch(box)
Expand Down Expand Up @@ -458,9 +454,9 @@ def _gate(self, xy, fc=None, ec=None, gt=None, sc=None, text='', subtext=''):
sub_width = self._get_text_width(subtext, sfs, param=True)
wid = max((text_width, sub_width, WID))

box = patches.Rectangle(xy=(xpos - 0.5 * wid, ypos - 0.5 * HIG),
width=wid, height=HIG, fc=fc, ec=ec,
linewidth=self._lwidth15, zorder=PORDER_GATE)
box = self.patches_mod.Rectangle(xy=(xpos - 0.5 * wid, ypos - 0.5 * HIG),
width=wid, height=HIG, fc=fc, ec=ec,
linewidth=self._lwidth15, zorder=PORDER_GATE)
self._ax.add_patch(box)

if text:
Expand Down Expand Up @@ -515,17 +511,17 @@ def _measure(self, qxy, cxy, cid, fc=None, ec=None, gt=None, sc=None):
self._gate(qxy, fc=fc, ec=ec, gt=gt, sc=sc)

# add measure symbol
arc = patches.Arc(xy=(qx, qy - 0.15 * HIG), width=WID * 0.7,
height=HIG * 0.7, theta1=0, theta2=180, fill=False,
ec=gt, linewidth=self._lwidth2, zorder=PORDER_GATE)
arc = self.patches_mod.Arc(xy=(qx, qy - 0.15 * HIG), width=WID * 0.7,
height=HIG * 0.7, theta1=0, theta2=180, fill=False,
ec=gt, linewidth=self._lwidth2, zorder=PORDER_GATE)
self._ax.add_patch(arc)
self._ax.plot([qx, qx + 0.35 * WID], [qy - 0.15 * HIG, qy + 0.20 * HIG],
color=gt, linewidth=self._lwidth2, zorder=PORDER_GATE)
# arrow
self._line(qxy, [cx, cy + 0.35 * WID], lc=self._style['cc'], ls=self._style['cline'])
arrowhead = patches.Polygon(((cx - 0.20 * WID, cy + 0.35 * WID),
(cx + 0.20 * WID, cy + 0.35 * WID),
(cx, cy + 0.04)), fc=self._style['cc'], ec=None)
arrowhead = self.patches_mod.Polygon(((cx - 0.20 * WID, cy + 0.35 * WID),
(cx + 0.20 * WID, cy + 0.35 * WID),
(cx, cy + 0.04)), fc=self._style['cc'], ec=None)
self._ax.add_artist(arrowhead)
# target
if self._cregbundle:
Expand All @@ -537,14 +533,15 @@ def _conditional(self, xy, istrue=False):
xpos, ypos = xy

fc = self._style['lc'] if istrue else self._style['bg']
box = patches.Circle(xy=(xpos, ypos), radius=WID * 0.15, fc=fc,
ec=self._style['lc'], linewidth=self._lwidth15, zorder=PORDER_GATE)
box = self.patches_mod.Circle(xy=(xpos, ypos), radius=WID * 0.15, fc=fc,
ec=self._style['lc'], linewidth=self._lwidth15,
zorder=PORDER_GATE)
self._ax.add_patch(box)

def _ctrl_qubit(self, xy, fc=None, ec=None, tc=None, text='', text_top=None):
xpos, ypos = xy
box = patches.Circle(xy=(xpos, ypos), radius=WID * 0.15,
fc=fc, ec=ec, linewidth=self._lwidth15, zorder=PORDER_GATE)
box = self.patches_mod.Circle(xy=(xpos, ypos), radius=WID * 0.15,
fc=fc, ec=ec, linewidth=self._lwidth15, zorder=PORDER_GATE)
self._ax.add_patch(box)
# display the control label at the top or bottom if there is one
if text_top is True:
Expand Down Expand Up @@ -583,9 +580,9 @@ def _set_ctrl_bits(self, ctrl_state, num_ctrl_qubits, qbit, ec=None, tc=None,
def _x_tgt_qubit(self, xy, ec=None, ac=None):
linewidth = self._lwidth2
xpos, ypos = xy
box = patches.Circle(xy=(xpos, ypos), radius=HIG * 0.35,
fc=ec, ec=ec, linewidth=linewidth,
zorder=PORDER_GATE)
box = self.patches_mod.Circle(xy=(xpos, ypos), radius=HIG * 0.35,
fc=ec, ec=ec, linewidth=linewidth,
zorder=PORDER_GATE)
self._ax.add_patch(box)

# add '+' symbol
Expand All @@ -611,10 +608,10 @@ def _barrier(self, config):
self._ax.plot([xpos, xpos], [ypos + 0.5, ypos - 0.5],
linewidth=self._scale, linestyle="dashed",
color=self._style['lc'], zorder=PORDER_TEXT)
box = patches.Rectangle(xy=(xpos - (0.3 * WID), ypos - 0.5),
width=0.6 * WID, height=1,
fc=self._style['bc'], ec=None, alpha=0.6,
linewidth=self._lwidth15, zorder=PORDER_GRAY)
box = self.patches_mod.Rectangle(xy=(xpos - (0.3 * WID), ypos - 0.5),
width=0.6 * WID, height=1,
fc=self._style['bc'], ec=None, alpha=0.6,
linewidth=self._lwidth15, zorder=PORDER_GRAY)
self._ax.add_patch(box)

def draw(self, filename=None, verbose=False):
Expand All @@ -636,16 +633,17 @@ def draw(self, filename=None, verbose=False):
self._figure.set_size_inches(self._style['figwidth'],
self._style['figwidth'] * fig_h / fig_w)
if self._global_phase:
plt.text(_xl, _yt, 'Global Phase: %s' % pi_check(self._global_phase,
output='mpl'))
self.plt_mod.text(_xl, _yt, 'Global Phase: %s' % pi_check(self._global_phase,
output='mpl'))

if filename:
self._figure.savefig(filename, dpi=self._style['dpi'], bbox_inches='tight',
facecolor=self._figure.get_facecolor())
if self._return_fig:
from matplotlib import get_backend
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(self._figure)
self.plt_mod.close(self._figure)
return self._figure

def _draw_regs(self):
Expand Down Expand Up @@ -1081,3 +1079,23 @@ def _draw_ops(self, verbose=False):
self._ax.text(x_coord, y_coord, str(ii + 1), ha='center',
va='center', fontsize=sfs,
color=self._style['tc'], clip_on=True, zorder=PORDER_TEXT)


class HasMatplotlibWrapper:
"""Wrapper to lazily import matplotlib."""
has_matplotlib = False

# pylint: disable=unused-import
def __bool__(self):
if not self.has_matplotlib:
try:
from matplotlib import get_backend
from matplotlib import patches
from matplotlib import pyplot as plt
self.has_matplotlib = True
except ImportError:
self.has_matplotlib = False
return self.has_matplotlib


HAS_MATPLOTLIB = HasMatplotlibWrapper()
35 changes: 25 additions & 10 deletions qiskit/visualization/pulse/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@

import numpy as np

try:
from matplotlib import pyplot as plt, gridspec
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False

from qiskit.visualization.matplotlib import HAS_MATPLOTLIB
from qiskit.visualization.pulse.qcstyle import PulseStyle, SchedStyle
from qiskit.visualization.pulse.interpolation import step_wise
from qiskit.pulse.channels import (DriveChannel, ControlChannel,
Expand Down Expand Up @@ -286,9 +281,18 @@ def draw(self, pulse: Waveform,

Returns:
matplotlib.figure.Figure: A matplotlib figure object of the pulse envelope.

Raises:
ImportError: If matplotlib is not installed
"""
# If these self.style.dpi or self.style.figsize are None, they will
# revert back to their default rcParam keys.
if not HAS_MATPLOTLIB:
raise ImportError("Matplotlib needs to be installed to use "
"WaveformDrawer. It can be installed with "
"'pip install matplotlib'")

from matplotlib import pyplot as plt
figure = plt.figure(dpi=self.style.dpi, figsize=self.style.figsize)

interp_method = interp_method or step_wise
Expand Down Expand Up @@ -339,7 +343,18 @@ def __init__(self, style: SchedStyle):

Args:
style: Style sheet for pulse schedule visualization.
Raises:
ImportError: If matplotlib is not installed
"""
if not HAS_MATPLOTLIB:
raise ImportError("Matplotlib needs to be installed to use "
"ScheduleDrawer. It can be installed with "
"'pip install matplotlib'")

from matplotlib import pyplot as plt
self.plt_mod = plt
from matplotlib import gridspec
self.gridspec_mod = gridspec
self.style = style or SchedStyle()

def _build_channels(self, schedule: ScheduleComponent,
Expand Down Expand Up @@ -509,9 +524,9 @@ def _draw_table(self, figure,
h_waves = (figure.get_size_inches()[1] - h_table)

# create subplots
gs = gridspec.GridSpec(2, 1, height_ratios=[h_table, h_waves], hspace=0)
tb = plt.subplot(gs[0])
ax = plt.subplot(gs[1])
gs = self.gridspec_mod.GridSpec(2, 1, height_ratios=[h_table, h_waves], hspace=0)
tb = self.plt_mod.subplot(gs[0])
ax = self.plt_mod.subplot(gs[1])

# configure each cell
tb.axis('off')
Expand Down Expand Up @@ -795,7 +810,7 @@ def draw(self, schedule: ScheduleComponent,
Raises:
VisualizationError: When schedule cannot be drawn
"""
figure = plt.figure(dpi=self.style.dpi, figsize=self.style.figsize)
figure = self.plt_mod.figure(dpi=self.style.dpi, figsize=self.style.figsize)

if channels is None:
channels = []
Expand Down
12 changes: 5 additions & 7 deletions qiskit/visualization/pulse_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@
from qiskit.visualization.pulse.qcstyle import PulseStyle, SchedStyle
from qiskit.visualization.exceptions import VisualizationError
from qiskit.visualization.pulse import matplotlib as _matplotlib

try:
from matplotlib import get_backend
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
from qiskit.visualization.matplotlib import HAS_MATPLOTLIB


def pulse_drawer(data: Union[Waveform, Union[Schedule, Instruction]],
Expand Down Expand Up @@ -156,6 +151,9 @@ def pulse_drawer(data: Union[Waveform, Union[Schedule, Instruction]],

if not HAS_MATPLOTLIB:
raise ImportError('Must have Matplotlib installed.')
from matplotlib import get_backend
from matplotlib import pyplot as plt

if isinstance(data, Waveform):
drawer = _matplotlib.WaveformDrawer(style=style)
image = drawer.draw(data, dt=dt, interp_method=interp_method, scale=scale)
Expand All @@ -175,7 +173,7 @@ def pulse_drawer(data: Union[Waveform, Union[Schedule, Instruction]],

if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
_matplotlib.plt.close(image)
plt.close(image)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is calling close on qiskit.visualization.pulse.matplotlib (which it was before) the same as calling it on matplotlib (which it is now)?

Copy link
Member Author

@mtreinish mtreinish Feb 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be, matplotlib.pyplot is a global state instance.

That being said this comment reminded me that I need to check the pulse v2 drawer because I originally wrote this before that merged. Let me do a quick pass on that now and verify that matplotlib doesn't get imported at import time from there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, just double checked and it looks like @nkanazawa1989 already made the matplotlib imports occur at runtime for the v2 drawer so there aren't any changes I need to make there.

if image and interactive:
image.show()
return image
Loading