Skip to content

Commit

Permalink
Fixing RAFT CI & a few small updates for SLHC Python wrapper (#178)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (@cjnolet)

Approvers:
  - Victor Lafargue (@viclafargue)
  - Alex Fender (@afender)

URL: #178
  • Loading branch information
cjnolet authored Mar 24, 2021
1 parent 7091ae3 commit d076399
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 107 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ if(NOT CUB_IS_PART_OF_CTK)
set(CUB_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub CACHE STRING "Path to cub repo")
ExternalProject_Add(cub
GIT_REPOSITORY https://github.com/thrust/cub.git
GIT_TAG 1.8.0
GIT_TAG 1.12.0
PREFIX ${CUB_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
7 changes: 4 additions & 3 deletions cpp/include/raft/sparse/hierarchy/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 };
* @tparam value_t
*/
template <typename value_idx, typename value_t>
struct linkage_output {
class linkage_output {
public:
value_idx m;
value_idx n_clusters;

Expand All @@ -41,8 +42,8 @@ struct linkage_output {
value_idx *children; // size: (m-1, 2)
};

struct linkage_output_int_float : public linkage_output<int, float> {};
struct linkage_output__int64_float : public linkage_output<int64_t, float> {};
class linkage_output_int_float : public linkage_output<int, float> {};
class linkage_output__int64_float : public linkage_output<int64_t, float> {};

}; // namespace hierarchy
}; // namespace raft
7 changes: 4 additions & 3 deletions cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ void build_dendrogram_host(const handle_t &handle, const value_idx *rows,
std::vector<value_idx> mst_dst_h(n_edges);
std::vector<value_t> mst_weights_h(n_edges);

printf("n_edges: %d\n", n_edges);

update_host(mst_src_h.data(), rows, n_edges, stream);
update_host(mst_dst_h.data(), cols, n_edges, stream);
update_host(mst_weights_h.data(), data, n_edges, stream);
Expand Down Expand Up @@ -132,6 +134,8 @@ void build_dendrogram_host(const handle_t &handle, const value_idx *rows,

out_size.resize(n_edges, stream);

printf("Finished dendrogram\n");

raft::update_device(children, children_h.data(), n_edges * 2, stream);
raft::update_device(out_size.data(), out_size_h.data(), n_edges, stream);
}
Expand Down Expand Up @@ -319,9 +323,6 @@ void extract_flattened_clusters(const raft::handle_t &handle, value_idx *labels,
raft::copy_async(label_roots.data(), children + children_cpy_start,
child_size, stream);

// thrust::device_ptr<value_idx> t_label_roots =
// thrust::device_pointer_cast(label_roots.data());
//
thrust::sort(thrust_policy, label_roots.data(),
label_roots.data() + (child_size), thrust::greater<value_idx>());

Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/sparse/hierarchy/detail/connectivities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ struct distance_graph_impl<raft::hierarchy::LinkageDistance::KNN_GRAPH,
raft::sparse::selection::knn_graph(handle, X, m, n, metric, knn_graph_coo,
c);

CUDA_CHECK(cudaDeviceSynchronize());

indices.resize(knn_graph_coo.nnz, stream);
data.resize(knn_graph_coo.nnz, stream);

Expand Down
20 changes: 15 additions & 5 deletions cpp/include/raft/sparse/hierarchy/detail/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ raft::Graph_COO<value_idx, value_idx, value_t> connect_knn_graph(
raft::copy_async(msf.weights.data() + msf.n_edges, connected_edges.vals(),
connected_edges.nnz, stream);

printf("connected components nnz: %d\n", final_nnz);
raft::sparse::COO<value_t, value_idx> final_coo(d_alloc, stream);
raft::sparse::linalg::symmetrize(handle, msf.src.data(), msf.dst.data(),
msf.weights.data(), m, n, final_nnz,
Expand Down Expand Up @@ -162,16 +163,25 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X,
handle, indptr, indices, pw_dists, (value_idx)m, nnz, color.data(), stream,
false);

if (linkage::get_n_components(color.data(), m, stream) > 1) {
int iters = 1;
int n_components = linkage::get_n_components(color.data(), m, stream);
while (n_components > 1 && iters < 100) {
printf("Found %d components. Going to try connecting graph\n",
n_components);
mst_coo = connect_knn_graph<value_idx, value_t>(handle, X, mst_coo, m, n,
color.data());

printf("Edges: %d\n", mst_coo.n_edges);
iters++;

RAFT_EXPECTS(
mst_coo.n_edges == m - 1,
"MST was not able to connect knn graph in a single iteration.");
n_components = linkage::get_n_components(color.data(), m, stream);
//
// printf("Connecting knn graph!\n");
//
// RAFT_EXPECTS(
// mst_coo.n_edges == m - 1,
// "MST was not able to connect knn graph in a single iteration.");
}
printf("Found %d components.\n", n_components);

sort_coo_by_data(mst_coo.src.data(), mst_coo.dst.data(),
mst_coo.weights.data(), mst_coo.n_edges, stream);
Expand Down
2 changes: 0 additions & 2 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
mst_result.dst.resize(mst_result.n_edges, stream);
mst_result.weights.resize(mst_result.n_edges, stream);

// raft::print_device_vector("Colors before sending: ", color_index, 7, std::cout);

return mst_result;
}

Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/sparse/op/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ void coo_sort(int m, int n, int nnz, int *rows, int *cols, T *vals,

CUSPARSE_CHECK(cusparseCreateIdentityPermutation(handle, nnz, d_P.data()));

printf("nnz: %d\n", nnz);

CUSPARSE_CHECK(cusparseXcoosortByRow(handle, m, n, nnz, rows, cols,
d_P.data(), pBuffer.data()));

Expand Down
Loading

0 comments on commit d076399

Please sign in to comment.