-
Notifications
You must be signed in to change notification settings - Fork 6
/
traversability_fusion
executable file
·173 lines (145 loc) · 7.09 KB
/
traversability_fusion
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import rospy
from sensor_msgs.msg import PointCloud2
from ros_numpy import msgify, numpify
from message_filters import ApproximateTimeSynchronizer, Subscriber
from numpy.lib.recfunctions import structured_to_unstructured, unstructured_to_structured
import scipy.spatial
import numpy as np
from threading import Lock
from tf2_ros import Buffer, TransformException, TransformListener
GEOMETRIC_TOPIC = 'geometric_traversability'
SEMANTIC_TOPIC = 'semantic_traversability'
FUSED_TOPIC = 'fused_traversability'
class TravFusion:
def __init__(self):
self.fixed_frame = rospy.get_param('~fixed_frame', 'map')
self.trigger = rospy.get_param('~trigger', 'geometric')
assert self.trigger in ('both', 'geometric', 'semantic', 'timer')
self.sync = rospy.get_param('~sync', False)
self.max_time_diff = rospy.get_param('~max_time_diff', 0.2)
self.dist_th = rospy.get_param('~dist_th', 0.25)
self.flat_cost_th = rospy.get_param('~flat_cost_th', 0.5)
self.obstacle_cost_th = rospy.get_param('~obstacle_cost_th', 1.0)
self.semantic_cost_offset = rospy.get_param('~semantic_cost_offset', self.flat_cost_th)
self.timeout = rospy.get_param('~timeout', 0.5)
self.rate = rospy.get_param('~rate', 0.0)
self.lock = Lock()
self.geom_msg = None
self.sem_msg = None
self.data_fields = ['x', 'y', 'z', 'cost']
self.tf_buffer = Buffer(rospy.Duration.from_sec(2 * self.max_time_diff))
self.tf_sub = TransformListener(self.tf_buffer)
self.fused_pc_pub = rospy.Publisher(FUSED_TOPIC, PointCloud2, queue_size=1)
rospy.loginfo("Subscribing to %s and %s...",
rospy.resolve_name(GEOMETRIC_TOPIC),
rospy.resolve_name(SEMANTIC_TOPIC))
if self.sync:
self.geom_sub = Subscriber(GEOMETRIC_TOPIC, PointCloud2)
self.sem_sub = Subscriber(SEMANTIC_TOPIC, PointCloud2)
self.time_synch = ApproximateTimeSynchronizer([self.geom_sub, self.sem_sub], queue_size=1,
slop=self.max_time_diff)
self.time_synch.registerCallback(self.update_trav)
else:
self.geom_sub = rospy.Subscriber(GEOMETRIC_TOPIC, PointCloud2, self.update_geom_trav, queue_size=1)
self.sem_sub = rospy.Subscriber(SEMANTIC_TOPIC, PointCloud2, self.update_sem_trav, queue_size=1)
if self.rate > 0.0:
self.timer = rospy.Timer(rospy.Duration.from_sec(1.0 / self.rate), self.timer_callback)
def update_trav(self, geom_msg, sem_msg):
with self.lock:
self.geom_msg = geom_msg
self.sem_msg = sem_msg
if self.trigger in ('both',):
self.fuse_available_msgs()
def update_geom_trav(self, msg):
with self.lock:
self.geom_msg = msg
if self.trigger in ('geometric', 'both'):
self.fuse_available_msgs()
def update_sem_trav(self, msg):
with self.lock:
self.sem_msg = msg
if self.trigger in ('semantic', 'both'):
self.fuse_available_msgs()
def fuse_clouds(self, geom_cloud, sem_cloud, sem_to_geom):
geom_pts = structured_to_unstructured(geom_cloud.ravel()[self.data_fields])
sem_pts = structured_to_unstructured(sem_cloud.ravel()[self.data_fields])
geom_pts, geom_cost = geom_pts[..., :3], geom_pts[..., 3]
sem_pts, sem_cost = sem_pts[..., :3], sem_pts[..., 3]
R = sem_to_geom[:3, :3]
t = sem_to_geom[:3, 3:]
sem_pts = np.matmul(sem_pts, R.T) + t.T
# Match geometric to semantic, so that every geometric point
# with appropriate costs is updated.
fused_cloud = geom_cloud.copy()
fused_cost = fused_cloud['cost']
tree = scipy.spatial.cKDTree(sem_pts)
dists, idxs = tree.query(geom_pts, k=1)
mask = dists <= self.dist_th
mask = np.logical_and(mask, geom_cost >= self.flat_cost_th)
mask = np.logical_and(mask, geom_cost <= self.obstacle_cost_th)
fused_cost[mask] = sem_cost[idxs[mask]] + self.semantic_cost_offset
fused_cloud['cost'] = fused_cost
# plt.cla()
# plt.plot(geom_cost[::10], label='Geom cost')
# plt.plot(sem_cost[::10], label='Semantic cost')
# plt.plot(fused_cost[::10], label='Fused cost')
# plt.legend()
# plt.grid()
# plt.draw()
# plt.pause(0.1)
return fused_cloud
def fuse_available_msgs(self):
with self.lock:
geom_msg = self.geom_msg
sem_msg = self.sem_msg
# Handle special cases first.
# Publish what is available if there is nothing to merge.
if not geom_msg and not sem_msg:
return
elif not sem_msg:
# Publish geometric traversability as fused.
self.fused_pc_pub.publish(geom_msg)
return
elif not geom_msg:
# Publish semantic traversability as fused.
self.fused_pc_pub.publish(sem_msg)
return
# Publish geometric in case of large time difference.
time_diff = abs((sem_msg.header.stamp - geom_msg.header.stamp).to_sec())
if time_diff > self.max_time_diff:
rospy.logwarn('Geometric and semantic clouds have big time stamp difference: %.3f s. '
'Using geometric segmentation as fusion result.',
time_diff)
self.fused_pc_pub.publish(geom_msg)
return
# Find transform from semantic to geometric cloud to match points and update costs.
# Publish geometric if we cannot transform between clouds.
try:
tf = self.tf_buffer.lookup_transform_full(geom_msg.header.frame_id, geom_msg.header.stamp,
sem_msg.header.frame_id, sem_msg.header.stamp,
self.fixed_frame, rospy.Duration.from_sec(self.timeout))
sem_to_geom = numpify(tf.transform)
except TransformException as ex:
rospy.logerr('Cannot transform from %s to %s through fixed %s: %s. '
'Using geometric traversability as fusion result.',
sem_msg.header.frame_id, geom_msg.header.frame_id, self.fixed_frame, ex)
self.fused_pc_pub.publish(geom_msg)
return
# Take geometric costs as default and correct it with semantic.
geom_cloud = numpify(geom_msg)
sem_cloud = numpify(sem_msg)
rospy.logdebug('Running traversability fusion: geometric + semantic...')
fused_cloud = self.fuse_clouds(geom_cloud, sem_cloud, sem_to_geom)
fused_msg = msgify(PointCloud2, fused_cloud)
fused_msg.header = geom_msg.header
self.fused_pc_pub.publish(fused_msg)
def timer_callback(self, event):
self.fuse_available_msgs()
def main():
rospy.init_node('traversability_fusion', log_level=rospy.INFO)
node = TravFusion()
rospy.spin()
if __name__ == '__main__':
main()