Skip to content

Commit

Permalink
Legend bug (#588)
Browse files Browse the repository at this point in the history
* remove rogue colorbar axis when hue is categorical

* make legend if specified even if no lgd_kwargs

* check if there is legend content before removing a legend

* fix bug with observationType labels and long legend formatting

* address UserWarning from empty legend items

* move psd.plot to plot_xy
  • Loading branch information
jordanplanders authored Jun 21, 2024
1 parent 33bbb5a commit 245728d
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 59 deletions.
89 changes: 64 additions & 25 deletions pyleoclim/core/psds.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,13 +700,22 @@ def plot(self, in_loglog=True, in_period=True, label=None, xlabel=None, ylabel='
ylim = [np.min(x_axis), np.max(x_axis)]
xlim = [np.min(y_axis), np.max(y_axis)]
x_axis, y_axis = y_axis, x_axis
ax.plot(x_axis,y_axis, **plot_kwargs)
# ax.plot(x_axis,y_axis, **plot_kwargs)

xticks, yticks = yticks, xticks
xlabel, ylabel = ylabel, xlabel
else:
ax.set_xlim(xlim)
ax.plot(x_axis, y_axis, **plot_kwargs)
# else:
# ax.set_xlim(xlim)
# ax.plot(x_axis, y_axis, **plot_kwargs)

ax = plotting.plot_xy(
x_axis, y_axis,
figsize=figsize, xlabel=xlabel, ylabel=ylabel,
title=title, savefig_settings=savefig_settings,
ax=ax, legend=legend, xlim=xlim, ylim=ylim,
plot_kwargs=plot_kwargs, lgd_kwargs=lgd_kwargs,
# invert_xaxis=invert_xaxis, invert_yaxis=invert_yaxis
)

#
# if transpose:
Expand Down Expand Up @@ -735,17 +744,34 @@ def plot(self, in_loglog=True, in_period=True, label=None, xlabel=None, ylabel='
idx = np.argwhere(q.frequency==0)
signif_x_axis = 1/np.delete(q.frequency, idx) if in_period else np.delete(q.frequency, idx)
signif_y_axis = np.delete(q.amplitude, idx)
legend = True
if transpose:
signif_x_axis, signif_y_axis = signif_y_axis, signif_x_axis

ax.plot(
plot_kwargs = {'label': f'{signif_method_label[self.signif_method]}, {q.label} threshold',
'color': signif_clr,
'linestyle': signif_linestyles[i%3],
'linewidth': signif_linewidth}

ax = plotting.plot_xy(
signif_x_axis, signif_y_axis,
label=f'{signif_method_label[self.signif_method]}, {q.label} threshold',
color=signif_clr,
linestyle=signif_linestyles[i%3],
linewidth=signif_linewidth,
# figsize=figsize,
# xlabel=xlabel, ylabel=ylabel,
# title=title, savefig_settings=savefig_settings,
ax=ax, legend=legend,
# xlim=xlim, ylim=ylim,
plot_kwargs=plot_kwargs, lgd_kwargs=lgd_kwargs,
# invert_xaxis=invert_xaxis, invert_yaxis=invert_yaxis
)

# ax.plot(
# signif_x_axis, signif_y_axis,
# label=f'{signif_method_label[self.signif_method]}, {q.label} threshold',
# color=signif_clr,
# linestyle=signif_linestyles[i%3],
# linewidth=signif_linewidth,
# )

if in_loglog:
ax.set_xscale('log')
ax.set_yscale('log')
Expand All @@ -760,8 +786,8 @@ def plot(self, in_loglog=True, in_period=True, label=None, xlabel=None, ylabel='
ax.yaxis.set_major_formatter(ScalarFormatter())
ax.yaxis.set_major_formatter(FormatStrFormatter('%g'))

ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
# ax.set_xlabel(xlabel)
# ax.set_ylabel(ylabel)

if plot_beta and self.beta_est_res is not None:
plot_beta_kwargs = {
Expand All @@ -775,23 +801,36 @@ def plot(self, in_loglog=True, in_period=True, label=None, xlabel=None, ylabel='
beta_y_axis = self.beta_est_res['Y_reg']
if transpose:
beta_x_axis, beta_y_axis = beta_y_axis, beta_x_axis
ax.plot(beta_x_axis, beta_y_axis , **plot_beta_kwargs)

if legend:
lgd_args = {'frameon': False}
lgd_args.update(lgd_kwargs)
ax.legend(**lgd_args)
# ax.plot(beta_x_axis, beta_y_axis , **plot_beta_kwargs)
legend = True
ax = plotting.plot_xy(
beta_x_axis, beta_y_axis,
# figsize=figsize,
# xlabel=xlabel, ylabel=ylabel,
title=title,
# savefig_settings=savefig_settings,
ax=ax, legend=legend,
# xlim=xlim, ylim=ylim,
plot_kwargs=plot_beta_kwargs, lgd_kwargs=lgd_kwargs,
# invert_xaxis=invert_xaxis, invert_yaxis=invert_yaxis
)

if title is not None:
ax.set_title(title)
# print('psd leg', legend)
# if legend:
# lgd_args = {'frameon': False}
# lgd_args.update(lgd_kwargs)
# ax.legend(**lgd_args)

if xlim is not None:
if False not in np.isfinite(xlim):
ax.set_xlim(xlim)
# if title is not None:
# ax.set_title(title)

if ylim is not None:
if False not in np.isfinite(ylim):
ax.set_ylim(ylim)
# if xlim is not None:
# if False not in np.isfinite(xlim):
# ax.set_xlim(xlim)
#
# if ylim is not None:
# if False not in np.isfinite(ylim):
# ax.set_ylim(ylim)

if 'fig' in locals():
if 'path' in savefig_settings:
Expand Down
3 changes: 3 additions & 0 deletions pyleoclim/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,9 @@ def plot(self, figsize=[10, 4],
if zorder is not None:
plot_kwargs.update({'zorder': zorder})

if self.label is None:
legend =False

res = plotting.plot_xy(
self.time, self.value,
figsize=figsize, xlabel=xlabel, ylabel=ylabel,
Expand Down
85 changes: 55 additions & 30 deletions pyleoclim/utils/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

from .plotting import savefig, make_scalar_mappable, keep_center_colormap, consolidate_legends
from .plotting import savefig, make_scalar_mappable, keep_center_colormap, consolidate_legends, tidy_labels
from .lipdutils import PLOT_DEFAULT, LipdToOntology, CaseInsensitiveDict


Expand Down Expand Up @@ -723,16 +723,17 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
fig=None, color_scale_type=None, # gs_slot=None,
cmap=None, **kwargs):

# ensure these are dictionaries if not specified in the function call
scatter_kwargs = {} if type(scatter_kwargs) != dict else scatter_kwargs
lgd_kwargs = {} if type(lgd_kwargs) != dict else lgd_kwargs
kwargs = {} if type(kwargs) != dict else kwargs

# color mapping specs
norm_kwargs = kwargs.pop('norm_kwargs', {})
ax_sm = kwargs.pop('scalar_mappable', None)

palette = None
hue_norm = None
# if (color_scale_type is not None) and (colorbar is None):
# colorbar = True

# plot_defaults = copy.copy(PLOT_DEFAULT)
f = copy.copy(PLOT_DEFAULT)
Expand All @@ -759,12 +760,6 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
ax = fig.add_subplot()

transform = ccrs.PlateCarree()
# if type(ax) == cartopy.mpl.geoaxes.GeoAxes:
# transform=ccrs.PlateCarree()
# if proj is not None:
# scatter_kwargs['transform'] = ccrs.PlateCarree()#proj
# else:
# scatter_kwargs['transform'] = ccrs.PlateCarree()

missing_d = {'hue': kwargs['missing_val_hue'] if 'missing_val_hue' in kwargs else 'k',
'marker': kwargs['missing_val_marker'] if 'missing_val_marker' in kwargs else r'$?$',
Expand All @@ -776,13 +771,14 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
if isinstance(scatter_kwargs, dict):
edgecolor = scatter_kwargs.pop('edgecolor', edgecolor)

if isinstance(scatter_kwargs, dict):
linewidth = scatter_kwargs.pop('linewidth', 1)

if isinstance(lgd_kwargs, dict):
handle_size = lgd_kwargs.pop('handle_size', 11)

if 'neighbor' in df.columns:
edgecolor_var = 'neighbor'
# if ~isinstance(edgecolor, np.ndarray):
# if isinstance(edgecolor, str):
# edgecolor = [edgecolor]
# edgecolor = np.array(edgecolor)


if isinstance(edgecolor, (list, np.ndarray)):
if len(edgecolor) == len(_df):
Expand All @@ -796,11 +792,14 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
if edgecolor_var in _df.columns:
_df['edgecolor'] = _df[edgecolor_var].map(edgecolor)

_df = _df.apply(lambda x: tidy_labels(x) if x.dtype == "str" else x)
hue_var = hue_var if hue_var in _df.columns else None
hue_var_type_numeric = False
if hue_var is not None:
hue_var_type_numeric = all(isinstance(i, (int, float)) for i in _df[_df[hue_var] != missing_val][
hue_var]) # pd.to_numeric(_df[hue_var], errors='coerce').notnull().all()
if hue_var_type_numeric is False:
colorbar = False

marker_var = marker_var if marker_var in _df.columns else None
marker_var_type_numeric = False
Expand All @@ -813,6 +812,11 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
size_var_type_numeric = all(isinstance(i, (int, float)) for i in _df[_df[size_var] != missing_val][
size_var]) # pd.to_numeric(_df[size_var], errors='coerce').notnull().all()

# if maker is None, and colorbar is True, legend should be False
if (marker_var is None):
if (colorbar is True):
legend = False

trait_vars = [trait_var for trait_var in [hue_var, marker_var, size_var] if
((trait_var != None) and (trait_var in _df.columns))]

Expand Down Expand Up @@ -887,12 +891,13 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va

# use hue mapping if supplied
if 'hue_mapping' in kwargs:
palette = kwargs['hue_mapping']
palette = kwargs.pop('hue_mapping', None)
# there should be different control for discrete and continuous hue
elif hue_var == 'archiveType':
palette = {key: value[0] for key, value in plot_defaults.items()}
elif isinstance(hue_var,str): #hue_var) == str:
hue_data = _df[_df[hue_var] != missing_val]

# If scalar mappable was passed, try to extract components.
if ax_sm is not None:
try:
Expand All @@ -904,13 +909,12 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
except:
hue_norm = None

# should not be changed to else, as the above block can set ax_sm to None
if ax_sm is None:
if cmap != None:
if cmap is not None:
palette = cmap
if len(hue_data[hue_var]) > 0:
if hue_var_type_numeric is not True:
# trait_val_types = [True if type(val) in (np.str_, str) else False for val in hue_data[hue_var]]
# if True in trait_val_types:
colorbar = False
if len(hue_data[hue_var].unique()) < 20:
palette = 'tab20'
Expand All @@ -935,6 +939,7 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
if ((type(hue_var) == str) and (type(palette) == dict)):
residual_traits = [trait for trait in _df[hue_var].unique() if
trait not in palette.keys()]

if len(residual_traits) > 0:
print(residual_traits)

Expand All @@ -953,7 +958,7 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
if 'neighbor' in hue_data.columns:
sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=hue_data.edgecolor.values, linewidth=2,
edgecolor=hue_data.edgecolor.values, linewidth=linewidth,
style=marker_var, hue=hue_var, palette=[missing_d['hue'] for ik in range(len(hue_data))],
ax=ax, legend=False,
**scatter_kwargs)
Expand Down Expand Up @@ -997,13 +1002,13 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
if 'neighbor' in hue_data.columns:
sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=hue_data.edgecolor.values,linewidth=2,
edgecolor=hue_data.edgecolor.values,linewidth=linewidth,
style=marker_var, hue=hue_var, palette=palette, ax=ax, legend=False, **scatter_kwargs)
if not isinstance(edgecolor, str):
edgecolor = None
linewidth = 0
else:
linewidth = 1
linewidth = linewidth

sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,
edgecolor=edgecolor,linewidth=linewidth,
Expand All @@ -1018,7 +1023,17 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
missing_handles, missing_labels = ax.get_legend_handles_labels()

# h, l= consolidate_legends([ax], hue=hue_var, style =marker_var, size=size_var, colorbar=colorbar)
def replace_last(source_string, replace_what, replace_with):
head, _sep, tail = source_string.rpartition(replace_what)
return head + replace_with + tail

h, l = ax.get_legend_handles_labels()
# l = [replace_last(label, ' ', '\n') if len(label)>20 else label for label in l]
l = [label.replace(', ', ',\n') if len(label)>20 else label for label in l]
if len(l)> 20:
ncols=2
else:
ncols = 1

if ((len(l) == 2) and (l[-1] == 'missing')) or (len(l) < 2):
legend = False
Expand Down Expand Up @@ -1056,10 +1071,12 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
# first pass at sig figs approach to number formatting
_label = np.format_float_positional(np.float16(pair_l[ik]), unique=True, precision=2)
except:
try:
_label = LipdToOntology(pair_l[ik])
except:
_label = pair_l[ik]
_label = pair_l[ik]
if label == 'archiveType':
try:
_label = LipdToOntology(pair_l[ik])
except:
pass
if _label not in d_leg[key]['labels']:
d_leg[key]['labels'].append(_label)
d_leg[key]['handles'].append(pair_h[ik])
Expand All @@ -1074,6 +1091,15 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
# # ax_d.pop('leg', None)
# else:
# d_leg.pop(hue_var, None)
# print(d_leg, d_leg.keys())
tmpHandles = []
for key in [hue_var, marker_var]:
if key in d_leg.keys():
for handle in d_leg[key]['handles']:
handle.set_markersize(handle_size)
# tmpHandles.append(handle)

# d_leg[hue_var]['handles'] = tmpHandles

if (legend is True) and (len(d_leg.keys()) > 0):
# Finally rebuild legend in single list with formatted section headers
Expand Down Expand Up @@ -1110,10 +1136,12 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
handles.append(han)
labels.append('')

if 'loc' not in lgd_kwargs:
if 'loc' not in lgd_kwargs.keys():
lgd_kwargs['loc'] = 'upper left'
if 'bbox_to_anchor' not in lgd_kwargs:
if 'bbox_to_anchor' not in lgd_kwargs.keys():
lgd_kwargs['bbox_to_anchor'] = (-.1, 1) # (1, 1)
if 'labelspacing' not in lgd_kwargs.keys():
lgd_kwargs['labelspacing'] = .275

built_legend = ax_leg.legend(handles, labels, **lgd_kwargs)
if headers is True:
Expand All @@ -1125,9 +1153,6 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va

ax_leg.set_axis_off()
ax.legend().remove()
# if colorbar == False:
# ax_cb.remove()
# ax_d.pop('cb', None)

else:
ax.legend().remove()
Expand Down
Loading

0 comments on commit 245728d

Please sign in to comment.