Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add remove_subtree! #621

Merged
merged 1 commit into from
May 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/RigidBodyDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ export
replace_joint!,
maximal_coordinates,
submechanism,
remove_fixed_tree_joints!
remove_fixed_tree_joints!,
remove_subtree!

# contact-related functionality
export # note: contact-related functionality may be changed significantly in the future
Expand Down
1 change: 1 addition & 0 deletions src/graphs/Graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export
tree_index,
ancestors,
lowest_common_ancestor,
subtree_vertices,
direction,
directions

Expand Down
25 changes: 24 additions & 1 deletion src/graphs/spanning_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ function SpanningTree(g::DirectedGraph{V, E}, root::V, flipped_edge_map::Union{A
SpanningTree(g, root, tree_edges)
end

# adds an edge and vertex to both the tree and the underlying graph
"""
Add an edge and vertex to both the tree and the underlying graph.
"""
function add_edge!(tree::SpanningTree{V, E}, source::V, target::V, edge::E) where {V, E}
@assert target ∉ vertices(tree)
add_edge!(tree.graph, source, target, edge)
Expand All @@ -96,6 +98,9 @@ function add_edge!(tree::SpanningTree{V, E}, source::V, target::V, edge::E) wher
tree
end

"""
Replace an edge in both the tree and the underlying graph.
"""
function replace_edge!(tree::SpanningTree{V, E}, old_edge::E, new_edge::E) where {V, E}
@assert old_edge ∈ edges(tree)
src = source(old_edge, tree)
Expand Down Expand Up @@ -145,3 +150,21 @@ function lowest_common_ancestor(v1::V, v2::V, tree::SpanningTree{V, E}) where {V
end
v1
end

"""
Return a list of vertices in the subtree rooted at `subtree_root`, including `subtree_root` itself.
The list is guaranteed to be topologically sorted.
"""
function subtree_vertices(subtree_root::V, tree::SpanningTree{V, E}) where {V, E}
@assert subtree_root ∈ vertices(tree)
frontier = [subtree_root]
subtree_vertices = V[]
while !isempty(frontier)
parent = pop!(frontier)
push!(subtree_vertices, parent)
for child in out_neighbors(parent, tree)
push!(frontier, child)
end
end
return subtree_vertices
end
35 changes: 35 additions & 0 deletions src/mechanism_modification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,41 @@ function submechanism(mechanism::Mechanism{T}, submechanismroot::RigidBody{T};
ret
end

"""
Remove all bodies in the subtree rooted at `subtree_root`, including `subtree_root` itself,
as well as all joints connected to these bodies.

The ordering of the joints that remain in the mechanism is retained.
"""
function remove_subtree!(mechanism::Mechanism{T}, subtree_root::RigidBody{T}) where {T}
@assert subtree_root != root_body(mechanism)
tree = mechanism.tree
graph = mechanism.graph
bodies_to_remove = subtree_vertices(subtree_root, tree)
new_tree_joints = copy(edges(tree))
for body in bodies_to_remove
# Remove the tree joint from our new ordered list of joints.
tree_joint = edge_to_parent(body, tree)
deleteat!(new_tree_joints, findfirst(isequal(tree_joint), new_tree_joints))
end
for body in bodies_to_remove
# Remove all edges to and from the vertex in the graph.
for joint in copy(in_edges(body, graph))
remove_edge!(graph, joint)
end
for joint in copy(out_edges(body, graph))
remove_edge!(graph, joint)
end

# Remove the vertex itself.
remove_vertex!(graph, body)
end
# Construct a new spanning tree with the new list of tree joints.
mechanism.tree = SpanningTree(graph, root_body(mechanism), new_tree_joints)
canonicalize_graph!(mechanism)
mechanism
end

"""
$(SIGNATURES)

Expand Down
19 changes: 19 additions & 0 deletions test/test_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,25 @@ Graphs.flip_direction(edge::Edge{Int32}) = Edge(-edge.data)
end
end

@testset "subtree_vertices" begin
graph = DirectedGraph{Vertex{Int64}, Edge{Int32}}()
root = Vertex(0)
add_vertex!(graph, root)
tree = SpanningTree(graph, root)
for i = 1 : 30
parent = vertices(tree)[rand(1 : num_vertices(graph))]
child = Vertex(i)
edge = Edge(Int32(i + 3))
add_edge!(tree, parent, child, edge)
end
for subtree_root in vertices(tree)
subtree = subtree_vertices(subtree_root, tree)
for vertex in vertices(tree)
@test (vertex ∈ subtree) == (subtree_root ∈ ancestors(vertex, tree))
end
end
end

@testset "reindex!" begin
Random.seed!(15)
graph = DirectedGraph{Vertex{Int64}, Edge{Float64}}()
Expand Down
51 changes: 51 additions & 0 deletions test/test_mechanism_modification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,55 @@
showerror(devnull, e)
end
end

@testset "remove_subtree! - tree mechanism" begin
mechanism = parse_urdf(joinpath(@__DIR__, "urdf", "atlas.urdf"), floating=true)
@test_throws AssertionError remove_subtree!(mechanism, root_body(mechanism))

original_joints = copy(joints(mechanism))

# Behead.
num_bodies = length(bodies(mechanism))
num_joints = length(joints(mechanism))
head = findbody(mechanism, "head")
neck_joint = joint_to_parent(head, mechanism)
remove_subtree!(mechanism, head)
@test length(bodies(mechanism)) == num_bodies - 1
@test length(joints(mechanism)) == num_joints - 1
@test head ∉ bodies(mechanism)
@test neck_joint ∉ joints(mechanism)

# Lop off an arm.
num_bodies = length(bodies(mechanism))
num_joints = length(joints(mechanism))
r_clav = findbody(mechanism, "r_clav")
r_hand = findbody(mechanism, "r_hand")
r_arm = path(mechanism, r_clav, r_hand)
arm_joints = collect(r_arm)
arm_bodies = [r_clav; map(joint -> successor(joint, mechanism), arm_joints)]
remove_subtree!(mechanism, r_clav)
@test length(joints(mechanism)) == num_joints - length(arm_joints) - 1
@test length(bodies(mechanism)) == num_bodies - length(arm_bodies)
@test isempty(intersect(arm_joints, joints(mechanism)))
@test isempty(intersect(arm_bodies, bodies(mechanism)))
@test issorted(joints(mechanism), by=joint ->findfirst(isequal(joint), original_joints))
end

@testset "remove_subtree! - maximal coordinates" begin
original = parse_urdf(joinpath(@__DIR__, "urdf", "atlas.urdf"), floating=true)
mechanism = maximal_coordinates(original)
num_bodies = length(bodies(mechanism))
num_joints = length(joints(mechanism))
@test_throws AssertionError remove_subtree!(mechanism, findbody(original, "head")) # body not in tree
head = findbody(mechanism, "head")
head_joints = copy(in_joints(head, mechanism))
@test length(head_joints) == 2 # floating joint + neck loop joint
remove_subtree!(mechanism, head)
@test length(joints(mechanism)) == num_joints - length(head_joints)
@test length(bodies(mechanism)) == num_bodies - 1
for joint in head_joints
@test joint ∉ joints(mechanism)
end
@test head ∉ bodies(mechanism)
end
end # mechanism modification