-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmain.py
96 lines (83 loc) · 4.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# -*- coding: utf-8 -*-
import os
import torch
import argparse
import numpy as np
import open3d as o3d
from segment import seg_point, seg_box, seg_mask
import sam2point.dataset as dataset
import sam2point.configs as configs
from sam2point.voxelizer import Voxelizer
from sam2point.utils import cal
from show import render_scene, render_scene_outdoor
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=['S3DIS', 'ScanNet', 'Objaverse', 'KITTI', 'Semantic3D'], default='Objaverse', help='dataset selected')
parser.add_argument('--prompt_type', choices=['point', 'box'], default='point', help='prompt type selected. Mask prompt will be supported soon...')
parser.add_argument('--sample_idx', type=int, default=0, help='the index of the scene or object')
parser.add_argument('--prompt_idx', type=int, default=0, help='the index of the prompt')
parser.add_argument('--voxel_size', type=float, default=0.02, help='voxel size')
parser.add_argument('--theta', type=float, default=0.)
parser.add_argument('--mode', type=str, default='bilinear')
args = parser.parse_args()
if args.dataset == 'S3DIS':
info = configs.S3DIS_samples[args.sample_idx]
point, color = dataset.load_S3DIS_sample(info['path'])
elif args.dataset == 'ScanNet':
info = configs.ScanNet_samples[args.sample_idx]
point, color = dataset.load_ScanNet_sample(info['path'])
elif args.dataset == 'Objaverse':
info = configs.Objaverse_samples[args.sample_idx]
point, color = dataset.load_Objaverse_sample(info['path'])
args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx]
args.mode, args.theta = 'nearest', 0.5
elif args.dataset == 'KITTI':
info = configs.KITTI_samples[args.sample_idx]
point, color = dataset.load_KITTI_sample(info['path'])
args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx]
args.mode, args.theta = 'nearest', 0.5
elif args.dataset == 'Semantic3D':
info = configs.Semantic3D_samples[args.sample_idx]
point, color = dataset.load_Semantic3D_sample(info['path'], args.sample_idx)
args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx]
args.mode, args.theta = 'nearest', 0.5
print(args)
point_color = np.concatenate([point, color], axis=1)
voxelizer = Voxelizer(voxel_size=args.voxel_size, clip_bound=None)
labels_in = point[:, :1].astype(int)
locs, feats, labels, inds_reconstruct = voxelizer.voxelize(point, color, labels_in)
if args.prompt_type == 'point':
mask = seg_point(locs, feats, info['point_prompts'], args)
point_prompts = np.array(info['point_prompts'])
prompt_point = list(point_prompts[args.prompt_idx])
prompt_box = None
elif args.prompt_type == 'box':
mask = seg_box(locs, feats, info['box_prompts'], args)
box_prompts = np.array(info['box_prompts'])
prompt_point = None
prompt_box = list(box_prompts[args.prompt_idx])
else:
print("Wrong prompt type! Please select prompt type from {point, box}. Mask prompt will be released soon. Please be patient. :))")
point_locs = locs[inds_reconstruct]
point_mask = mask[point_locs[:, 0], point_locs[:, 1], point_locs[:, 2]]
point_mask = point_mask.unsqueeze(-1)
point_mask_not = ~point_mask
point, color = point_color[:, :3], point_color[:, 3:]
new_color = color * point_mask_not.numpy() + (color * 0 + np.array([[0., 1., 0.]])) * point_mask.numpy()
os.makedirs('results', exist_ok=True)
name_list = [args.dataset, "sample" + str(args.sample_idx), args.prompt_type + "-prompt" + str(args.prompt_idx)]
name = '_'.join(name_list)
if args.dataset == 'KITTI':
render_scene_outdoor(point, new_color, name, prompt_point=prompt_point, prompt_box=prompt_box)
render_scene_outdoor(point, new_color, name, prompt_point=prompt_point, prompt_box=prompt_box, close=True)
elif args.dataset == 'Semantic3D':
render_scene_outdoor(point, new_color, name, prompt_point=prompt_point, prompt_box=prompt_box, semantic=True, args=args)
else:
render_scene(point, new_color, name, prompt_point=prompt_point, prompt_box=prompt_box)
if __name__=='__main__':
main()