Skip to content

Commit

Permalink
images in dataset, ros costmapping node, resnetCNN-based costmaps
Browse files Browse the repository at this point in the history
  • Loading branch information
striest committed Jun 2, 2022
1 parent bc13b4b commit 1016ab5
Show file tree
Hide file tree
Showing 9 changed files with 455 additions and 32 deletions.
27 changes: 25 additions & 2 deletions maxent_irl_costmaps/algos/mppi_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ def __init__(self, expert_dataset, mppi, batch_size=64):
"""
self.expert_dataset = expert_dataset
self.mppi = mppi
self.mppi_itrs = 5
self.mppi_itrs = 10

# mlp_hiddens = [128, 128]
# self.network = MLP(insize = len(expert_dataset.feature_keys), outsize=1, hiddens=mlp_hiddens)

hiddens = []
hiddens = [128,]
# hiddens = []
self.network = ResnetCostmapCNN(in_channels=len(expert_dataset.feature_keys), out_channels=1, hidden_channels=hiddens)

print(sum([x.numel() for x in self.network.parameters()]))
Expand Down Expand Up @@ -97,6 +98,14 @@ def gradient_step(self, batch):
HACK = {"state":initial_state, "steer_angle":torch.zeros(1)}
x = self.mppi.model.get_observations(HACK)

map_params = {
'resolution': map_metadata['resolution'].item(),
'height': map_metadata['height'].item(),
'width': map_metadata['width'].item(),
'origin': map_metadata['origin']
}
self.mppi.reset()
self.mppi.cost_fn.update_map_params(map_params)
self.mppi.cost_fn.update_costmap(costmap)
self.mppi.cost_fn.update_goal(expert_traj[-1, :2])

Expand Down Expand Up @@ -148,6 +157,11 @@ def visualize(self):

map_features = data['map_features'][0]
map_metadata = {k:v[0] for k,v in data['metadata'].items()}
metadata = data['metadata']
xmin = metadata['origin'][0, 0]
ymin = metadata['origin'][0, 1]
xmax = xmin + metadata['width'][0]
ymax = ymin + metadata['height'][0]
expert_traj = data['traj'][0]

#compute costmap
Expand All @@ -162,6 +176,15 @@ def visualize(self):
initial_state = expert_traj[0]
HACK = {"state":initial_state, "steer_angle":torch.zeros(1)}
x = self.mppi.model.get_observations(HACK)

map_params = {
'resolution': map_metadata['resolution'].item(),
'height': map_metadata['height'].item(),
'width': map_metadata['width'].item(),
'origin': map_metadata['origin']
}
self.mppi.reset()
self.mppi.cost_fn.update_map_params(map_params)
self.mppi.cost_fn.update_costmap(costmap)
self.mppi.cost_fn.update_goal(expert_traj[-1, :2])

Expand Down
20 changes: 12 additions & 8 deletions maxent_irl_costmaps/dataset/maxent_irl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MaxEntIRLDataset(Dataset):
Ok, the ony diff now is that there are multiple bag files and we save the trajdata to a temporary pt file.
"""
def __init__(self, bag_fp, preprocess_fp, map_features_topic='/local_gridmap', odom_topic='/integrated_to_init', horizon=70, dt=0.1, fill_value=0.):
def __init__(self, bag_fp, preprocess_fp, map_features_topic='/local_gridmap', odom_topic='/integrated_to_init', image_topic='/multisense/left/image_rect_color', horizon=70, dt=0.1, fill_value=0.):
"""
Args:
bag_fp: The bag to get data from
Expand All @@ -33,16 +33,14 @@ def __init__(self, bag_fp, preprocess_fp, map_features_topic='/local_gridmap', o
self.preprocess_fp = preprocess_fp
self.map_features_topic = map_features_topic
self.odom_topic = odom_topic
self.image_topic = image_topic
self.horizon = horizon
self.dt = dt
self.fill_value = fill_value #I don't know if this is the best way to do this, but setting the fill value to 0 implies that missing features contribute nothing to the cost.
self.N = 0

self.initialize_dataset()

# self.dataset, self.feature_keys = load_data(bag_fp, self.map_features_topic, self.odom_topic, self.horizon, self.dt, self.fill_value)
# self.normalize_map_features()

def initialize_dataset(self):
"""
Profile the trajectories in the bag to:
Expand All @@ -64,7 +62,7 @@ def initialize_dataset(self):
if preprocess:
for tfp in os.listdir(self.bag_fp):
raw_fp = os.path.join(self.bag_fp, tfp)
data, feature_keys = load_data(raw_fp, self.map_features_topic, self.odom_topic, self.horizon, self.dt, self.fill_value)
data, feature_keys = load_data(raw_fp, self.map_features_topic, self.odom_topic, self.image_topic, self.horizon, self.dt, self.fill_value)
if data is None:
continue
for i in range(len(data['traj'])):
Expand All @@ -80,7 +78,6 @@ def initialize_dataset(self):

torch.save(subdata, pp_fp)
self.N += 1


#Actually read all the data to get statistics.
#need number of trajs, and the mean/std of all the map features.
Expand Down Expand Up @@ -111,14 +108,14 @@ def initialize_dataset(self):
self.feature_var = var_new
K += k

self.feature_std = self.feature_var.sqrt()
self.feature_std = self.feature_var.sqrt() + 1e-6
self.feature_std[~torch.isfinite(self.feature_std)] = 1e-6

def visualize(self):
"""
Get a rough sense of features
"""
n_panes = len(self.feature_keys)
n_panes = len(self.feature_keys) + 1

nx = int(np.sqrt(n_panes))
ny = int(n_panes / nx) + 1
Expand All @@ -134,12 +131,19 @@ def visualize(self):
ymin = metadata['origin'][1]
xmax = xmin + metadata['width']
ymax = ymin + metadata['height']


for ax, feat, feat_key in zip(axs, feats, self.feature_keys):
ax.imshow(feat, origin='lower', cmap='gray', extent=(xmin, xmax, ymin, ymax))
ax.plot(traj[:, 0], traj[:, 1], c='y')
ax.set_title(feat_key)

if 'image' in data.keys():
image = data['image']
ax = axs[len(self.feature_keys)]
ax.imshow(image.permute(1, 2, 0))
ax.set_title('Image')

return fig, axs

def __len__(self):
Expand Down
2 changes: 1 addition & 1 deletion maxent_irl_costmaps/networks/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, in_channels, out_channels, hidden_channels, hidden_activation

def forward(self, x):
cnn_out = self.cnn.forward(x)
return cnn_out
return cnn_out.sigmoid()

class ResnetCostmapBlock(nn.Module):
"""
Expand Down
67 changes: 63 additions & 4 deletions maxent_irl_costmaps/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import torch
import numpy as np
import scipy.interpolate, scipy.spatial

def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value):
import cv2

def load_data(bag_fp, map_features_topic, odom_topic, image_topic, horizon, dt, fill_value):
"""
Extract map features and trajectory data from the bag.
"""
print(bag_fp)
map_features_list = []
traj = []
vels = []
timestamps = []
dataset = []

Expand All @@ -26,12 +29,24 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value):
pose.orientation.w,
])

twist = msg.twist.twist
v = np.array([
twist.linear.x,
twist.linear.y,
twist.linear.z,
twist.angular.x,
twist.angular.y,
twist.angular.z,
])

traj.append(p)
vels.append(v)
timestamps.append(msg.header.stamp.to_sec())
elif topic == map_features_topic:
map_features_list.append(msg)

traj = np.stack(traj, axis=0)
vels = np.stack(vels, axis=0)
timestamps = np.array(timestamps)

#edge case check
Expand All @@ -51,6 +66,16 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value):
rots = scipy.spatial.transform.Rotation.from_quat(traj[:, 3:])
interp_q = scipy.spatial.transform.Slerp(timestamps[idxs], rots[idxs])

interp_vx = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 0])
interp_vy = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 1])
interp_vz = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 2])

interp_wx = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 3])
interp_wy = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 4])
interp_wz = scipy.interpolate.interp1d(timestamps[idxs], vels[idxs, 5])

map_target_times = []

#get a registered trajectory for each map.
for i, map_features in enumerate(map_features_list):
print('{}/{}'.format(i+1, len(map_features_list)), end='\r')
Expand Down Expand Up @@ -82,14 +107,26 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value):

start_time = map_features.info.header.stamp.to_sec()
targets = start_time + np.arange(horizon) * dt
map_target_times.append(start_time)

xs = interp_x(targets)
ys = interp_y(targets)
zs = interp_z(targets)
qs = interp_q(targets).as_quat()

vxs = interp_vx(targets)
vys = interp_vy(targets)
vzs = interp_vz(targets)
wxs = interp_wx(targets)
wys = interp_wy(targets)
wzs = interp_wz(targets)

#handle transforms to deserialize map/costmap
traj = np.concatenate([np.stack([xs, ys, zs], axis=-1), qs], axis=-1)
traj = np.concatenate([
np.stack([xs, ys, zs], axis=-1),
qs,
np.stack([vxs, vys, vzs, wxs, wys, wzs], axis=-1)
], axis=-1)

map_metadata = map_features.info
xmin = map_metadata.pose.position.x - 0.5 * (map_metadata.length_x)
Expand All @@ -110,14 +147,36 @@ def load_data(bag_fp, map_features_topic, odom_topic, horizon, dt, fill_value):
'map_features': torch.tensor(map_feature_data).float(),
'metadata': metadata_out
}

dataset.append(data)

#convert from gridmap to occgrid metadata
feature_keys = dataset[0]['feature_keys']
dataset = {
'map_features':[x['map_features'] for x in dataset],
'traj':[x['traj'] for x in dataset],
'metadata':[x['metadata'] for x in dataset]
'metadata':[x['metadata'] for x in dataset],
}

#If image topic exists, add to bag
if image_topic is not None:
image_timestamps = []
for topic, msg, t in bag.read_messages(topics=[image_topic]):
image_timestamps.append(t.to_sec())
image_timestamps = np.array(image_timestamps)
#get closest image to targets
dists = np.abs(np.expand_dims(image_timestamps, axis=0) - np.expand_dims(map_target_times, axis=1))
image_targets = np.argmin(dists, axis=1)

images = []
for i, (topic, msg, t) in enumerate(bag.read_messages(topics=[image_topic])):
n_hits = np.sum(image_targets == i)
for j in range(n_hits):
img = np.frombuffer(msg.data, dtype=np.uint8)
img = cv2.imdecode(img, cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_AREA)
images.append(torch.tensor(img).permute(2, 0, 1)[[2, 1, 0]] / 255.)

dataset['image'] = images

return dataset, feature_keys
60 changes: 53 additions & 7 deletions scripts/ros/gridmap_to_costmap_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch

from nav_msgs.msg import OccupancyGrid
from nav_msgs.msg import OccupancyGrid, Odometry
from grid_map_msgs.msg import GridMap

from rosbag_to_dataset.dtypes.gridmap import GridMapConvert
Expand All @@ -13,16 +13,58 @@ class CostmapperNode:
"""
Node that listens to gridmaps from perception and uses IRL nets to make them into costmaps
"""
def __init__(self, grid_map_topic, cost_map_topic):
def __init__(self, grid_map_topic, cost_map_topic, odom_topic, dataset, network):
"""
Args:
grid_map_topic: the topic to get map features from
cost_map_topic: The topic to publish costmaps to
odom_topic: The topic to get height from
dataset: The dataset that the network was trained on. (Need to get feature mean/var)
network: the network to produce costmaps.
"""
self.feature_keys = dataset.feature_keys
self.feature_mean = dataset.feature_mean
self.feature_std = dataset.feature_std
self.map_metadata = dataset.metadata
self.network = network
self.current_height = 0.

#we will set the output resolution dynamically
self.grid_map_cvt = GridMapConvert(channels=self.feature_keys, size=[1, 1])

self.grid_map_sub = rospy.Subscriber(grid_map_topic, GridMap, self.handle_grid_map, queue_size=1)
self.odom_sub = rospy.Subscriber(odom_topic, Odometry, self.handle_odom, queue_size=1)
self.cost_map_pub = rospy.Publisher(cost_map_topic, OccupancyGrid, queue_size=1)
self.grid_map_cvt = GridMapConvert(channels=['diff'], output_resolution=[120, 120])

def handle_odom(self, msg):
self.current_height = msg.pose.pose.position.z

def handle_grid_map(self, msg):
rospy.loginfo('handling gridmap...')
nx = int(msg.info.length_x / msg.info.resolution)
ny = int(msg.info.length_y / msg.info.resolution)
self.grid_map_cvt.size = [nx, ny]
gridmap = self.grid_map_cvt.ros_to_numpy(msg)
gridmap[~np.isfinite(gridmap)] = 0.
costmap = (gridmap[0] > 1.3).astype(np.uint8) * 100

rospy.loginfo_throttle(1.0, "output shape: {}".format(gridmap['data'].shape))

map_feats = torch.from_numpy(gridmap['data']).float()
for k in self.feature_keys:
if 'height' in k or 'terrain' in k:
idx = self.feature_keys.index(k)
map_feats[idx] -= self.current_height

map_feats[~torch.isfinite(map_feats)] = 0.
map_feats[map_feats.abs() > 100.] = 0.

map_feats_norm = (map_feats - self.feature_mean.view(-1, 1, 1)) / self.feature_std.view(-1, 1, 1)
with torch.no_grad():
costmap = self.network.forward(map_feats_norm.view(1, *map_feats_norm.shape))[0]

#experiment w/ normalizing
rospy.loginfo_throttle(1.0, "min = {}, max = {}".format(costmap.min(), costmap.max()))
costmap = (costmap - costmap.min()) / (costmap.max() - costmap.min())
costmap = (costmap * 100.).long().numpy()

costmap_msg = OccupancyGrid()
costmap_msg.header.stamp = msg.info.header.stamp
Expand All @@ -33,7 +75,7 @@ def handle_grid_map(self, msg):
costmap_msg.info.origin.position.x = msg.info.pose.position.x - msg.info.length_x/2.
costmap_msg.info.origin.position.y = msg.info.pose.position.y - msg.info.length_y/2.

costmap_msg.data = costmap[::-1, ::-1].flatten()
costmap_msg.data = costmap.flatten()

self.cost_map_pub.publish(costmap_msg)

Expand All @@ -42,7 +84,11 @@ def handle_grid_map(self, msg):

grid_map_topic = '/local_gridmap'
cost_map_topic = '/local_cost_map_final_occupancy_grid'
odom_topic = '/integrated_to_init'
mppi_irl = torch.load('../training/ackermann_costmaps/baseline2.pt')

# mppi_irl.visualize()

costmapper = CostmapperNode(grid_map_topic, cost_map_topic)
costmapper = CostmapperNode(grid_map_topic, cost_map_topic, odom_topic, mppi_irl.expert_dataset, mppi_irl.network)

rospy.spin()
Loading

0 comments on commit 1016ab5

Please sign in to comment.