Skip to content

Commit

Permalink
Mdspanifying (currently tested) raft::matrix (#846)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Mark Hoemmen (https://github.com/mhoemmen)
  - Divye Gala (https://github.com/divyegala)

URL: #846
  • Loading branch information
cjnolet authored Oct 2, 2022
1 parent ae9e3b9 commit d265e58
Show file tree
Hide file tree
Showing 44 changed files with 1,559 additions and 210 deletions.
26 changes: 18 additions & 8 deletions BUILD.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- [Build Dependencies](#required_depenencies)
- [Header-only C++](#install_header_only_cpp)
- [C++ Shared Libraries](#shared_cpp_libs)
- [Improving Rebuild Times](#ccache)
- [Googletests](#gtests)
- [C++ Using Cmake](#cpp_using_cmake)
- [Python](#python)
Expand All @@ -29,7 +30,6 @@ In addition to the libraries included with cudatoolkit 11.0+, there are some oth
- [RMM](https://github.com/rapidsai/rmm) corresponding to RAFT version.

#### Optional
- [mdspan](https://github.com/rapidsai/mdspan) - On by default but can be disabled.
- [Thrust](https://github.com/NVIDIA/thrust) v1.15 / [CUB](https://github.com/NVIDIA/cub) - On by default but can be disabled.
- [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API.
- [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0
Expand All @@ -53,11 +53,6 @@ The following example will download the needed dependencies and install the RAFT
./build.sh libraft --install
```

The `--minimal-deps` flag can be used to install the headers with minimal dependencies:
```bash
./build.sh libraft --install --minimal-deps
```

### <a id="shared_cpp_libs"></a>C++ Shared Libraries (optional)

For larger projects which make heavy use of the pairwise distances or nearest neighbors APIs, shared libraries can be built to speed up compile times. These shared libraries can also significantly improve re-compile times both while developing RAFT and developing against the APIs. Build all of the available shared libraries by passing `--compile-libs` flag to `build.sh`:
Expand All @@ -72,6 +67,14 @@ Individual shared libraries have their own flags and multiple can be used (thoug

Add the `--install` flag to the above example to also install the shared libraries into `$INSTALL_PREFIX/lib`.

### <a id="ccache"></a>`ccache` and `sccache`

`ccache` and `sccache` can be used to better cache parts of the build when rebuilding frequently, such as when working on a new feature. You can also use `ccache` or `sccache` with `build.sh`:

```bash
./build.sh libraft --cache-tool=ccache
```

### <a id="gtests"></a>Tests

Compile the tests using the `tests` target in `build.sh`.
Expand All @@ -86,10 +89,17 @@ Test compile times can be improved significantly by using the optional shared li
./build.sh libraft tests --compile-libs
```

To run C++ tests:
The tests are broken apart by algorithm category, so you will find several binaries in `cpp/build/` named `*_TEST`.

For example, to run the distance tests:
```bash
./cpp/build/DISTANCE_TEST
```

It can take sometime to compile all of the tests. You can build individual tests by providing a semicolon-separated list to the `--limit-tests` option in `build.sh`:

```bash
./cpp/build/test_raft
./build.sh libraft tests --limit-tests=SPATIAL_TEST;DISTANCE_TEST;MATRIX_TEST
```

### <a id="benchmarks"></a>Benchmarks
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/single_linkage.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void single_linkage(const raft::handle_t& handle,
size_t m,
size_t n,
raft::distance::DistanceType metric,
linkage_output<value_idx, value_t>* out,
linkage_output<value_idx>* out,
int c,
size_t n_clusters)
{
Expand Down
47 changes: 46 additions & 1 deletion cpp/include/raft/cluster/single_linkage.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

#include <raft/cluster/detail/single_linkage.cuh>
#include <raft/cluster/single_linkage_types.hpp>
#include <raft/core/device_mdspan.hpp>

namespace raft::cluster {

constexpr int DEFAULT_CONST_C = 15;

/**
* Single-linkage clustering, capable of constructing a KNN graph to
* scale the algorithm beyond the n^2 memory consumption of implementations
Expand Down Expand Up @@ -48,11 +51,53 @@ void single_linkage(const raft::handle_t& handle,
size_t m,
size_t n,
raft::distance::DistanceType metric,
linkage_output<value_idx, value_t>* out,
linkage_output<value_idx>* out,
int c,
size_t n_clusters)
{
detail::single_linkage<value_idx, value_t, dist_type>(
handle, X, m, n, metric, out, c, n_clusters);
}

/**
* Single-linkage clustering, capable of constructing a KNN graph to
* scale the algorithm beyond the n^2 memory consumption of implementations
* that use the fully-connected graph of pairwise distances by connecting
* a knn graph when k is not large enough to connect it.
* @tparam value_idx
* @tparam value_t
* @tparam dist_type method to use for constructing connectivities graph
* @param[in] handle raft handle
* @param[in] X dense input matrix in row-major layout
* @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2)
* @param[out] labels output labels vector (size n_rows)
* @param[in] metric distance metrix to use when constructing connectivities graph
* @param[in] n_clusters number of clusters to assign data samples
* @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect
control of k. The algorithm will set `k = log(n) + c`
*/
template <typename value_t, typename idx_t, LinkageDistance dist_type = LinkageDistance::KNN_GRAPH>
void single_linkage(const raft::handle_t& handle,
raft::device_matrix_view<const value_t, idx_t, row_major> X,
raft::device_matrix_view<idx_t, idx_t, row_major> dendrogram,
raft::device_vector_view<idx_t, idx_t> labels,
raft::distance::DistanceType metric,
size_t n_clusters,
std::optional<int> c = std::make_optional<int>(DEFAULT_CONST_C))
{
linkage_output<idx_t> out_arrs;
out_arrs.children = dendrogram.data_handle();
out_arrs.labels = labels.data_handle();

single_linkage<idx_t, value_t, dist_type>(handle,
X.data_handle(),
static_cast<std::size_t>(X.extent(0)),
static_cast<std::size_t>(X.extent(1)),
metric,
&out_arrs,
c.has_value() ? c.value() : DEFAULT_CONST_C,
n_clusters);
}

}; // namespace raft::cluster
30 changes: 21 additions & 9 deletions cpp/include/raft/cluster/single_linkage_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#pragma once

#include <raft/core/device_mdspan.hpp>

namespace raft::cluster {

enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 };
Expand All @@ -27,23 +29,33 @@ enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 };
* @tparam value_idx
* @tparam value_t
*/
template <typename value_idx, typename value_t>
template <typename idx_t>
class linkage_output {
public:
value_idx m;
value_idx n_clusters;
idx_t m;
idx_t n_clusters;

idx_t n_leaves;
idx_t n_connected_components;

value_idx n_leaves;
value_idx n_connected_components;
// TODO: These will be made private in a future release
idx_t* labels; // size: m
idx_t* children; // size: (m-1, 2)

value_idx* labels; // size: m
raft::device_vector_view<idx_t> get_labels()
{
return raft::make_device_vector_view<idx_t>(labels, m);
}

value_idx* children; // size: (m-1, 2)
raft::device_matrix_view<idx_t> get_children()
{
return raft::make_device_matrix_view<idx_t>(children, m - 1, 2);
}
};

class linkage_output_int_float : public linkage_output<int, float> {
class linkage_output_int_float : public linkage_output<int> {
};
class linkage_output__int64_float : public linkage_output<int64_t, float> {
class linkage_output__int64_float : public linkage_output<int64_t> {
};

}; // namespace raft::cluster
3 changes: 1 addition & 2 deletions cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,5 +255,4 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx,
return unravel_index_impl<uint32_t>(static_cast<uint32_t>(idx), shape);
}
}

} // namespace raft
} // namespace raft
107 changes: 103 additions & 4 deletions cpp/include/raft/matrix/col_wise_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/matrix/detail/columnWiseSort.cuh>

namespace raft {
namespace matrix {
namespace raft::matrix {

/**
* @brief sort columns within each row of row-major input matrix and return sorted indexes
Expand Down Expand Up @@ -50,7 +51,105 @@ void sort_cols_per_row(const InType* in,
detail::sortColumnsPerRow<InType, OutType>(
in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys);
}
}; // end namespace matrix
}; // end namespace raft

/**
* @brief sort columns within each row of row-major input matrix and return sorted indexes
* modelled as key-value sort with key being input matrix and value being index of values
* @tparam in_t: element type of input matrix
* @tparam out_t: element type of output matrix
* @tparam matrix_idx_t: integer type for matrix indexing
* @param[in] handle: raft handle
* @param[in] in: input matrix
* @param[out] out: output value(index) matrix
* @param[out] sorted_keys: Optional, output matrix for sorted keys (input)
*/
template <typename in_t, typename out_t, typename matrix_idx_t>
void sort_cols_per_row(const raft::handle_t& handle,
raft::device_matrix_view<const in_t, matrix_idx_t, raft::row_major> in,
raft::device_matrix_view<out_t, matrix_idx_t, raft::row_major> out,
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>>
sorted_keys = std::nullopt)
{
RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0),
"Input and output matrices must have the same shape.");

if (sorted_keys.has_value()) {
RAFT_EXPECTS(in.extent(1) == sorted_keys.value().extent(1) &&
in.extent(0) == sorted_keys.value().extent(0),
"Input and `sorted_keys` matrices must have the same shape.");
}

size_t workspace_size = 0;
bool alloc_workspace = false;

in_t* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr;

detail::sortColumnsPerRow<in_t, out_t>(in.data_handle(),
out.data_handle(),
in.extent(0),
in.extent(1),
alloc_workspace,
(void*)nullptr,
workspace_size,
handle.get_stream(),
keys);

if (alloc_workspace) {
auto workspace = raft::make_device_vector<char>(handle, workspace_size);

detail::sortColumnsPerRow<in_t, out_t>(in.data_handle(),
out.data_handle(),
in.extent(0),
in.extent(1),
alloc_workspace,
(void*)workspace.data_handle(),
workspace_size,
handle.get_stream(),
keys);
}
}

namespace sort_cols_per_row_impl {
template <typename T>
struct sorted_keys_alias {
};

template <>
struct sorted_keys_alias<std::nullopt_t> {
using type = double;
};

template <typename in_t, typename matrix_idx_t>
struct sorted_keys_alias<
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>>> {
using type = typename raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>::value_type;
};

template <typename T>
using sorted_keys_t = typename sorted_keys_alias<T>::type;
} // namespace sort_cols_per_row_impl

/**
* @brief Overload of `sort_keys_per_row` to help the
* compiler find the above overload, in case users pass in
* `std::nullopt` for one or both of the optional arguments.
*
* Please see above for documentation of `sort_keys_per_row`.
*/
template <typename in_t, typename out_t, typename matrix_idx_t, typename sorted_keys_vector_type>
void sort_cols_per_row(const raft::handle_t& handle,
raft::device_matrix_view<const in_t, matrix_idx_t, raft::row_major> in,
raft::device_matrix_view<out_t, matrix_idx_t, raft::row_major> out,
sorted_keys_vector_type sorted_keys)
{
using sorted_keys_type = sort_cols_per_row_impl::sorted_keys_t<
std::remove_const_t<std::remove_reference_t<sorted_keys_vector_type>>>;
std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>> sorted_keys_opt =
std::forward<sorted_keys_vector_type>(sorted_keys);

sort_cols_per_row(handle, in, out, sorted_keys_opt);
}

}; // end namespace raft::matrix

#endif
Loading

0 comments on commit d265e58

Please sign in to comment.