Skip to content

Commit

Permalink
added contour plots
Browse files Browse the repository at this point in the history
  • Loading branch information
mmaelicke committed Feb 8, 2021
1 parent cf57aa7 commit 3a3f92f
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 60 deletions.
69 changes: 9 additions & 60 deletions skgstat/SpaceTimeVariogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
"""
import numpy as np
from scipy.spatial.distance import pdist
from scipy.ndimage.interpolation import zoom
from scipy.interpolate import griddata
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import inspect

from skgstat import binning, estimators, Variogram, stmodels, plotting
Expand Down Expand Up @@ -1403,65 +1400,17 @@ def contourf(self, ax=None, zoom_factor=100., levels=10,
return self._plot2d(kind='contourf', ax=ax, zoom_factor=zoom_factor,
levels=levels, cmap=cmap, method=method, **kwargs)

def _plot2d(self, kind='contour', ax=None, zoom_factor=100.,
levels=10, method="fast", **kwargs):
# get or create the figure
if ax is not None:
fig = ax.get_figure()
else:
fig, ax = plt.subplots(1, 1, figsize=kwargs.get('figsize', (8, 8)))

# prepare the meshgrid
xx, yy = self.meshbins
z = self.experimental
x = xx.flatten()
y = yy.flatten()

xxi = zoom(xx, zoom_factor, order=1)
yyi = zoom(yy, zoom_factor, order=1)

# interpolation, either fast or precise
if method.lower() == "fast":
zi = zoom(z.reshape((self.t_lags, self.x_lags)), zoom_factor,
order=1, prefilter=False)
elif method.lower() == "precise":
# zoom the meshgrid by linear interpolation

# interpolate the semivariance
zi = griddata((x, y), z, (xxi, yyi), method='linear')
else:
raise ValueError("method has to be one of ['fast', 'precise']")

# get the bounds
zmin = np.nanmin(zi)
zmax = np.nanmax(zi)

# get the plotting parameters
lev = np.linspace(0, zmax, levels)
c = kwargs.get('color') if 'color' in kwargs else kwargs.get('c', 'k')
cmap = kwargs.get('cmap', 'RdYlBu_r')

# plot
if kind.lower() == 'contour':
ax.contour(xxi, yyi, zi, colors=c, levels=lev, vmin=zmin * 1.1,
vmax=zmax * 0.9, linewidths=kwargs.get('linewidths', 0.3)
)
elif kind.lower() == 'contourf':
C = ax.contourf(xxi, yyi, zi, cmap=cmap, levels=lev, vmin=zmin *
1.1,
vmax=zmax * 0.9)
if kwargs.get('colorbar', True):
plt.colorbar(C, ax=ax)
else:
raise ValueError("%s is not a valid 2D plot" % kind)
def _plot2d(self, kind='contour', ax=None, zoom_factor=100., levels=10, method="fast", **kwargs):
# get the backend
used_backend = plotting.backend()

# some labels
ax.set_xlabel(kwargs.get('xlabel', 'space'))
ax.set_ylabel(kwargs.get('ylabel', 'time'))
ax.set_xlim(kwargs.get('xlim', (0, self.xbins[-1])))
ax.set_ylim(kwargs.get('ylim', (0, self.tbins[-1])))
if used_backend == 'matplotlib':
return plotting.matplotlib_plot_2d(self, kind=kind, ax=ax, zoom_factor=zoom_factor, level=10, method=method, **kwargs)
elif used_backend == 'plotly':
return plotting.plotly_plot_2d(self, kind=kind, fig=ax, **kwargs)

return fig
# if we reach this line, somethings wrong with plotting backend
raise ValueError('The plotting backend has an undefined state.')

def marginals(self, plot=True, axes=None, sharey=True, include_model=False,
**kwargs):
Expand Down
1 change: 1 addition & 0 deletions skgstat/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .variogram_dd_plot import matplotlib_dd_plot, plotly_dd_plot
from .directtional_variogram import matplotlib_pair_field, plotly_pair_field
from .stvariogram_plot3d import matplotlib_plot_3d, plotly_plot_3d
from .stvariogram_plot2d import matplotlib_plot_2d, plotly_plot_2d


ALLOWED_BACKENDS = [
Expand Down
115 changes: 115 additions & 0 deletions skgstat/plotting/stvariogram_plot2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage.interpolation import zoom
from scipy.interpolate import griddata

try:
import plotly.graph_objects as go
except ImportError:
pass


def matplotlib_plot_2d(stvariogram, kind='contour', ax=None, zoom_factor=100., levels=10, method='fast', **kwargs):
# get or create the figure
if ax is not None:
fig = ax.get_figure()
else:
fig, ax = plt.subplots(1, 1, figsize=kwargs.get('figsize', (8, 8)))

# prepare the meshgrid
xx, yy = stvariogram.meshbins
z = stvariogram.experimental
x = xx.flatten()
y = yy.flatten()

xxi = zoom(xx, zoom_factor, order=1)
yyi = zoom(yy, zoom_factor, order=1)

# interpolation, either fast or precise
if method.lower() == "fast":
zi = zoom(z.reshape((stvariogram.t_lags, stvariogram.x_lags)), zoom_factor, order=1, prefilter=False)
elif method.lower() == "precise":
# zoom the meshgrid by linear interpolation
# interpolate the semivariance
zi = griddata((x, y), z, (xxi, yyi), method='linear')
else:
raise ValueError("method has to be one of ['fast', 'precise']")

# get the bounds
zmin = np.nanmin(zi)
zmax = np.nanmax(zi)

# get the plotting parameters
lev = np.linspace(0, zmax, levels)
c = kwargs.get('color', kwargs.get('c', 'k'))
cmap = kwargs.get('cmap', 'RdYlBu_r')

# plot
if kind.lower() == 'contour':
ax.contour(xxi, yyi, zi, colors=c, levels=lev, vmin=zmin * 1.1, vmax=zmax * 0.9, linewidths=kwargs.get('linewidths', 0.3))
elif kind.lower() == 'contourf':
C = ax.contourf(xxi, yyi, zi, cmap=cmap, levels=lev, vmin=zmin *1.1, vmax=zmax * 0.9)
if kwargs.get('colorbar', True):
plt.colorbar(C, ax=ax)
else:
raise ValueError("%s is not a valid 2D plot" % kind)

# some labels
ax.set_xlabel(kwargs.get('xlabel', 'space'))
ax.set_ylabel(kwargs.get('ylabel', 'time'))
ax.set_xlim(kwargs.get('xlim', (0, stvariogram.xbins[-1])))
ax.set_ylim(kwargs.get('ylim', (0, stvariogram.tbins[-1])))

return fig


def plotly_plot_2d(stvariogram, kind='contour', fig=None, **kwargs):
# get base data
x = stvariogram.xbins
y = stvariogram.tbins
z = stvariogram.experimental.reshape((len(x), len(y)))

# get settings
showlabels = kwargs.get('showlabels', True)
colorscale = kwargs.get('colorscale', 'Earth_r')
smooth = kwargs.get('line_smoothing', 0.0)
coloring = kwargs.get('coloring', 'heatmap')
if kind == 'contourf':
coloring = 'lines'
lw = kwargs.get('line_width', kwargs.get('lw', 2))
label_color = kwargs.get('label_color', 'black')
else:
label_color = kwargs.get('label_color', 'white')
lw = kwargs.get('line_width', kwargs.get('lw', .3))

# get the figure
if fig is None:
fig = go.Figure()

# do the plot
fig.add_trace(
go.Contour(
x=x,
y=y,
z=z,
line_smoothing=smooth,
colorscale=colorscale,
contours=dict(
coloring=coloring,
showlabels=showlabels,
labelfont=dict(
color=label_color,
size=kwargs.get('label_size', 14)
)
),
line_width=lw
)
)

# update the labels
fig.update_layout(scene=dict(
xaxis_title=kwargs.get('xlabel', 'space'),
yaxis_title=kwargs.get('ylabel', 'time')
))

return fig

0 comments on commit 3a3f92f

Please sign in to comment.