Skip to content

Commit

Permalink
RAFT distance prims public API update (#4280)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #4280
  • Loading branch information
cjnolet authored Oct 22, 2021
1 parent f6ded9e commit ee2e863
Show file tree
Hide file tree
Showing 40 changed files with 111 additions and 179 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/prims/distance_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <raft/cudart_utils.h>
#include <common/ml_benchmark.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>

namespace MLCommon {
namespace Bench {
Expand Down
11 changes: 7 additions & 4 deletions cpp/bench/prims/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#include <raft/cudart_utils.h>
#include <common/ml_benchmark.hpp>
#include <limits>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/distance/fused_l2_nn.hpp>
#include <raft/handle.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/random/rng.cuh>

Expand All @@ -43,13 +44,15 @@ struct FusedL2NN : public Fixture {
alloc(out, params.m);
alloc(workspace, params.m);
raft::random::Rng r(123456ULL);
raft::handle_t handle;
handle.set_stream(stream);

r.uniform(x, params.m * params.k, T(-1.0), T(1.0), stream);
r.uniform(y, params.n * params.k, T(-1.0), T(1.0), stream);
raft::linalg::rowNorm(xn, x, params.k, params.m, raft::linalg::L2Norm, true, stream);
raft::linalg::rowNorm(yn, y, params.k, params.n, raft::linalg::L2Norm, true, stream);
auto blks = raft::ceildiv(params.m, 256);
raft::distance::initKernel<T, cub::KeyValuePair<int, T>, int>
<<<blks, 256, 0, stream>>>(out, params.m, std::numeric_limits<T>::max(), op);
raft::distance::initialize<T, cub::KeyValuePair<int, T>, int>(
handle, out, params.m, std::numeric_limits<T>::max(), op);
}

void deallocateBuffers(const ::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ set(CUML_BRANCH_VERSION_raft "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}")
find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${CUML_BRANCH_VERSION_raft}
)
)
2 changes: 1 addition & 1 deletion cpp/src/hdbscan/detail/reachability.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include <rmm/exec_policy.hpp>

#include <cuml/neighbors/knn.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>

#include <thrust/transform.h>

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/kmeans/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include <raft/cudart_utils.h>
#include <raft/comms/comms.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/distance/fused_l2_nn.hpp>
#include <raft/linalg/binary_op.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/mean_squared_error.cuh>
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

#include <raft/sparse/distance/common.h>
#include <cuml/metrics/metrics.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <raft/sparse/distance/distance.cuh>
#include <raft/sparse/distance/distance.hpp>
#include "pairwise_distance_canberra.cuh"
#include "pairwise_distance_chebyshev.cuh"
#include "pairwise_distance_correlation.cuh"
Expand Down
16 changes: 5 additions & 11 deletions cpp/src/metrics/pairwise_distance_canberra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_canberra.cuh"
Expand All @@ -33,12 +33,9 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::Canberra, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

void pairwise_distance_canberra(const raft::handle_t& handle,
Expand All @@ -51,12 +48,9 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::Canberra, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_canberra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
14 changes: 5 additions & 9 deletions cpp/src/metrics/pairwise_distance_chebyshev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_chebyshev.cuh"
Expand All @@ -32,11 +32,9 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());
// Call the distance function
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::Linf, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

void pairwise_distance_chebyshev(const raft::handle_t& handle,
Expand All @@ -49,11 +47,9 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());
// Call the distance function
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::Linf, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_chebyshev.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/
#pragma once
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
16 changes: 5 additions & 11 deletions cpp/src/metrics/pairwise_distance_correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_correlation.cuh"
Expand All @@ -33,13 +33,10 @@ void pairwise_distance_correlation(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::CorrelationExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

void pairwise_distance_correlation(const raft::handle_t& handle,
Expand All @@ -52,13 +49,10 @@ void pairwise_distance_correlation(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::CorrelationExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_correlation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
18 changes: 7 additions & 11 deletions cpp/src/metrics/pairwise_distance_cosine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_cosine.cuh"
Expand All @@ -33,12 +33,10 @@ void pairwise_distance_cosine(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::CosineExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::
distance<raft::distance::DistanceType::CosineExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

void pairwise_distance_cosine(const raft::handle_t& handle,
Expand All @@ -51,11 +49,9 @@ void pairwise_distance_cosine(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::CosineExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::CosineExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
#pragma once

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
41 changes: 18 additions & 23 deletions cpp/src/metrics/pairwise_distance_euclidean.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_euclidean.cuh"
Expand All @@ -34,29 +34,27 @@ void pairwise_distance_euclidean(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::L2Expanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::
distance<raft::distance::DistanceType::L2Expanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::L2SqrtExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2Unexpanded:
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::L2Unexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2Unexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::L2SqrtUnexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
Expand All @@ -73,29 +71,26 @@ void pairwise_distance_euclidean(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::L2Expanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::L2Expanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::L2SqrtExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2Unexpanded:
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::L2Unexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2Unexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::L2SqrtUnexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_euclidean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/
#pragma once
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
Loading

0 comments on commit ee2e863

Please sign in to comment.