-
Notifications
You must be signed in to change notification settings - Fork 27
/
superpixel_segmenter.py
73 lines (67 loc) · 2.46 KB
/
superpixel_segmenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from multiprocessing import Pool
from skimage.segmentation import slic
from nuscenes.nuscenes import NuScenes
def compute_slic(cam_token):
cam = nusc.get("sample_data", cam_token)
im = Image.open(os.path.join(nusc.dataroot, cam["filename"]))
segments_slic = slic(
im, n_segments=150, compactness=6, sigma=3.0, start_label=0
).astype(np.uint8)
im = Image.fromarray(segments_slic)
im.save(
"./superpixels/nuscenes/superpixels_slic/" + cam["token"] + ".png"
)
def compute_slic_30(cam_token):
cam = nusc.get("sample_data", cam_token)
im = Image.open(os.path.join(nusc.dataroot, cam["filename"]))
segments_slic = slic(
im, n_segments=30, compactness=6, sigma=3.0, start_label=0
).astype(np.uint8)
im = Image.fromarray(segments_slic)
im.save(
"./superpixels/nuscenes/superpixels_slic_30/" + cam["token"] + ".png"
)
if __name__ == "__main__":
nuscenes_path = "datasets/nuscenes"
parser = argparse.ArgumentParser(description="arg parser")
parser.add_argument(
"--model", type=str, default="minkunet", help="specify the model targeted, either minkunet or voxelnet"
)
assert os.path.exists(nuscenes_path), f"nuScenes not found in {nuscenes_path}"
args = parser.parse_args()
assert args.model in ["minkunet", "voxelnet"]
nusc = NuScenes(
version="v1.0-trainval", dataroot=nuscenes_path, verbose=False
)
os.makedirs("superpixels/nuscenes/superpixels_slic/")
camera_list = [
"CAM_FRONT",
"CAM_FRONT_RIGHT",
"CAM_BACK_RIGHT",
"CAM_BACK",
"CAM_BACK_LEFT",
"CAM_FRONT_LEFT",
]
with Pool(6) as p:
for scene_idx in tqdm(range(len(nusc.scene))):
scene = nusc.scene[scene_idx]
current_sample_token = scene["first_sample_token"]
while current_sample_token != "":
current_sample = nusc.get("sample", current_sample_token)
if args.model == "minkunet":
func = compute_slic
elif args.model == "voxelnet":
func = compute_slic_30
p.map(
func,
[
current_sample["data"][camera_name]
for camera_name in camera_list
],
)
current_sample_token = current_sample["next"]