diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d09e1b329b..a1902e2f3b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -281,6 +281,7 @@ if(BUILD_RAFT_TESTS) test/sparse/add.cu test/sparse/convert_coo.cu test/sparse/convert_csr.cu + test/sparse/connect_components.cu test/sparse/csr_row_slice.cu test/sparse/csr_to_dense.cu test/sparse/csr_transpose.cu @@ -290,6 +291,8 @@ if(BUILD_RAFT_TESTS) test/sparse/distance.cu test/sparse/filter.cu test/sparse/knn.cu + test/sparse/knn_graph.cu + test/sparse/linkage.cu test/sparse/norm.cu test/sparse/row_op.cu test/sparse/selection.cu diff --git a/cpp/cmake/Dependencies.cmake b/cpp/cmake/Dependencies.cmake index 080efb5b1f..e1e21a9cd3 100644 --- a/cpp/cmake/Dependencies.cmake +++ b/cpp/cmake/Dependencies.cmake @@ -23,7 +23,7 @@ if(NOT CUB_IS_PART_OF_CTK) set(CUB_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub CACHE STRING "Path to cub repo") ExternalProject_Add(cub GIT_REPOSITORY https://github.com/thrust/cub.git - GIT_TAG 1.8.0 + GIT_TAG 1.12.0 PREFIX ${CUB_DIR} CONFIGURE_COMMAND "" BUILD_COMMAND "" diff --git a/cpp/include/raft/sparse/hierarchy/common.h b/cpp/include/raft/sparse/hierarchy/common.h new file mode 100644 index 0000000000..48326bd347 --- /dev/null +++ b/cpp/include/raft/sparse/hierarchy/common.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft { +namespace hierarchy { + +enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; + +/** + * Simple POCO for consolidating linkage results. This closely + * mirrors the trained instance variables populated in + * Scikit-learn's AgglomerativeClustering estimator. + * @tparam value_idx + * @tparam value_t + */ +template +struct linkage_output { + value_idx m; + value_idx n_clusters; + + value_idx n_leaves; + value_idx n_connected_components; + + value_idx *labels; // size: m + + value_idx *children; // size: (m-1, 2) +}; + +struct linkage_output_int_float : public linkage_output {}; +struct linkage_output__int64_float : public linkage_output {}; + +}; // namespace hierarchy +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh b/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh new file mode 100644 index 0000000000..aa2cfc8fe9 --- /dev/null +++ b/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh @@ -0,0 +1,361 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace raft { + +namespace hierarchy { +namespace detail { + +template +class UnionFind { + public: + value_idx next_label; + std::vector parent; + std::vector size; + + value_idx n_indices; + + UnionFind(value_idx N_) + : n_indices(2 * N_ - 1), + parent(2 * N_ - 1, -1), + size(2 * N_ - 1, 1), + next_label(N_) { + memset(size.data() + N_, 0, (size.size() - N_) * sizeof(value_idx)); + } + + value_idx find(value_idx n) { + value_idx p; + p = n; + + while (parent[n] != -1) n = parent[n]; + + // path compression + while (parent[p] != n) { + p = parent[p == -1 ? n_indices - 1 : p]; + parent[p == -1 ? n_indices - 1 : p] = n; + } + return n; + } + + void perform_union(value_idx m, value_idx n) { + size[next_label] = size[m] + size[n]; + parent[m] = next_label; + parent[n] = next_label; + + next_label += 1; + } +}; + +/** + * Standard single-threaded agglomerative labeling on host. This should work + * well for smaller sizes of m. This is a C++ port of the original reference + * implementation of HDBSCAN. + * + * @tparam value_idx + * @tparam value_t + * @param[in] handle the raft handle + * @param[in] rows src edges of the sorted MST + * @param[in] cols dst edges of the sorted MST + * @param[in] nnz the number of edges in the sorted MST + * @param[out] out_src parents of output + * @param[out] out_dst children of output + * @param[out] out_delta distances of output + * @param[out] out_size cluster sizes of output + */ +template +void build_dendrogram_host(const handle_t &handle, const value_idx *rows, + const value_idx *cols, const value_t *data, + size_t nnz, value_idx *children, + rmm::device_uvector &out_delta, + rmm::device_uvector &out_size) { + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + + value_idx n_edges = nnz; + + std::vector mst_src_h(n_edges); + std::vector mst_dst_h(n_edges); + std::vector mst_weights_h(n_edges); + + update_host(mst_src_h.data(), rows, n_edges, stream); + update_host(mst_dst_h.data(), cols, n_edges, stream); + update_host(mst_weights_h.data(), data, n_edges, stream); + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + std::vector children_h(n_edges * 2); + std::vector out_size_h(n_edges); + + UnionFind U(nnz + 1); + + for (value_idx i = 0; i < nnz; i++) { + value_idx a = mst_src_h[i]; + value_idx b = mst_dst_h[i]; + + value_idx aa = U.find(a); + value_idx bb = U.find(b); + + value_idx children_idx = i * 2; + + children_h[children_idx] = aa; + children_h[children_idx + 1] = bb; + out_size_h[i] = U.size[aa] + U.size[bb]; + + U.perform_union(aa, bb); + } + + out_size.resize(n_edges, stream); + + raft::update_device(children, children_h.data(), n_edges * 2, stream); + raft::update_device(out_size.data(), out_size_h.data(), n_edges, stream); +} + +/** + * Parallel agglomerative labeling. This amounts to a parallel Kruskal's + * MST algorithm, which breaks apart the sorted MST results into overlapping + * subsets and independently runs Kruskal's algorithm on each subset, + * merging them back together into a single hierarchy when complete. + * + * This outputs the same format as the reference HDBSCAN, but as 4 separate + * arrays, rather than a single 2D array. + * + * Reference: http://cucis.ece.northwestern.edu/publications/pdf/HenPat12.pdf + * + * TODO: Investigate potential for the following end-to-end single-hierarchy batching: + * For each of k (independent) batches over the input: + * - Sample n elements from X + * - Compute mutual reachability graph of batch + * - Construct labels from batch + * + * The sampled datasets should have some overlap across batches. This will + * allow for the cluster hierarchies to be merged. Being able to batch + * will reduce the memory cost so that the full n^2 pairwise distances + * don't need to be materialized in memory all at once. + * + * @tparam value_idx + * @tparam value_t + * @param[in] handle the raft handle + * @param[in] rows src edges of the sorted MST + * @param[in] cols dst edges of the sorted MST + * @param[in] nnz the number of edges in the sorted MST + * @param[out] out_src parents of output + * @param[out] out_dst children of output + * @param[out] out_delta distances of output + * @param[out] out_size cluster sizes of output + * @param[in] k_folds number of folds for parallelizing label step + */ +template +void build_dendrogram_device(const handle_t &handle, const value_idx *rows, + const value_idx *cols, const value_t *data, + value_idx nnz, value_idx *children, + value_t *out_delta, value_idx *out_size, + value_idx k_folds) { + ASSERT(k_folds < nnz / 2, "k_folds must be < n_edges / 2"); + /** + * divide (sorted) mst coo into overlapping subsets. Easiest way to do this is to + * break it into k-folds and iterate through two folds at a time. + */ + + // 1. Generate ranges for the overlapping subsets + + // 2. Run union-find in parallel for each pair of folds + + // 3. Sort individual label hierarchies + + // 4. Merge label hierarchies together +} + +template +__global__ void write_levels_kernel(const value_idx *children, + value_idx *parents, value_idx n_vertices) { + value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < n_vertices) { + value_idx level = tid / 2; + value_idx child = children[tid]; + parents[child] = level; + } +} + +/** + * Instead of propagating a label from roots to children, + * the children each iterate up the tree until they find + * the label of their parent. This increases the potential + * parallelism. + * @tparam value_idx + * @param children + * @param parents + * @param n_leaves + * @param labels + */ +template +__global__ void inherit_labels(const value_idx *children, + const value_idx *levels, size_t n_leaves, + value_idx *labels, int cut_level, + value_idx n_vertices) { + value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; + + if (tid < n_vertices) { + value_idx node = children[tid]; + value_idx cur_level = tid / 2; + + /** + * Any roots above the cut level should be ignored. + * Any leaves at the cut level should already be labeled + */ + if (cur_level > cut_level) return; + + value_idx cur_parent = node; + value_idx label = labels[cur_parent]; + + while (label == -1) { + cur_parent = cur_level + n_leaves; + cur_level = levels[cur_parent]; + label = labels[cur_parent]; + } + + labels[node] = label; + } +} + +template +struct init_label_roots { + init_label_roots(value_idx *labels_) : labels(labels_) {} + + template + __host__ __device__ void operator()(Tuple t) { + labels[thrust::get<1>(t)] = thrust::get<0>(t); + } + + private: + value_idx *labels; +}; + +/** + * Cuts the dendrogram at a particular level where the number of nodes + * is equal to n_clusters, then propagates the resulting labels + * to all the children. + * + * @tparam value_idx + * @param handle + * @param labels + * @param children + * @param n_clusters + * @param n_leaves + */ +template +void extract_flattened_clusters(const raft::handle_t &handle, value_idx *labels, + const value_idx *children, size_t n_clusters, + size_t n_leaves) { + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + auto thrust_policy = rmm::exec_policy(stream); + + /** + * Compute levels for each node + * + * 1. Initialize "levels" array of size n_leaves * 2 + * + * 2. For each entry in children, write parent + * out for each of the children + */ + + size_t n_edges = (n_leaves - 1) * 2; + + thrust::device_ptr d_ptr = + thrust::device_pointer_cast(children); + value_idx n_vertices = + *(thrust::max_element(thrust_policy, d_ptr, d_ptr + n_edges)) + 1; + + // Prevent potential infinite loop from labeling disconnected + // connectivities graph. + RAFT_EXPECTS(n_vertices == (n_leaves - 1) * 2, + "Multiple components found in MST or MST is invalid. " + "Cannot find single-linkage solution."); + + rmm::device_uvector levels(n_vertices, stream); + + value_idx n_blocks = ceildiv(n_vertices, (value_idx)tpb); + write_levels_kernel<<>>(children, levels.data(), + n_vertices); + /** + * Step 1: Find label roots: + * + * 1. Copying children[children.size()-(n_clusters-1):] entries to + * separate arrayo + * 2. sort array + * 3. take first n_clusters entries + */ + + value_idx child_size = (n_clusters - 1) * 2; + rmm::device_uvector label_roots(child_size, stream); + + value_idx children_cpy_start = n_edges - child_size; + raft::copy_async(label_roots.data(), children + children_cpy_start, + child_size, stream); + + // thrust::device_ptr t_label_roots = + // thrust::device_pointer_cast(label_roots.data()); + // + thrust::sort(thrust_policy, label_roots.data(), + label_roots.data() + (child_size), thrust::greater()); + + rmm::device_uvector tmp_labels(n_vertices, stream); + + // Init labels to -1 + thrust::fill(thrust_policy, tmp_labels.data(), tmp_labels.data() + n_vertices, + -1); + + // Write labels for cluster roots to "labels" + thrust::counting_iterator first(0); + + auto z_iter = thrust::make_zip_iterator(thrust::make_tuple( + first, label_roots.data() + (label_roots.size() - n_clusters))); + + thrust::for_each(thrust_policy, z_iter, z_iter + n_clusters, + init_label_roots(tmp_labels.data())); + + /** + * Step 2: Propagate labels by having children iterate through their parents + * 1. Initialize labels to -1 + * 2. For each element in levels array, propagate until parent's + * label is !=-1 + */ + value_idx cut_level = (n_edges / 2) - (n_clusters - 1); + + inherit_labels<<>>(children, levels.data(), + n_leaves, tmp_labels.data(), + cut_level, n_vertices); + + // copy tmp labels to actual labels + raft::copy_async(labels, tmp_labels.data(), n_leaves, stream); +} + +}; // namespace detail +}; // namespace hierarchy +}; // namespace raft diff --git a/cpp/include/raft/sparse/hierarchy/detail/connectivities.cuh b/cpp/include/raft/sparse/hierarchy/detail/connectivities.cuh new file mode 100644 index 0000000000..229f2034b0 --- /dev/null +++ b/cpp/include/raft/sparse/hierarchy/detail/connectivities.cuh @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace hierarchy { +namespace detail { + +template +struct distance_graph_impl { + void run(const raft::handle_t &handle, const value_t *X, size_t m, size_t n, + raft::distance::DistanceType metric, + rmm::device_uvector &indptr, + rmm::device_uvector &indices, + rmm::device_uvector &data, int c); +}; + +/** + * Connectivities specialization to build a knn graph + * @tparam value_idx + * @tparam value_t + */ +template +struct distance_graph_impl { + void run(const raft::handle_t &handle, const value_t *X, size_t m, size_t n, + raft::distance::DistanceType metric, + rmm::device_uvector &indptr, + rmm::device_uvector &indices, + rmm::device_uvector &data, int c) { + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + + // Need to symmetrize knn into undirected graph + raft::sparse::COO knn_graph_coo(d_alloc, stream); + + raft::sparse::selection::knn_graph(handle, X, m, n, metric, knn_graph_coo, + c); + + indices.resize(knn_graph_coo.nnz, stream); + data.resize(knn_graph_coo.nnz, stream); + + raft::sparse::convert::sorted_coo_to_csr(&knn_graph_coo, indptr.data(), + d_alloc, stream); + + //TODO: This is a bug in the coo_to_csr prim + value_idx max_offset = 0; + raft::update_host(&max_offset, indptr.data() + (m - 1), 1, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + max_offset += (knn_graph_coo.nnz - max_offset); + + raft::update_device(indptr.data() + m, &max_offset, 1, stream); + + raft::copy_async(indices.data(), knn_graph_coo.cols(), knn_graph_coo.nnz, + stream); + raft::copy_async(data.data(), knn_graph_coo.vals(), knn_graph_coo.nnz, + stream); + } +}; + +/** + * Returns a CSR connectivities graph based on the given linkage distance. + * @tparam value_idx + * @tparam value_t + * @tparam dist_type + * @param[in] handle raft handle + * @param[in] X dense data for which to construct connectivites + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] metric distance metric to use + * @param[out] indptr indptr array of connectivities graph + * @param[out] indices column indices array of connectivities graph + * @param[out] data distances array of connectivities graph + * @param[out] c constant 'c' used for nearest neighbors-based distances + * which will guarantee k <= log(n) + c + */ +template +void get_distance_graph(const raft::handle_t &handle, const value_t *X, + size_t m, size_t n, raft::distance::DistanceType metric, + rmm::device_uvector &indptr, + rmm::device_uvector &indices, + rmm::device_uvector &data, int c) { + auto stream = handle.get_stream(); + + indptr.resize(m + 1, stream); + + distance_graph_impl dist_graph; + dist_graph.run(handle, X, m, n, metric, indptr, indices, data, c); + + // a little adjustment for distances of 0. + // TODO: This will only need to be done when src_v==dst_v + raft::linalg::unaryOp( + data.data(), data.data(), data.size(), + [] __device__(value_t input) { + if (input == 0) + return std::numeric_limits::max(); + else + return input; + }, + stream); +} + +}; // namespace detail +}; // namespace hierarchy +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/hierarchy/detail/mst.cuh b/cpp/include/raft/sparse/hierarchy/detail/mst.cuh new file mode 100644 index 0000000000..1b890c239b --- /dev/null +++ b/cpp/include/raft/sparse/hierarchy/detail/mst.cuh @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace raft { +namespace hierarchy { +namespace detail { + +/** + * Sorts a COO by its weight + * @tparam value_idx + * @tparam value_t + * @param[inout] rows source edges + * @param[inout] cols dest edges + * @param[inout] data edge weights + * @param[in] nnz number of edges in edge list + * @param[in] stream cuda stream for which to order cuda operations + */ +template +void sort_coo_by_data(value_idx *rows, value_idx *cols, value_t *data, + value_idx nnz, cudaStream_t stream) { + thrust::device_ptr t_rows = thrust::device_pointer_cast(rows); + thrust::device_ptr t_cols = thrust::device_pointer_cast(cols); + thrust::device_ptr t_data = thrust::device_pointer_cast(data); + + auto first = thrust::make_zip_iterator(thrust::make_tuple(rows, cols)); + + thrust::sort_by_key(thrust::cuda::par.on(stream), t_data, t_data + nnz, + first); +} + +/** + * Connect an unconnected knn graph (one in which mst returns an msf). The + * device buffers underlying the Graph_COO object are modified in-place. + * @tparam value_idx index type + * @tparam value_t floating-point value type + * @param[in] handle raft handle + * @param[in] X original dense data from which knn grpah was constructed + * @param[inout] msf edge list containing the mst result + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] color the color labels array returned from the mst invocation + * @return updated MST edge list + */ +template +raft::Graph_COO connect_knn_graph( + const raft::handle_t &handle, const value_t *X, + raft::Graph_COO &msf, size_t m, size_t n, + value_idx *color) { + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + + raft::sparse::COO connected_edges(d_alloc, stream); + + raft::linkage::connect_components(handle, connected_edges, + X, color, m, n); + + int final_nnz = connected_edges.nnz + msf.n_edges; + + msf.src.resize(final_nnz, stream); + msf.dst.resize(final_nnz, stream); + msf.weights.resize(final_nnz, stream); + + /** + * Construct final edge list + */ + raft::copy_async(msf.src.data() + msf.n_edges, connected_edges.rows(), + connected_edges.nnz, stream); + raft::copy_async(msf.dst.data() + msf.n_edges, connected_edges.cols(), + connected_edges.nnz, stream); + raft::copy_async(msf.weights.data() + msf.n_edges, connected_edges.vals(), + connected_edges.nnz, stream); + + raft::sparse::COO final_coo(d_alloc, stream); + raft::sparse::linalg::symmetrize(handle, msf.src.data(), msf.dst.data(), + msf.weights.data(), m, n, final_nnz, + final_coo); + + rmm::device_uvector indptr2(m + 1, stream); + + raft::sparse::convert::sorted_coo_to_csr(final_coo.rows(), final_coo.nnz, + indptr2.data(), m, d_alloc, stream); + + value_idx max_offset = 0; + raft::update_host(&max_offset, indptr2.data() + (m - 1), 1, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + max_offset += (final_nnz - max_offset); + + raft::update_device(indptr2.data() + m, &max_offset, 1, stream); + + return raft::mst::mst( + handle, indptr2.data(), final_coo.cols(), final_coo.vals(), m, + final_coo.nnz, color, stream, false, true); +} + +/** + * Constructs an MST and sorts the resulting edges in ascending + * order by their weight. + * + * Hierarchical clustering heavily relies upon the ordering + * and vertices returned in the MST. If the result of the + * MST was actually a minimum-spanning forest, the CSR + * being passed into the MST is not connected. In such a + * case, this graph will be connected by performing a + * KNN across the components. + * @tparam value_idx + * @tparam value_t + * @param[in] handle raft handle + * @param[in] indptr CSR indptr of connectivities graph + * @param[in] indices CSR indices array of connectivities graph + * @param[in] pw_dists CSR weights array of connectivities graph + * @param[in] m number of rows in X / src vertices in connectivities graph + * @param[in] n number of columns in X + * @param[out] mst_src output src edges + * @param[out] mst_dst output dst edges + * @param[out] mst_weight output weights (distances) + */ +template +void build_sorted_mst(const raft::handle_t &handle, const value_t *X, + const value_idx *indptr, const value_idx *indices, + const value_t *pw_dists, size_t m, size_t n, + rmm::device_uvector &mst_src, + rmm::device_uvector &mst_dst, + rmm::device_uvector &mst_weight, + const size_t nnz) { + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + + rmm::device_uvector color(m, stream); + + auto mst_coo = raft::mst::mst( + handle, indptr, indices, pw_dists, (value_idx)m, nnz, color.data(), stream, + false); + + if (linkage::get_n_components(color.data(), m, stream) > 1) { + mst_coo = connect_knn_graph(handle, X, mst_coo, m, n, + color.data()); + + printf("Edges: %d\n", mst_coo.n_edges); + + RAFT_EXPECTS( + mst_coo.n_edges == m - 1, + "MST was not able to connect knn graph in a single iteration."); + } + + sort_coo_by_data(mst_coo.src.data(), mst_coo.dst.data(), + mst_coo.weights.data(), mst_coo.n_edges, stream); + + // TODO: be nice if we could pass these directly into the MST + mst_src.resize(mst_coo.n_edges, stream); + mst_dst.resize(mst_coo.n_edges, stream); + mst_weight.resize(mst_coo.n_edges, stream); + + raft::copy_async(mst_src.data(), mst_coo.src.data(), mst_coo.n_edges, stream); + raft::copy_async(mst_dst.data(), mst_coo.dst.data(), mst_coo.n_edges, stream); + raft::copy_async(mst_weight.data(), mst_coo.weights.data(), mst_coo.n_edges, + stream); +} + +}; // namespace detail +}; // namespace hierarchy +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/hierarchy/single_linkage.hpp b/cpp/include/raft/sparse/hierarchy/single_linkage.hpp new file mode 100644 index 0000000000..a773b31e05 --- /dev/null +++ b/cpp/include/raft/sparse/hierarchy/single_linkage.hpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace raft { +namespace hierarchy { + +static const size_t EMPTY = 0; + +/** + * Single-linkage clustering, capable of constructing a KNN graph to + * scale the algorithm beyond the n^2 memory consumption of implementations + * that use the fully-connected graph of pairwise distances by connecting + * a knn graph when k is not large enough to connect it. + + * @tparam value_idx + * @tparam value_t + * @tparam dist_type method to use for constructing connectivities graph + * @param[in] handle raft handle + * @param[in] X dense input matrix in row-major layout + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] metric distance metrix to use when constructing connectivities graph + * @param[out] out struct containing output dendrogram and cluster assignments + * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect control + * of k. The algorithm will set `k = log(n) + c` + * @param[in] n_clusters number of clusters to assign data samples + */ +template +void single_linkage(const raft::handle_t &handle, const value_t *X, size_t m, + size_t n, raft::distance::DistanceType metric, + linkage_output *out, int c, + size_t n_clusters) { + ASSERT(n_clusters <= m, + "n_clusters must be less than or equal to the number of data points"); + + auto stream = handle.get_stream(); + auto d_alloc = handle.get_device_allocator(); + + rmm::device_uvector indptr(EMPTY, stream); + rmm::device_uvector indices(EMPTY, stream); + rmm::device_uvector pw_dists(EMPTY, stream); + + /** + * 1. Construct distance graph + */ + detail::get_distance_graph( + handle, X, m, n, metric, indptr, indices, pw_dists, c); + + rmm::device_uvector mst_rows(EMPTY, stream); + rmm::device_uvector mst_cols(EMPTY, stream); + rmm::device_uvector mst_data(EMPTY, stream); + + /** + * 2. Construct MST, sorted by weights + */ + detail::build_sorted_mst( + handle, X, indptr.data(), indices.data(), pw_dists.data(), m, n, mst_rows, + mst_cols, mst_data, indices.size()); + + pw_dists.release(); + + /** + * Perform hierarchical labeling + */ + size_t n_edges = mst_rows.size(); + + rmm::device_uvector out_delta(n_edges, stream); + rmm::device_uvector out_size(n_edges, stream); + // Create dendrogram + detail::build_dendrogram_host( + handle, mst_rows.data(), mst_cols.data(), mst_data.data(), n_edges, + out->children, out_delta, out_size); + detail::extract_flattened_clusters(handle, out->labels, out->children, + n_clusters, m); + + out->m = m; + out->n_clusters = n_clusters; + out->n_leaves = m; + out->n_connected_components = 1; +} + +}; // namespace hierarchy +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/linalg/symmetrize.cuh b/cpp/include/raft/sparse/linalg/symmetrize.cuh index bb298008b7..f1b2d1e28c 100644 --- a/cpp/include/raft/sparse/linalg/symmetrize.cuh +++ b/cpp/include/raft/sparse/linalg/symmetrize.cuh @@ -24,13 +24,15 @@ #include #include -#include - +#include #include #include +#include #include #include +#include +#include #include #include @@ -304,6 +306,92 @@ void from_knn_symmetrize_matrix( CUDA_CHECK(cudaPeekAtLastError()); } +template +__global__ void compute_duplicates_diffs(const value_idx *rows, + const value_idx *cols, value_idx *diff, + size_t nnz) { + size_t tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= nnz) return; + + value_idx d = 1; + if (tid == 0 || (rows[tid - 1] == rows[tid] && cols[tid - 1] == cols[tid])) + d = 0; + diff[tid] = d; +} + +template +__global__ void reduce_duplicates_kernel( + const value_idx *src_rows, const value_idx *src_cols, const value_t *src_vals, + const value_idx *index, value_idx *out_rows, value_idx *out_cols, + value_t *out_vals, size_t nnz) { + size_t tid = blockDim.x * blockIdx.x + threadIdx.x; + + if (tid < nnz) { + value_idx idx = index[tid]; + atomicMax(&out_vals[idx], src_vals[tid]); + out_rows[idx] = src_rows[tid]; + out_cols[idx] = src_cols[tid]; + } +} + +/** + * Symmetrizes a COO matrix + */ +template +void symmetrize(const raft::handle_t &handle, const value_idx *rows, + const value_idx *cols, const value_t *vals, size_t m, size_t n, + size_t nnz, raft::sparse::COO &out) { + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + + auto exec_policy = rmm::exec_policy(stream); + + // copy rows to cols and cols to rows + rmm::device_uvector symm_rows(nnz * 2, stream); + rmm::device_uvector symm_cols(nnz * 2, stream); + rmm::device_uvector symm_vals(nnz * 2, stream); + + raft::copy_async(symm_rows.data(), rows, nnz, stream); + raft::copy_async(symm_rows.data() + nnz, cols, nnz, stream); + raft::copy_async(symm_cols.data(), cols, nnz, stream); + raft::copy_async(symm_cols.data() + nnz, rows, nnz, stream); + + raft::copy_async(symm_vals.data(), vals, nnz, stream); + raft::copy_async(symm_vals.data() + nnz, vals, nnz, stream); + + // sort COO + raft::sparse::op::coo_sort(m, n, nnz * 2, symm_rows.data(), symm_cols.data(), + symm_vals.data(), d_alloc, stream); + + // compute diffs & take exclusive scan + rmm::device_uvector diff((nnz * 2) + 1, stream); + + CUDA_CHECK(cudaMemsetAsync(diff.data(), 0, + ((nnz * 2) + 1) * sizeof(value_idx), stream)); + + compute_duplicates_diffs<<>>(symm_rows.data(), symm_cols.data(), + diff.data(), nnz * 2); + + thrust::exclusive_scan(exec_policy, diff.data(), diff.data() + diff.size(), + diff.data()); + + // compute final size + value_idx size = 0; + raft::update_host(&size, diff.data() + (diff.size() - 1), 1, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + size++; + + out.allocate(size, m, n, true, stream); + + // perform reduce + reduce_duplicates_kernel<<>>( + symm_rows.data(), symm_cols.data(), symm_vals.data(), diff.data() + 1, + out.rows(), out.cols(), out.vals(), nnz * 2); +} + }; // end NAMESPACE linalg }; // end NAMESPACE sparse -}; // end NAMESPACE raft \ No newline at end of file +}; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/selection/connect_components.cuh b/cpp/include/raft/sparse/selection/connect_components.cuh new file mode 100644 index 0000000000..a1944fc0b3 --- /dev/null +++ b/cpp/include/raft/sparse/selection/connect_components.cuh @@ -0,0 +1,471 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include + +#include + +namespace raft { +namespace linkage { + +/** + * \brief A key identifier paired with a corresponding value + */ +template +struct KeyValuePair { + typedef _Key Key; ///< Key data type + typedef _Value Value; ///< Value data type + + Key key; ///< Item key + Value value; ///< Item value + + /// Constructor + __host__ __device__ __forceinline__ KeyValuePair() {} + + /// Copy Constructor + __host__ __device__ __forceinline__ + KeyValuePair(cub::KeyValuePair<_Key, _Value> kvp) + : key(kvp.key), value(kvp.value) {} + + /// Constructor + __host__ __device__ __forceinline__ KeyValuePair(Key const &key, + Value const &value) + : key(key), value(value) {} + + /// Inequality operator + __host__ __device__ __forceinline__ bool operator!=(const KeyValuePair &b) { + return (value != b.value) || (key != b.key); + } +}; + +/** + * Functor with reduction ops for performing fused 1-nn + * computation and guaranteeing only cross-component + * neighbors are considered. + * @tparam value_idx + * @tparam value_t + */ +template +struct FixConnectivitiesRedOp { + value_idx *colors; + value_idx m; + + FixConnectivitiesRedOp(value_idx *colors_, value_idx m_) + : colors(colors_), m(m_){}; + + typedef typename cub::KeyValuePair KVP; + DI void operator()(value_idx rit, KVP *out, const KVP &other) { + if (rit < m && other.value < out->value && + colors[rit] != colors[other.key]) { + out->key = other.key; + out->value = other.value; + } + } + + DI KVP operator()(value_idx rit, const KVP &a, const KVP &b) { + if (rit < m && a.value < b.value && colors[rit] != colors[a.key]) { + return a; + } else + return b; + } + + DI void init(value_t *out, value_t maxVal) { *out = maxVal; } + DI void init(KVP *out, value_t maxVal) { + out->key = -1; + out->value = maxVal; + } +}; + +/** + * Count the unique vertices adjacent to each component. + * This is essentially a count_unique_by_key. + */ +template +__global__ void count_components_by_color_kernel(value_idx *out_indptr, + const value_idx *colors_indptr, + const value_idx *colors_nn, + value_idx n_colors) { + value_idx tid = threadIdx.x; + value_idx row = blockIdx.x; + + __shared__ extern value_idx count_smem[]; + + value_idx start_offset = colors_indptr[row]; + value_idx stop_offset = colors_indptr[row + 1]; + + for (value_idx i = tid; i < n_colors; i += blockDim.x) { + count_smem[i] = 0; + } + + __syncthreads(); + + for (value_idx i = tid; i < (stop_offset - start_offset); i += blockDim.x) { + count_smem[colors_nn[start_offset + i]] = 1; + } + + __syncthreads(); + + for (value_idx i = tid; i < n_colors; i += blockDim.x) { + // TODO: Warp-level reduction + atomicAdd(out_indptr + row, count_smem[i] > 0); + } +} + +/** + * Compute indptr for the min set of unique components that neighbor the components + * of each source vertex + * @tparam value_idx + * @param[out] out_indptr output indptr + * @param[in] colors_indptr indptr of components for each source vertex + * @param[in] colors_nn array of components for the 1-nn around each source vertex + * @param[in] n_colors number of components + * @param[in] stream cuda stream for which to order cuda operations + */ +template +void count_components_by_color(value_idx *out_indptr, + const value_idx *colors_indptr, + const value_idx *colors_nn, value_idx n_colors, + cudaStream_t stream) { + count_components_by_color_kernel<<>>( + out_indptr, colors_indptr, colors_nn, n_colors); +} + +/** + * colors_nn is not assumed to be sorted wrt colors_indptr + * so we need to perform atomic reductions in each thread. + */ +template +__global__ void min_components_by_color_kernel( + value_idx *out_cols, value_t *out_vals, value_idx *out_rows, + const value_idx *out_indptr, const value_idx *colors_indptr, + const value_idx *colors_nn, const value_idx *indices, + const cub::KeyValuePair *kvp, value_idx n_colors) { + __shared__ extern char min_smem[]; + + int *mutex = (int *)min_smem; + + cub::KeyValuePair *min = + (cub::KeyValuePair *)(mutex + n_colors); + value_idx *src_inds = (value_idx *)(min + n_colors); + + value_idx start_offset = colors_indptr[blockIdx.x]; + value_idx stop_offset = colors_indptr[blockIdx.x + 1]; + + // initialize + for (value_idx i = threadIdx.x; i < n_colors; i += blockDim.x) { + mutex[i] = 0; + auto skvp = min + i; + skvp->key = -1; + skvp->value = std::numeric_limits::max(); + } + + __syncthreads(); + + for (value_idx i = threadIdx.x; i < (stop_offset - start_offset); + i += blockDim.x) { + value_idx new_color = colors_nn[start_offset + i]; + while (atomicCAS(mutex + new_color, 0, 1) == 1) + ; + __threadfence(); + auto cur_kvp = kvp[start_offset + i]; + if (cur_kvp.value < min[new_color].value) { + src_inds[new_color] = indices[start_offset + i]; + min[new_color].key = cur_kvp.key; + min[new_color].value = cur_kvp.value; + } + __threadfence(); + atomicCAS(mutex + new_color, 1, 0); + } + + __syncthreads(); + + // printf("block %d thread %d did final sync\n", blockIdx.x, threadIdx.x); + + value_idx out_offset = out_indptr[blockIdx.x]; + + // TODO: Do this across threads, using an atomic counter for each color + if (threadIdx.x == 0) { + value_idx cur_offset = 0; + + for (value_idx i = 0; i < n_colors; i++) { + auto min_color = min[i]; + if (min_color.key > -1) { + out_rows[out_offset + cur_offset] = src_inds[i]; + out_cols[out_offset + cur_offset] = min_color.key; + out_vals[out_offset + cur_offset] = min_color.value; + cur_offset += 1; + } + } + } +} + +/** + * Computes the min set of unique components that neighbor the + * components of each source vertex. + * @tparam value_idx + * @tparam value_t + * @param[out] coo output edge list + * @param[in] out_indptr output indptr for ordering edge list + * @param[in] colors_indptr indptr of source components + * @param[in] colors_nn components of nearest neighbors to each source component + * @param[in] indices indices of source vertices for each component + * @param[in] kvp indices and distances of each destination vertex for each component + * @param[in] n_colors number of components + * @param[in] stream cuda stream for which to order cuda operations + */ +template +void min_components_by_color(raft::sparse::COO &coo, + const value_idx *out_indptr, + const value_idx *colors_indptr, + const value_idx *colors_nn, + const value_idx *indices, + const cub::KeyValuePair *kvp, + value_idx n_colors, cudaStream_t stream) { + int smem_bytes = (n_colors * sizeof(int)) + (n_colors * sizeof(kvp)) + + ((n_colors + 1) * sizeof(value_idx)); + + min_components_by_color_kernel<<>>( + coo.cols(), coo.vals(), coo.rows(), out_indptr, colors_indptr, colors_nn, + indices, kvp, n_colors); +} + +/** + * Gets max maximum value (max number of components) from array of + * components. Note that this does not assume the components are + * drawn from a monotonically increasing set. + * @tparam value_idx + * @param[in] colors array of components + * @param[in] n_rows size of components array + * @param[in] stream cuda stream for which to order cuda operations + * @return total number of components + */ +template +value_idx get_n_components(value_idx *colors, size_t n_rows, + cudaStream_t stream) { + thrust::device_ptr t_colors = thrust::device_pointer_cast(colors); + return *(thrust::max_element(thrust::cuda::par.on(stream), t_colors, + t_colors + n_rows)) + + 1; +} + +/** + * Build CSR indptr array for sorted edge list mapping components of source + * vertices to the components of their nearest neighbor vertices + * @tparam value_idx + * @param[out] degrees output indptr array + * @param[in] components_indptr indptr of original CSR array of components + * @param[in] nn_components indptr of nearest neighbors CSR array of components + * @param[in] n_components size of nn_components + * @param[in] stream cuda stream for which to order cuda operations + */ +template +void build_output_colors_indptr(value_idx *degrees, + const value_idx *components_indptr, + const value_idx *nn_components, + value_idx n_components, cudaStream_t stream) { + CUDA_CHECK(cudaMemsetAsync(degrees, 0, (n_components + 1) * sizeof(value_idx), + stream)); + + /** + * Create COO array by first computing CSR indptr w/ degrees of each + * color followed by COO row/col/val arrays. + */ + // map each component to a separate warp, perform warp reduce by key to find + // number of unique components in output. + + count_components_by_color(degrees, components_indptr, nn_components, + n_components, stream); + + thrust::device_ptr t_degrees = + thrust::device_pointer_cast(degrees); + thrust::exclusive_scan(thrust::cuda::par.on(stream), t_degrees, + t_degrees + n_components + 1, t_degrees); +} + +/** + * Functor to look up a component for a vertex + * @tparam value_idx + * @tparam value_t + */ +template +struct LookupColorOp { + value_idx *colors; + + LookupColorOp(value_idx *colors_) : colors(colors_) {} + + DI value_idx operator()(const cub::KeyValuePair &kvp) { + return colors[kvp.key]; + } +}; + +/** + * Compute the cross-component 1-nearest neighbors for each row in X using + * the given array of components + * @tparam value_idx + * @tparam value_t + * @param[out] kvp mapping of closest neighbor vertex and distance for each vertex in the given array of components + * @param[out] nn_colors components of nearest neighbors for each vertex + * @param[in] colors components of each vertex + * @param[in] X original dense data + * @param[in] n_rows number of rows in original dense data + * @param[in] n_cols number of columns in original dense data + * @param[in] d_alloc device allocator to use + * @param[in] stream cuda stream for which to order cuda operations + */ +template +void perform_1nn(cub::KeyValuePair *kvp, + value_idx *nn_colors, value_idx *colors, const value_t *X, + size_t n_rows, size_t n_cols, + std::shared_ptr d_alloc, + cudaStream_t stream) { + rmm::device_uvector workspace(n_rows, stream); + rmm::device_uvector x_norm(n_rows, stream); + + raft::linalg::rowNorm(x_norm.data(), X, n_cols, n_rows, raft::linalg::L2Norm, + true, stream); + + FixConnectivitiesRedOp red_op(colors, n_rows); + raft::distance::fusedL2NN, + value_idx>(kvp, X, X, x_norm.data(), x_norm.data(), + n_rows, n_rows, n_cols, workspace.data(), + red_op, red_op, true, true, stream); + + LookupColorOp extract_colors_op(colors); + thrust::transform(thrust::cuda::par.on(stream), kvp, kvp + n_rows, nn_colors, + extract_colors_op); +} + +/** + * Sort nearest neighboring components wrt component of source vertices + * @tparam value_idx + * @tparam value_t + * @param[inout] colors components array of source vertices + * @param[inout] nn_colors nearest neighbors components array + * @param[inout] kvp nearest neighbor source vertex / distance array + * @param[inout] src_indices array of source vertex indices which will become arg_sort + * indices + * @param n_rows number of components in `colors` + * @param stream stream for which to order CUDA operations + */ +template +void sort_by_color(value_idx *colors, value_idx *nn_colors, + cub::KeyValuePair *kvp, + value_idx *src_indices, size_t n_rows, cudaStream_t stream) { + thrust::counting_iterator arg_sort_iter(0); + thrust::copy(thrust::cuda::par.on(stream), arg_sort_iter, + arg_sort_iter + n_rows, src_indices); + + auto keys = thrust::make_zip_iterator(thrust::make_tuple(colors)); + auto vals = thrust::make_zip_iterator( + thrust::make_tuple((raft::linkage::KeyValuePair *)kvp, + src_indices, nn_colors)); + + // get all the colors in contiguous locations so we can map them to warps. + thrust::sort_by_key(thrust::cuda::par.on(stream), keys, keys + n_rows, vals); +} + +/** + * Connects the components of an otherwise unconnected knn graph + * by computing a 1-nn to neighboring components of each data point + * (e.g. component(nn) != component(self)) and reducing the results to + * include the set of smallest destination components for each source + * component. The result will not necessarily contain + * n_components^2 - n_components number of elements because many components + * will likely not be contained in the neighborhoods of 1-nns. + * @tparam value_idx + * @tparam value_t + * @param[in] handle raft handle + * @param[out] out output edge list containing nearest cross-component + * edges. + * @param[in] X original (row-major) dense matrix for which knn graph should be constructed. + * @param[in] colors array containing component number for each row of X + * @param n_rows number of rows in X + * @param n_cols number of cols in X + */ +template +void connect_components(const raft::handle_t &handle, + raft::sparse::COO &out, + const value_t *X, value_idx *colors, size_t n_rows, + size_t n_cols) { + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + + value_idx n_components = get_n_components(colors, n_rows, stream); + + /** + * First compute 1-nn for all colors where the color of each data point + * is guaranteed to be != color of its nearest neighbor. + */ + rmm::device_uvector nn_colors(n_rows, stream); + rmm::device_uvector> temp_inds_dists( + n_rows, stream); + rmm::device_uvector src_indices(n_rows, stream); + rmm::device_uvector color_neigh_degrees(n_components + 1, stream); + rmm::device_uvector colors_indptr(n_components + 1, stream); + + perform_1nn(temp_inds_dists.data(), nn_colors.data(), colors, X, n_rows, + n_cols, d_alloc, stream); + + /** + * Sort data points by color (neighbors are not sorted) + */ + // max_color + 1 = number of connected components + // sort nn_colors by key w/ original colors + sort_by_color(colors, nn_colors.data(), temp_inds_dists.data(), + src_indices.data(), n_rows, stream); + + // create an indptr array for newly sorted colors + raft::sparse::convert::sorted_coo_to_csr(colors, n_rows, colors_indptr.data(), + n_components + 1, d_alloc, stream); + + // create output degree array for closest components per row + build_output_colors_indptr(color_neigh_degrees.data(), colors_indptr.data(), + nn_colors.data(), n_components, stream); + + value_idx nnz; + raft::update_host(&nnz, color_neigh_degrees.data() + n_components, 1, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + raft::sparse::COO min_edges(d_alloc, stream, nnz); + min_components_by_color(min_edges, color_neigh_degrees.data(), + colors_indptr.data(), nn_colors.data(), + src_indices.data(), temp_inds_dists.data(), + n_components, stream); + + // symmetrize + raft::sparse::linalg::symmetrize(handle, min_edges.rows(), min_edges.cols(), + min_edges.vals(), n_rows, n_rows, nnz, out); +} + +}; // end namespace linkage +}; // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/selection/knn_graph.cuh b/cpp/include/raft/sparse/selection/knn_graph.cuh new file mode 100644 index 0000000000..a78fb8d0f6 --- /dev/null +++ b/cpp/include/raft/sparse/selection/knn_graph.cuh @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include + +#include + +namespace raft { +namespace sparse { +namespace selection { + +/** + * Fills indices array of pairwise distance array + * @tparam value_idx + * @param indices + * @param m + */ +template +__global__ void fill_indices(value_idx *indices, size_t m, size_t nnz) { + value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x; + if (tid >= nnz) return; + value_idx v = tid / m; + indices[tid] = v; +} + +template +value_idx build_k(value_idx n_samples, int c) { + // from "kNN-MST-Agglomerative: A fast & scalable graph-based data clustering + // approach on GPU" + return min(n_samples, + max((value_idx)2, (value_idx)floor(log2(n_samples)) + c)); +} + +template +__global__ void conv_indices_kernel(in_t *inds, out_t *out, size_t nnz) { + size_t tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= nnz) return; + out_t v = inds[tid]; + out[tid] = v; +} + +template +void conv_indices(in_t *inds, out_t *out, size_t size, cudaStream_t stream) { + size_t blocks = ceildiv(size, (size_t)tpb); + conv_indices_kernel<<>>(inds, out, size); +} + +/** + * Constructs a (symmetrized) knn graph edge list from + * dense input vectors. + * + * Note: The resulting KNN graph is not guaranteed to be connected. + * + * @tparam value_idx + * @tparam value_t + * @param[in] handle raft handle + * @param[in] X dense matrix of input data samples and observations + * @param[in] m number of data samples (rows) in X + * @param[in] n number of observations (columns) in X + * @param[in] metric distance metric to use when constructing neighborhoods + * @param[out] out output edge list + * @param c + */ +template +void knn_graph(const handle_t &handle, const value_t *X, size_t m, size_t n, + distance::DistanceType metric, + raft::sparse::COO &out, int c = 15) { + int k = build_k(m, c); + + printf("K=%d\n", k); + + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + + size_t nnz = m * k; + + rmm::device_uvector rows(nnz, stream); + rmm::device_uvector indices(nnz, stream); + rmm::device_uvector data(nnz, stream); + + size_t blocks = ceildiv(nnz, (size_t)256); + fill_indices<<>>(rows.data(), k, nnz); + + std::vector inputs; + inputs.push_back(const_cast(X)); + + std::vector sizes; + sizes.push_back(m); + + // This is temporary. Once faiss is updated, we should be able to + // pass value_idx through to knn. + rmm::device_uvector int64_indices(nnz, stream); + + uint32_t knn_start = curTimeMillis(); + raft::spatial::knn::brute_force_knn( + handle, inputs, sizes, n, const_cast(X), m, int64_indices.data(), + data.data(), k, true, true, nullptr, metric); + + // convert from current knn's 64-bit to 32-bit. + conv_indices(int64_indices.data(), indices.data(), nnz, stream); + + raft::sparse::linalg::symmetrize(handle, rows.data(), indices.data(), + data.data(), m, k, nnz, out); +} + +}; // namespace selection +}; // namespace sparse +}; // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp b/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp index e686fff587..ff0df5f6e5 100644 --- a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp +++ b/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp @@ -330,6 +330,8 @@ void brute_force_knn_impl( userStream, trans.data()); } + raft::print_device_vector("before sqrt", res_D, n * k, std::cout); + // Perform necessary post-processing if ((m == faiss::MetricType::METRIC_L2 || m == faiss::MetricType::METRIC_Lp) && @@ -344,6 +346,10 @@ void brute_force_knn_impl( [p] __device__(float input) { return powf(input, p); }, userStream); } + CUDA_CHECK(cudaStreamSynchronize(userStream)); + + raft::print_device_vector("after sqrt", res_D, n * k, std::cout); + query_metric_processor->revert(search_items); query_metric_processor->postprocess(out_D); for (size_t i = 0; i < input.size(); i++) { diff --git a/cpp/test/sparse/connect_components.cu b/cpp/test/sparse/connect_components.cu new file mode 100644 index 0000000000..8d264252e0 --- /dev/null +++ b/cpp/test/sparse/connect_components.cu @@ -0,0 +1,564 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../test_utils.h" + +namespace raft { +namespace sparse { + +using namespace std; + +template +struct ConnectComponentsInputs { + value_idx n_row; + value_idx n_col; + std::vector data; + + std::vector expected_labels; + + int n_clusters; + + int c; +}; + +template +class ConnectComponentsTest : public ::testing::TestWithParam< + ConnectComponentsInputs> { + protected: + void basicTest() { + raft::handle_t handle; + + auto d_alloc = handle.get_device_allocator(); + auto stream = handle.get_stream(); + + params = ::testing::TestWithParam< + ConnectComponentsInputs>::GetParam(); + + out_edges = new raft::sparse::COO( + handle.get_device_allocator(), handle.get_stream()); + + rmm::device_uvector data(params.n_row * params.n_col, + handle.get_stream()); + + // Allocate result labels and expected labels on device + raft::allocate(labels, params.n_row); + raft::allocate(labels_ref, params.n_row); + + raft::copy(data.data(), params.data.data(), data.size(), + handle.get_stream()); + raft::copy(labels_ref, params.expected_labels.data(), params.n_row, + handle.get_stream()); + + rmm::device_uvector indptr(params.n_row + 1, stream); + + /** + * 1. Construct knn graph + */ + raft::sparse::COO knn_graph_coo(d_alloc, stream); + + raft::sparse::selection::knn_graph( + handle, data.data(), params.n_row, params.n_col, + raft::distance::DistanceType::L2Unexpanded, knn_graph_coo, params.c); + + raft::sparse::convert::sorted_coo_to_csr(&knn_graph_coo, indptr.data(), + d_alloc, stream); + + value_idx max_offset = 0; + raft::update_host(&max_offset, indptr.data() + (params.n_row - 1), 1, + stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + max_offset += (knn_graph_coo.nnz - max_offset); + + raft::update_device(indptr.data() + params.n_row, &max_offset, 1, stream); + + /** + * 2. Construct MST, sorted by weights + */ + rmm::device_uvector colors(params.n_row, stream); + + auto mst_coo = raft::mst::mst( + handle, indptr.data(), knn_graph_coo.cols(), knn_graph_coo.vals(), + params.n_row, knn_graph_coo.nnz, colors.data(), stream, false); + + raft::print_device_vector("colors", colors.data(), colors.size(), + std::cout); + + /** + * 3. connect_components to fix connectivities + */ + raft::linkage::connect_components( + handle, *out_edges, data.data(), colors.data(), params.n_row, + params.n_col); + + int final_nnz = out_edges->nnz + mst_coo.n_edges; + + mst_coo.src.resize(final_nnz, stream); + mst_coo.dst.resize(final_nnz, stream); + mst_coo.weights.resize(final_nnz, stream); + + printf("New nnz: %d\n", final_nnz); + + /** + * Construct final edge list + */ + raft::copy_async(mst_coo.src.data() + mst_coo.n_edges, out_edges->rows(), + out_edges->nnz, stream); + raft::copy_async(mst_coo.dst.data() + mst_coo.n_edges, out_edges->cols(), + out_edges->nnz, stream); + raft::copy_async(mst_coo.weights.data() + mst_coo.n_edges, + out_edges->vals(), out_edges->nnz, stream); + + raft::sparse::COO final_coo(d_alloc, stream); + raft::sparse::linalg::symmetrize( + handle, mst_coo.src.data(), mst_coo.dst.data(), mst_coo.weights.data(), + params.n_row, params.n_col, final_nnz, final_coo); + + rmm::device_uvector indptr2(params.n_row + 1, stream); + + raft::sparse::convert::sorted_coo_to_csr(final_coo.rows(), final_coo.nnz, + indptr2.data(), params.n_row, + d_alloc, stream); + + max_offset = 0; + raft::update_host(&max_offset, indptr2.data() + (params.n_row - 1), 1, + stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + max_offset += (final_nnz - max_offset); + + raft::update_device(indptr2.data() + params.n_row, &max_offset, 1, stream); + + auto output_mst = raft::mst::mst( + handle, indptr2.data(), final_coo.cols(), final_coo.vals(), params.n_row, + final_coo.nnz, colors.data(), stream, false, true); + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + printf("output edges: %d\n", output_mst.n_edges); + + final_edges = output_mst.n_edges; + } + + void SetUp() override { basicTest(); } + + void TearDown() override { + // CUDA_CHECK(cudaFree(labels)); + // CUDA_CHECK(cudaFree(labels_ref)); + } + + protected: + ConnectComponentsInputs params; + value_idx *labels, *labels_ref; + raft::sparse::COO *out_edges; + + value_idx final_edges; +}; + +const std::vector> fix_conn_inputsf2 = { + // Test n_clusters == n_points + {10, + 5, + {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, + 0.77782677, 0.43772379, 0.4035871, 0.3282796, 0.47544681, 0.59862974, + 0.12319357, 0.06239463, 0.28200272, 0.1345717, 0.50498218, 0.5113505, + 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, + 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, + 0.84854131, 0.28890216, 0.85267903, 0.74703138, 0.83842071, 0.34942792, + 0.27864171, 0.70911132, 0.21338564, 0.32035554, 0.73788331, 0.46926692, + 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, + 0.76166195, 0.66613745}, + {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + 10, + -1}, + // Test n_points == 100 + {100, + 10, + {6.26168372e-01, 9.30437651e-01, 6.02450208e-01, + 2.73025296e-01, 9.53050619e-01, 3.32164396e-01, + 6.88942598e-01, 5.79163537e-01, 6.70341547e-01, + 2.70140602e-02, 9.30429671e-01, 7.17721157e-01, + 9.89948537e-01, 7.75253347e-01, 1.34491522e-02, + 2.48522428e-02, 3.51413378e-01, 7.64405834e-01, + 7.86373507e-01, 7.18748577e-01, 8.66998621e-01, + 6.80316582e-01, 2.51288712e-01, 4.91078420e-01, + 3.76246281e-01, 4.86828710e-01, 5.67464772e-01, + 5.30734742e-01, 8.99478296e-01, 7.66699088e-01, + 9.49339111e-01, 3.55248484e-01, 9.06046929e-01, + 4.48407772e-01, 6.96395305e-01, 2.44277335e-01, + 7.74840000e-01, 5.21046603e-01, 4.66423971e-02, + 5.12019638e-02, 8.95019614e-01, 5.28956953e-01, + 4.31536306e-01, 5.83857744e-01, 4.41787364e-01, + 4.68656523e-01, 5.73971433e-01, 6.79989654e-01, + 3.19650588e-01, 6.12579596e-01, 6.49126442e-02, + 8.39131142e-01, 2.85252117e-01, 5.84848929e-01, + 9.46507115e-01, 8.58440748e-01, 3.61528940e-01, + 2.44215959e-01, 3.80101125e-01, 4.57128957e-02, + 8.82216988e-01, 8.31498633e-01, 7.23474381e-01, + 7.75788607e-01, 1.40864146e-01, 6.62092382e-01, + 5.13985168e-01, 3.00686418e-01, 8.70109949e-01, + 2.43187753e-01, 2.89391938e-01, 2.84214238e-01, + 8.70985521e-01, 8.77491176e-01, 6.72537226e-01, + 3.30929686e-01, 1.85934324e-01, 9.16222614e-01, + 6.18239142e-01, 2.64768597e-01, 5.76145451e-01, + 8.62961369e-01, 6.84757925e-01, 7.60549082e-01, + 1.27645356e-01, 4.51004673e-01, 3.92292980e-01, + 4.63170803e-01, 4.35449330e-02, 2.17583404e-01, + 5.71832605e-02, 2.06763039e-01, 3.70116249e-01, + 2.09750028e-01, 6.17283019e-01, 8.62549231e-01, + 9.84156240e-02, 2.66249156e-01, 3.87635103e-01, + 2.85591012e-02, 4.24826068e-01, 4.45795088e-01, + 6.86227676e-01, 1.08848960e-01, 5.96731841e-02, + 3.71770228e-01, 1.91548833e-01, 6.95136078e-01, + 9.00700636e-01, 8.76363105e-01, 2.67334632e-01, + 1.80619709e-01, 7.94060419e-01, 1.42854171e-02, + 1.09372387e-01, 8.74028108e-01, 6.46403232e-01, + 4.86588834e-01, 5.93446175e-02, 6.11886291e-01, + 8.83865057e-01, 3.15879821e-01, 2.27043992e-01, + 9.76764951e-01, 6.15620336e-01, 9.76199360e-01, + 2.40548962e-01, 3.21795663e-01, 8.75087904e-02, + 8.11234663e-01, 6.96070480e-01, 8.12062321e-01, + 1.21958818e-01, 3.44348628e-02, 8.72630414e-01, + 3.06162776e-01, 1.76043529e-02, 9.45894971e-01, + 5.33896401e-01, 6.21642973e-01, 4.93062535e-01, + 4.48984262e-01, 2.24560379e-01, 4.24052195e-02, + 4.43447610e-01, 8.95646149e-01, 6.05220676e-01, + 1.81840491e-01, 9.70831206e-01, 2.12563586e-02, + 6.92582693e-01, 7.55946922e-01, 7.95086143e-01, + 6.05328941e-01, 3.99350764e-01, 4.32846636e-01, + 9.81114529e-01, 4.98266428e-01, 6.37127930e-03, + 1.59085889e-01, 6.34682067e-05, 5.59429440e-01, + 7.38827633e-01, 8.93214770e-01, 2.16494306e-01, + 9.35430573e-02, 4.75665868e-02, 7.80503518e-01, + 7.86240041e-01, 7.06854594e-01, 2.13725879e-02, + 7.68246091e-01, 4.50234808e-01, 5.21231104e-01, + 5.01989826e-03, 4.22081572e-02, 1.65337732e-01, + 8.54134740e-01, 4.99430262e-01, 8.94525601e-01, + 1.14028379e-01, 3.69739861e-01, 1.32955599e-01, + 2.65563824e-01, 2.52811151e-01, 1.44792843e-01, + 6.88449594e-01, 4.44921417e-01, 8.23296587e-01, + 1.93266317e-01, 1.19033309e-01, 1.36368966e-01, + 3.42600285e-01, 5.64505195e-01, 5.57594559e-01, + 7.44257892e-01, 8.38231569e-02, 4.11548847e-01, + 3.21010077e-01, 8.55081359e-01, 4.30105779e-01, + 1.16229135e-01, 9.87731964e-02, 3.14712335e-01, + 4.50880592e-01, 2.72289598e-01, 6.31615256e-01, + 8.97432958e-01, 4.44764250e-01, 8.03776440e-01, + 2.68767748e-02, 2.43374608e-01, 4.02141103e-01, + 4.98881209e-01, 5.33173003e-01, 8.82890436e-01, + 7.16149148e-01, 4.19664401e-01, 2.29335357e-01, + 2.88637806e-01, 3.44696803e-01, 6.78171906e-01, + 5.69849716e-01, 5.86454477e-01, 3.54474989e-01, + 9.03876540e-01, 6.45980000e-01, 6.34887593e-01, + 7.88039746e-02, 2.04814126e-01, 7.82251754e-01, + 2.43147074e-01, 7.50951808e-01, 1.72799092e-02, + 2.95349590e-01, 6.57991826e-01, 8.81214312e-01, + 5.73970708e-01, 2.77610881e-01, 1.82155097e-01, + 7.69797417e-02, 6.44792402e-01, 9.46950998e-01, + 7.73064845e-01, 6.04733624e-01, 5.80094567e-01, + 1.67498426e-01, 2.66514296e-01, 6.50140368e-01, + 1.91170299e-01, 2.08752199e-01, 3.01664091e-01, + 9.85033484e-01, 2.92909152e-01, 8.65816607e-01, + 1.85222119e-01, 2.28814559e-01, 1.34286382e-02, + 2.89234322e-01, 8.18668708e-01, 4.71706924e-01, + 9.23199803e-01, 2.80879188e-01, 1.47319284e-01, + 4.13915748e-01, 9.31274932e-02, 6.66322195e-01, + 9.66953974e-01, 3.19405786e-01, 6.69486551e-01, + 5.03096313e-02, 6.95225201e-01, 5.78469859e-01, + 6.29481655e-01, 1.39252534e-01, 1.22564968e-01, + 6.80663678e-01, 6.34607157e-01, 6.42765834e-01, + 1.57127410e-02, 2.92132086e-01, 5.24423878e-01, + 4.68676824e-01, 2.86003928e-01, 7.18608322e-01, + 8.95617933e-01, 5.48844309e-01, 1.74517278e-01, + 5.24379196e-01, 2.13526524e-01, 5.88375435e-01, + 9.88560185e-01, 4.17435771e-01, 6.14438688e-01, + 9.53760881e-01, 5.27151288e-01, 7.03017278e-01, + 3.44448559e-01, 4.47059676e-01, 2.83414901e-01, + 1.98979011e-01, 4.24917361e-01, 5.73172761e-01, + 2.32398853e-02, 1.65887230e-01, 4.05552785e-01, + 9.29665524e-01, 2.26135696e-01, 9.20563384e-01, + 7.65259963e-01, 4.54820075e-01, 8.97710267e-01, + 3.78559302e-03, 9.15219382e-01, 3.55705698e-01, + 6.94905124e-01, 8.58540202e-01, 3.89790666e-01, + 2.49478206e-01, 7.93679304e-01, 4.75830027e-01, + 4.40425353e-01, 3.70579459e-01, 1.40578049e-01, + 1.70386675e-01, 7.04056121e-01, 4.85963102e-01, + 9.68450060e-01, 6.77178001e-01, 2.65934654e-01, + 2.58915007e-01, 6.70052890e-01, 2.61945109e-01, + 8.46207759e-01, 1.01928951e-01, 2.85611334e-01, + 2.45776933e-01, 2.66658783e-01, 3.71724077e-01, + 4.34319025e-01, 4.24407347e-01, 7.15417683e-01, + 8.07997684e-01, 1.64296275e-01, 6.01638065e-01, + 8.60606804e-02, 2.68719187e-01, 5.11764101e-01, + 9.75844338e-01, 7.81226782e-01, 2.20925515e-01, + 7.18135040e-01, 9.82395577e-01, 8.39160243e-01, + 9.08058083e-01, 6.88010677e-01, 8.14271847e-01, + 5.12460821e-01, 1.17311345e-01, 5.96075228e-01, + 9.17455497e-01, 2.12052706e-01, 7.04074603e-01, + 8.72872565e-02, 8.76047818e-01, 6.96235046e-01, + 8.54801557e-01, 2.49729159e-01, 9.76594604e-01, + 2.87386363e-01, 2.36461559e-02, 9.94075254e-01, + 4.25193986e-01, 7.61869994e-01, 5.13334255e-01, + 6.44711165e-02, 8.92156689e-01, 3.55235167e-01, + 1.08154647e-01, 8.78446825e-01, 2.43833016e-01, + 9.23071293e-01, 2.72724115e-01, 9.46631338e-01, + 3.74510294e-01, 4.08451278e-02, 9.78392777e-01, + 3.65079221e-01, 6.37199516e-01, 5.51144906e-01, + 5.25978080e-01, 1.42803678e-01, 4.05451674e-01, + 7.79788219e-01, 6.26009784e-01, 3.35249497e-01, + 1.43159543e-02, 1.80363779e-01, 5.05096904e-01, + 2.82619947e-01, 5.83561392e-01, 3.10951324e-01, + 8.73223968e-01, 4.38545619e-01, 4.81348800e-01, + 6.68497085e-01, 3.79345401e-01, 9.58832501e-01, + 1.89869550e-01, 2.34083070e-01, 2.94066207e-01, + 5.74892667e-02, 6.92106828e-02, 9.61127686e-02, + 6.72650672e-02, 8.47345378e-01, 2.80916761e-01, + 7.32177357e-03, 9.80785961e-01, 5.73192225e-02, + 8.48781331e-01, 8.83225408e-01, 7.34398275e-01, + 7.70381941e-01, 6.20778343e-01, 8.96822048e-01, + 5.40732486e-01, 3.69704071e-01, 5.77305837e-01, + 2.08221827e-01, 7.34275341e-01, 1.06110900e-01, + 3.49496706e-01, 8.34948910e-01, 1.56403291e-02, + 6.78576376e-01, 8.96141268e-01, 5.94835119e-01, + 1.43943153e-01, 3.49618530e-01, 2.10440392e-01, + 3.46585620e-01, 1.05153093e-01, 3.45446174e-01, + 2.72177079e-01, 7.07946300e-01, 4.33717726e-02, + 3.31232203e-01, 3.91874320e-01, 4.76338141e-01, + 6.22777789e-01, 2.95989228e-02, 4.32855769e-01, + 7.61049310e-01, 3.63279149e-01, 9.47210350e-01, + 6.43721247e-01, 6.58025802e-01, 1.05247633e-02, + 5.29974442e-01, 7.30675767e-01, 4.30041079e-01, + 6.62634841e-01, 8.25936616e-01, 9.91253704e-01, + 6.79399281e-01, 5.44177006e-01, 7.52876048e-01, + 3.32139049e-01, 7.98732398e-01, 7.38865223e-01, + 9.16055132e-01, 6.11736493e-01, 9.63672879e-01, + 1.83778839e-01, 7.27558919e-02, 5.91602822e-01, + 3.25235484e-01, 2.34741217e-01, 9.52346277e-01, + 9.18556407e-01, 9.35373324e-01, 6.89209070e-01, + 2.56049054e-01, 6.17975395e-01, 7.82285691e-01, + 9.84983432e-01, 6.62322741e-01, 2.04144457e-01, + 3.98446577e-01, 1.38918297e-01, 3.05919921e-01, + 3.14043787e-01, 5.91072666e-01, 7.44703771e-01, + 8.92272567e-01, 9.78017873e-01, 9.01203161e-01, + 1.41526372e-01, 4.14878484e-01, 6.80683651e-01, + 5.01733152e-02, 8.14635389e-01, 2.27926375e-01, + 9.03269815e-01, 8.68443745e-01, 9.86939190e-01, + 7.40779486e-01, 2.61005311e-01, 3.19276232e-01, + 9.69509248e-01, 1.11908818e-01, 4.49198556e-01, + 1.27056715e-01, 3.84064823e-01, 5.14591811e-01, + 2.10747488e-01, 9.53884090e-01, 8.43167950e-01, + 4.51187972e-01, 3.75331782e-01, 6.23566461e-01, + 3.55290379e-01, 2.95705968e-01, 1.69622690e-01, + 1.42981830e-01, 2.72180991e-01, 9.46468040e-01, + 3.70932500e-01, 9.94292830e-01, 4.62587505e-01, + 7.14817405e-01, 2.45370540e-02, 3.00906377e-01, + 5.75768304e-01, 9.71448393e-01, 6.95574827e-02, + 3.93693854e-01, 5.29306116e-01, 5.04694554e-01, + 6.73797120e-02, 6.76596969e-01, 5.50948898e-01, + 3.24909641e-01, 7.70337719e-01, 6.51842631e-03, + 3.03264879e-01, 7.61037886e-03, 2.72289601e-01, + 1.50502041e-01, 6.71103888e-02, 7.41503703e-01, + 1.92088941e-01, 2.19043977e-01, 9.09320161e-01, + 2.37993569e-01, 6.18107973e-02, 8.31447852e-01, + 2.23355609e-01, 1.84789435e-01, 4.16104518e-01, + 4.21573859e-01, 8.72446305e-02, 2.97294197e-01, + 4.50328256e-01, 8.72199917e-01, 2.51279916e-01, + 4.86219272e-01, 7.57071329e-01, 4.85655942e-01, + 1.06187277e-01, 4.92341327e-01, 1.46017513e-01, + 5.25421017e-01, 4.22637906e-01, 2.24685018e-01, + 8.72648431e-01, 5.54051490e-01, 1.80745062e-01, + 2.12756336e-01, 5.20883169e-01, 7.60363654e-01, + 8.30254678e-01, 5.00003328e-01, 4.69017439e-01, + 6.38105527e-01, 3.50638261e-02, 5.22217353e-02, + 9.06516882e-02, 8.52975842e-01, 1.19985883e-01, + 3.74926753e-01, 6.50302066e-01, 1.98875727e-01, + 6.28362507e-02, 4.32693501e-01, 3.10500685e-01, + 6.20732833e-01, 4.58503272e-01, 3.20790034e-01, + 7.91284868e-01, 7.93054570e-01, 2.93406765e-01, + 8.95399023e-01, 1.06441034e-01, 7.53085241e-02, + 8.67523104e-01, 1.47963482e-01, 1.25584706e-01, + 3.81545040e-02, 6.34338619e-01, 1.76368938e-02, + 5.75553531e-02, 5.31607516e-01, 2.63869588e-01, + 9.41945823e-01, 9.24028838e-02, 5.21496463e-01, + 7.74866558e-01, 5.65210610e-01, 7.28015327e-02, + 6.51963790e-01, 8.94727453e-01, 4.49571590e-01, + 1.29932405e-01, 8.64026259e-01, 9.92599934e-01, + 7.43721560e-01, 8.87300215e-01, 1.06369925e-01, + 8.11335531e-01, 7.87734900e-01, 9.87344678e-01, + 5.32502820e-01, 4.42612382e-01, 9.64041183e-01, + 1.66085871e-01, 1.12937664e-01, 5.24423470e-01, + 6.54689333e-01, 4.59119726e-01, 5.22774091e-01, + 3.08722276e-02, 6.26979315e-01, 4.49754105e-01, + 8.07495757e-01, 2.34199499e-01, 1.67765675e-01, + 9.22168418e-01, 3.73210378e-01, 8.04432575e-01, + 5.61890354e-01, 4.47025593e-01, 6.43155678e-01, + 2.40407640e-01, 5.91631279e-01, 1.59369206e-01, + 7.75799090e-01, 8.32067212e-01, 5.59791576e-02, + 6.39105224e-01, 4.85274738e-01, 2.12630838e-01, + 2.81431312e-02, 7.16205363e-01, 6.83885011e-01, + 5.23869697e-01, 9.99418314e-01, 8.35331599e-01, + 4.69877463e-02, 6.74712562e-01, 7.99273684e-01, + 2.77001890e-02, 5.75809742e-01, 2.78513031e-01, + 8.36209905e-01, 7.25472379e-01, 4.87173943e-01, + 7.88311357e-01, 9.64676177e-01, 1.75752651e-01, + 4.98112580e-01, 8.08850418e-02, 6.40981131e-01, + 4.06647450e-01, 8.46539387e-01, 2.12620694e-01, + 9.11012851e-01, 8.25041445e-01, 8.90065575e-01, + 9.63626055e-01, 5.96689242e-01, 1.63372670e-01, + 4.51640148e-01, 3.43026542e-01, 5.80658851e-01, + 2.82327625e-01, 4.75535418e-01, 6.27760926e-01, + 8.46314115e-01, 9.61961932e-01, 3.19806094e-01, + 5.05508062e-01, 5.28102944e-01, 6.13045057e-01, + 7.44714938e-01, 1.50586073e-01, 7.91878033e-01, + 4.89839179e-01, 3.10496849e-01, 8.82309038e-01, + 2.86922314e-01, 4.84687559e-01, 5.20838630e-01, + 4.62955493e-01, 2.38185305e-01, 5.47259907e-02, + 7.10916137e-01, 7.31887202e-01, 6.25602317e-01, + 8.77741168e-01, 4.19881322e-01, 4.81222328e-01, + 1.28224501e-01, 2.46034010e-01, 3.34971854e-01, + 7.37216484e-01, 5.62134821e-02, 7.14089724e-01, + 9.85549393e-01, 4.66295827e-01, 3.08722434e-03, + 4.70237690e-01, 2.66524167e-01, 7.93875484e-01, + 4.54795911e-02, 8.09702944e-01, 1.47709735e-02, + 1.70082405e-01, 6.35905179e-01, 3.75379109e-01, + 4.30315011e-01, 3.15788760e-01, 5.58065230e-01, + 2.24643800e-01, 2.42142981e-01, 6.57283636e-01, + 3.34921891e-01, 1.26588975e-01, 7.68064155e-01, + 9.43856291e-01, 4.47518596e-01, 5.44453573e-01, + 9.95764932e-01, 7.16444391e-01, 8.51019765e-01, + 1.01179183e-01, 4.45473958e-01, 4.60327322e-01, + 4.96895844e-02, 4.72907738e-01, 5.58987444e-01, + 3.41027487e-01, 1.56175026e-01, 7.58283148e-01, + 6.83600909e-01, 2.14623396e-01, 3.27348880e-01, + 3.92517893e-01, 6.70418431e-01, 5.16440832e-01, + 8.63140348e-01, 5.73277464e-01, 3.46608058e-01, + 7.39396341e-01, 7.20852434e-01, 2.35653246e-02, + 3.89935659e-01, 7.53783745e-01, 6.34563528e-01, + 8.79339335e-01, 7.41599159e-02, 5.62433904e-01, + 6.15553852e-01, 4.56956324e-01, 5.20047447e-01, + 5.26845015e-02, 5.58471266e-01, 1.63632233e-01, + 5.38936665e-02, 6.49593683e-01, 2.56838748e-01, + 8.99035326e-01, 7.20847756e-01, 5.68954684e-01, + 7.43684755e-01, 5.70924238e-01, 3.82318724e-01, + 4.89328290e-01, 5.62208561e-01, 4.97540804e-02, + 4.18011085e-01, 6.88041565e-01, 2.16234653e-01, + 7.89548214e-01, 8.46136387e-01, 8.46816189e-01, + 1.73842353e-01, 6.11627842e-02, 8.44440559e-01, + 4.50646654e-01, 3.74785037e-01, 4.87196697e-01, + 4.56276448e-01, 9.13284391e-01, 4.15715464e-01, + 7.13597697e-01, 1.23641270e-02, 5.10031271e-01, + 4.74601930e-02, 2.55731159e-01, 3.22090006e-01, + 1.91165703e-01, 4.51170940e-01, 7.50843157e-01, + 4.42420576e-01, 4.25380660e-01, 4.50667257e-01, + 6.55689206e-01, 9.68257670e-02, 1.96528793e-01, + 8.97343028e-01, 4.99940904e-01, 6.65504083e-01, + 9.41828079e-01, 4.54397338e-01, 5.61893331e-01, + 5.09839880e-01, 4.53117514e-01, 8.96804127e-02, + 1.74888861e-01, 6.65641378e-01, 2.81668336e-01, + 1.89532742e-01, 5.61668382e-01, 8.68330157e-02, + 8.25092797e-01, 5.18106324e-01, 1.71904024e-01, + 3.68385523e-01, 1.62005436e-01, 7.48507399e-01, + 9.30274827e-01, 2.38198517e-01, 9.52222901e-01, + 5.23587800e-01, 6.94384557e-01, 1.09338652e-01, + 4.83356794e-01, 2.73050402e-01, 3.68027050e-01, + 5.92366466e-01, 1.83192289e-01, 8.60376029e-01, + 7.13926203e-01, 8.16750052e-01, 1.57890291e-01, + 6.25691951e-01, 5.24831646e-01, 1.73873797e-01, + 1.02429784e-01, 9.17488471e-01, 4.03584434e-01, + 9.31170884e-01, 2.79386137e-01, 8.77745206e-01, + 2.45200576e-01, 1.28896951e-01, 3.15713052e-01, + 5.27874291e-01, 2.16444335e-01, 7.03883817e-01, + 7.74738919e-02, 8.42422142e-01, 3.75598924e-01, + 3.51002411e-01, 6.22752776e-01, 4.82407943e-01, + 7.43107867e-01, 9.46182666e-01, 9.44344819e-01, + 3.28124763e-01, 1.06147431e-01, 1.65102684e-01, + 3.84060507e-01, 2.91057722e-01, 7.68173662e-02, + 1.03543651e-01, 6.76698940e-01, 1.43141994e-01, + 7.21342202e-01, 6.69471294e-03, 9.07298311e-01, + 5.57080171e-01, 8.10954489e-01, 4.11120526e-01, + 2.06407453e-01, 2.59590556e-01, 7.58512718e-01, + 5.79873897e-01, 2.92875650e-01, 2.83686529e-01, + 2.42829343e-01, 9.19323719e-01, 3.46832864e-01, + 3.58238858e-01, 7.42827585e-01, 2.05760059e-01, + 9.58438860e-01, 5.66326411e-01, 6.60292846e-01, + 5.61095078e-02, 6.79465531e-01, 7.05118513e-01, + 4.44713264e-01, 2.09732933e-01, 5.22732436e-01, + 1.74396512e-01, 5.29356748e-01, 4.38475687e-01, + 4.94036404e-01, 4.09785794e-01, 6.40025507e-01, + 5.79371821e-01, 1.57726118e-01, 6.04572263e-01, + 5.41072639e-01, 5.18847173e-01, 1.97093284e-01, + 8.91767002e-01, 4.29050835e-01, 8.25490570e-01, + 3.87699807e-01, 4.50705808e-01, 2.49371643e-01, + 3.36074898e-01, 9.29925118e-01, 6.65393649e-01, + 9.07275994e-01, 3.73075859e-01, 4.14044139e-03, + 2.37463702e-01, 2.25893784e-01, 2.46900245e-01, + 4.50350196e-01, 3.48618117e-01, 5.07193932e-01, + 5.23435142e-01, 8.13611417e-01, 8.92715622e-01, + 1.02623450e-01, 3.06088345e-01, 7.80461650e-01, + 2.21453645e-01, 2.01419652e-01, 2.84254457e-01, + 3.68286735e-01, 7.39358243e-01, 8.97879394e-01, + 9.81599566e-01, 7.56526442e-01, 7.37645545e-01, + 4.23976657e-02, 8.25922012e-01, 2.60956996e-01, + 2.90702065e-01, 8.98388344e-01, 3.03733299e-01, + 8.49071471e-01, 3.45835425e-01, 7.65458276e-01, + 5.68094872e-01, 8.93770930e-01, 9.93161641e-01, + 5.63368667e-02, 4.26548945e-01, 5.46745780e-01, + 5.75674571e-01, 7.94599487e-01, 7.18935553e-02, + 4.46492976e-01, 6.40240123e-01, 2.73246969e-01, + 2.00465968e-01, 1.30718835e-01, 1.92492005e-01, + 1.96617189e-01, 6.61271644e-01, 8.12687657e-01, + 8.66342445e-01 + + }, + {0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 4, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + 10, + -4}}; + +typedef ConnectComponentsTest ConnectComponentsTestF_Int; +TEST_P(ConnectComponentsTestF_Int, Result) { + /** + * Verify the src & dst vertices on each edge have different colors + */ + EXPECT_TRUE(final_edges == params.n_row - 1); +} + +INSTANTIATE_TEST_CASE_P(ConnectComponentsTest, ConnectComponentsTestF_Int, + ::testing::ValuesIn(fix_conn_inputsf2)); +}; // namespace sparse +}; // end namespace raft diff --git a/cpp/test/sparse/knn_graph.cu b/cpp/test/sparse/knn_graph.cu new file mode 100644 index 0000000000..ec41b32374 --- /dev/null +++ b/cpp/test/sparse/knn_graph.cu @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "../test_utils.h" + +#include +#include + +#include + +namespace raft { +namespace sparse { + +template +__global__ void assert_symmetry(value_idx *rows, value_idx *cols, value_t *vals, + value_idx nnz, value_idx *sum) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + if (tid >= nnz) return; + + atomicAdd(sum, rows[tid]); + atomicAdd(sum, -1 * cols[tid]); +} + +template +struct KNNGraphInputs { + value_idx m; + value_idx n; + + std::vector X; + + int k = 2; +}; + +template +::std::ostream &operator<<(::std::ostream &os, + const KNNGraphInputs &dims) { + return os; +} + +template +class KNNGraphTest + : public ::testing::TestWithParam> { + void SetUp() override { + params = + ::testing::TestWithParam>::GetParam(); + + raft::handle_t handle; + + auto alloc = handle.get_device_allocator(); + stream = handle.get_stream(); + + out = new raft::sparse::COO(alloc, stream); + + allocate(X, params.X.size()); + + update_device(X, params.X.data(), params.X.size(), stream); + + raft::sparse::selection::knn_graph( + handle, X, params.m, params.n, raft::distance::DistanceType::L2Unexpanded, + *out); + + rmm::device_uvector sum(1, stream); + + CUDA_CHECK(cudaMemsetAsync(sum.data(), 0, 1 * sizeof(value_idx), stream)); + + /** + * Assert the knn graph is symmetric + */ + assert_symmetry<<nnz, 256), 256, 0, stream>>>( + out->rows(), out->cols(), out->vals(), out->nnz, sum.data()); + + raft::update_host(&sum_h, sum.data(), 1, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + + void TearDown() override { + CUDA_CHECK(cudaFree(X)); + + delete out; + } + + protected: + cudaStream_t stream; + + // input data + raft::sparse::COO *out; + + value_t *X; + + value_idx sum_h; + + KNNGraphInputs params; +}; + +const std::vector> knn_graph_inputs_fint = { + // Test n_clusters == n_points + {4, 2, {0, 100, 0.01, 0.02, 5000, 10000, -5, -2}, 2}}; + +typedef KNNGraphTest KNNGraphTestF_int; +TEST_P(KNNGraphTestF_int, Result) { + // nnz should not be larger than twice m * k + ASSERT_TRUE(out->nnz <= (params.m * params.k * 2)); + ASSERT_TRUE(sum_h == 0); +} + +INSTANTIATE_TEST_CASE_P(KNNGraphTest, KNNGraphTestF_int, + ::testing::ValuesIn(knn_graph_inputs_fint)); + +} // namespace sparse +} // namespace raft diff --git a/cpp/test/sparse/linkage.cu b/cpp/test/sparse/linkage.cu new file mode 100644 index 0000000000..fd1aca92ca --- /dev/null +++ b/cpp/test/sparse/linkage.cu @@ -0,0 +1,501 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../test_utils.h" + +namespace raft { + +using namespace std; + +template +struct LinkageInputs { + IdxT n_row; + IdxT n_col; + + std::vector data; + + std::vector expected_labels; + + int n_clusters; + + int c; +}; + +template +::std::ostream &operator<<(::std::ostream &os, + const LinkageInputs &dims) { + return os; +} + +template +class LinkageTest : public ::testing::TestWithParam> { + protected: + void basicTest() { + raft::handle_t handle; + + params = ::testing::TestWithParam>::GetParam(); + + rmm::device_uvector data(params.n_row * params.n_col, + handle.get_stream()); + + // Allocate result labels and expected labels on device + raft::allocate(labels, params.n_row); + raft::allocate(labels_ref, params.n_row); + + raft::copy(data.data(), params.data.data(), data.size(), + handle.get_stream()); + raft::copy(labels_ref, params.expected_labels.data(), params.n_row, + handle.get_stream()); + + raft::hierarchy::linkage_output out_arrs; + out_arrs.labels = labels; + + rmm::device_uvector out_children(params.n_row * 2, + handle.get_stream()); + + out_arrs.children = out_children.data(); + + raft::hierarchy::single_linkage< + IdxT, T, raft::hierarchy::LinkageDistance::KNN_GRAPH>( + handle, data.data(), params.n_row, params.n_col, + raft::distance::DistanceType::L2Unexpanded, &out_arrs, params.c, + params.n_clusters); + + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); + } + + void SetUp() override { basicTest(); } + + void TearDown() override { + CUDA_CHECK(cudaFree(labels)); + CUDA_CHECK(cudaFree(labels_ref)); + } + + protected: + LinkageInputs params; + IdxT *labels, *labels_ref; + + double score; +}; + +const std::vector> linkage_inputsf2 = { + // Test n_clusters == n_points + {10, + 5, + {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, + 0.77782677, 0.43772379, 0.4035871, 0.3282796, 0.47544681, 0.59862974, + 0.12319357, 0.06239463, 0.28200272, 0.1345717, 0.50498218, 0.5113505, + 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, + 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, + 0.84854131, 0.28890216, 0.85267903, 0.74703138, 0.83842071, 0.34942792, + 0.27864171, 0.70911132, 0.21338564, 0.32035554, 0.73788331, 0.46926692, + 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, + 0.76166195, 0.66613745}, + {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + 10, + -1}, + // // Test outlier points + {9, + 2, + {-1, -50, 3, 4, 5000, 10000, 1, 3, 4, 5, 0.000005, 0.00002, 2000000, 500000, + 10, 50, 30, 5}, + {6, 0, 5, 0, 0, 4, 3, 2, 1}, + 7, + -1}, + + // Test n_clusters == (n_points / 2) + {10, + 5, + {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, + 0.77782677, 0.43772379, 0.4035871, 0.3282796, 0.47544681, 0.59862974, + 0.12319357, 0.06239463, 0.28200272, 0.1345717, 0.50498218, 0.5113505, + 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, + 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, + 0.84854131, 0.28890216, 0.85267903, 0.74703138, 0.83842071, 0.34942792, + 0.27864171, 0.70911132, 0.21338564, 0.32035554, 0.73788331, 0.46926692, + 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, + 0.76166195, 0.66613745}, + {1, 0, 4, 0, 0, 3, 2, 0, 2, 1}, + 5, + -1}, + + // Test n_points == 100 + {100, + 10, + {6.26168372e-01, 9.30437651e-01, 6.02450208e-01, + 2.73025296e-01, 9.53050619e-01, 3.32164396e-01, + 6.88942598e-01, 5.79163537e-01, 6.70341547e-01, + 2.70140602e-02, 9.30429671e-01, 7.17721157e-01, + 9.89948537e-01, 7.75253347e-01, 1.34491522e-02, + 2.48522428e-02, 3.51413378e-01, 7.64405834e-01, + 7.86373507e-01, 7.18748577e-01, 8.66998621e-01, + 6.80316582e-01, 2.51288712e-01, 4.91078420e-01, + 3.76246281e-01, 4.86828710e-01, 5.67464772e-01, + 5.30734742e-01, 8.99478296e-01, 7.66699088e-01, + 9.49339111e-01, 3.55248484e-01, 9.06046929e-01, + 4.48407772e-01, 6.96395305e-01, 2.44277335e-01, + 7.74840000e-01, 5.21046603e-01, 4.66423971e-02, + 5.12019638e-02, 8.95019614e-01, 5.28956953e-01, + 4.31536306e-01, 5.83857744e-01, 4.41787364e-01, + 4.68656523e-01, 5.73971433e-01, 6.79989654e-01, + 3.19650588e-01, 6.12579596e-01, 6.49126442e-02, + 8.39131142e-01, 2.85252117e-01, 5.84848929e-01, + 9.46507115e-01, 8.58440748e-01, 3.61528940e-01, + 2.44215959e-01, 3.80101125e-01, 4.57128957e-02, + 8.82216988e-01, 8.31498633e-01, 7.23474381e-01, + 7.75788607e-01, 1.40864146e-01, 6.62092382e-01, + 5.13985168e-01, 3.00686418e-01, 8.70109949e-01, + 2.43187753e-01, 2.89391938e-01, 2.84214238e-01, + 8.70985521e-01, 8.77491176e-01, 6.72537226e-01, + 3.30929686e-01, 1.85934324e-01, 9.16222614e-01, + 6.18239142e-01, 2.64768597e-01, 5.76145451e-01, + 8.62961369e-01, 6.84757925e-01, 7.60549082e-01, + 1.27645356e-01, 4.51004673e-01, 3.92292980e-01, + 4.63170803e-01, 4.35449330e-02, 2.17583404e-01, + 5.71832605e-02, 2.06763039e-01, 3.70116249e-01, + 2.09750028e-01, 6.17283019e-01, 8.62549231e-01, + 9.84156240e-02, 2.66249156e-01, 3.87635103e-01, + 2.85591012e-02, 4.24826068e-01, 4.45795088e-01, + 6.86227676e-01, 1.08848960e-01, 5.96731841e-02, + 3.71770228e-01, 1.91548833e-01, 6.95136078e-01, + 9.00700636e-01, 8.76363105e-01, 2.67334632e-01, + 1.80619709e-01, 7.94060419e-01, 1.42854171e-02, + 1.09372387e-01, 8.74028108e-01, 6.46403232e-01, + 4.86588834e-01, 5.93446175e-02, 6.11886291e-01, + 8.83865057e-01, 3.15879821e-01, 2.27043992e-01, + 9.76764951e-01, 6.15620336e-01, 9.76199360e-01, + 2.40548962e-01, 3.21795663e-01, 8.75087904e-02, + 8.11234663e-01, 6.96070480e-01, 8.12062321e-01, + 1.21958818e-01, 3.44348628e-02, 8.72630414e-01, + 3.06162776e-01, 1.76043529e-02, 9.45894971e-01, + 5.33896401e-01, 6.21642973e-01, 4.93062535e-01, + 4.48984262e-01, 2.24560379e-01, 4.24052195e-02, + 4.43447610e-01, 8.95646149e-01, 6.05220676e-01, + 1.81840491e-01, 9.70831206e-01, 2.12563586e-02, + 6.92582693e-01, 7.55946922e-01, 7.95086143e-01, + 6.05328941e-01, 3.99350764e-01, 4.32846636e-01, + 9.81114529e-01, 4.98266428e-01, 6.37127930e-03, + 1.59085889e-01, 6.34682067e-05, 5.59429440e-01, + 7.38827633e-01, 8.93214770e-01, 2.16494306e-01, + 9.35430573e-02, 4.75665868e-02, 7.80503518e-01, + 7.86240041e-01, 7.06854594e-01, 2.13725879e-02, + 7.68246091e-01, 4.50234808e-01, 5.21231104e-01, + 5.01989826e-03, 4.22081572e-02, 1.65337732e-01, + 8.54134740e-01, 4.99430262e-01, 8.94525601e-01, + 1.14028379e-01, 3.69739861e-01, 1.32955599e-01, + 2.65563824e-01, 2.52811151e-01, 1.44792843e-01, + 6.88449594e-01, 4.44921417e-01, 8.23296587e-01, + 1.93266317e-01, 1.19033309e-01, 1.36368966e-01, + 3.42600285e-01, 5.64505195e-01, 5.57594559e-01, + 7.44257892e-01, 8.38231569e-02, 4.11548847e-01, + 3.21010077e-01, 8.55081359e-01, 4.30105779e-01, + 1.16229135e-01, 9.87731964e-02, 3.14712335e-01, + 4.50880592e-01, 2.72289598e-01, 6.31615256e-01, + 8.97432958e-01, 4.44764250e-01, 8.03776440e-01, + 2.68767748e-02, 2.43374608e-01, 4.02141103e-01, + 4.98881209e-01, 5.33173003e-01, 8.82890436e-01, + 7.16149148e-01, 4.19664401e-01, 2.29335357e-01, + 2.88637806e-01, 3.44696803e-01, 6.78171906e-01, + 5.69849716e-01, 5.86454477e-01, 3.54474989e-01, + 9.03876540e-01, 6.45980000e-01, 6.34887593e-01, + 7.88039746e-02, 2.04814126e-01, 7.82251754e-01, + 2.43147074e-01, 7.50951808e-01, 1.72799092e-02, + 2.95349590e-01, 6.57991826e-01, 8.81214312e-01, + 5.73970708e-01, 2.77610881e-01, 1.82155097e-01, + 7.69797417e-02, 6.44792402e-01, 9.46950998e-01, + 7.73064845e-01, 6.04733624e-01, 5.80094567e-01, + 1.67498426e-01, 2.66514296e-01, 6.50140368e-01, + 1.91170299e-01, 2.08752199e-01, 3.01664091e-01, + 9.85033484e-01, 2.92909152e-01, 8.65816607e-01, + 1.85222119e-01, 2.28814559e-01, 1.34286382e-02, + 2.89234322e-01, 8.18668708e-01, 4.71706924e-01, + 9.23199803e-01, 2.80879188e-01, 1.47319284e-01, + 4.13915748e-01, 9.31274932e-02, 6.66322195e-01, + 9.66953974e-01, 3.19405786e-01, 6.69486551e-01, + 5.03096313e-02, 6.95225201e-01, 5.78469859e-01, + 6.29481655e-01, 1.39252534e-01, 1.22564968e-01, + 6.80663678e-01, 6.34607157e-01, 6.42765834e-01, + 1.57127410e-02, 2.92132086e-01, 5.24423878e-01, + 4.68676824e-01, 2.86003928e-01, 7.18608322e-01, + 8.95617933e-01, 5.48844309e-01, 1.74517278e-01, + 5.24379196e-01, 2.13526524e-01, 5.88375435e-01, + 9.88560185e-01, 4.17435771e-01, 6.14438688e-01, + 9.53760881e-01, 5.27151288e-01, 7.03017278e-01, + 3.44448559e-01, 4.47059676e-01, 2.83414901e-01, + 1.98979011e-01, 4.24917361e-01, 5.73172761e-01, + 2.32398853e-02, 1.65887230e-01, 4.05552785e-01, + 9.29665524e-01, 2.26135696e-01, 9.20563384e-01, + 7.65259963e-01, 4.54820075e-01, 8.97710267e-01, + 3.78559302e-03, 9.15219382e-01, 3.55705698e-01, + 6.94905124e-01, 8.58540202e-01, 3.89790666e-01, + 2.49478206e-01, 7.93679304e-01, 4.75830027e-01, + 4.40425353e-01, 3.70579459e-01, 1.40578049e-01, + 1.70386675e-01, 7.04056121e-01, 4.85963102e-01, + 9.68450060e-01, 6.77178001e-01, 2.65934654e-01, + 2.58915007e-01, 6.70052890e-01, 2.61945109e-01, + 8.46207759e-01, 1.01928951e-01, 2.85611334e-01, + 2.45776933e-01, 2.66658783e-01, 3.71724077e-01, + 4.34319025e-01, 4.24407347e-01, 7.15417683e-01, + 8.07997684e-01, 1.64296275e-01, 6.01638065e-01, + 8.60606804e-02, 2.68719187e-01, 5.11764101e-01, + 9.75844338e-01, 7.81226782e-01, 2.20925515e-01, + 7.18135040e-01, 9.82395577e-01, 8.39160243e-01, + 9.08058083e-01, 6.88010677e-01, 8.14271847e-01, + 5.12460821e-01, 1.17311345e-01, 5.96075228e-01, + 9.17455497e-01, 2.12052706e-01, 7.04074603e-01, + 8.72872565e-02, 8.76047818e-01, 6.96235046e-01, + 8.54801557e-01, 2.49729159e-01, 9.76594604e-01, + 2.87386363e-01, 2.36461559e-02, 9.94075254e-01, + 4.25193986e-01, 7.61869994e-01, 5.13334255e-01, + 6.44711165e-02, 8.92156689e-01, 3.55235167e-01, + 1.08154647e-01, 8.78446825e-01, 2.43833016e-01, + 9.23071293e-01, 2.72724115e-01, 9.46631338e-01, + 3.74510294e-01, 4.08451278e-02, 9.78392777e-01, + 3.65079221e-01, 6.37199516e-01, 5.51144906e-01, + 5.25978080e-01, 1.42803678e-01, 4.05451674e-01, + 7.79788219e-01, 6.26009784e-01, 3.35249497e-01, + 1.43159543e-02, 1.80363779e-01, 5.05096904e-01, + 2.82619947e-01, 5.83561392e-01, 3.10951324e-01, + 8.73223968e-01, 4.38545619e-01, 4.81348800e-01, + 6.68497085e-01, 3.79345401e-01, 9.58832501e-01, + 1.89869550e-01, 2.34083070e-01, 2.94066207e-01, + 5.74892667e-02, 6.92106828e-02, 9.61127686e-02, + 6.72650672e-02, 8.47345378e-01, 2.80916761e-01, + 7.32177357e-03, 9.80785961e-01, 5.73192225e-02, + 8.48781331e-01, 8.83225408e-01, 7.34398275e-01, + 7.70381941e-01, 6.20778343e-01, 8.96822048e-01, + 5.40732486e-01, 3.69704071e-01, 5.77305837e-01, + 2.08221827e-01, 7.34275341e-01, 1.06110900e-01, + 3.49496706e-01, 8.34948910e-01, 1.56403291e-02, + 6.78576376e-01, 8.96141268e-01, 5.94835119e-01, + 1.43943153e-01, 3.49618530e-01, 2.10440392e-01, + 3.46585620e-01, 1.05153093e-01, 3.45446174e-01, + 2.72177079e-01, 7.07946300e-01, 4.33717726e-02, + 3.31232203e-01, 3.91874320e-01, 4.76338141e-01, + 6.22777789e-01, 2.95989228e-02, 4.32855769e-01, + 7.61049310e-01, 3.63279149e-01, 9.47210350e-01, + 6.43721247e-01, 6.58025802e-01, 1.05247633e-02, + 5.29974442e-01, 7.30675767e-01, 4.30041079e-01, + 6.62634841e-01, 8.25936616e-01, 9.91253704e-01, + 6.79399281e-01, 5.44177006e-01, 7.52876048e-01, + 3.32139049e-01, 7.98732398e-01, 7.38865223e-01, + 9.16055132e-01, 6.11736493e-01, 9.63672879e-01, + 1.83778839e-01, 7.27558919e-02, 5.91602822e-01, + 3.25235484e-01, 2.34741217e-01, 9.52346277e-01, + 9.18556407e-01, 9.35373324e-01, 6.89209070e-01, + 2.56049054e-01, 6.17975395e-01, 7.82285691e-01, + 9.84983432e-01, 6.62322741e-01, 2.04144457e-01, + 3.98446577e-01, 1.38918297e-01, 3.05919921e-01, + 3.14043787e-01, 5.91072666e-01, 7.44703771e-01, + 8.92272567e-01, 9.78017873e-01, 9.01203161e-01, + 1.41526372e-01, 4.14878484e-01, 6.80683651e-01, + 5.01733152e-02, 8.14635389e-01, 2.27926375e-01, + 9.03269815e-01, 8.68443745e-01, 9.86939190e-01, + 7.40779486e-01, 2.61005311e-01, 3.19276232e-01, + 9.69509248e-01, 1.11908818e-01, 4.49198556e-01, + 1.27056715e-01, 3.84064823e-01, 5.14591811e-01, + 2.10747488e-01, 9.53884090e-01, 8.43167950e-01, + 4.51187972e-01, 3.75331782e-01, 6.23566461e-01, + 3.55290379e-01, 2.95705968e-01, 1.69622690e-01, + 1.42981830e-01, 2.72180991e-01, 9.46468040e-01, + 3.70932500e-01, 9.94292830e-01, 4.62587505e-01, + 7.14817405e-01, 2.45370540e-02, 3.00906377e-01, + 5.75768304e-01, 9.71448393e-01, 6.95574827e-02, + 3.93693854e-01, 5.29306116e-01, 5.04694554e-01, + 6.73797120e-02, 6.76596969e-01, 5.50948898e-01, + 3.24909641e-01, 7.70337719e-01, 6.51842631e-03, + 3.03264879e-01, 7.61037886e-03, 2.72289601e-01, + 1.50502041e-01, 6.71103888e-02, 7.41503703e-01, + 1.92088941e-01, 2.19043977e-01, 9.09320161e-01, + 2.37993569e-01, 6.18107973e-02, 8.31447852e-01, + 2.23355609e-01, 1.84789435e-01, 4.16104518e-01, + 4.21573859e-01, 8.72446305e-02, 2.97294197e-01, + 4.50328256e-01, 8.72199917e-01, 2.51279916e-01, + 4.86219272e-01, 7.57071329e-01, 4.85655942e-01, + 1.06187277e-01, 4.92341327e-01, 1.46017513e-01, + 5.25421017e-01, 4.22637906e-01, 2.24685018e-01, + 8.72648431e-01, 5.54051490e-01, 1.80745062e-01, + 2.12756336e-01, 5.20883169e-01, 7.60363654e-01, + 8.30254678e-01, 5.00003328e-01, 4.69017439e-01, + 6.38105527e-01, 3.50638261e-02, 5.22217353e-02, + 9.06516882e-02, 8.52975842e-01, 1.19985883e-01, + 3.74926753e-01, 6.50302066e-01, 1.98875727e-01, + 6.28362507e-02, 4.32693501e-01, 3.10500685e-01, + 6.20732833e-01, 4.58503272e-01, 3.20790034e-01, + 7.91284868e-01, 7.93054570e-01, 2.93406765e-01, + 8.95399023e-01, 1.06441034e-01, 7.53085241e-02, + 8.67523104e-01, 1.47963482e-01, 1.25584706e-01, + 3.81545040e-02, 6.34338619e-01, 1.76368938e-02, + 5.75553531e-02, 5.31607516e-01, 2.63869588e-01, + 9.41945823e-01, 9.24028838e-02, 5.21496463e-01, + 7.74866558e-01, 5.65210610e-01, 7.28015327e-02, + 6.51963790e-01, 8.94727453e-01, 4.49571590e-01, + 1.29932405e-01, 8.64026259e-01, 9.92599934e-01, + 7.43721560e-01, 8.87300215e-01, 1.06369925e-01, + 8.11335531e-01, 7.87734900e-01, 9.87344678e-01, + 5.32502820e-01, 4.42612382e-01, 9.64041183e-01, + 1.66085871e-01, 1.12937664e-01, 5.24423470e-01, + 6.54689333e-01, 4.59119726e-01, 5.22774091e-01, + 3.08722276e-02, 6.26979315e-01, 4.49754105e-01, + 8.07495757e-01, 2.34199499e-01, 1.67765675e-01, + 9.22168418e-01, 3.73210378e-01, 8.04432575e-01, + 5.61890354e-01, 4.47025593e-01, 6.43155678e-01, + 2.40407640e-01, 5.91631279e-01, 1.59369206e-01, + 7.75799090e-01, 8.32067212e-01, 5.59791576e-02, + 6.39105224e-01, 4.85274738e-01, 2.12630838e-01, + 2.81431312e-02, 7.16205363e-01, 6.83885011e-01, + 5.23869697e-01, 9.99418314e-01, 8.35331599e-01, + 4.69877463e-02, 6.74712562e-01, 7.99273684e-01, + 2.77001890e-02, 5.75809742e-01, 2.78513031e-01, + 8.36209905e-01, 7.25472379e-01, 4.87173943e-01, + 7.88311357e-01, 9.64676177e-01, 1.75752651e-01, + 4.98112580e-01, 8.08850418e-02, 6.40981131e-01, + 4.06647450e-01, 8.46539387e-01, 2.12620694e-01, + 9.11012851e-01, 8.25041445e-01, 8.90065575e-01, + 9.63626055e-01, 5.96689242e-01, 1.63372670e-01, + 4.51640148e-01, 3.43026542e-01, 5.80658851e-01, + 2.82327625e-01, 4.75535418e-01, 6.27760926e-01, + 8.46314115e-01, 9.61961932e-01, 3.19806094e-01, + 5.05508062e-01, 5.28102944e-01, 6.13045057e-01, + 7.44714938e-01, 1.50586073e-01, 7.91878033e-01, + 4.89839179e-01, 3.10496849e-01, 8.82309038e-01, + 2.86922314e-01, 4.84687559e-01, 5.20838630e-01, + 4.62955493e-01, 2.38185305e-01, 5.47259907e-02, + 7.10916137e-01, 7.31887202e-01, 6.25602317e-01, + 8.77741168e-01, 4.19881322e-01, 4.81222328e-01, + 1.28224501e-01, 2.46034010e-01, 3.34971854e-01, + 7.37216484e-01, 5.62134821e-02, 7.14089724e-01, + 9.85549393e-01, 4.66295827e-01, 3.08722434e-03, + 4.70237690e-01, 2.66524167e-01, 7.93875484e-01, + 4.54795911e-02, 8.09702944e-01, 1.47709735e-02, + 1.70082405e-01, 6.35905179e-01, 3.75379109e-01, + 4.30315011e-01, 3.15788760e-01, 5.58065230e-01, + 2.24643800e-01, 2.42142981e-01, 6.57283636e-01, + 3.34921891e-01, 1.26588975e-01, 7.68064155e-01, + 9.43856291e-01, 4.47518596e-01, 5.44453573e-01, + 9.95764932e-01, 7.16444391e-01, 8.51019765e-01, + 1.01179183e-01, 4.45473958e-01, 4.60327322e-01, + 4.96895844e-02, 4.72907738e-01, 5.58987444e-01, + 3.41027487e-01, 1.56175026e-01, 7.58283148e-01, + 6.83600909e-01, 2.14623396e-01, 3.27348880e-01, + 3.92517893e-01, 6.70418431e-01, 5.16440832e-01, + 8.63140348e-01, 5.73277464e-01, 3.46608058e-01, + 7.39396341e-01, 7.20852434e-01, 2.35653246e-02, + 3.89935659e-01, 7.53783745e-01, 6.34563528e-01, + 8.79339335e-01, 7.41599159e-02, 5.62433904e-01, + 6.15553852e-01, 4.56956324e-01, 5.20047447e-01, + 5.26845015e-02, 5.58471266e-01, 1.63632233e-01, + 5.38936665e-02, 6.49593683e-01, 2.56838748e-01, + 8.99035326e-01, 7.20847756e-01, 5.68954684e-01, + 7.43684755e-01, 5.70924238e-01, 3.82318724e-01, + 4.89328290e-01, 5.62208561e-01, 4.97540804e-02, + 4.18011085e-01, 6.88041565e-01, 2.16234653e-01, + 7.89548214e-01, 8.46136387e-01, 8.46816189e-01, + 1.73842353e-01, 6.11627842e-02, 8.44440559e-01, + 4.50646654e-01, 3.74785037e-01, 4.87196697e-01, + 4.56276448e-01, 9.13284391e-01, 4.15715464e-01, + 7.13597697e-01, 1.23641270e-02, 5.10031271e-01, + 4.74601930e-02, 2.55731159e-01, 3.22090006e-01, + 1.91165703e-01, 4.51170940e-01, 7.50843157e-01, + 4.42420576e-01, 4.25380660e-01, 4.50667257e-01, + 6.55689206e-01, 9.68257670e-02, 1.96528793e-01, + 8.97343028e-01, 4.99940904e-01, 6.65504083e-01, + 9.41828079e-01, 4.54397338e-01, 5.61893331e-01, + 5.09839880e-01, 4.53117514e-01, 8.96804127e-02, + 1.74888861e-01, 6.65641378e-01, 2.81668336e-01, + 1.89532742e-01, 5.61668382e-01, 8.68330157e-02, + 8.25092797e-01, 5.18106324e-01, 1.71904024e-01, + 3.68385523e-01, 1.62005436e-01, 7.48507399e-01, + 9.30274827e-01, 2.38198517e-01, 9.52222901e-01, + 5.23587800e-01, 6.94384557e-01, 1.09338652e-01, + 4.83356794e-01, 2.73050402e-01, 3.68027050e-01, + 5.92366466e-01, 1.83192289e-01, 8.60376029e-01, + 7.13926203e-01, 8.16750052e-01, 1.57890291e-01, + 6.25691951e-01, 5.24831646e-01, 1.73873797e-01, + 1.02429784e-01, 9.17488471e-01, 4.03584434e-01, + 9.31170884e-01, 2.79386137e-01, 8.77745206e-01, + 2.45200576e-01, 1.28896951e-01, 3.15713052e-01, + 5.27874291e-01, 2.16444335e-01, 7.03883817e-01, + 7.74738919e-02, 8.42422142e-01, 3.75598924e-01, + 3.51002411e-01, 6.22752776e-01, 4.82407943e-01, + 7.43107867e-01, 9.46182666e-01, 9.44344819e-01, + 3.28124763e-01, 1.06147431e-01, 1.65102684e-01, + 3.84060507e-01, 2.91057722e-01, 7.68173662e-02, + 1.03543651e-01, 6.76698940e-01, 1.43141994e-01, + 7.21342202e-01, 6.69471294e-03, 9.07298311e-01, + 5.57080171e-01, 8.10954489e-01, 4.11120526e-01, + 2.06407453e-01, 2.59590556e-01, 7.58512718e-01, + 5.79873897e-01, 2.92875650e-01, 2.83686529e-01, + 2.42829343e-01, 9.19323719e-01, 3.46832864e-01, + 3.58238858e-01, 7.42827585e-01, 2.05760059e-01, + 9.58438860e-01, 5.66326411e-01, 6.60292846e-01, + 5.61095078e-02, 6.79465531e-01, 7.05118513e-01, + 4.44713264e-01, 2.09732933e-01, 5.22732436e-01, + 1.74396512e-01, 5.29356748e-01, 4.38475687e-01, + 4.94036404e-01, 4.09785794e-01, 6.40025507e-01, + 5.79371821e-01, 1.57726118e-01, 6.04572263e-01, + 5.41072639e-01, 5.18847173e-01, 1.97093284e-01, + 8.91767002e-01, 4.29050835e-01, 8.25490570e-01, + 3.87699807e-01, 4.50705808e-01, 2.49371643e-01, + 3.36074898e-01, 9.29925118e-01, 6.65393649e-01, + 9.07275994e-01, 3.73075859e-01, 4.14044139e-03, + 2.37463702e-01, 2.25893784e-01, 2.46900245e-01, + 4.50350196e-01, 3.48618117e-01, 5.07193932e-01, + 5.23435142e-01, 8.13611417e-01, 8.92715622e-01, + 1.02623450e-01, 3.06088345e-01, 7.80461650e-01, + 2.21453645e-01, 2.01419652e-01, 2.84254457e-01, + 3.68286735e-01, 7.39358243e-01, 8.97879394e-01, + 9.81599566e-01, 7.56526442e-01, 7.37645545e-01, + 4.23976657e-02, 8.25922012e-01, 2.60956996e-01, + 2.90702065e-01, 8.98388344e-01, 3.03733299e-01, + 8.49071471e-01, 3.45835425e-01, 7.65458276e-01, + 5.68094872e-01, 8.93770930e-01, 9.93161641e-01, + 5.63368667e-02, 4.26548945e-01, 5.46745780e-01, + 5.75674571e-01, 7.94599487e-01, 7.18935553e-02, + 4.46492976e-01, 6.40240123e-01, 2.73246969e-01, + 2.00465968e-01, 1.30718835e-01, 1.92492005e-01, + 1.96617189e-01, 6.61271644e-01, 8.12687657e-01, + 8.66342445e-01 + + }, + {0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 4, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + 10, + -4}}; + +typedef LinkageTest LinkageTestF_Int; +TEST_P(LinkageTestF_Int, Result) { + EXPECT_TRUE( + raft::devArrMatch(labels, labels_ref, params.n_row, raft::Compare())); +} + +INSTANTIATE_TEST_CASE_P(LinkageTest, LinkageTestF_Int, + ::testing::ValuesIn(linkage_inputsf2)); +} // end namespace raft diff --git a/cpp/test/sparse/symmetrize.cu b/cpp/test/sparse/symmetrize.cu index 07dd9d11a2..d104028d2b 100644 --- a/cpp/test/sparse/symmetrize.cu +++ b/cpp/test/sparse/symmetrize.cu @@ -19,6 +19,7 @@ #include #include "../test_utils.h" +#include #include #include @@ -27,27 +28,122 @@ namespace raft { namespace sparse { +template +__global__ void assert_symmetry(value_idx *rows, value_idx *cols, value_t *vals, + value_idx nnz, value_idx *sum) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + if (tid >= nnz) return; + + atomicAdd(sum, rows[tid]); + atomicAdd(sum, -1 * cols[tid]); +} + +template +struct SparseSymmetrizeInputs { + value_idx n_cols; + + std::vector indptr_h; + std::vector indices_h; + std::vector data_h; +}; + +template +::std::ostream &operator<<( + ::std::ostream &os, const SparseSymmetrizeInputs &dims) { + return os; +} + +template +class SparseSymmetrizeTest : public ::testing::TestWithParam< + SparseSymmetrizeInputs> { + protected: + void make_data() { + std::vector indptr_h = params.indptr_h; + std::vector indices_h = params.indices_h; + std::vector data_h = params.data_h; + + allocate(indptr, indptr_h.size()); + allocate(indices, indices_h.size()); + allocate(data, data_h.size()); + + update_device(indptr, indptr_h.data(), indptr_h.size(), stream); + update_device(indices, indices_h.data(), indices_h.size(), stream); + update_device(data, data_h.data(), data_h.size(), stream); + } + + void SetUp() override { + params = ::testing::TestWithParam< + SparseSymmetrizeInputs>::GetParam(); + + raft::handle_t handle; + + auto alloc = handle.get_device_allocator(); + stream = handle.get_stream(); + + make_data(); + + value_idx m = params.indptr_h.size() - 1; + value_idx n = params.n_cols; + value_idx nnz = params.indices_h.size(); + + raft::mr::device::buffer coo_rows(alloc, stream, nnz); + + raft::sparse::convert::csr_to_coo(indptr, m, coo_rows.data(), nnz, stream); + + raft::sparse::COO out(alloc, stream); + + raft::sparse::linalg::symmetrize(handle, coo_rows.data(), indices, data, m, + n, coo_rows.size(), out); + + raft::mr::device::buffer sum(alloc, stream, 1); + + CUDA_CHECK(cudaMemsetAsync(sum.data(), 0, 1 * sizeof(value_idx), stream)); + + assert_symmetry<<>>( + out.rows(), out.cols(), out.vals(), out.nnz, sum.data()); + + raft::update_host(&sum_h, sum.data(), 1, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + + void TearDown() override { + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaFree(indptr)); + CUDA_CHECK(cudaFree(indices)); + CUDA_CHECK(cudaFree(data)); + } + + protected: + cudaStream_t stream; + + // input data + value_idx *indptr, *indices; + value_t *data; + + value_idx sum_h; + + SparseSymmetrizeInputs params; +}; + template -struct SparseSymmetrizeInput { +struct COOSymmetrizeInputs { int m, n, nnz; unsigned long long int seed; }; template -class SparseSymmetrizeTest - : public ::testing::TestWithParam> { +class COOSymmetrizeTest + : public ::testing::TestWithParam> { protected: void SetUp() override {} void TearDown() override {} - - protected: - SparseSymmetrizeInput params; }; -const std::vector> inputsf = {{5, 10, 5, 1234ULL}}; +const std::vector> inputsf = {{5, 10, 5, 1234ULL}}; -typedef SparseSymmetrizeTest COOSymmetrize; +typedef COOSymmetrizeTest COOSymmetrize; TEST_P(COOSymmetrize, Result) { cudaStream_t stream; cudaStreamCreate(&stream); @@ -104,8 +200,29 @@ TEST_P(COOSymmetrize, Result) { delete[] exp_vals_h; } -INSTANTIATE_TEST_CASE_P(SparseSymmetrizeTest, COOSymmetrize, +INSTANTIATE_TEST_CASE_P(COOSymmetrizeTest, COOSymmetrize, ::testing::ValuesIn(inputsf)); +const std::vector> symm_inputs_fint = { + // Test n_clusters == n_points + { + 2, + {0, 2, 4, 6, 8}, + {0, 1, 0, 1, 0, 1, 0, 1}, + {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, + }, + {2, + {0, 2, 4, 6, 8}, + {0, 1, 0, 1, 0, 1, 0, 1}, // indices + {1.0f, 3.0f, 1.0f, 5.0f, 50.0f, 28.0f, 16.0f, 2.0f}}, + +}; + +typedef SparseSymmetrizeTest SparseSymmetrizeTestF_int; +TEST_P(SparseSymmetrizeTestF_int, Result) { ASSERT_TRUE(sum_h == 0); } + +INSTANTIATE_TEST_CASE_P(SparseSymmetrizeTest, SparseSymmetrizeTestF_int, + ::testing::ValuesIn(symm_inputs_fint)); + } // namespace sparse } // namespace raft