Skip to content

Commit

Permalink
Added arc_layout
Browse files Browse the repository at this point in the history
  • Loading branch information
cvanelteren committed Sep 26, 2023
1 parent 620f675 commit 8f1bcf3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
45 changes: 45 additions & 0 deletions networkx/drawing/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"spiral_layout",
"multipartite_layout",
"arf_layout",
"arc_layout",
]


Expand Down Expand Up @@ -1219,6 +1220,50 @@ def arf_layout(
return dict(zip(G.nodes(), p))


def arc_layout(G: nx.Graph, subset_key="subset", radius=1, rotation=0) -> dict:
"""Arc layout for networkx
Provides a layout where a multipartite graph is
displayed on a unit circle. This could provide clear
visuals for data that is highly clustered.
Parameters
----------
G : nx.Graph
Networkx (Di)Graph
subset_key : object
Node attribute to cluster the network on
radius : float
Radius of the unit circle
rotation : float
Rotation of the axes of the unit circle
Returns
-------
pos : dict
A dictionary of positions keyed by node.
Examples
--------
>>> subset_sizes = [5, 5, 4, 3, 2, 4, 4, 3]
>>> G = nx.complete_multipartite_graph(*subset_sizes)
>>> pos = nx.arf_layout(G, subset_key = "subset")
"""
import numpy as np

attrs = nx.get_node_attributes(G, subset_key)
categories = set(attrs.values())
angles = np.linspace(0, 2 * np.pi, len(categories), 0) + rotation
pos = {}
for category, angle in zip(categories, angles):
# collect nodes
subset = [node for node, attr in attrs.items() if attr == category]
radii = np.linspace(0, radius, len(subset), 0)
for node, rad in zip(subset, radii):
pos[node] = rad * np.array([np.cos(angle), np.sin(angle)])
return pos


def rescale_layout(pos, scale=1):
"""Returns scaled position array to (-scale, scale) in all axes.
Expand Down
21 changes: 21 additions & 0 deletions networkx/drawing/tests/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ def test_multipartite_layout(self):

pytest.raises(ValueError, nx.multipartite_layout, G, align="foo")

def test_arc_layout(self):
sizes = (0, 5, 7, 2, 8)
G = nx.complete_multipartite_graph(*sizes)

vpos = nx.arc_layout(G, subset_key="subset")
assert len(vpos) == len(G)

pytest.raises(TypeError, nx.multipartite_layout, G, radius="non-sensible input")

def test_kamada_kawai_costfn_1d(self):
costfn = nx.drawing.layout._kamada_kawai_costfn

Expand Down Expand Up @@ -448,6 +457,18 @@ def test_multipartite_layout_nonnumeric_partition_labels():
assert len(pos) == len(G)


def test_arc_layout_nonnumeric_paratition_labels():
"""see gh-5123."""
g = nx.Graph()
g.add_node(0, subset="s0")
g.add_node(1, subset="s0")
g.add_node(2, subset="s1")
g.add_node(3, subset="s1")
g.add_edges_from([(0, 2), (0, 3), (1, 2)])
pos = nx.arc_layout(g, subset_key="subset")
assert len(pos) == len(g)


def test_multipartite_layout_layer_order():
"""Return the layers in sorted order if the layers of the multipartite
graph are sortable. See gh-5691"""
Expand Down

0 comments on commit 8f1bcf3

Please sign in to comment.