diff --git a/src/RigidBodyDynamics.jl b/src/RigidBodyDynamics.jl index 1365ee42..96ff7017 100644 --- a/src/RigidBodyDynamics.jl +++ b/src/RigidBodyDynamics.jl @@ -87,7 +87,8 @@ export replace_joint!, maximal_coordinates, submechanism, - remove_fixed_tree_joints! + remove_fixed_tree_joints!, + remove_subtree! # contact-related functionality export # note: contact-related functionality may be changed significantly in the future diff --git a/src/graphs/Graphs.jl b/src/graphs/Graphs.jl index bf19eb4b..84a92239 100644 --- a/src/graphs/Graphs.jl +++ b/src/graphs/Graphs.jl @@ -39,6 +39,7 @@ export tree_index, ancestors, lowest_common_ancestor, + subtree_vertices, direction, directions diff --git a/src/graphs/spanning_tree.jl b/src/graphs/spanning_tree.jl index 2572f56e..b5728483 100644 --- a/src/graphs/spanning_tree.jl +++ b/src/graphs/spanning_tree.jl @@ -82,7 +82,9 @@ function SpanningTree(g::DirectedGraph{V, E}, root::V, flipped_edge_map::Union{A SpanningTree(g, root, tree_edges) end -# adds an edge and vertex to both the tree and the underlying graph +""" +Add an edge and vertex to both the tree and the underlying graph. +""" function add_edge!(tree::SpanningTree{V, E}, source::V, target::V, edge::E) where {V, E} @assert target ∉ vertices(tree) add_edge!(tree.graph, source, target, edge) @@ -96,6 +98,9 @@ function add_edge!(tree::SpanningTree{V, E}, source::V, target::V, edge::E) wher tree end +""" +Replace an edge in both the tree and the underlying graph. +""" function replace_edge!(tree::SpanningTree{V, E}, old_edge::E, new_edge::E) where {V, E} @assert old_edge ∈ edges(tree) src = source(old_edge, tree) @@ -145,3 +150,21 @@ function lowest_common_ancestor(v1::V, v2::V, tree::SpanningTree{V, E}) where {V end v1 end + +""" +Return a list of vertices in the subtree rooted at `subtree_root`, including `subtree_root` itself. +The list is guaranteed to be topologically sorted. +""" +function subtree_vertices(subtree_root::V, tree::SpanningTree{V, E}) where {V, E} + @assert subtree_root ∈ vertices(tree) + frontier = [subtree_root] + subtree_vertices = V[] + while !isempty(frontier) + parent = pop!(frontier) + push!(subtree_vertices, parent) + for child in out_neighbors(parent, tree) + push!(frontier, child) + end + end + return subtree_vertices +end diff --git a/src/mechanism_modification.jl b/src/mechanism_modification.jl index 2ac1767a..bd3b55c4 100644 --- a/src/mechanism_modification.jl +++ b/src/mechanism_modification.jl @@ -152,6 +152,41 @@ function submechanism(mechanism::Mechanism{T}, submechanismroot::RigidBody{T}; ret end +""" +Remove all bodies in the subtree rooted at `subtree_root`, including `subtree_root` itself, +as well as all joints connected to these bodies. + +The ordering of the joints that remain in the mechanism is retained. +""" +function remove_subtree!(mechanism::Mechanism{T}, subtree_root::RigidBody{T}) where {T} + @assert subtree_root != root_body(mechanism) + tree = mechanism.tree + graph = mechanism.graph + bodies_to_remove = subtree_vertices(subtree_root, tree) + new_tree_joints = copy(edges(tree)) + for body in bodies_to_remove + # Remove the tree joint from our new ordered list of joints. + tree_joint = edge_to_parent(body, tree) + deleteat!(new_tree_joints, findfirst(isequal(tree_joint), new_tree_joints)) + end + for body in bodies_to_remove + # Remove all edges to and from the vertex in the graph. + for joint in copy(in_edges(body, graph)) + remove_edge!(graph, joint) + end + for joint in copy(out_edges(body, graph)) + remove_edge!(graph, joint) + end + + # Remove the vertex itself. + remove_vertex!(graph, body) + end + # Construct a new spanning tree with the new list of tree joints. + mechanism.tree = SpanningTree(graph, root_body(mechanism), new_tree_joints) + canonicalize_graph!(mechanism) + mechanism +end + """ $(SIGNATURES) diff --git a/test/test_graph.jl b/test/test_graph.jl index b23e578d..1d8c6145 100644 --- a/test/test_graph.jl +++ b/test/test_graph.jl @@ -326,6 +326,25 @@ Graphs.flip_direction(edge::Edge{Int32}) = Edge(-edge.data) end end + @testset "subtree_vertices" begin + graph = DirectedGraph{Vertex{Int64}, Edge{Int32}}() + root = Vertex(0) + add_vertex!(graph, root) + tree = SpanningTree(graph, root) + for i = 1 : 30 + parent = vertices(tree)[rand(1 : num_vertices(graph))] + child = Vertex(i) + edge = Edge(Int32(i + 3)) + add_edge!(tree, parent, child, edge) + end + for subtree_root in vertices(tree) + subtree = subtree_vertices(subtree_root, tree) + for vertex in vertices(tree) + @test (vertex ∈ subtree) == (subtree_root ∈ ancestors(vertex, tree)) + end + end + end + @testset "reindex!" begin Random.seed!(15) graph = DirectedGraph{Vertex{Int64}, Edge{Float64}}() diff --git a/test/test_mechanism_modification.jl b/test/test_mechanism_modification.jl index 000c2185..4c45a90a 100644 --- a/test/test_mechanism_modification.jl +++ b/test/test_mechanism_modification.jl @@ -357,4 +357,55 @@ showerror(devnull, e) end end + + @testset "remove_subtree! - tree mechanism" begin + mechanism = parse_urdf(joinpath(@__DIR__, "urdf", "atlas.urdf"), floating=true) + @test_throws AssertionError remove_subtree!(mechanism, root_body(mechanism)) + + original_joints = copy(joints(mechanism)) + + # Behead. + num_bodies = length(bodies(mechanism)) + num_joints = length(joints(mechanism)) + head = findbody(mechanism, "head") + neck_joint = joint_to_parent(head, mechanism) + remove_subtree!(mechanism, head) + @test length(bodies(mechanism)) == num_bodies - 1 + @test length(joints(mechanism)) == num_joints - 1 + @test head ∉ bodies(mechanism) + @test neck_joint ∉ joints(mechanism) + + # Lop off an arm. + num_bodies = length(bodies(mechanism)) + num_joints = length(joints(mechanism)) + r_clav = findbody(mechanism, "r_clav") + r_hand = findbody(mechanism, "r_hand") + r_arm = path(mechanism, r_clav, r_hand) + arm_joints = collect(r_arm) + arm_bodies = [r_clav; map(joint -> successor(joint, mechanism), arm_joints)] + remove_subtree!(mechanism, r_clav) + @test length(joints(mechanism)) == num_joints - length(arm_joints) - 1 + @test length(bodies(mechanism)) == num_bodies - length(arm_bodies) + @test isempty(intersect(arm_joints, joints(mechanism))) + @test isempty(intersect(arm_bodies, bodies(mechanism))) + @test issorted(joints(mechanism), by=joint ->findfirst(isequal(joint), original_joints)) + end + + @testset "remove_subtree! - maximal coordinates" begin + original = parse_urdf(joinpath(@__DIR__, "urdf", "atlas.urdf"), floating=true) + mechanism = maximal_coordinates(original) + num_bodies = length(bodies(mechanism)) + num_joints = length(joints(mechanism)) + @test_throws AssertionError remove_subtree!(mechanism, findbody(original, "head")) # body not in tree + head = findbody(mechanism, "head") + head_joints = copy(in_joints(head, mechanism)) + @test length(head_joints) == 2 # floating joint + neck loop joint + remove_subtree!(mechanism, head) + @test length(joints(mechanism)) == num_joints - length(head_joints) + @test length(bodies(mechanism)) == num_bodies - 1 + for joint in head_joints + @test joint ∉ joints(mechanism) + end + @test head ∉ bodies(mechanism) + end end # mechanism modification