Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot grid modif #1311

Merged
merged 4 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 58 additions & 48 deletions pyart/graph/gridmapdisplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
except ImportError:
_LAMBERT_GRIDLINES = False

class GridMapDisplay(object):

class GridMapDisplay():
"""
A class for creating plots from a grid object using xarray
with a cartopy projection.
Expand All @@ -71,31 +72,31 @@ def __init__(self, grid, debug=False):
if not _CARTOPY_AVAILABLE:
raise MissingOptionalDependency(
'Cartopy is required to use GridMapDisplay but is not '
+ 'installed!')
'installed!')
if not _XARRAY_AVAILABLE:
raise MissingOptionalDependency(
'Xarray is required to use GridMapDisplay but is not '
+ 'installed!')
'installed!')
if not _NETCDF4_AVAILABLE:
raise MissingOptionalDependency(
'netCDF4 is required to use GridMapDisplay but is not '
+ 'installed!')
'installed!')

# set attributes
self.grid = grid
self.debug = debug
self.mappables = []
self.fields = []
self.origin = 'origin'

def plot_grid(self, field, level=0, vmin=None, vmax=None,
norm=None, cmap=None, mask_outside=False,
title=None, title_flag=True, axislabels=(None, None),
axislabels_flag=False, colorbar_flag=True,
colorbar_label=None, colorbar_orient='vertical',
ax=None, fig=None, lat_lines=None,
lon_lines=None, projection=None,
embellish=True, ticks=None, ticklabs=None,
embellish=True, add_grid_lines=True, ticks=None, ticklabs=None,
imshow=False, **kwargs):
"""
Plot the grid using xarray and cartopy.
Expand Down Expand Up @@ -158,8 +159,10 @@ def plot_grid(self, field, level=0, vmin=None, vmax=None,
Map projection supported by cartopy. Used for all subsequent calls
to the GeoAxes object generated. Defaults to PlateCarree.
embellish : bool
True by default. Set to False to supress drawinf of coastlines
True by default. Set to False to supress drawing of coastlines
etc... Use for speedup when specifying shapefiles.
add_grid_lines : bool
True by default. Set to False to supress drawing of lat/lon lines
Note that lat lon labels only work with certain projections.
ticks : array
Colorbar custom tick label locations.
Expand All @@ -176,13 +179,6 @@ def plot_grid(self, field, level=0, vmin=None, vmax=None,
vmin, vmax = common.parse_vmin_vmax(self.grid, field, vmin, vmax)
cmap = common.parse_cmap(cmap, field)

if lon_lines is None:
lon_lines = np.linspace(np.around(ds.lon.min()-.1, decimals=2),
np.around(ds.lon.max()+.1, decimals=2), 5)
if lat_lines is None:
lat_lines = np.linspace(np.around(ds.lat.min()-.1, decimals=2),
np.around(ds.lat.max()+.1, decimals=2), 5)

# mask the data where outside the limits
if mask_outside:
data = ds[field].data
Expand Down Expand Up @@ -228,7 +224,7 @@ def plot_grid(self, field, level=0, vmin=None, vmax=None,
ax = plt.axes(projection=projection)

# plot the grid using xarray
if norm is not None: # if norm is set do not override with vmin/vmax
if norm is not None: # if norm is set do not override with vmin/vmax
vmin = vmax = None

if imshow:
Expand All @@ -252,6 +248,16 @@ def plot_grid(self, field, level=0, vmin=None, vmax=None,
ax.add_feature(coastlines, linestyle='-', edgecolor='k',
linewidth=2)

if add_grid_lines:
if lon_lines is None:
lon_lines = np.linspace(
np.around(ds.lon.min()-.1, decimals=2),
np.around(ds.lon.max()+.1, decimals=2), 5)
if lat_lines is None:
lat_lines = np.linspace(
np.around(ds.lat.min()-.1, decimals=2),
np.around(ds.lat.max()+.1, decimals=2), 5)

# labeling gridlines poses some difficulties depending on the
# projection, so we need some projection-specific methods
if ax.projection in [cartopy.crs.PlateCarree(),
Expand All @@ -261,7 +267,7 @@ def plot_grid(self, field, level=0, vmin=None, vmax=None,
xlocs=lon_lines, ylocs=lat_lines)
ax.set_extent([lon_lines.min(), lon_lines.max(),
lat_lines.min(), lat_lines.max()],
crs=projection)
crs=projection)
ax.set_xticks(lon_lines, crs=projection)
ax.set_yticks(lat_lines, crs=projection)

Expand Down Expand Up @@ -296,8 +302,6 @@ def plot_grid(self, field, level=0, vmin=None, vmax=None,
field=field, ax=ax, fig=fig,
ticks=ticks, ticklabs=ticklabs)

return

def plot_crosshairs(self, lon=None, lat=None, linestyle='--', color='r',
linewidth=2, ax=None):
"""
Expand Down Expand Up @@ -439,7 +443,7 @@ def plot_latitudinal_level(self, field, y_index, vmin=None, vmax=None,
if len(z_1d) > 1:
z_1d = _interpolate_axes_edges(z_1d)
xd, yd = np.meshgrid(x_1d, z_1d)
if norm is not None: # if norm is set do not override with vmin, vmax
if norm is not None: # if norm is set do not override with vmin, vmax
vmin = vmax = None

pm = ax.pcolormesh(
Expand All @@ -464,7 +468,6 @@ def plot_latitudinal_level(self, field, y_index, vmin=None, vmax=None,
self.plot_colorbar(mappable=pm, label=colorbar_label,
orientation=colorbar_orient, field=field,
ax=ax, fig=fig, ticks=ticks, ticklabs=ticklabs)
return

def plot_longitude_slice(self, field, lon=None, lat=None, **kwargs):
"""
Expand Down Expand Up @@ -580,7 +583,7 @@ def plot_longitudinal_level(self, field, x_index, vmin=None, vmax=None,
z_1d = _interpolate_axes_edges(z_1d)
xd, yd = np.meshgrid(y_1d, z_1d)

if norm is not None: # if norm is set do not override with vmin, vmax
if norm is not None: # if norm is set do not override with vmin, vmax
vmin = vmax = None

pm = ax.pcolormesh(
Expand All @@ -605,8 +608,7 @@ def plot_longitudinal_level(self, field, x_index, vmin=None, vmax=None,
self.plot_colorbar(mappable=pm, label=colorbar_label,
orientation=colorbar_orient, field=field,
ax=ax, fig=fig, ticks=ticks, ticklabs=ticklabs)
return


def plot_cross_section(self, field, start, end,
steps=100, interp_type='linear', x_axis=None,
vmin=None, vmax=None,
Expand All @@ -618,7 +620,8 @@ def plot_cross_section(self, field, start, end,
ax=None, fig=None, ticks=None,
ticklabs=None, **kwargs):
"""
Plot a cross section through a set of given points (latitude, longitude).
Plot a cross section through a set of given points (latitude,
longitude).

This uses the MetPy cross section interpolation function.

Expand All @@ -629,14 +632,15 @@ def plot_cross_section(self, field, start, end,
field : str
Field to be plotted.
start : tuple
A latitude-longitude pair designating the start point of the cross section
(units are degrees north and degrees east).
A latitude-longitude pair designating the start point of the cross
section (units are degrees north and degrees east).
end : tuple
A latitude-longitude pair designating the end point of the cross section
(units are degrees north and degrees east).
A latitude-longitude pair designating the end point of the cross
section (units are degrees north and degrees east).
steps: int
The number of points along the geodesic between the start and the end point
(including the end points) to use in the cross section. Defaults to 100.
The number of points along the geodesic between the start and the
end point (including the end points) to use in the cross section.
Defaults to 100.
interp_type: str
The interpolation method, either ‘linear’ or ‘nearest’
(see xarray.DataArray.interp() for details). Defaults to ‘linear’.
Expand Down Expand Up @@ -701,21 +705,24 @@ def plot_cross_section(self, field, start, end,
cmap = common.parse_cmap(cmap, field)

# Convert the grid into an xarray object
ds= self.grid.to_xarray()
ds = self.grid.to_xarray()

# Extract the proj parameters
proj_params = self.grid.get_projparams()

# Convert the projection information into cartopy
radar_crs = cartopy.crs.AzimuthalEquidistant(central_longitude=proj_params["lon_0"],
central_latitude=proj_params["lat_0"])

# Now, convert that to cf-compliant coordinate information and assign it to data
radar_crs = cartopy.crs.AzimuthalEquidistant(
central_longitude=proj_params["lon_0"],
central_latitude=proj_params["lat_0"])

# Now, convert that to cf-compliant coordinate information and assign
# it to data
projection_info = radar_crs.to_cf()
ds = ds.metpy.assign_crs(projection_info)

# Calculate the cross section, which returns a dataset
ds = cross_section(ds, start, end, steps, interp_type).set_coords(('lat', 'lon'))
ds = cross_section(ds, start, end, steps, interp_type).set_coords(
('lat', 'lon'))

# Convert from meters to km for the different variables
ds["z"] = ds["z"] / 1000
Expand All @@ -724,7 +731,7 @@ def plot_cross_section(self, field, start, end,
if x_axis == 'y':
ds["y"] = ds["y"] / 1000
ds.y.attrs["units"] = 'North South distance from radar (km)'

if x_axis == 'x':
ds["x"] = ds["x"] / 1000
ds.y.attrs["units"] = 'East West distance from radar (km)'
Expand All @@ -733,8 +740,6 @@ def plot_cross_section(self, field, start, end,
plot = ds[field].plot(y='z', x=x_axis, vmin=vmin, vmax=vmax, norm=norm,
add_colorbar=False, ax=ax, cmap=cmap, **kwargs)



self.mappables.append(plot)
self.fields.append(field)

Expand All @@ -743,19 +748,20 @@ def plot_cross_section(self, field, start, end,

if title_flag:
if title is None:
ax.set_title(common.generate_cross_section_title(self.grid, field, start, end))
ax.set_title(
common.generate_cross_section_title(
self.grid, field, start, end))
else:
ax.set_title(title)

if colorbar_flag:
self.plot_colorbar(mappable=plot, label=colorbar_label,
orientation=colorbar_orient, field=field,
ax=ax, fig=fig, ticks=ticks, ticklabs=ticklabs)
return

def plot_colorbar(self, mappable=None, orientation='horizontal', label=None,
cax=None, ax=None, fig=None, field=None, ticks=None,
ticklabs=None):
def plot_colorbar(self, mappable=None, orientation='horizontal',
label=None, cax=None, ax=None, fig=None, field=None,
ticks=None, ticklabs=None):
"""
Plot a colorbar.

Expand Down Expand Up @@ -789,8 +795,7 @@ def plot_colorbar(self, mappable=None, orientation='horizontal', label=None,
if mappable is None:
if len(self.mappables) == 0:
raise ValueError('mappable must be specified.')
else:
mappable = self.mappables[-1]
mappable = self.mappables[-1]

if label is None:
if len(self.fields) == 0:
Expand All @@ -809,7 +814,6 @@ def plot_colorbar(self, mappable=None, orientation='horizontal', label=None,
if ticklabs is not None:
cb.set_ticklabels(ticklabs)
cb.set_label(label)
return

def _find_nearest_grid_indices(self, lon, lat):
""" Find the nearest x, y grid indices for a given latitude and
Expand Down Expand Up @@ -977,6 +981,7 @@ def cartopy_coastlines(self):
# These methods are a hack to allow gridlines when the projection is lambert
# https://nbviewer.jupyter.org/gist/ajdawson/dd536f786741e987ae4e


def find_side(ls, side):
"""
Given a shapely LineString which is assumed to be rectangular, return the
Expand All @@ -989,6 +994,7 @@ def find_side(ls, side):
'top': [(minx, maxy), (maxx, maxy)]}
return sgeom.LineString(points[side])


def lambert_xticks(ax, ticks):
""" Draw ticks on the bottom x-axis of a Lambert Conformal projection. """
def te(xy):
Expand All @@ -1003,6 +1009,7 @@ def lc(t, n, b):
ax.set_xticklabels([ax.xaxis.get_major_formatter()(xtick) for
xtick in xticklabels])


def lambert_yticks(ax, ticks):
""" Draw ticks on the left y-axis of a Lambert Conformal projection. """
def te(xy):
Expand All @@ -1017,8 +1024,11 @@ def lc(t, n, b):
ax.set_yticklabels([ax.yaxis.get_major_formatter()(ytick) for
ytick in yticklabels])


def _lambert_ticks(ax, ticks, tick_location, line_constructor, tick_extractor):
""" Get the tick locations and labels for a Lambert Conformal projection. """
"""
Get the tick locations and labels for a Lambert Conformal projection.
"""
outline_patch = sgeom.LineString(
ax.spines['geo'].get_path().vertices.tolist())
axis = find_side(outline_patch, tick_location)
Expand Down
Loading