Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the _color_arg_to_dict function #402

Merged
merged 8 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,67 @@ def test_scalar_arg_to_dict(edgelist4):
def test_color_arg_to_dict(edgelist4):
ids = [1, 2, 3]

arg = "black"
d = _color_arg_to_dict(arg, ids, None)
# single values
arg1 = "black"
arg2 = (0.1, 0.2, 0.3)
arg3 = (0.1, 0.2, 0.3, 0.5)

# test iterables of colors
arg4 = [(0.1, 0.2, 0.3), (0.1, 0.2, 0.4), (0.1, 0.2, 0.5)]
arg5 = ["blue", "black", "red"]
arg6 = np.array(["blue", "black", "red"])
arg7 = {0: (0.1, 0.2, 0.3), 1: (0.1, 0.2, 0.4), 2: (0.1, 0.2, 0.5)}
arg8 = {0: "blue", 1: "black", 2: "red"}

# test iterables of values
arg9 = [0, 0.1, 0.2]
arg10 = {1: 0, 2: 0.1, 3: 0.2}
arg11 = np.array([0, 0.1, 0.2])

# test single values
d = _color_arg_to_dict(arg1, ids, None)
assert d == {1: "black", 2: "black", 3: "black"}

with pytest.raises(TypeError):
arg = 0.3
d = _color_arg_to_dict(arg, ids, None)
d = _color_arg_to_dict(arg2, ids, None)
assert d == {1: (0.1, 0.2, 0.3), 2: (0.1, 0.2, 0.3), 3: (0.1, 0.2, 0.3)}

with pytest.raises(TypeError):
arg = 1
d = _color_arg_to_dict(arg, ids, None)
d = _color_arg_to_dict(arg3, ids, None)
for i in d:
assert np.allclose(d[i], np.array([0.1, 0.2, 0.3, 0.5]))

# Test iterables of colors
d = _color_arg_to_dict(arg4, ids, None)
assert d == {1: (0.1, 0.2, 0.3), 2: (0.1, 0.2, 0.4), 3: (0.1, 0.2, 0.5)}

d = _color_arg_to_dict(arg5, ids, None)
assert d == {1: "blue", 2: "black", 3: "red"}

d = _color_arg_to_dict(arg6, ids, None)
assert d == {1: "blue", 2: "black", 3: "red"}

d = _color_arg_to_dict(arg7, ids, None)
assert d == {1: (0.1, 0.2, 0.4), 2: (0.1, 0.2, 0.5)}

arg = ["black", "blue", "red"]
d = _color_arg_to_dict(arg, ids, None)
assert d == {1: "black", 2: "blue", 3: "red"}
d = _color_arg_to_dict(arg8, ids, None)
assert d == {1: "black", 2: "red"}

arg = np.array(["black", "blue", "red"])
d = _color_arg_to_dict(arg, ids, None)
assert d == {1: "black", 2: "blue", 3: "red"}
# Test iterables of values
cdict = {
1: np.array([[0.89173395, 0.93510188, 0.97539408, 1.0]]),
2: np.array([[0.41708574, 0.68063053, 0.83823145, 1.0]]),
3: np.array([[0.03137255, 0.28973472, 0.57031911, 1.0]]),
}
d = _color_arg_to_dict(arg9, ids, cm.Blues)
for i in d:
assert np.allclose(d[i], cdict[i])

d = _color_arg_to_dict(arg10, ids, cm.Blues)
for i in d:
assert np.allclose(d[i], cdict[i])

d = _color_arg_to_dict(arg11, ids, cm.Blues)
for i in d:
assert np.allclose(d[i], cdict[i])

H = xgi.Hypergraph(edgelist4)
arg = H.nodes.degree
Expand All @@ -88,6 +130,15 @@ def test_color_arg_to_dict(edgelist4):
assert np.allclose(d[2], np.array([[0.98357555, 0.41279508, 0.28835063, 1.0]]))
assert np.allclose(d[3], np.array([[0.59461745, 0.0461361, 0.07558631, 1.0]]))

# Test bad calls
with pytest.raises(TypeError):
arg = 0.3
d = _color_arg_to_dict(arg, ids, None)

with pytest.raises(TypeError):
arg = 1
d = _color_arg_to_dict(arg, ids, None)


def test_draw(edgelist8):
H = xgi.Hypergraph(edgelist8)
Expand Down
11 changes: 7 additions & 4 deletions xgi/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,11 @@ def to_bipartite_graph(H, index=False):
"""
G = nx.Graph()

node_dict = dict(zip(H.nodes, range(H.num_nodes)))
edge_dict = dict(zip(H.edges, range(H.num_nodes, H.num_nodes + H.num_edges)))
n = H.num_nodes
m = H.num_edges

node_dict = dict(zip(H.nodes, range(n)))
edge_dict = dict(zip(H.edges, range(n, n + m)))
G.add_nodes_from(node_dict.values(), bipartite=0)
G.add_nodes_from(edge_dict.values(), bipartite=1)
for node in H.nodes:
Expand All @@ -746,8 +749,8 @@ def to_bipartite_graph(H, index=False):
if index:
return (
G,
dict(zip(range(H.num_nodes), H.nodes)),
dict(zip(range(H.num_nodes, H.num_nodes + H.num_edges), H.edges)),
dict(zip(range(n), H.nodes)),
dict(zip(range(n, n + m), H.edges)),
)
else:
return G
Expand Down
89 changes: 78 additions & 11 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from matplotlib.patches import FancyArrow
from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection
from networkx import spring_layout
from numpy import ndarray
from scipy.spatial import ConvexHull

from .. import convert
from ..classes import DiHypergraph, Hypergraph, SimplicialComplex, max_edge_order
from ..exception import XGIError
from ..stats import EdgeStat, NodeStat
from ..stats import IDStat
from .layout import _augmented_projection, barycenter_spring_layout

__all__ = [
Expand Down Expand Up @@ -722,7 +723,7 @@ def _scalar_arg_to_dict(arg, ids, min_val, max_val):
return {id: arg[id] for id in arg if id in ids}
elif type(arg) in [int, float]:
return {id: arg for id in ids}
elif isinstance(arg, NodeStat) or isinstance(arg, EdgeStat):
elif isinstance(arg, IDStat):
vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [min_val, max_val])
return dict(zip(ids, vals))
elif isinstance(arg, Iterable):
Expand All @@ -739,7 +740,30 @@ def _color_arg_to_dict(arg, ids, cmap):

Parameters
----------
arg : str, dict, iterable, or NodeStat/EdgeStat
arg : There are several different valid arguments. They are:
nwlandry marked this conversation as resolved.
Show resolved Hide resolved

Single color values

* str
* 3- or 4-tuple

Iterable of colors

* numpy array of color values as described above
* list of color values as described above
* dict of color values as described above
nwlandry marked this conversation as resolved.
Show resolved Hide resolved

Iterable of values

* list of floats
* dict of floats
* numpy array of floats
nwlandry marked this conversation as resolved.
Show resolved Hide resolved

Stats

* NodeStat
* EdgeStat

Attributes for drawing parameter.
ids : NodeView or EdgeView
This is the node or edge IDs that attributes get mapped to.
Expand All @@ -755,22 +779,65 @@ def _color_arg_to_dict(arg, ids, cmap):
------
TypeError
If a string, dict, iterable, or NodeStat/EdgeStat is not passed.

Notes
-----
For the iterable of values, we do not accept numpy arrays or tuples,
because there is the potential for ambiguity.
nwlandry marked this conversation as resolved.
Show resolved Hide resolved
"""
if isinstance(arg, dict):
return {id: arg[id] for id in arg if id in ids}
elif isinstance(arg, str):

# single argument. Must be a string or a tuple of floats
if isinstance(arg, str) or (isinstance(arg, tuple) and isinstance(arg[0], float)):
return {id: arg for id in ids}
elif isinstance(arg, (NodeStat, EdgeStat)):

# Iterables of colors. The values of these iterables must strings or tuples. As of now,
# there is not a check to verify that the tuples contain floats.
if isinstance(arg, Iterable):
if isinstance(arg, dict) and isinstance(
next(iter(arg.values())), (str, tuple, ndarray)
):
return {id: arg[id] for id in arg if id in ids}
if isinstance(arg, (list, ndarray)) and isinstance(
arg[0], (str, tuple, ndarray)
):
return {id: arg[idx] for idx, id in enumerate(ids)}

# Stats or iterable of values
if isinstance(arg, (Iterable, IDStat)):
# set max and min of interpolation based on color map
if isinstance(cmap, ListedColormap):
vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [0, cmap.N])
minval = 0
maxval = cmap.N
elif isinstance(cmap, LinearSegmentedColormap):
vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [0.1, 0.9])
minval = 0.1
maxval = 0.9
Comment on lines +840 to +841
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are these decided?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These values are so that the colors don't fade into the background but I suppose we could add them as an arg in the future?

else:
raise XGIError("Invalid colormap!")

# handle the case of IDStat vs iterables
if isinstance(arg, IDStat):
vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [minval, maxval])

elif isinstance(arg, Iterable):
if isinstance(arg, dict) and isinstance(
next(iter(arg.values())), (int, float)
):
v = list(arg.values())
vals = np.interp(v, [np.min(v), np.max(v)], [minval, maxval])
# because we have ids, we can't just assume that the keys of arg correspond to
# the ids.
return {
id: np.array(cmap(v)).reshape(1, -1)
for v, id in zip(vals, arg.keys())
if id in ids
}

if isinstance(arg, (list, ndarray)) and isinstance(arg[0], (int, float)):
vals = np.interp(arg, [np.min(arg), np.max(arg)], [minval, maxval])
else:
raise TypeError("Argument must be an iterable of floats.")

return {id: np.array(cmap(vals[i])).reshape(1, -1) for i, id in enumerate(ids)}
nwlandry marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(arg, Iterable):
return {id: arg[idx] for idx, id in enumerate(ids)}
else:
raise TypeError(
"Argument must be str, dict, iterable, or "
Expand Down