From ed427c1bb05cc8f19990519cd5d47d0a1cd59c88 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 8 Mar 2024 21:00:41 +0100 Subject: [PATCH] feat(plot): improve plotter arguments This commit adds new arguments to the plotter that allow users to customize the plots more. --- stable_learning_control/utils/plot.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/stable_learning_control/utils/plot.py b/stable_learning_control/utils/plot.py index 7aa291a9..9770ad57 100644 --- a/stable_learning_control/utils/plot.py +++ b/stable_learning_control/utils/plot.py @@ -29,8 +29,10 @@ def plot_data( xaxis="Epoch", value="AverageEpRet", condition="Condition1", + errorbar="sd", smooth=1, font_scale=1.5, + style="darkgrid", **kwargs, ): """Function used to plot data. @@ -48,10 +50,14 @@ def plot_data( off-policy algorithms. The plotter will automatically figure out which of ``AverageEpRet`` or ``AverageTestEpRet`` to report for each separate logdir. - condition (str, optional): The condition to search for. By default ``None``. + condition (str, optional): The condition to search for. By default + ``Condition1``. + errorbar (str): The error bar you want to use for the plot. Defaults + to ``sd``. smooth (int): Smooth data by averaging it over a fixed window. This parameter says how wide the averaging window will be. font_scale (int): The font scale you want to use for the plot text. + style (str): The style you want to use for the plot. """ if smooth > 1: """ @@ -69,8 +75,8 @@ def plot_data( if isinstance(data, list): data = pd.concat(data, ignore_index=True) - sns.set(style="darkgrid", font_scale=font_scale) - sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar="sd", **kwargs) + sns.set(style=style, font_scale=font_scale) + sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar=errorbar, **kwargs) plt.legend(loc="best").set_draggable(True) xscale = np.max(np.asarray(data[xaxis])) > 5e3 @@ -209,6 +215,7 @@ def make_plots( values=None, count=False, font_scale=1.5, + style="darkgrid", smooth=1, select=None, exclude=None, @@ -233,7 +240,7 @@ def make_plots( rules (below).) xaxis (str): Pick what column from data is used for the x-axis. Defaults to ``TotalEnvInteracts``. - value (str): Pick what columns from data to graph on the y-axis. + values (list): Pick what columns from data to graph on the y-axis. Submitting multiple values will produce multiple graphs. Defaults to ``Performance``, which is not an actual output of any algorithm. Instead, ``Performance`` refers to either ``AverageEpRet``, the @@ -247,6 +254,8 @@ def make_plots( which is typically a set of identical experiments that only vary in random seed. But if you'd like to see all of those curves separately, use the ``--count`` flag. + font_scale (int): The font scale you want to use for the plot text. + style (str): The style you want to use for the plot. smooth (int): Smooth data by averaging it over a fixed window. This parameter says how wide the averaging window will be. select (list[str]): Optional selection rule: the plotter will only show @@ -271,6 +280,7 @@ def make_plots( smooth=smooth, estimator=estimator, font_scale=font_scale, + style=style, ) plt.show()