Skip to content

Commit

Permalink
functionallity for passing station specific style kwargs implemented …
Browse files Browse the repository at this point in the history
…in lineplots
  • Loading branch information
vergauwenthomas committed Oct 9, 2024
1 parent c139d04 commit db6f259
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 57 deletions.
18 changes: 10 additions & 8 deletions metobs_toolkit/dataset_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def make_plot(
legend=True,
show_outliers=True,
show_filled=True,
name_color_def={},
sta_plot_kwargs_dict={},
_ax=None, # needed for GUI, not recommended use
):
"""Make a timeseries plot.
Expand Down Expand Up @@ -100,12 +100,14 @@ def make_plot(
If true the filled values for gaps and missing observations will
be included in the plot. This is only true when colorby == 'name'.
The default is True.
name_color_def : dict, optional
If colorby is 'name', a colormap is used as color defenitions for
the name. If a name_color_def dictionary is given, then the color
defenition (value) for a station name (key) is used as defined by
the user. Colors are strings that can be represent by matplotlib
name, or in hex-form. The default is {}.
sta_plot_kwargs_dict : dict, optional
sta_plot_kwargs_dict is a nested dictionary that can contain extra
styling arguments that is used for a specific station, if colorby=name.
The keys are the station names, and the values are a dict with the
keys elements of ['color', 'linewidth', 'zorder', 'linestyle']. Refer
to the corresponding keyword in matplotlib for their meaning and
possible types. The default is {}.
Returns
Expand Down Expand Up @@ -271,7 +273,7 @@ def make_plot(
show_outliers=show_outliers,
show_filled=show_filled,
settings=self.settings,
name_col_def=name_color_def,
sta_plot_kwargs_dict=sta_plot_kwargs_dict,
_ax=_ax,
)

Expand Down
39 changes: 29 additions & 10 deletions metobs_toolkit/modeldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,7 @@ def make_plot(
show_outliers=True,
show_filled=True,
legend=True,
name_color_def={},
sta_plot_kwargs_dict={},
_ax=None, # needed for GUI, not recommended use
):
"""Plot timeseries of the modeldata.
Expand Down Expand Up @@ -1942,12 +1942,13 @@ def make_plot(
The default is True.
legend : bool, optional
If True, a legend is added to the plot. The default is True.
name_color_def : dict, optional
If colorby is 'name', a colormap is used as color defenitions for
the name. If a name_color_def dictionary is given, then the color
defenition (value) for a station name (key) is used as defined by
the user. Colors are strings that can be represent by matplotlib
name, or in hex-form. The default is {}.
sta_plot_kwargs_dict : dict, optional
sta_plot_kwargs_dict is a nested dictionary that can contain extra
styling arguments that is used for a specific station.
The keys are the station names, and the values are a dict with the
keys elements of ['color', 'linewidth', 'zorder', 'linestyle']. Refer
to the corresponding keyword in matplotlib for their meaning and
possible types. The default is {}.
Returns
-------
Expand Down Expand Up @@ -2113,10 +2114,28 @@ def make_plot(
show_outliers=show_outliers,
show_filled=show_filled,
settings=Dataset.settings,
name_col_def=name_color_def,
sta_plot_kwargs_dict=sta_plot_kwargs_dict,
_ax=_ax,
)

# use the col_map to update the sta_plot_kwargs_dict,
# so that the same colors are use for modeldata

# Note: a simplile dict update will not work since all other-than-color
# elements are removed
for staname, col in col_map.items():
if staname in sta_plot_kwargs_dict.keys():
if (
"color" in sta_plot_kwargs_dict[staname].keys()
): # loop for readability
# update color
sta_plot_kwargs_dict[staname]["color"] = col
else:
# add color-pair
sta_plot_kwargs_dict[staname]["color"] = col
else:
sta_plot_kwargs_dict[staname] = {"color": col}

# Make plot of the model on the previous axes
ax, col_map = model_timeseries_plot(
df=model_df,
Expand All @@ -2126,7 +2145,7 @@ def make_plot(
show_primary_legend=False,
add_second_legend=True,
_ax=_ax,
colorby_name_colordict=col_map,
sta_plot_kwargs_dict=sta_plot_kwargs_dict,
)

else:
Expand All @@ -2138,7 +2157,7 @@ def make_plot(
ylabel=y_label,
show_primary_legend=legend,
add_second_legend=False,
name_col_def=name_color_def,
sta_plot_kwargs_dict=sta_plot_kwargs_dict,
_ax=_ax,
)

Expand Down
101 changes: 62 additions & 39 deletions metobs_toolkit/plotting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def _create_linecollection(
const_color=None,
value_col_name="value",
label_col_name="label",
line_force_kwargs={},
):

# 1. convert datetime to numerics values
Expand All @@ -537,16 +538,33 @@ def _create_linecollection(
points = np.array([inxval, linedf[value_col_name]]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

# 3. get styling info
# 3. styling attributes

style_kwargs = {
"color": const_color,
"linewidth": linewidth,
"zorder": linezorder,
"linestyle": None,
}

# 4. update styling attributes
style_kwargs.update(line_force_kwargs)

# 5. Construct arguments
if const_color is None:
color = linedf[label_col_name].map(colormapper).to_list()
else:
color = [const_color] * linedf.shape[0]
linewidth = [linewidth] * linedf.shape[0]
zorder = linezorder
linestyle = linedf[label_col_name].map(linestylemapper).fillna("-").to_list()
color = [style_kwargs["color"]] * linedf.shape[0]

# 4. Make line collection
linewidth = [style_kwargs["linewidth"]] * linedf.shape[0]
zorder = style_kwargs["zorder"]

if style_kwargs["linestyle"] is None:
linestyle = linedf[label_col_name].map(linestylemapper).fillna("-").to_list()
else:
linestyle = style_kwargs["linestyle"]

# 6. create plot
lc = LineCollection(
segments=segments,
colors=color,
Expand All @@ -566,9 +584,8 @@ def timeseries_plot(
show_outliers,
show_filled,
settings,
name_col_def={},
sta_plot_kwargs_dict={},
_ax=None, # needed for GUI, not recommended use
colorby_name_colordict=None,
): # when colorscheme will be reused
"""Make a timeseries plot.
Expand Down Expand Up @@ -596,9 +613,6 @@ def timeseries_plot(
_ax : matplotlib.pyplot.axes
An axes to plot on. If None, a new axes will be made. The
default is None.
colorby_name_colorscheme : dict
A colormapper for the station names. If None, a new colormapper will
be created. The default is None.
Returns
-------
Expand Down Expand Up @@ -823,20 +837,19 @@ def timeseries_plot(
# all lines are solid lines
line_style_mapper = {lab: "-" for lab in line_labels}

# create color mapper if none is given
if colorby_name_colordict is None:
col_mapper = make_cat_colormapper(
mergedf.index.get_level_values("name").unique(),
plot_settings["time_series"]["colormap"],
name_col_def,
)
else:
# col_mapper = colorby_name_colordict
col_mapper = make_cat_colormapper(
mergedf.index.get_level_values("name").unique(),
plot_settings["time_series"]["colormap"],
colorby_name_colordict,
)
# create color mapper

col_mapper = make_cat_colormapper(
mergedf.index.get_level_values("name").unique(),
plot_settings["time_series"]["colormap"],
)

# update the colormapper using the station plot kwargs
kwargs_col_map = {}
for staname, stakwargs in sta_plot_kwargs_dict.items():
if "color" in stakwargs.keys():
kwargs_col_map[staname] = stakwargs["color"]
col_mapper.update(kwargs_col_map)

# iterate over station and make line collection to avoid interpolation
# sort mergedf by station name so the same colors are used for the same stations
Expand All @@ -852,14 +865,21 @@ def timeseries_plot(
# interpolation in the plot
stadf.loc[~stadf.index.isin(linedf.index), "value"] = np.nan

# plot kwargs to forece for this station
if sta in sta_plot_kwargs_dict.keys():
stakwargs = sta_plot_kwargs_dict[sta]
else:
stakwargs = {}

# make line collection
sta_line_lc = _create_linecollection(
linedf=stadf,
colormapper=None,
const_color=col_mapper[sta],
linestylemapper=line_style_mapper,
# plotsettings=plot_settings,
line_force_kwargs=stakwargs,
)

ax.add_collection(sta_line_lc)

if show_legend is True:
Expand Down Expand Up @@ -912,13 +932,12 @@ def model_timeseries_plot(
show_primary_legend,
add_second_legend=True,
_ax=None, # needed for GUI, not recommended use
colorby_name_colordict=None, # automatic --> not by user
figsize=(15, 5),
colormap="tab20",
legend_n_columns=5,
linewidth=2,
linezorder=1,
name_col_def={}, # Defined by user
sta_plot_kwargs_dict={},
):
"""Make a timeseries plot for modeldata.
Expand Down Expand Up @@ -959,9 +978,6 @@ def model_timeseries_plot(
The width of the plotted lines. The default is 2.
linezorder: int, optional.
The zorder of the lines in the plot. The default is 1.
name_col_def: dict, optional
A dictionary to force colors for a station. The keys are the names and
the values are the colors to use for them.
Expand Down Expand Up @@ -991,19 +1007,25 @@ def model_timeseries_plot(
line_style_mapper = {"modeldata": "--"}

# create color mapper if none is given
if colorby_name_colordict is None:
col_mapper = make_cat_colormapper(
df.index.get_level_values("name").unique(),
colormap,
name_col_def,
)
else:
col_mapper = colorby_name_colordict
col_mapper = make_cat_colormapper(
df.index.get_level_values("name").unique(), colormap
)

# update the colormapper using the station plot kwargs
kwargs_col_map = {}
for staname, stakwargs in sta_plot_kwargs_dict.items():
if "color" in stakwargs.keys():
kwargs_col_map[staname] = stakwargs["color"]
col_mapper.update(kwargs_col_map)

# iterate over station and make line collection to avoid interpolation
for sta in df.index.get_level_values("name").unique():
stadf = xs_save(df, sta, "name") # subset to one station

if sta in sta_plot_kwargs_dict.keys():
force_line_kwargs = sta_plot_kwargs_dict[sta]
else:
force_line_kwargs = {}
# make line collection
sta_line_lc = _create_linecollection(
linedf=stadf,
Expand All @@ -1012,6 +1034,7 @@ def model_timeseries_plot(
linestylemapper=line_style_mapper,
linewidth=linewidth,
linezorder=linezorder,
line_force_kwargs=force_line_kwargs,
)
ax.add_collection(sta_line_lc)

Expand Down

0 comments on commit db6f259

Please sign in to comment.