Skip to content

Commit

Permalink
Merge pull request #394 from thomasrobiglio/fix_draw_dh
Browse files Browse the repository at this point in the history
bug fix for draw_dihypergraph
  • Loading branch information
thomasrobiglio authored Jun 9, 2023
2 parents 7a74996 + 6180428 commit 4402c47
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 49 deletions.
21 changes: 20 additions & 1 deletion tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def test_draw_dihypergraph(diedgelist2, edgelist8):
# number of elements
assert len(ax1.lines) == 7 # number of source nodes
assert len(ax1.patches) == 4 # number of target nodes
assert len(ax1.collections) == DH.num_edges + 1 # hyperedges markers + nodes
assert len(ax1.collections) == DH.num_edges + 1 - len(
DH.edges.filterby("size", 1)
) # hyperedges markers + nodes

# zorder
for line, z in zip(ax1.lines, [1, 1, 1, 1, 0, 0, 0]): # lines for source nodes
Expand All @@ -253,3 +255,20 @@ def test_draw_dihypergraph(diedgelist2, edgelist8):
H = xgi.Hypergraph(edgelist8)
ax3 = xgi.draw_dihypergraph(H)
plt.close()


def test_draw_dihypergraph_with_str_labels_and_isolated_nodes():
DH1 = xgi.DiHypergraph()
DH1.add_edges_from(
[
[{"one"}, {"two", "three"}],
[{"two", "three"}, {"four", "five"}],
[{"six"}, {}],
]
)
ax4 = xgi.draw_dihypergraph(DH1)
assert len(ax4.lines) == 3
assert len(ax4.patches) == 4
assert len(ax4.collections) == DH1.num_edges + 1 - len(
DH1.edges.filterby("size", 1)
)
101 changes: 53 additions & 48 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,61 +1633,66 @@ def draw_dihypergraph(

G_aug = _augmented_projection(H_conv)
for dyad in H_conv.edges.filterby("size", 2).members():
index = max(G_aug.nodes) + 1
try:
index = max(n for n in G_aug.nodes if isinstance(n, int)) + 1
except ValueError:
# The list of node-labels has no integers, so I start from 0
index = 0
G_aug.add_edges_from([[list(dyad)[0], index], [list(dyad)[1], index]])

phantom_nodes = [n for n in list(G_aug.nodes) if n not in list(H_conv.nodes)]
pos = spring_layout(G_aug)

for id, he in DH.edges.members(dtype=dict).items():
d = len(he) - 1
# identify the center of the edge in the augemented projection
center = [n for n in phantom_nodes if set(G_aug.neighbors(n)) == he][0]
x_center, y_center = pos[center]
for node in DH.edges.dimembers(id)[0]:
x_coords = [pos[node][0], x_center]
y_coords = [pos[node][1], y_center]
line = plt.Line2D(
x_coords,
y_coords,
color=lines_fc[id],
lw=lines_lw[id],
zorder=max_order - d,
)
ax.add_line(line)
for node in DH.edges.dimembers(id)[1]:
dx, dy = pos[node][0] - x_center, pos[node][1] - y_center
# the following to avoid the point of the arrow overlapping the node
distance = np.hypot(dx, dy)
direction_vector = np.array([dx, dy]) / distance
shortened_distance = (
distance - node_size[node] * 0.003
) # Calculate the shortened length
dx = direction_vector[0] * shortened_distance
dy = direction_vector[1] * shortened_distance
arrow = FancyArrow(
x_center,
y_center,
dx,
dy,
color=lines_fc[id],
width=lines_lw[id] * 0.001,
length_includes_head=True,
head_width=line_head_width,
zorder=max_order - d,
)
ax.add_patch(arrow)
if edge_marker_toggle:
ax.scatter(
x=x_center,
y=y_center,
marker=edge_marker,
s=edge_marker_size**2,
c=edge_marker_fc[id],
edgecolors=edge_marker_ec[id],
linewidths=edge_marker_lw,
zorder=max_order,
)
if d > 0:
# identify the center of the edge in the augemented projection
center = [n for n in phantom_nodes if set(G_aug.neighbors(n)) == he][0]
x_center, y_center = pos[center]
for node in DH.edges.dimembers(id)[0]:
x_coords = [pos[node][0], x_center]
y_coords = [pos[node][1], y_center]
line = plt.Line2D(
x_coords,
y_coords,
color=lines_fc[id],
lw=lines_lw[id],
zorder=max_order - d,
)
ax.add_line(line)
for node in DH.edges.dimembers(id)[1]:
dx, dy = pos[node][0] - x_center, pos[node][1] - y_center
# the following to avoid the point of the arrow overlapping the node
distance = np.hypot(dx, dy)
direction_vector = np.array([dx, dy]) / distance
shortened_distance = (
distance - node_size[node] * 0.003
) # Calculate the shortened length
dx = direction_vector[0] * shortened_distance
dy = direction_vector[1] * shortened_distance
arrow = FancyArrow(
x_center,
y_center,
dx,
dy,
color=lines_fc[id],
width=lines_lw[id] * 0.001,
length_includes_head=True,
head_width=line_head_width,
zorder=max_order - d,
)
ax.add_patch(arrow)
if edge_marker_toggle:
ax.scatter(
x=x_center,
y=y_center,
marker=edge_marker,
s=edge_marker_size**2,
c=edge_marker_fc[id],
edgecolors=edge_marker_ec[id],
linewidths=edge_marker_lw,
zorder=max_order,
)

if hyperedge_labels:
# Get all valid keywords by inspecting the signatures of draw_node_labels
Expand Down

0 comments on commit 4402c47

Please sign in to comment.