-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathObjectDetector.py
190 lines (157 loc) · 8 KB
/
ObjectDetector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import numpy as np
from pydrake.all import (
AbstractValue,
DiagramBuilder,
LeafSystem,
Context,
ImageRgba8U,
ImageDepth32F,
ImageLabel16I,
Diagram,
PointCloud,
DepthImageToPointCloud,
RigidTransform,
)
from typing import List, Tuple
import matplotlib.pyplot as plt
import torchvision
torchvision.disable_beta_transforms_warning()
import os
import cv2
import numpy as np
from TidySpotFSM import SpotState
class ObjectDetector(LeafSystem):
def __init__(self, station: Diagram, camera_names: List[str], image_size: Tuple[int, int],
use_groundedsam: bool, groundedsam_path: str = os.path.join(os.getcwd(), "third_party/Grounded-Segment-Anything"),
device: str = "cpu"):
LeafSystem.__init__(self)
if use_groundedsam:
from perception.GroundedSAM import GroundedSAM
self.perceptor = GroundedSAM(groundedsam_path, device=device)
else:
# if we aren't using GroundedSAM, then we're using groundtruth
from perception.GroundTruthSensor import GroundTruthSensor
self.perceptor = GroundTruthSensor(station)
self._camera_names = camera_names
# Get the cameras (RgbdSensor) from the HardwareStation
self._cameras = {
camera_names: station.GetSubsystemByName(f"rgbd_sensor_{camera_names}")
for camera_names in camera_names
}
# Input ports
self.fsm_state_input = self.DeclareAbstractInputPort("fsm_state", AbstractValue.Make(SpotState.IDLE))
self._camera_inputs_indexes = {
camera_name: {
image_type: self.DeclareAbstractInputPort(
f"{camera_name}.{image_type}",
AbstractValue.Make(image_class(image_size[0], image_size[1]))
).get_index()
for image_type, image_class in {
'rgb_image': ImageRgba8U,
'depth_image': ImageDepth32F,
'label_image': ImageLabel16I,
}.items()
}
for camera_name in camera_names
}
# Output ports
self.DeclareAbstractOutputPort(
"object_detection_segmentations",
lambda: AbstractValue.Make(dict()),
self.SegmentAllCameras
)
self.DeclareAbstractOutputPort(
"grasping_object_segmentation",
lambda: AbstractValue.Make(dict()),
self.SegmentFrontCameras
)
def SegmentAllCameras(self, context: Context, output):
# print("ObjectDetector: SegmentAllCameras")
fsm_state = self.fsm_state_input.Eval(context)
if fsm_state == SpotState.GRASP_OBJECT:
output.set_value({})
return
segmentation_mask_dict = {}
for camera_name in self._camera_names:
if camera_name == "back": # Skip back camera because of obscured view
continue
rgb_image = cv2.cvtColor(self.get_color_image(camera_name, context).data, cv2.COLOR_RGBA2RGB)
label_image = self.get_label_image(camera_name, context).data
mask, confidence = self.perceptor.detect_and_segment_objects(rgb_image, camera_name, label_image)
if mask is not None:
segmentation_mask_dict[camera_name] = mask
if segmentation_mask_dict:
print(f"Detected {len(segmentation_mask_dict)} objects! Sending masks to PointCloudCropper")
else:
print("No objects detected")
output.set_value(segmentation_mask_dict)
# print("ObjectDetector: SegmentAllCameras complete)
def SegmentFrontCameras(self, context: Context, output): # For grasp selection
# print("ObjectDetector: SegmentFrontCameras")
segmentation_mask_dict = {}
frontleft_rgb_image = cv2.cvtColor(self.get_color_image("frontleft", context).data, cv2.COLOR_RGBA2RGB)
frontright_rgb_image = cv2.cvtColor(self.get_color_image("frontright", context).data, cv2.COLOR_RGBA2RGB)
frontleft_label_image = self.get_label_image("frontleft", context).data
frontright_label_image = self.get_label_image("frontright", context).data
front_left_mask, front_left_confidence = self.perceptor.detect_and_segment_objects(frontleft_rgb_image, "frontleft", frontleft_label_image)
front_right_mask, front_right_confidence = self.perceptor.detect_and_segment_objects(frontright_rgb_image, "frontright", frontright_label_image)
masks_confidences = [
(front_left_mask, front_left_confidence, "frontleft"),
(front_right_mask, front_right_confidence, "frontright")
]
# Filter out None masks
valid_masks = [(mask, confidence, name) for mask, confidence, name in masks_confidences if mask is not None]
if valid_masks:
# Select the mask with the highest confidence
# print("Detected object! Sending mask to PointCloudCropper")
segmentation_mask, _, camera_name = max(valid_masks, key=lambda x: x[1])
segmentation_mask_dict[camera_name] = segmentation_mask
# if segmentation_mask_dict:
# print("Object to grasp detected! Sending mask to PointCloudCropper")
# else:
# print("No objects detected")
output.set_value(segmentation_mask_dict)
# print("ObjectDetector: GetClosestObjectSegmentation complete")
def test_segmentation(self, object_detector_context: Context, camera_name):
rgb_image = cv2.cvtColor(self.get_color_image(camera_name, object_detector_context).data, cv2.COLOR_RGBA2RGB)
mask, confidence = self.perceptor.detect_and_segment_objects(rgb_image, camera_name)
if mask is None:
print(camera_name, " segmentation failed, no object detections found")
else:
print("Mask shape:", mask.shape)
print(camera_name, " segmentation test complete")
def connect_cameras(self, station: Diagram, builder: DiagramBuilder):
for camera_name in self._camera_names:
for image_type in ['rgb_image', 'depth_image', 'label_image']:
builder.Connect(
station.GetOutputPort(f"{camera_name}.{image_type}"),
self.get_input_port(self._camera_inputs_indexes[camera_name][image_type])
)
def get_image(self, camera_name: str, image_type: str, camera_hub_context: Context):
return self.get_input_port(self._camera_inputs_indexes[camera_name][image_type]).Eval(camera_hub_context)
def get_color_image(self, camera_name: str, camera_hub_context: Context):
return self.get_image(camera_name, 'rgb_image', camera_hub_context)
def get_depth_image(self, camera_name: str, camera_hub_context: Context):
return self.get_image(camera_name, 'depth_image', camera_hub_context)
def get_label_image(self, camera_name: str, camera_hub_context: Context):
return self.get_image(camera_name, 'label_image', camera_hub_context)
def display_all_camera_images(self, camera_hub_context: Context):
fig, axes = plt.subplots(len(self._camera_names), 3, figsize=(15, 5 * len(self._camera_names)))
for i, camera_name in enumerate(self._camera_names):
color_img = self.get_color_image(camera_name, camera_hub_context).data
depth_img = self.get_depth_image(camera_name, camera_hub_context).data
label_img = self.get_label_image(camera_name, camera_hub_context).data
# Plot the color image.
axes[i, 0].imshow(color_img)
axes[i, 0].set_title(f"{camera_name} Color image")
axes[i, 0].axis('off')
# Plot the depth image.
axes[i, 1].imshow(np.squeeze(depth_img))
axes[i, 1].set_title(f"{camera_name} Depth image")
axes[i, 1].axis('off')
# Plot the label image.
axes[i, 2].imshow(label_img)
axes[i, 2].set_title(f"{camera_name} Label image")
axes[i, 2].axis('off')
plt.tight_layout()
plt.show()