Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(plot): improve plotter arguments #425

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions stable_learning_control/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ def plot_data(
xaxis="Epoch",
value="AverageEpRet",
condition="Condition1",
errorbar="sd",
smooth=1,
font_scale=1.5,
style="darkgrid",
**kwargs,
):
"""Function used to plot data.
Expand All @@ -48,10 +50,14 @@ def plot_data(
off-policy algorithms. The plotter will automatically figure out
which of ``AverageEpRet`` or ``AverageTestEpRet`` to report for
each separate logdir.
condition (str, optional): The condition to search for. By default ``None``.
condition (str, optional): The condition to search for. By default
``Condition1``.
errorbar (str): The error bar you want to use for the plot. Defaults
to ``sd``.
smooth (int): Smooth data by averaging it over a fixed window. This
parameter says how wide the averaging window will be.
font_scale (int): The font scale you want to use for the plot text.
style (str): The style you want to use for the plot.
"""
if smooth > 1:
"""
Expand All @@ -69,8 +75,8 @@ def plot_data(

if isinstance(data, list):
data = pd.concat(data, ignore_index=True)
sns.set(style="darkgrid", font_scale=font_scale)
sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar="sd", **kwargs)
sns.set(style=style, font_scale=font_scale)
sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar=errorbar, **kwargs)
plt.legend(loc="best").set_draggable(True)

xscale = np.max(np.asarray(data[xaxis])) > 5e3
Expand Down Expand Up @@ -209,6 +215,7 @@ def make_plots(
values=None,
count=False,
font_scale=1.5,
style="darkgrid",
smooth=1,
select=None,
exclude=None,
Expand All @@ -233,7 +240,7 @@ def make_plots(
rules (below).)
xaxis (str): Pick what column from data is used for the x-axis.
Defaults to ``TotalEnvInteracts``.
value (str): Pick what columns from data to graph on the y-axis.
values (list): Pick what columns from data to graph on the y-axis.
Submitting multiple values will produce multiple graphs. Defaults
to ``Performance``, which is not an actual output of any algorithm.
Instead, ``Performance`` refers to either ``AverageEpRet``, the
Expand All @@ -247,6 +254,8 @@ def make_plots(
which is typically a set of identical experiments that only vary
in random seed. But if you'd like to see all of those curves
separately, use the ``--count`` flag.
font_scale (int): The font scale you want to use for the plot text.
style (str): The style you want to use for the plot.
smooth (int): Smooth data by averaging it over a fixed window. This
parameter says how wide the averaging window will be.
select (list[str]): Optional selection rule: the plotter will only show
Expand All @@ -271,6 +280,7 @@ def make_plots(
smooth=smooth,
estimator=estimator,
font_scale=font_scale,
style=style,
)
plt.show()

Expand Down