diff --git a/src/steiner_tree.jl b/src/steiner_tree.jl index 6ccc8db..ecffb05 100644 --- a/src/steiner_tree.jl +++ b/src/steiner_tree.jl @@ -1,7 +1,7 @@ using Graphs: Graphs, IsDirected, nv, steiner_tree using SimpleTraits: SimpleTraits, Not, @traitfn -@traitfn function Graphs.steiner_tree( +@traitfn function namedgraph_steiner_tree( g::AbstractNamedGraph::(!IsDirected), term_vert, distmx=weights(g) ) position_tree = steiner_tree( @@ -15,3 +15,15 @@ using SimpleTraits: SimpleTraits, Not, @traitfn end return tree end + +@traitfn function Graphs.steiner_tree( + g::AbstractNamedGraph::(!IsDirected), term_vert, args... +) + return namedgraph_steiner_tree(g, term_vert, args...) +end + +@traitfn function Graphs.steiner_tree( + g::AbstractNamedGraph::(!IsDirected), term_vert::Vector{<:Integer}, args... +) + return namedgraph_steiner_tree(g, term_vert, args...) +end diff --git a/test/test_namedgraph.jl b/test/test_namedgraph.jl index 5ec2802..eddaf0f 100644 --- a/test/test_namedgraph.jl +++ b/test/test_namedgraph.jl @@ -680,6 +680,16 @@ end for e in es @test has_edge(st, e) end + + g = named_path_graph(4) + terminal_vertices = [1, 3] + st = steiner_tree(g, terminal_vertices) + es = [1 => 2, 2 => 3] + @test ne(st) == 2 + @test nv(st) == 3 + for e in es + @test has_edge(st, e) + end end @testset "topological_sort_by_dfs" begin g = NamedDiGraph(["A", "B", "C", "D", "E", "F", "G"])