diff --git a/README.md b/README.md index 51bef144d..010ae06c8 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,7 @@ bin_center, gamma = gs.vario_estimate((x, y), field) fit_model = gs.Stable(dim=2) fit_model.fit_variogram(bin_center, gamma, nugget=False) # output -ax = fit_model.plot(x_max=bin_center[-1]) +ax = fit_model.plot(x_max=max(bin_center)) ax.scatter(bin_center, gamma) print(fit_model) ``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 3ef28c988..763a01cea 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -220,7 +220,7 @@ model again. fit_model = gs.Stable(dim=2) fit_model.fit_variogram(bin_center, gamma, nugget=False) # output - ax = fit_model.plot(x_max=bin_center[-1]) + ax = fit_model.plot(x_max=max(bin_center)) ax.scatter(bin_center, gamma) print(fit_model) diff --git a/examples/01_random_field/07_higher_dimensions.py b/examples/01_random_field/07_higher_dimensions.py index 795a230e3..43f19912c 100755 --- a/examples/01_random_field/07_higher_dimensions.py +++ b/examples/01_random_field/07_higher_dimensions.py @@ -53,9 +53,8 @@ # In order to "prove" correctness, we can calculate an empirical variogram # of the generated field and fit our model to it. -bin_edges = range(size) bin_center, vario = gs.vario_estimate( - pos, field, bin_edges, sampling_size=2000, mesh_type="structured" + pos, field, sampling_size=2000, mesh_type="structured" ) model.fit_variogram(bin_center, vario) print(model) @@ -67,9 +66,16 @@ # Let's have a look at the fit and a x-y cross-section of the 4D field: f, a = plt.subplots(1, 2, gridspec_kw={"width_ratios": [2, 1]}, figsize=[9, 3]) -model.plot(x_max=size + 1, ax=a[0]) +model.plot(x_max=max(bin_center), ax=a[0]) a[0].scatter(bin_center, vario) a[1].imshow(field[:, :, 0, 0].T, origin="lower") a[0].set_title("isotropic empirical variogram with fitted model") a[1].set_title("x-y cross-section") f.show() + +############################################################################### +# GSTools also provides plotting routines for higher dimensions. +# Fields are shown by 2D cross-sections, where other dimensions can be +# controlled via sliders. + +srf.plot() diff --git a/examples/02_cov_model/02_aniso_rotation.py b/examples/02_cov_model/02_aniso_rotation.py index 58d7b7365..2a8bac788 100755 --- a/examples/02_cov_model/02_aniso_rotation.py +++ b/examples/02_cov_model/02_aniso_rotation.py @@ -52,3 +52,4 @@ # - in 3D: given by yaw, pitch, and roll (known as # `Tait–Bryan `_ # angles) +# - in nD: See the random field example about higher dimensions diff --git a/examples/03_variogram/03_directional_2d.py b/examples/03_variogram/03_directional_2d.py index c74dc7c0f..4a4943d9c 100755 --- a/examples/03_variogram/03_directional_2d.py +++ b/examples/03_variogram/03_directional_2d.py @@ -16,7 +16,7 @@ angle = np.pi / 8 model = gs.Exponential(dim=2, len_scale=[10, 5], angles=angle) -x = y = range(100) +x = y = range(101) srf = gs.SRF(model, seed=123456) field = srf((x, y), mesh_type="structured") @@ -56,9 +56,7 @@ model.plot("vario_axis", axis=1, ax=ax1, x_max=40, label="fit on axis 1") ax1.set_title("Fitting an anisotropic model") -srf.plot(ax=ax2) -ax2.set_aspect("equal") - +srf.plot(ax=ax2, show_colorbar=False) plt.show() ############################################################################### diff --git a/examples/03_variogram/04_directional_3d.py b/examples/03_variogram/04_directional_3d.py index 023954f09..d7b11de3e 100755 --- a/examples/03_variogram/04_directional_3d.py +++ b/examples/03_variogram/04_directional_3d.py @@ -17,11 +17,12 @@ dim = 3 # rotation around z, y, x -angles = [np.pi / 2, np.pi / 4, np.pi / 8] +angles = [np.deg2rad(90), np.deg2rad(45), np.deg2rad(22.5)] model = gs.Gaussian(dim=3, len_scale=[16, 8, 4], angles=angles) x = y = z = range(50) +pos = (x, y, z) srf = gs.SRF(model, seed=1001) -field = srf.structured((x, y, z)) +field = srf.structured(pos) ############################################################################### # Here we generate the axes of the rotated coordinate system @@ -37,9 +38,9 @@ # with the longest correlation length (flattest gradient). # Then check the transversal directions and so on. -bins = range(0, 40, 3) bin_center, dir_vario, counts = gs.vario_estimate( - *([x, y, z], field, bins), + pos, + field, direction=main_axes, bandwidth=10, sampling_size=2000, @@ -51,48 +52,45 @@ ############################################################################### # Afterwards we can use the estimated variogram to fit a model to it. # Note, that the rotation angles need to be set beforehand. -# -# We can use the `counts` of data pairs per bin as weights for the fitting -# routines to give more attention to areas where more data was available. -# In order to not introduce to much offset at the origin, we disable -# fitting the nugget. print("Original:") print(model) -model.fit_variogram(bin_center, dir_vario, weights=counts, nugget=False) +model.fit_variogram(bin_center, dir_vario) print("Fitted:") print(model) ############################################################################### -# Plotting. - -fig = plt.figure(figsize=[15, 5]) -ax1 = fig.add_subplot(131) -ax2 = fig.add_subplot(132, projection=Axes3D.name) -ax3 = fig.add_subplot(133) - -srf.plot(ax=ax1) -ax1.set_aspect("equal") - -ax2.plot([0, axis1[0]], [0, axis1[1]], [0, axis1[2]], label="0.") -ax2.plot([0, axis2[0]], [0, axis2[1]], [0, axis2[2]], label="1.") -ax2.plot([0, axis3[0]], [0, axis3[1]], [0, axis3[2]], label="2.") -ax2.set_xlim(-1, 1) -ax2.set_ylim(-1, 1) -ax2.set_zlim(-1, 1) -ax2.set_xlabel("X") -ax2.set_ylabel("Y") -ax2.set_zlabel("Z") -ax2.set_title("Tait-Bryan main axis") -ax2.legend(loc="lower left") - -ax3.scatter(bin_center, dir_vario[0], label="0. axis") -ax3.scatter(bin_center, dir_vario[1], label="1. axis") -ax3.scatter(bin_center, dir_vario[2], label="2. axis") -model.plot("vario_axis", axis=0, ax=ax3, label="fit on axis 0") -model.plot("vario_axis", axis=1, ax=ax3, label="fit on axis 1") -model.plot("vario_axis", axis=2, ax=ax3, label="fit on axis 2") -ax3.set_title("Fitting an anisotropic model") -ax3.legend() +# Plotting main axes and the fitted directional variogram. + +fig = plt.figure(figsize=[10, 5]) +ax1 = fig.add_subplot(121, projection=Axes3D.name) +ax2 = fig.add_subplot(122) + +ax1.plot([0, axis1[0]], [0, axis1[1]], [0, axis1[2]], label="0.") +ax1.plot([0, axis2[0]], [0, axis2[1]], [0, axis2[2]], label="1.") +ax1.plot([0, axis3[0]], [0, axis3[1]], [0, axis3[2]], label="2.") +ax1.set_xlim(-1, 1) +ax1.set_ylim(-1, 1) +ax1.set_zlim(-1, 1) +ax1.set_xlabel("X") +ax1.set_ylabel("Y") +ax1.set_zlabel("Z") +ax1.set_title("Tait-Bryan main axis") +ax1.legend(loc="lower left") + +x_max = max(bin_center) +ax2.scatter(bin_center, dir_vario[0], label="0. axis") +ax2.scatter(bin_center, dir_vario[1], label="1. axis") +ax2.scatter(bin_center, dir_vario[2], label="2. axis") +model.plot("vario_axis", axis=0, ax=ax2, x_max=x_max, label="fit on axis 0") +model.plot("vario_axis", axis=1, ax=ax2, x_max=x_max, label="fit on axis 1") +model.plot("vario_axis", axis=2, ax=ax2, x_max=x_max, label="fit on axis 2") +ax2.set_title("Fitting an anisotropic model") +ax2.legend() plt.show() + +############################################################################### +# Also, let's have a look at the field. + +srf.plot() diff --git a/examples/03_variogram/05_auto_fit_variogram.py b/examples/03_variogram/05_auto_fit_variogram.py index 9a7963147..53b113d6a 100644 --- a/examples/03_variogram/05_auto_fit_variogram.py +++ b/examples/03_variogram/05_auto_fit_variogram.py @@ -19,7 +19,7 @@ bin_center, gamma = gs.vario_estimate((x, y), field) print("estimated bin number:", len(bin_center)) -print("maximal bin distance:", bin_center[-1]) +print("maximal bin distance:", max(bin_center)) ############################################################################### # Fit the variogram with a stable model (no nugget fitted). @@ -31,5 +31,5 @@ ############################################################################### # Plot the fitting result. -ax = fit_model.plot(x_max=bin_center[-1]) +ax = fit_model.plot(x_max=max(bin_center)) ax.scatter(bin_center, gamma) diff --git a/gstools/field/base.py b/gstools/field/base.py index 30224a794..5e6519ae8 100755 --- a/gstools/field/base.py +++ b/gstools/field/base.py @@ -268,7 +268,9 @@ def vtk_export( fieldname=fieldname, ) - def plot(self, field="field", fig=None, ax=None): # pragma: no cover + def plot( + self, field="field", fig=None, ax=None, **kwargs + ): # pragma: no cover """ Plot the spatial random field. @@ -283,6 +285,8 @@ def plot(self, field="field", fig=None, ax=None): # pragma: no cover ax : :class:`Axes` or :any:`None` Axes to plot on. If `None`, a new one will be added to the figure. Default: `None` + **kwargs + Forwarded to the plotting routine. """ # just import if needed; matplotlib is not required by setup from gstools.field.plot import plot_field, plot_vec_field @@ -294,11 +298,11 @@ def plot(self, field="field", fig=None, ax=None): # pragma: no cover ) elif self.value_type == "scalar": - r = plot_field(self, field, fig, ax) + r = plot_field(self, field, fig, ax, **kwargs) elif self.value_type == "vector": if self.model.dim == 2: - r = plot_vec_field(self, field, fig, ax) + r = plot_vec_field(self, field, fig, ax, **kwargs) else: raise NotImplementedError( "Streamflow plotting only supported for 2d case." diff --git a/gstools/field/plot.py b/gstools/field/plot.py index 572d88a80..b0474f6de 100644 --- a/gstools/field/plot.py +++ b/gstools/field/plot.py @@ -13,9 +13,12 @@ # pylint: disable=C0103 import numpy as np from scipy import interpolate as inter +from scipy.spatial import ConvexHull import matplotlib.pyplot as plt from matplotlib.widgets import Slider, RadioButtons from gstools.covmodel.plot import _get_fig_ax +from gstools.tools.geometric import rotation_planes + __all__ = ["plot_field", "plot_vec_field"] @@ -23,7 +26,9 @@ # plotting routines ####################################################### -def plot_field(fld, field="field", fig=None, ax=None): # pragma: no cover +def plot_field( + fld, field="field", fig=None, ax=None, **kwargs +): # pragma: no cover """ Plot a spatial field. @@ -39,186 +44,236 @@ def plot_field(fld, field="field", fig=None, ax=None): # pragma: no cover ax : :class:`Axes` or :any:`None`, optional Axes to plot on. If `None`, a new one will be added to the figure. Default: `None` + **kwargs + Forwarded to the plotting routine. """ plot_field = getattr(fld, field) assert not (fld.pos is None or plot_field is None) if fld.dim == 1: - ax = _plot_1d(fld.pos, plot_field, fig, ax) - elif fld.dim == 2: - ax = _plot_2d( - fld.pos, plot_field, fld.mesh_type, fig, ax, fld.model.latlon - ) - elif fld.dim == 3: - ax = _plot_3d(fld.pos, plot_field, fld.mesh_type, fig, ax) - else: - raise ValueError("Field.plot: only possible for dim=1,2,3!") - return ax + return plot_1d(fld.pos, plot_field, fig, ax, **kwargs) + return plot_nd( + fld.pos, plot_field, fld.mesh_type, fig, ax, fld.model.latlon, **kwargs + ) -def _plot_1d(pos, field, fig=None, ax=None): # pragma: no cover - """Plot a 1d field.""" +def plot_1d(pos, field, fig=None, ax=None, ax_names=None): # pragma: no cover + """ + Plot a 1D field. + + Parameters + ---------- + pos : :class:`list` + the position tuple, containing either the point coordinates (x, y, ...) + or the axes descriptions (for mesh_type='structured') + field : :class:`numpy.ndarray` + Field values. + fig : :class:`Figure` or :any:`None`, optional + Figure to plot the axes on. If `None`, a new one will be created. + Default: `None` + ax : :class:`Axes` or :any:`None`, optional + Axes to plot on. If `None`, a new one will be added to the figure. + Default: `None` + ax_names : :class:`list` of :class:`str`, optional + Axes names. The default is ["$x$", "field"]. + + Returns + ------- + ax : :class:`Axes` + Axis containing the plot. + """ fig, ax = _get_fig_ax(fig, ax) title = "Field 1D: " + str(field.shape) x = pos[0] x = x.flatten() arg = np.argsort(x) + ax_names = _ax_names(1, ax_names=ax_names) ax.plot(x[arg], field.ravel()[arg]) - ax.set_xlabel("X") - ax.set_ylabel("field") + ax.set_xlabel(ax_names[0]) + ax.set_ylabel(ax_names[1]) ax.set_title(title) fig.show() return ax -def _plot_2d( - pos, field, mesh_type, fig=None, ax=None, latlon=False +def plot_nd( + pos, + field, + mesh_type, + fig=None, + ax=None, + latlon=False, + resolution=128, + ax_names=None, + aspect="quad", + show_colorbar=True, + convex_hull=False, + contour_plot=True, + **kwargs ): # pragma: no cover - """Plot a 2d field.""" - fig, ax = _get_fig_ax(fig, ax) - title = "Field 2D " + mesh_type + ": " + str(field.shape) - y = pos[0] if latlon else pos[1] - x = pos[1] if latlon else pos[0] - if mesh_type == "unstructured": - cont = ax.tricontourf(x, y, field.ravel(), levels=256) - else: - plot_field = field if latlon else field.T - try: - cont = ax.contourf(x, y, plot_field, levels=256) - except TypeError: - cont = ax.contourf(x, y, plot_field, 256) - if latlon: - ax.set_ylabel("Lat in deg") - ax.set_xlabel("Lon in deg") - else: - ax.set_xlabel("X") - ax.set_ylabel("Y") - ax.set_title(title) - fig.colorbar(cont) - fig.show() - return ax + """ + Plot field in arbitrary dimensions. + Parameters + ---------- + pos : :class:`list` + the position tuple, containing either the point coordinates (x, y, ...) + or the axes descriptions (for mesh_type='structured') + field : :class:`numpy.ndarray` + Field values. + fig : :class:`Figure` or :any:`None`, optional + Figure to plot the axes on. If `None`, a new one will be created. + Default: `None` + ax : :class:`Axes` or :any:`None`, optional + Axes to plot on. If `None`, a new one will be added to the figure. + Default: `None` + latlon : :class:`bool`, optional + Whether the data is representing 2D fields on earths surface described + by latitude and longitude. When using this, the estimator will + use great-circle distance for variogram estimation. + Note, that only an isotropic variogram can be estimated and a + ValueError will be raised, if a direction was specified. + Bin edges need to be given in radians in this case. + Default: False + resolution : :class:`int`, optional + Resolution of the imshow plot. The default is 128. + ax_names : :class:`list` of :class:`str`, optional + Axes names. The default is ["$x$", "field"]. + aspect : :class:`str` or :any:`None` or :class:`float`, optional + Aspect of the plot. Can be "auto", "equal", "quad", None or a number + describing the aspect ratio. + The default is "quad". + show_colorbar : :class:`bool`, optional + Whether to show the colorbar. The default is True. + convex_hull : :class:`bool`, optional + Whether to show the convex hull in 2D with unstructured data. + The default is False. + contour_plot : :class:`bool`, optional + Whether to use a contour-plot in 2D. The default is True. -def _plot_3d(pos, field, mesh_type, fig=None, ax=None): # pragma: no cover - """Plot 3D field.""" - dir1, dir2 = np.mgrid[0:1:51j, 0:1:51j] - levels = np.linspace(field.min(), field.max(), 100, endpoint=True) - - x_min = pos[0].min() - x_max = pos[0].max() - y_min = pos[1].min() - y_max = pos[1].max() - z_min = pos[2].min() - z_max = pos[2].max() - x_range = x_max - x_min - y_range = y_max - y_min - z_range = z_max - z_min - x_step = x_range / 50.0 - y_step = y_range / 50.0 - z_step = z_range / 50.0 - ax_info = { - "x": [x_min, x_max, x_range, x_step], - "y": [y_min, y_max, y_range, y_step], - "z": [z_min, z_max, z_range, z_step], - } + Returns + ------- + ax : :class:`Axes` + Axis containing the plot. + """ + dim = len(pos) + assert dim > 1 + assert not latlon or dim == 2 + if dim == 2 and contour_plot: + return _plot_2d( + pos, field, mesh_type, fig, ax, latlon, ax_names, **kwargs + ) + pos = pos[::-1] if latlon else pos + field = field.T if (latlon and mesh_type != "unstructured") else field + ax_names = _ax_names(dim, latlon, ax_names) + # init planes + planes = rotation_planes(dim) + plane_names = [ + " {} - {}".format(ax_names[p[0]], ax_names[p[1]]) for p in planes + ] + ax_ends = [[p.min(), p.max()] for p in pos] + ax_rngs = [end[1] - end[0] for end in ax_ends] + ax_steps = [rng / resolution for rng in ax_rngs] + ax_extents = [ax_ends[p[0]] + ax_ends[p[1]] for p in planes] + # create figure + reformat = fig is None and ax is None fig, ax = _get_fig_ax(fig, ax) - title = "Field 3D " + mesh_type + ": " + str(field.shape) - fig.subplots_adjust(left=0.2, right=0.8, bottom=0.25) - sax = plt.axes([0.15, 0.1, 0.65, 0.03]) - z_height = Slider( - sax, - "z value", - z_min, - z_max, - valinit=z_min + z_range / 2.0, - valstep=z_step, - ) - rax = plt.axes([0.05, 0.7, 0.1, 0.15]) - radio = RadioButtons(rax, ("x slice", "y slice", "z slice"), active=2) - z_dir_tmp = "z" - # create container - container_class = type( - "info", (object,), {"z_height": z_height, "z_dir_tmp": z_dir_tmp} + ax.set_title("Field {}D {} {}".format(dim, mesh_type, field.shape)) + if reformat: # only format fig if it was created here + fig.set_size_inches(8, 5.5 + 0.5 * (dim - 2)) + # init additional axis, radio-buttons and sliders + s_frac = 0.5 * (dim - 2) / (6 + 0.5 * (dim - 2)) + s_size = s_frac / max(dim - 2, 1) + left, bottom = (0.25, s_frac + 0.13) if dim > 2 else (None, None) + fig.subplots_adjust(left=left, bottom=bottom) + slider = [] + for i in range(dim - 2, 0, -1): + slider_ax = fig.add_axes([0.3, i * s_size, 0.435, s_size * 0.6]) + slider.append(Slider(slider_ax, "", 0, 1, facecolor="grey")) + slider[-1].vline.set_color("k") + # create radio buttons + if dim > 2: + rax = fig.add_axes( + [0.05, 0.85 - 2 * s_frac, 0.15, 2 * s_frac], frame_on=0, alpha=0 + ) + rax.set_title(" Plane", loc="left") + radio = RadioButtons(rax, plane_names, activecolor="grey") + # make radio buttons circular + rpos = rax.get_position().get_points() + fh, fw = fig.get_figheight(), fig.get_figwidth() + rscale = (rpos[:, 1].ptp() / rpos[:, 0].ptp()) * (fh / fw) + for circ in radio.circles: + circ.set_radius(0.06) + circ.height /= rscale + elif mesh_type == "unstructured" and convex_hull: + # show convex hull in 2D + hull = ConvexHull(pos.T) + for simplex in hull.simplices: + ax.plot(pos[0, simplex], pos[1, simplex], "k") + # init imshow and colorbar axis + grid = np.mgrid[0 : 1 : resolution * 1j, 0 : 1 : resolution * 1j] + f_ini, vmin, vmax = np.full_like(grid[0], np.nan), field.min(), field.max() + im = ax.imshow( + f_ini.T, interpolation="bicubic", origin="lower", vmin=vmin, vmax=vmax ) - container = container_class() - - def get_plane(z_val_in, z_dir): - """Get the plane.""" - if z_dir == "z": - x_io = dir1 * x_range + x_min - y_io = dir2 * y_range + y_min - z_io = np.full_like(x_io, z_val_in) - elif z_dir == "y": - x_io = dir1 * x_range + x_min - z_io = dir2 * z_range + z_min - y_io = np.full_like(x_io, z_val_in) - else: - y_io = dir1 * y_range + y_min - z_io = dir2 * z_range + z_min - x_io = np.full_like(y_io, z_val_in) + # actions + def inter_plane(cuts, axes): + """Interpolate plane.""" + plane_ax = [] + for i, (rng, end, cut) in enumerate(zip(ax_rngs, ax_ends, cuts)): + if i in axes: + plane_ax.append(grid[axes.index(i)] * rng + end[0]) + else: + plane_ax.append(np.full_like(grid[0], cut, dtype=float)) + # needs to be a tuple + plane_ax = tuple(plane_ax) if mesh_type != "unstructured": - # contourf plots image like for griddata, therefore transpose - plane = inter.interpn( - pos, field, np.array((x_io, y_io, z_io)).T, bounds_error=False - ).T - else: - plane = inter.griddata( - pos, field, (x_io, y_io, z_io), method="linear" - ) - if z_dir == "x": - return y_io, z_io, plane - elif z_dir == "y": - return x_io, z_io, plane - return x_io, y_io, plane - - def update(__): - """Widget update.""" - z_dir_in = radio.value_selected[0] - if z_dir_in != container.z_dir_tmp: - sax.clear() - container.z_height = Slider( - sax, - z_dir_in + " value", - ax_info[z_dir_in][0], - ax_info[z_dir_in][1], - valinit=ax_info[z_dir_in][0] + ax_info[z_dir_in][2] / 2.0, - valstep=ax_info[z_dir_in][3], - ) - container.z_height.on_changed(update) - container.z_dir_tmp = z_dir_in - z_val = container.z_height.val - ax.clear() - xx, yy, zz = get_plane(z_val, z_dir_in) - cont = ax.contourf( - xx, - yy, - zz, - vmin=field.min(), - vmax=field.max(), - levels=levels, - ) - # cont.cmap.set_under("k", alpha=0.0) - # cont.cmap.set_bad("k", alpha=0.0) - if z_dir_in == "x": - ax.set_xlabel("Y") - ax.set_ylabel("Z") - elif z_dir_in == "y": - ax.set_xlabel("X") - ax.set_ylabel("Z") - else: - ax.set_xlabel("X") - ax.set_ylabel("Y") - ax.set_xlim([x_min, x_max]) - ax.set_ylim([y_min, y_max]) - ax.set_title(title) + return inter.interpn(pos, field, plane_ax, bounds_error=False) + return inter.griddata(pos.T, field, plane_ax, method="nearest") + + def update_field(*args): + """Sliders update.""" + p = plane_names.index(radio.value_selected) if dim > 2 else 0 + # dummy cut values for selected plane-axes (setting to 0) + cuts = [s.val for s in slider] + cuts.insert(planes[p][0], 0) + cuts.insert(planes[p][1], 0) + im.set_array(inter_plane(cuts, planes[p]).T) fig.canvas.draw_idle() - return cont - container.z_height.on_changed(update) - radio.on_clicked(update) - cont = update(0) - cax = plt.axes([0.85, 0.2, 0.03, 0.6]) - fig.colorbar(cont, cax=cax, ax=ax) + def update_plane(label): + """Radio button update.""" + p = plane_names.index(label) + cut_select = [i for i in range(dim) if i not in planes[p]] + # reset sliders + for i, s in zip(cut_select, slider): + s.label.set_text(ax_names[i]) + s.valmin, s.valmax = ax_ends[i] + s.valinit = ax_ends[i][0] + ax_rngs[i] / 2.0 + s.valstep = ax_steps[i] + s.ax.set_xlim(*ax_ends[i]) + # update representation + s.poly.xy[:2] = (s.valmin, 0), (s.valmin, 1) + s.vline.set_data(2 * [s.valinit], [-0.1, 1.1]) + s.reset() + im.set_extent(ax_extents[p]) + if aspect == "quad": + asp = ax_rngs[planes[p][0]] / ax_rngs[planes[p][1]] + if aspect is not None: + ax.set_aspect(asp if aspect == "quad" else aspect) + ax.set_xlabel(ax_names[planes[p][0]]) + ax.set_ylabel(ax_names[planes[p][1]]) + update_field() + + # initial plot on xy plane + update_plane(plane_names[0]) + # bind actions + if dim > 2: + radio.on_clicked(update_plane) + for s in slider: + s.on_changed(update_field) + if show_colorbar: + fig.colorbar(im, ax=ax) fig.show() return ax @@ -242,8 +297,8 @@ def plot_vec_field(fld, field="field", fig=None, ax=None): # pragma: no cover """ if fld.mesh_type == "unstructured": raise RuntimeError( - "Only structured vector fields are supported" - + " for plotting. Please create one on a structured grid." + "Only structured vector fields are supported " + "for plotting. Please create one on a structured grid." ) plot_field = getattr(fld, field) assert not (fld.pos is None or plot_field is None) @@ -269,3 +324,47 @@ def plot_vec_field(fld, field="field", fig=None, ax=None): # pragma: no cover fig.colorbar(sp.lines) fig.show() return ax + + +def _ax_names(dim, latlon=False, ax_names=None): + if ax_names is not None: + assert len(ax_names) >= dim + return ax_names[:dim] + if dim == 2 and latlon: + return ["lon", "lat"] + if dim <= 3: + return ["$x$", "$y$", "$z$"][:dim] + (dim == 1) * ["field"] + return ["$x_{" + str(i) + "}$" for i in range(dim)] + + +def _plot_2d( + pos, + field, + mesh_type, + fig=None, + ax=None, + latlon=False, + ax_names=None, + levels=64, + antialias=True, +): # pragma: no cover + """Plot a 2d field with a contour plot.""" + fig, ax = _get_fig_ax(fig, ax) + title = "Field 2D " + mesh_type + ": " + str(field.shape) + ax_names = _ax_names(2, latlon, ax_names=ax_names) + x, y = pos[::-1] if latlon else pos + if mesh_type == "unstructured": + cont = ax.tricontourf(x, y, field.ravel(), levels=levels) + if antialias: + ax.tricontour(x, y, field.ravel(), levels=levels, zorder=-10) + else: + plot_field = field if latlon else field.T + cont = ax.contourf(x, y, plot_field, levels=levels) + if antialias: + ax.contour(x, y, plot_field, levels=levels, zorder=-10) + ax.set_xlabel(ax_names[0]) + ax.set_ylabel(ax_names[1]) + ax.set_title(title) + fig.colorbar(cont) + fig.show() + return ax