Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the _color_arg_to_dict function #402

Merged
merged 8 commits into from
Jun 30, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .. import convert
from ..classes import DiHypergraph, Hypergraph, SimplicialComplex, max_edge_order
from ..exception import XGIError
from ..stats import EdgeStat, NodeStat
from ..stats import IDStat
from .layout import _augmented_projection, barycenter_spring_layout

__all__ = [
Expand Down Expand Up @@ -644,6 +644,7 @@ def draw_simplices(
color=dyad_color[id],
lw=dyad_lw[id],
zorder=max_order - 1,
cmap=settings["dyad"],
)
ax.add_line(line)
else:
Expand Down Expand Up @@ -722,7 +723,7 @@ def _scalar_arg_to_dict(arg, ids, min_val, max_val):
return {id: arg[id] for id in arg if id in ids}
elif type(arg) in [int, float]:
return {id: arg for id in ids}
elif isinstance(arg, NodeStat) or isinstance(arg, EdgeStat):
elif isinstance(arg, IDStat):
vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [min_val, max_val])
return dict(zip(ids, vals))
elif isinstance(arg, Iterable):
Expand Down Expand Up @@ -756,21 +757,55 @@ def _color_arg_to_dict(arg, ids, cmap):
TypeError
If a string, dict, iterable, or NodeStat/EdgeStat is not passed.
"""
if isinstance(arg, dict):
return {id: arg[id] for id in arg if id in ids}
elif isinstance(arg, str):

# single argument. Must be a string or a tuple of floats
if isinstance(arg, str) or (isinstance(arg, tuple) and isinstance(arg[0], float)):
return {id: arg for id in ids}
elif isinstance(arg, (NodeStat, EdgeStat)):

# Iterables of colors. The values of these iterables must strings or tuples. As of now,
# there is not a check to verify that the tuples contain floats.
if isinstance(arg, Iterable):
if isinstance(arg, dict) and isinstance(next(iter(arg.values())), (str, tuple)):
return {id: arg[id] for id in arg if id in ids}
if isinstance(arg, (list, np.ndarray)) and isinstance(arg[0], (str, tuple)):
return {id: arg[idx] for idx, id in enumerate(ids)}

# Stats or iterable of values
if isinstance(arg, (Iterable, IDStat)):
# set max and min of interpolation based on color map
if isinstance(cmap, ListedColormap):
vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [0, cmap.N])
minval = 0
maxval = cmap.N
elif isinstance(cmap, LinearSegmentedColormap):
vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [0.1, 0.9])
minval = 0.1
maxval = 0.9
Comment on lines +840 to +841
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are these decided?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These values are so that the colors don't fade into the background but I suppose we could add them as an arg in the future?

else:
raise XGIError("Invalid colormap!")

# handle the case of IDStat vs iterables
if isinstance(arg, IDStat):
vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [minval, maxval])

elif isinstance(arg, Iterable):
if isinstance(arg, dict) and isinstance(
next(iter(arg.values())), (int, float)
):
v = list(arg.values())
vals = np.interp(v, [np.min(v), np.max(v)], [minval, maxval])
# because we have ids, we can't just assume that the keys of arg correspond to
# the ids.
return {
id: np.array(cmap(v)).reshape(1, -1)
for v, id in zip(vals, arg.keys())
if id in ids
}

if isinstance(arg, (list, np.ndarray)) and isinstance(arg[0], (int, float)):
vals = np.interp(arg, [np.min(arg), np.max(arg)], [minval, maxval])
else:
raise TypeError("Argument must be an iterable of floats.")

return {id: np.array(cmap(vals[i])).reshape(1, -1) for i, id in enumerate(ids)}
nwlandry marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(arg, Iterable):
return {id: arg[idx] for idx, id in enumerate(ids)}
else:
raise TypeError(
"Argument must be str, dict, iterable, or "
Expand Down