diff --git a/pyleoclim/core/ensemblegeoseries.py b/pyleoclim/core/ensemblegeoseries.py index e73c178d..33d27191 100644 --- a/pyleoclim/core/ensemblegeoseries.py +++ b/pyleoclim/core/ensemblegeoseries.py @@ -13,6 +13,8 @@ import warnings +from collections import Counter + import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec @@ -86,6 +88,19 @@ def __init__(self, series_list,label=None,lat=None,lon=None,elevation=None,archi super().__init__(series_list,label) + # Assign archiveType if it isn't passed and it is present in all series in series_list + if archiveType is None: + if not all([isinstance(ts, GeoSeries) for ts in series_list]): + pass + else: + archiveList = [str(ts.archiveType) for ts in series_list] + #Check that they're all the same + if all([a == archiveList[0] for a in archiveList]): + archiveType = archiveList[0] + #If they aren't, pick the most common one + else: + archiveType = Counter(archiveList).most_common(1)[0][0] + if lat is None: # check that all components are GeoSeries if not all([isinstance(ts, GeoSeries) for ts in series_list]): @@ -619,8 +634,16 @@ def dashboard(self, figsize=[11, 8], gs=None, plt_kwargs=None, histplt_kwargs=No if archiveType not in lipdutils.PLOT_DEFAULT.keys(): archiveType = 'Other' else: - archiveType = 'Other' - + if not all([isinstance(ts, GeoSeries) for ts in self.series_list]): + archiveType = 'Other' + else: + archiveList = [str(ts.archiveType) for ts in self.series_list] + #Check that they're all the same + if all([a == archiveList[0] for a in archiveList]): + archiveType = archiveList[0] + #If they aren't, pick the most common one + else: + archiveType = Counter(archiveList).most_common(1)[0][0] # if 'marker' not in plt_kwargs.keys(): # plt_kwargs.update({'marker': lipdutils.PLOT_DEFAULT[archiveType][1]}) if 'curve_clr' not in plt_kwargs.keys(): diff --git a/pyleoclim/core/ensembleseries.py b/pyleoclim/core/ensembleseries.py index 5058faef..6617c16b 100644 --- a/pyleoclim/core/ensembleseries.py +++ b/pyleoclim/core/ensembleseries.py @@ -13,6 +13,7 @@ from ..core.multipleseries import MultipleSeries import warnings +warnings.filterwarnings("ignore") import seaborn as sns import matplotlib.pyplot as plt @@ -666,7 +667,7 @@ def correlation(self, target=None, timespan=None, alpha=0.05, method = 'ttest', return corr_ens def plot_traces(self, figsize=[10, 4], xlabel=None, ylabel=None, title=None, num_traces=10, seed=None, - xlim=None, ylim=None, linestyle='-', savefig_settings=None, ax=None, plot_legend=True, + xlim=None, ylim=None, linestyle='-', savefig_settings=None, ax=None, legend=True, color=sns.xkcd_rgb['pale red'], lw=0.5, alpha=0.3, lgd_kwargs=None): '''Plot EnsembleSeries as a subset of traces. @@ -804,8 +805,8 @@ def plot_traces(self, figsize=[10, 4], xlabel=None, ylabel=None, title=None, num for idx in random_draw_idx: self.series_list[idx].plot(xlabel=xlabel, ylabel=ylabel, zorder=99, linewidth=lw, - xlim=xlim, ylim=ylim, ax=ax, color=color, alpha=alpha,linestyle='-') - ax.plot(np.nan, np.nan, color=color, label=f'example members (n={num_traces})',linestyle='-') + xlim=xlim, ylim=ylim, ax=ax, color=color, alpha=alpha,linestyle='-', label='_ignore') + l1, = ax.plot(np.nan, np.nan, color=color, label=f'example members (n={num_traces})',linestyle='-') if title is not None: ax.set_title(title) @@ -813,10 +814,14 @@ def plot_traces(self, figsize=[10, 4], xlabel=None, ylabel=None, title=None, num if self.label is not None: ax.set_title(self.label) - if plot_legend: + if legend==True: lgd_args = {'frameon': False} lgd_args.update(lgd_kwargs) ax.legend(**lgd_args) + elif legend==False: + ax.legend().remove() + else: + raise ValueError('legend should be set to either True or False') if 'fig' in locals(): if 'path' in savefig_settings: