diff --git a/networkx/drawing/layout.py b/networkx/drawing/layout.py index fa120d670748..c6f2f16ec92f 100644 --- a/networkx/drawing/layout.py +++ b/networkx/drawing/layout.py @@ -33,6 +33,7 @@ "spiral_layout", "multipartite_layout", "arf_layout", + "arc_layout", ] @@ -1219,6 +1220,50 @@ def arf_layout( return dict(zip(G.nodes(), p)) +def arc_layout(G: nx.Graph, subset_key="layer", radius=1, rotation=0) -> dict: + """Arc layout for networkx + + Provides a layout where a multipartite graph is + displayed on a unit circle. This could provide clear + visuals for data that is highly clustered. + + Parameters + ---------- + G : nx.Graph + Networkx (Di)Graph + subset_key : object + Node attribute to cluster the network on + radius : float + Radius of the unit circle + rotation : float + Rotation of the axes of the unit circle + + Returns + ------- + pos : dict + A dictionary of positions keyed by node. + + Examples + -------- + >>> subset_sizes = [5, 5, 4, 3, 2, 4, 4, 3] + >>> G = nx.complete_multipartite_graph(*subset_sizes) + >>> pos = nx.arf_layout(G, subset_key = "layer") + """ + import numpy as np + + attrs = nx.get_node_attributes(G, subset_key) + categories = set(attrs.values()) + angles = np.linspace(0, 2 * np.pi, len(categories), 0) + rotation + pos = {} + for category, angle in zip(categories, angles): + # collect nodes + subset = [node for node, attr in attrs.items() if attr == category] + radii = np.linspace(0, radius, len(subset), 0) + for node, rad in zip(subset, radii): + pos[node] = rad * np.array([np.cos(angle), np.sin(angle)]) + return pos + + def rescale_layout(pos, scale=1): """Returns scaled position array to (-scale, scale) in all axes. diff --git a/networkx/drawing/tests/test_layout.py b/networkx/drawing/tests/test_layout.py index 48d0e6d9d888..6bba0defdfbc 100644 --- a/networkx/drawing/tests/test_layout.py +++ b/networkx/drawing/tests/test_layout.py @@ -303,6 +303,15 @@ def test_multipartite_layout(self): pytest.raises(ValueError, nx.multipartite_layout, G, align="foo") + def test_arc_layout(self): + sizes = (0, 5, 7, 2, 8) + G = nx.complete_multipartite_graph(*sizes) + + vpos = nx.arc_layout(G, subset_key="subset") + assert len(vpos) == len(G) + + pytest.raises(TypeError, nx.multipartite_layout, G, radius="non-sensible input") + def test_kamada_kawai_costfn_1d(self): costfn = nx.drawing.layout._kamada_kawai_costfn @@ -448,6 +457,18 @@ def test_multipartite_layout_nonnumeric_partition_labels(): assert len(pos) == len(G) +def test_arc_layout_nonnumeric_paratition_labels(): + """see gh-5123.""" + g = nx.Graph() + g.add_node(0, subset="s0") + g.add_node(1, subset="s0") + g.add_node(2, subset="s1") + g.add_node(3, subset="s1") + g.add_edges_from([(0, 2), (0, 3), (1, 2)]) + pos = nx.arc_layout(g, subset_key="subset") + assert len(pos) == len(g) + + def test_multipartite_layout_layer_order(): """Return the layers in sorted order if the layers of the multipartite graph are sortable. See gh-5691"""