diff --git a/tests/drawing/test_draw.py b/tests/drawing/test_draw.py index 90a8d0d3c..143ded3bf 100644 --- a/tests/drawing/test_draw.py +++ b/tests/drawing/test_draw.py @@ -230,7 +230,9 @@ def test_draw_dihypergraph(diedgelist2, edgelist8): # number of elements assert len(ax1.lines) == 7 # number of source nodes assert len(ax1.patches) == 4 # number of target nodes - assert len(ax1.collections) == DH.num_edges + 1 # hyperedges markers + nodes + assert len(ax1.collections) == DH.num_edges + 1 - len( + DH.edges.filterby("size", 1) + ) # hyperedges markers + nodes # zorder for line, z in zip(ax1.lines, [1, 1, 1, 1, 0, 0, 0]): # lines for source nodes @@ -253,3 +255,20 @@ def test_draw_dihypergraph(diedgelist2, edgelist8): H = xgi.Hypergraph(edgelist8) ax3 = xgi.draw_dihypergraph(H) plt.close() + + +def test_draw_dihypergraph_with_str_labels_and_isolated_nodes(): + DH1 = xgi.DiHypergraph() + DH1.add_edges_from( + [ + [{"one"}, {"two", "three"}], + [{"two", "three"}, {"four", "five"}], + [{"six"}, {}], + ] + ) + ax4 = xgi.draw_dihypergraph(DH1) + assert len(ax4.lines) == 3 + assert len(ax4.patches) == 4 + assert len(ax4.collections) == DH1.num_edges + 1 - len( + DH1.edges.filterby("size", 1) + ) diff --git a/xgi/drawing/draw.py b/xgi/drawing/draw.py index 61c041bda..97c96b6df 100644 --- a/xgi/drawing/draw.py +++ b/xgi/drawing/draw.py @@ -1633,7 +1633,11 @@ def draw_dihypergraph( G_aug = _augmented_projection(H_conv) for dyad in H_conv.edges.filterby("size", 2).members(): - index = max(G_aug.nodes) + 1 + try: + index = max(n for n in G_aug.nodes if isinstance(n, int)) + 1 + except ValueError: + # The list of node-labels has no integers, so I start from 0 + index = 0 G_aug.add_edges_from([[list(dyad)[0], index], [list(dyad)[1], index]]) phantom_nodes = [n for n in list(G_aug.nodes) if n not in list(H_conv.nodes)] @@ -1641,53 +1645,54 @@ def draw_dihypergraph( for id, he in DH.edges.members(dtype=dict).items(): d = len(he) - 1 - # identify the center of the edge in the augemented projection - center = [n for n in phantom_nodes if set(G_aug.neighbors(n)) == he][0] - x_center, y_center = pos[center] - for node in DH.edges.dimembers(id)[0]: - x_coords = [pos[node][0], x_center] - y_coords = [pos[node][1], y_center] - line = plt.Line2D( - x_coords, - y_coords, - color=lines_fc[id], - lw=lines_lw[id], - zorder=max_order - d, - ) - ax.add_line(line) - for node in DH.edges.dimembers(id)[1]: - dx, dy = pos[node][0] - x_center, pos[node][1] - y_center - # the following to avoid the point of the arrow overlapping the node - distance = np.hypot(dx, dy) - direction_vector = np.array([dx, dy]) / distance - shortened_distance = ( - distance - node_size[node] * 0.003 - ) # Calculate the shortened length - dx = direction_vector[0] * shortened_distance - dy = direction_vector[1] * shortened_distance - arrow = FancyArrow( - x_center, - y_center, - dx, - dy, - color=lines_fc[id], - width=lines_lw[id] * 0.001, - length_includes_head=True, - head_width=line_head_width, - zorder=max_order - d, - ) - ax.add_patch(arrow) - if edge_marker_toggle: - ax.scatter( - x=x_center, - y=y_center, - marker=edge_marker, - s=edge_marker_size**2, - c=edge_marker_fc[id], - edgecolors=edge_marker_ec[id], - linewidths=edge_marker_lw, - zorder=max_order, - ) + if d > 0: + # identify the center of the edge in the augemented projection + center = [n for n in phantom_nodes if set(G_aug.neighbors(n)) == he][0] + x_center, y_center = pos[center] + for node in DH.edges.dimembers(id)[0]: + x_coords = [pos[node][0], x_center] + y_coords = [pos[node][1], y_center] + line = plt.Line2D( + x_coords, + y_coords, + color=lines_fc[id], + lw=lines_lw[id], + zorder=max_order - d, + ) + ax.add_line(line) + for node in DH.edges.dimembers(id)[1]: + dx, dy = pos[node][0] - x_center, pos[node][1] - y_center + # the following to avoid the point of the arrow overlapping the node + distance = np.hypot(dx, dy) + direction_vector = np.array([dx, dy]) / distance + shortened_distance = ( + distance - node_size[node] * 0.003 + ) # Calculate the shortened length + dx = direction_vector[0] * shortened_distance + dy = direction_vector[1] * shortened_distance + arrow = FancyArrow( + x_center, + y_center, + dx, + dy, + color=lines_fc[id], + width=lines_lw[id] * 0.001, + length_includes_head=True, + head_width=line_head_width, + zorder=max_order - d, + ) + ax.add_patch(arrow) + if edge_marker_toggle: + ax.scatter( + x=x_center, + y=y_center, + marker=edge_marker, + s=edge_marker_size**2, + c=edge_marker_fc[id], + edgecolors=edge_marker_ec[id], + linewidths=edge_marker_lw, + zorder=max_order, + ) if hyperedge_labels: # Get all valid keywords by inspecting the signatures of draw_node_labels