Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature_3dssd_FPS_With_Dist_OP #66

Merged
merged 14 commits into from
Aug 30, 2020
12 changes: 7 additions & 5 deletions mmdet3d/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
sigmoid_focal_loss)

from .ball_query import ball_query
from .furthest_point_sample import furthest_point_sample
from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)
from .gather_points import gather_points
from .group_points import (GroupAll, QueryAndGroup, group_points,
grouping_operation)
Expand All @@ -24,8 +25,9 @@
'SigmoidFocalLoss', 'SparseBasicBlock', 'SparseBottleneck',
'RoIAwarePool3d', 'points_in_boxes_gpu', 'points_in_boxes_cpu',
'make_sparse_convmodule', 'ball_query', 'furthest_point_sample',
'three_interpolate', 'three_nn', 'gather_points', 'grouping_operation',
'group_points', 'GroupAll', 'QueryAndGroup', 'PointSAModule',
'PointSAModuleMSG', 'PointFPModule', 'points_in_boxes_batch',
'get_compiler_version', 'get_compiling_cuda_version'
'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn',
'gather_points', 'grouping_operation', 'group_points', 'GroupAll',
'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule',
'points_in_boxes_batch', 'get_compiler_version',
'get_compiling_cuda_version'
]
5 changes: 3 additions & 2 deletions mmdet3d/ops/furthest_point_sample/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .furthest_point_sample import furthest_point_sample
from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)

__all__ = ['furthest_point_sample']
__all__ = ['furthest_point_sample', 'furthest_point_sample_with_dist']
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
46 changes: 46 additions & 0 deletions mmdet3d/ops/furthest_point_sample/furthest_point_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,50 @@ def backward(xyz, a=None):
return None, None


class FurthestPointSamplingWithDist(Function):
"""Furthest Point Sampling With Distance.

Uses iterative furthest point sampling to select a set of features whose
corresponding points have the furthest distance.
"""

@staticmethod
def forward(ctx, points_dist: torch.Tensor,
num_points: int) -> torch.Tensor:
"""forward.

Args:
points_dist (Tensor): (B, N, N) Distance between each point pair.
num_points (int): Number of points in the sampled set.

Returns:
Tensor: (B, num_points) indices of the sampled points.
"""
assert points_dist.is_contiguous()

B, N, _ = points_dist.size()
output = torch.cuda.IntTensor(B, num_points)
encore-zhou marked this conversation as resolved.
Show resolved Hide resolved
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
encore-zhou marked this conversation as resolved.
Show resolved Hide resolved

furthest_point_sample_ext.furthest_point_sampling_with_dist_wrapper(
B, N, num_points, points_dist, temp, output)
ctx.mark_non_differentiable(output)
return output

@staticmethod
def backward(xyz, a=None):
return None, None


furthest_point_sample = FurthestPointSampling.apply
furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply

if __name__ == '__main__':
encore-zhou marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
fps_idx = np.load('fps_idx.npy')
features_for_fps_distance = np.load('features_for_fps_distance.npy')
fps_idx = torch.from_numpy(fps_idx).cuda()
features_for_fps_distance = torch.from_numpy(
features_for_fps_distance).cuda()
fps_idx_t = furthest_point_sample_with_dist(features_for_fps_distance, 512)
assert (fps_idx_t - fps_idx).sum() == 0
27 changes: 27 additions & 0 deletions mmdet3d/ops/furthest_point_sample/src/furthest_point_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp,
int *idxs, cudaStream_t stream);

int furthest_point_sampling_with_dist_wrapper(int b, int n, int m,
encore-zhou marked this conversation as resolved.
Show resolved Hide resolved
at::Tensor points_tensor,
at::Tensor temp_tensor,
at::Tensor idx_tensor);

void furthest_point_sampling_with_dist_kernel_launcher(int b, int n, int m,
const float *dataset,
float *temp, int *idxs,
cudaStream_t stream);

int furthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor,
at::Tensor temp_tensor,
Expand All @@ -32,7 +42,24 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
return 1;
}

int furthest_point_sampling_with_dist_wrapper(int b, int n, int m,
at::Tensor points_tensor,
at::Tensor temp_tensor,
at::Tensor idx_tensor) {

const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();

cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
furthest_point_sampling_with_dist_kernel_launcher(b, n, m, points, temp, idx, stream);
return 1;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper,
"furthest_point_sampling_wrapper");
m.def("furthest_point_sampling_with_dist_wrapper",
&furthest_point_sampling_with_dist_wrapper,
"furthest_point_sampling_with_dist_wrapper");
}
189 changes: 189 additions & 0 deletions mmdet3d/ops/furthest_point_sample/src/furthest_point_sample_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,192 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
exit(-1);
}
}

template <unsigned int block_size>
__global__ void furthest_point_sampling_with_dist_kernel(
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
int b, int n, int m, const float *__restrict__ dataset,
float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, N)
// tmp: (B, N)
// output:
// idx: (B, M)

if (m <= 0)
return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];

int batch_index = blockIdx.x;
dataset += batch_index * n * n;
temp += batch_index * n;
idxs += batch_index * m;

int tid = threadIdx.x;
const int stride = block_size;

int old = 0;
if (threadIdx.x == 0)
idxs[0] = old;

__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
// float x1 = dataset[old * 3 + 0];
// float y1 = dataset[old * 3 + 1];
// float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
// float x2, y2, z2;
// x2 = dataset[k * 3 + 0];
// y2 = dataset[k * 3 + 1];
// z2 = dataset[k * 3 + 2];

// float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) *
// (z2 - z1);
float d = dataset[old * n + k];

float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();

if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
}

if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}

old = dists_i[0];
if (tid == 0)
idxs[j] = old;
}
}

void furthest_point_sampling_with_dist_kernel_launcher(int b, int n, int m,
const float *dataset,
float *temp, int *idxs,
cudaStream_t stream) {
// dataset: (B, N, N)
// tmp: (B, N)
// output:
// idx: (B, M)

cudaError_t err;
unsigned int n_threads = opt_n_threads(n);

switch (n_threads) {
encore-zhou marked this conversation as resolved.
Show resolved Hide resolved
case 1024:
furthest_point_sampling_with_dist_kernel<1024><<<b, n_threads, 0, stream>>>(
encore-zhou marked this conversation as resolved.
Show resolved Hide resolved
b, n, m, dataset, temp, idxs);
break;
case 512:
furthest_point_sampling_with_dist_kernel<512><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 256:
furthest_point_sampling_with_dist_kernel<256><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 128:
furthest_point_sampling_with_dist_kernel<128><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 64:
furthest_point_sampling_with_dist_kernel<64><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 32:
furthest_point_sampling_with_dist_kernel<32><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 16:
furthest_point_sampling_with_dist_kernel<16><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 8:
furthest_point_sampling_with_dist_kernel<8><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 4:
furthest_point_sampling_with_dist_kernel<4><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 2:
furthest_point_sampling_with_dist_kernel<2><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 1:
furthest_point_sampling_with_dist_kernel<1><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
default:
furthest_point_sampling_with_dist_kernel<512><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
}

err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
22 changes: 21 additions & 1 deletion tests/test_pointnet_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import torch

from mmdet3d.ops import (ball_query, furthest_point_sample, gather_points,
from mmdet3d.ops import (ball_query, furthest_point_sample,
furthest_point_sample_with_dist, gather_points,
grouping_operation, three_interpolate, three_nn)


Expand Down Expand Up @@ -312,3 +313,22 @@ def test_three_nn():

assert torch.allclose(dist, expected_dist, 1e-4)
assert torch.all(idx == expected_idx)


def test_fps_with_dist():
if not torch.cuda.is_available():
pytest.skip()
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
[-0.8070, 2.4137,
-0.5845], [-1.0001, 2.1982, -0.5859],
[0.3841, 1.8983, -0.7431]],
[[-1.0696, 3.0758,
-0.1899], [-0.2559, 3.5521, -0.1402],
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
[-0.0518, 3.7251, -0.3950]]]).cuda()

expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
xyz_square_dist = ((xyz.unsqueeze(dim=1) -
xyz.unsqueeze(dim=2))**2).sum(-1)
idx = furthest_point_sample_with_dist(xyz_square_dist, 3)
assert torch.all(idx == expected_idx)