diff --git a/tests/drawing/test_draw.py b/tests/drawing/test_draw.py index 143ded3bf..aa8ec3a1e 100644 --- a/tests/drawing/test_draw.py +++ b/tests/drawing/test_draw.py @@ -57,29 +57,75 @@ def test_scalar_arg_to_dict(edgelist4): arg = "2" d = _scalar_arg_to_dict(arg, ids, min_val, max_val) + with pytest.raises(TypeError): + arg = (1, 2, 3) + d = _scalar_arg_to_dict(arg, ids, min_val, max_val) + def test_color_arg_to_dict(edgelist4): ids = [1, 2, 3] - arg = "black" - d = _color_arg_to_dict(arg, ids, None) + # single values + arg1 = "black" + arg2 = (0.1, 0.2, 0.3) + arg3 = (0.1, 0.2, 0.3, 0.5) + + # test iterables of colors + arg4 = [(0.1, 0.2, 0.3), (0.1, 0.2, 0.4), (0.1, 0.2, 0.5)] + arg5 = ["blue", "black", "red"] + arg6 = np.array(["blue", "black", "red"]) + arg7 = {0: (0.1, 0.2, 0.3), 1: (0.1, 0.2, 0.4), 2: (0.1, 0.2, 0.5)} + arg8 = {0: "blue", 1: "black", 2: "red"} + + # test iterables of values + arg9 = [0, 0.1, 0.2] + arg10 = {1: 0, 2: 0.1, 3: 0.2} + arg11 = np.array([0, 0.1, 0.2]) + + # test single values + d = _color_arg_to_dict(arg1, ids, None) assert d == {1: "black", 2: "black", 3: "black"} - with pytest.raises(TypeError): - arg = 0.3 - d = _color_arg_to_dict(arg, ids, None) + d = _color_arg_to_dict(arg2, ids, None) + assert d == {1: (0.1, 0.2, 0.3), 2: (0.1, 0.2, 0.3), 3: (0.1, 0.2, 0.3)} - with pytest.raises(TypeError): - arg = 1 - d = _color_arg_to_dict(arg, ids, None) + d = _color_arg_to_dict(arg3, ids, None) + for i in d: + assert np.allclose(d[i], np.array([0.1, 0.2, 0.3, 0.5])) + + # Test iterables of colors + d = _color_arg_to_dict(arg4, ids, None) + assert d == {1: (0.1, 0.2, 0.3), 2: (0.1, 0.2, 0.4), 3: (0.1, 0.2, 0.5)} + + d = _color_arg_to_dict(arg5, ids, None) + assert d == {1: "blue", 2: "black", 3: "red"} + + d = _color_arg_to_dict(arg6, ids, None) + assert d == {1: "blue", 2: "black", 3: "red"} - arg = ["black", "blue", "red"] - d = _color_arg_to_dict(arg, ids, None) - assert d == {1: "black", 2: "blue", 3: "red"} + d = _color_arg_to_dict(arg7, ids, None) + assert d == {1: (0.1, 0.2, 0.4), 2: (0.1, 0.2, 0.5)} - arg = np.array(["black", "blue", "red"]) - d = _color_arg_to_dict(arg, ids, None) - assert d == {1: "black", 2: "blue", 3: "red"} + d = _color_arg_to_dict(arg8, ids, None) + assert d == {1: "black", 2: "red"} + + # Test iterables of values + cdict = { + 1: np.array([[0.89173395, 0.93510188, 0.97539408, 1.0]]), + 2: np.array([[0.41708574, 0.68063053, 0.83823145, 1.0]]), + 3: np.array([[0.03137255, 0.28973472, 0.57031911, 1.0]]), + } + d = _color_arg_to_dict(arg9, ids, cm.Blues) + for i in d: + assert np.allclose(d[i], cdict[i]) + + d = _color_arg_to_dict(arg10, ids, cm.Blues) + for i in d: + assert np.allclose(d[i], cdict[i]) + + d = _color_arg_to_dict(arg11, ids, cm.Blues) + for i in d: + assert np.allclose(d[i], cdict[i]) H = xgi.Hypergraph(edgelist4) arg = H.nodes.degree @@ -88,6 +134,15 @@ def test_color_arg_to_dict(edgelist4): assert np.allclose(d[2], np.array([[0.98357555, 0.41279508, 0.28835063, 1.0]])) assert np.allclose(d[3], np.array([[0.59461745, 0.0461361, 0.07558631, 1.0]])) + # Test bad calls + with pytest.raises(TypeError): + arg = 0.3 + d = _color_arg_to_dict(arg, ids, None) + + with pytest.raises(TypeError): + arg = 1 + d = _color_arg_to_dict(arg, ids, None) + def test_draw(edgelist8): H = xgi.Hypergraph(edgelist8) diff --git a/xgi/convert.py b/xgi/convert.py index b32233842..5ae2017e5 100644 --- a/xgi/convert.py +++ b/xgi/convert.py @@ -735,8 +735,11 @@ def to_bipartite_graph(H, index=False): """ G = nx.Graph() - node_dict = dict(zip(H.nodes, range(H.num_nodes))) - edge_dict = dict(zip(H.edges, range(H.num_nodes, H.num_nodes + H.num_edges))) + n = H.num_nodes + m = H.num_edges + + node_dict = dict(zip(H.nodes, range(n))) + edge_dict = dict(zip(H.edges, range(n, n + m))) G.add_nodes_from(node_dict.values(), bipartite=0) G.add_nodes_from(edge_dict.values(), bipartite=1) for node in H.nodes: @@ -746,8 +749,8 @@ def to_bipartite_graph(H, index=False): if index: return ( G, - dict(zip(range(H.num_nodes), H.nodes)), - dict(zip(range(H.num_nodes, H.num_nodes + H.num_edges), H.edges)), + dict(zip(range(n), H.nodes)), + dict(zip(range(n, n + m), H.edges)), ) else: return G diff --git a/xgi/drawing/draw.py b/xgi/drawing/draw.py index 97c96b6df..023fda047 100644 --- a/xgi/drawing/draw.py +++ b/xgi/drawing/draw.py @@ -11,12 +11,13 @@ from matplotlib.patches import FancyArrow from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection from networkx import spring_layout +from numpy import ndarray from scipy.spatial import ConvexHull from .. import convert from ..classes import DiHypergraph, Hypergraph, SimplicialComplex, max_edge_order from ..exception import XGIError -from ..stats import EdgeStat, NodeStat +from ..stats import IDStat from .layout import _augmented_projection, barycenter_spring_layout __all__ = [ @@ -689,12 +690,12 @@ def draw_simplices( return ax -def _scalar_arg_to_dict(arg, ids, min_val, max_val): +def _scalar_arg_to_dict(scalar_arg, ids, min_val, max_val): """Map different types of arguments for drawing style to a dict with scalar values. Parameters ---------- - arg : int, float, dict, iterable, or NodeStat/EdgeStat + scalar_arg : int, float, dict, iterable, or NodeStat/EdgeStat Attributes for drawing parameter. ids : NodeView or EdgeView This is the node or edge IDs that attributes get mapped to. @@ -706,40 +707,89 @@ def _scalar_arg_to_dict(arg, ids, min_val, max_val): Returns ------- dict - An ID: attribute dictionary. + An ID: scalar dictionary. Raises ------ TypeError If a int, float, list, dict, or NodeStat/EdgeStat is not passed. """ - if isinstance(arg, str): + if isinstance(scalar_arg, str): raise TypeError( "Argument must be int, float, dict, iterable, " - f"or NodeStat/EdgeStat. Received {type(arg)}" + f"or NodeStat/EdgeStat. Received {type(scalar_arg)}" ) - elif isinstance(arg, dict): - return {id: arg[id] for id in arg if id in ids} - elif type(arg) in [int, float]: - return {id: arg for id in ids} - elif isinstance(arg, NodeStat) or isinstance(arg, EdgeStat): - vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [min_val, max_val]) - return dict(zip(ids, vals)) - elif isinstance(arg, Iterable): - return {id: arg[idx] for idx, id in enumerate(ids)} - else: - raise TypeError( - "Argument must be int, float, dict, iterable, " - f"or NodeStat/EdgeStat. Received {type(arg)}" + + # Single argument + if isinstance(scalar_arg, (int, float)): + return {id: scalar_arg for id in ids} + + # IDStat + if isinstance(scalar_arg, IDStat): + vals = np.interp( + scalar_arg.asnumpy(), + [scalar_arg.min(), scalar_arg.max()], + [min_val, max_val], ) + return dict(zip(ids, vals)) + + # Iterables of floats or ints + if isinstance(scalar_arg, Iterable): + if isinstance(scalar_arg, dict): + try: + return {id: float(scalar_arg[id]) for id in scalar_arg if id in ids} + except ValueError as e: + raise TypeError( + "The input dict must have values that can be cast to floats." + ) + elif isinstance(scalar_arg, (list, ndarray)): + try: + return {id: float(scalar_arg[idx]) for idx, id in enumerate(ids)} + except ValueError as e: + raise TypeError( + "The input list or array must have values that can be cast to floats." + ) + else: + raise TypeError( + "Argument must be an dict, list, or numpy array of floats or ints." + ) + + raise TypeError( + "Argument must be int, float, dict, iterable, " + f"or NodeStat/EdgeStat. Received {type(scalar_arg)}" + ) -def _color_arg_to_dict(arg, ids, cmap): + +def _color_arg_to_dict(color_arg, ids, cmap): """Map different types of arguments for drawing style to a dict with color values. Parameters ---------- - arg : str, dict, iterable, or NodeStat/EdgeStat + color_arg : Several formats are accepted: + + Single color values + + * str + * 3- or 4-tuple + + Iterable of colors (each color specified as above) + + * numpy array + * list + * dict {id: color} pairs + + Iterable of numerical values (floats or ints) + + * list + * dict + * numpy array + + Stats + + * NodeStat + * EdgeStat + Attributes for drawing parameter. ids : NodeView or EdgeView This is the node or edge IDs that attributes get mapped to. @@ -749,33 +799,93 @@ def _color_arg_to_dict(arg, ids, cmap): Returns ------- dict - An ID: attribute dictionary. + An ID: color dictionary. Raises ------ TypeError If a string, dict, iterable, or NodeStat/EdgeStat is not passed. + + Notes + ----- + For the iterable of values, we do not accept tuples, + because there is the potential for ambiguity. """ - if isinstance(arg, dict): - return {id: arg[id] for id in arg if id in ids} - elif isinstance(arg, str): - return {id: arg for id in ids} - elif isinstance(arg, (NodeStat, EdgeStat)): + + # single argument. Must be a string or a tuple of floats + if isinstance(color_arg, str) or ( + isinstance(color_arg, tuple) and isinstance(color_arg[0], float) + ): + return {id: color_arg for id in ids} + + # Iterables of colors. The values of these iterables must strings or tuples. As of now, + # there is not a check to verify that the tuples contain floats. + if isinstance(color_arg, Iterable): + if isinstance(color_arg, dict) and isinstance( + next(iter(color_arg.values())), (str, tuple, ndarray) + ): + return {id: color_arg[id] for id in color_arg if id in ids} + if isinstance(color_arg, (list, ndarray)) and isinstance( + color_arg[0], (str, tuple, ndarray) + ): + return {id: color_arg[idx] for idx, id in enumerate(ids)} + + # Stats or iterable of values + if isinstance(color_arg, (Iterable, IDStat)): + # set max and min of interpolation based on color map if isinstance(cmap, ListedColormap): - vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [0, cmap.N]) + minval = 0 + maxval = cmap.N elif isinstance(cmap, LinearSegmentedColormap): - vals = np.interp(arg.asnumpy(), [arg.min(), arg.max()], [0.1, 0.9]) + minval = 0.1 + maxval = 0.9 else: raise XGIError("Invalid colormap!") - return {id: np.array(cmap(vals[i])).reshape(1, -1) for i, id in enumerate(ids)} - elif isinstance(arg, Iterable): - return {id: arg[idx] for idx, id in enumerate(ids)} - else: - raise TypeError( - "Argument must be str, dict, iterable, or " - f"NodeStat/EdgeStat. Received {type(arg)}" - ) + # handle the case of IDStat vs iterables + if isinstance(color_arg, IDStat): + vals = np.interp( + color_arg.asnumpy(), + [color_arg.min(), color_arg.max()], + [minval, maxval], + ) + return { + id: np.array(cmap(vals[i])).reshape(1, -1) for i, id in enumerate(ids) + } + + elif isinstance(color_arg, Iterable): + if isinstance(color_arg, dict) and isinstance( + next(iter(color_arg.values())), (int, float) + ): + v = list(color_arg.values()) + vals = np.interp(v, [np.min(v), np.max(v)], [minval, maxval]) + # because we have ids, we can't just assume that the keys of arg correspond to + # the ids. + return { + id: np.array(cmap(v)).reshape(1, -1) + for v, id in zip(vals, color_arg.keys()) + if id in ids + } + + if isinstance(color_arg, (list, ndarray)) and isinstance( + color_arg[0], (int, float) + ): + vals = np.interp( + color_arg, [np.min(color_arg), np.max(color_arg)], [minval, maxval] + ) + return { + id: np.array(cmap(vals[i])).reshape(1, -1) + for i, id in enumerate(ids) + } + else: + raise TypeError( + "Argument must be an dict, list, or numpy array of floats." + ) + + raise TypeError( + "Argument must be str, dict, iterable, or " + f"NodeStat/EdgeStat. Received {type(color_arg)}" + ) def _CCW_sort(p):