This repository has been archived by the owner on Aug 5, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 209
/
Copy pathKNN.py
142 lines (112 loc) · 5.07 KB
/
KNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
# This file is covered by the LICENSE file in the root of this project.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import __init__ as booger
def get_gaussian_kernel(kernel_size=3, sigma=2, channels=1):
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
x_coord = torch.arange(kernel_size)
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
mean = (kernel_size - 1) / 2.
variance = sigma**2.
# Calculate the 2-dimensional gaussian kernel which is
# the product of two gaussian distributions for two different
# variables (in this case called x and y)
gaussian_kernel = (1. / (2. * math.pi * variance)) *\
torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / (2 * variance))
# Make sure sum of values in gaussian kernel equals 1.
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
# Reshape to 2d depthwise convolutional weight
gaussian_kernel = gaussian_kernel.view(kernel_size, kernel_size)
return gaussian_kernel
class KNN(nn.Module):
def __init__(self, params, nclasses):
super().__init__()
print("*"*80)
print("Cleaning point-clouds with kNN post-processing")
self.knn = params["knn"]
self.search = params["search"]
self.sigma = params["sigma"]
self.cutoff = params["cutoff"]
self.nclasses = nclasses
print("kNN parameters:")
print("knn:", self.knn)
print("search:", self.search)
print("sigma:", self.sigma)
print("cutoff:", self.cutoff)
print("nclasses:", self.nclasses)
print("*"*80)
def forward(self, proj_range, unproj_range, proj_argmax, px, py):
''' Warning! Only works for un-batched pointclouds.
If they come batched we need to iterate over the batch dimension or do
something REALLY smart to handle unaligned number of points in memory
'''
# get device
if proj_range.is_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# sizes of projection scan
H, W = proj_range.shape
# number of points
P = unproj_range.shape
# check if size of kernel is odd and complain
if (self.search % 2 == 0):
raise ValueError("Nearest neighbor kernel must be odd number")
# calculate padding
pad = int((self.search - 1) / 2)
# unfold neighborhood to get nearest neighbors for each pixel (range image)
proj_unfold_k_rang = F.unfold(proj_range[None, None, ...],
kernel_size=(self.search, self.search),
padding=(pad, pad))
# index with px, py to get ALL the pcld points
idx_list = py * W + px
unproj_unfold_k_rang = proj_unfold_k_rang[:, :, idx_list]
# WARNING, THIS IS A HACK
# Make non valid (<0) range points extremely big so that there is no screwing
# up the nn self.search
unproj_unfold_k_rang[unproj_unfold_k_rang < 0] = float("inf")
# now the matrix is unfolded TOTALLY, replace the middle points with the actual range points
center = int(((self.search * self.search) - 1) / 2)
unproj_unfold_k_rang[:, center, :] = unproj_range
# now compare range
k2_distances = torch.abs(unproj_unfold_k_rang - unproj_range)
# make a kernel to weigh the ranges according to distance in (x,y)
# I make this 1 - kernel because I want distances that are close in (x,y)
# to matter more
inv_gauss_k = (
1 - get_gaussian_kernel(self.search, self.sigma, 1)).view(1, -1, 1)
inv_gauss_k = inv_gauss_k.to(device).type(proj_range.type())
# apply weighing
k2_distances = k2_distances * inv_gauss_k
# find nearest neighbors
_, knn_idx = k2_distances.topk(
self.knn, dim=1, largest=False, sorted=False)
# do the same unfolding with the argmax
proj_unfold_1_argmax = F.unfold(proj_argmax[None, None, ...].float(),
kernel_size=(self.search, self.search),
padding=(pad, pad)).long()
unproj_unfold_1_argmax = proj_unfold_1_argmax[:, :, idx_list]
# get the top k predictions from the knn at each pixel
knn_argmax = torch.gather(
input=unproj_unfold_1_argmax, dim=1, index=knn_idx)
# fake an invalid argmax of classes + 1 for all cutoff items
if self.cutoff > 0:
knn_distances = torch.gather(input=k2_distances, dim=1, index=knn_idx)
knn_invalid_idx = knn_distances > self.cutoff
knn_argmax[knn_invalid_idx] = self.nclasses
# now vote
# argmax onehot has an extra class for objects after cutoff
knn_argmax_onehot = torch.zeros(
(1, self.nclasses + 1, P[0]), device=device).type(proj_range.type())
ones = torch.ones_like(knn_argmax).type(proj_range.type())
knn_argmax_onehot = knn_argmax_onehot.scatter_add_(1, knn_argmax, ones)
# now vote (as a sum over the onehot shit) (don't let it choose unlabeled OR invalid)
knn_argmax_out = knn_argmax_onehot[:, 1:-1].argmax(dim=1) + 1
# reshape again
knn_argmax_out = knn_argmax_out.view(P)
return knn_argmax_out