Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] RAFT distance prims public API update #4280

Merged
2 changes: 1 addition & 1 deletion cpp/bench/prims/distance_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <raft/cudart_utils.h>
#include <common/ml_benchmark.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>

namespace MLCommon {
namespace Bench {
Expand Down
11 changes: 7 additions & 4 deletions cpp/bench/prims/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#include <raft/cudart_utils.h>
#include <common/ml_benchmark.hpp>
#include <limits>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/distance/fused_l2_nn.hpp>
#include <raft/handle.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/random/rng.cuh>

Expand All @@ -43,13 +44,15 @@ struct FusedL2NN : public Fixture {
alloc(out, params.m);
alloc(workspace, params.m);
raft::random::Rng r(123456ULL);
raft::handle_t handle;
handle.set_stream(stream);

r.uniform(x, params.m * params.k, T(-1.0), T(1.0), stream);
r.uniform(y, params.n * params.k, T(-1.0), T(1.0), stream);
raft::linalg::rowNorm(xn, x, params.k, params.m, raft::linalg::L2Norm, true, stream);
raft::linalg::rowNorm(yn, y, params.k, params.n, raft::linalg::L2Norm, true, stream);
auto blks = raft::ceildiv(params.m, 256);
raft::distance::initKernel<T, cub::KeyValuePair<int, T>, int>
<<<blks, 256, 0, stream>>>(out, params.m, std::numeric_limits<T>::max(), op);
raft::distance::initialize<T, cub::KeyValuePair<int, T>, int>(
handle, out, params.m, std::numeric_limits<T>::max(), op);
}

void deallocateBuffers(const ::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ set(CUML_BRANCH_VERSION_raft "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}")
find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${CUML_BRANCH_VERSION_raft}
)
)
2 changes: 1 addition & 1 deletion cpp/src/hdbscan/detail/reachability.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include <rmm/exec_policy.hpp>

#include <cuml/neighbors/knn.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>

#include <thrust/transform.h>

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/kmeans/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include <raft/cudart_utils.h>
#include <raft/comms/comms.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/distance/fused_l2_nn.hpp>
#include <raft/linalg/binary_op.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/mean_squared_error.cuh>
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

#include <raft/sparse/distance/common.h>
#include <cuml/metrics/metrics.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <raft/sparse/distance/distance.cuh>
#include <raft/sparse/distance/distance.hpp>
#include "pairwise_distance_canberra.cuh"
#include "pairwise_distance_chebyshev.cuh"
#include "pairwise_distance_correlation.cuh"
Expand Down
16 changes: 5 additions & 11 deletions cpp/src/metrics/pairwise_distance_canberra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_canberra.cuh"
Expand All @@ -33,12 +33,9 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::Canberra, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

void pairwise_distance_canberra(const raft::handle_t& handle,
Expand All @@ -51,12 +48,9 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::Canberra, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_canberra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
14 changes: 5 additions & 9 deletions cpp/src/metrics/pairwise_distance_chebyshev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_chebyshev.cuh"
Expand All @@ -32,11 +32,9 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());
// Call the distance function
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::Linf, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

void pairwise_distance_chebyshev(const raft::handle_t& handle,
Expand All @@ -49,11 +47,9 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());
// Call the distance function
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::Linf, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_chebyshev.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/
#pragma once
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
16 changes: 5 additions & 11 deletions cpp/src/metrics/pairwise_distance_correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_correlation.cuh"
Expand All @@ -33,13 +33,10 @@ void pairwise_distance_correlation(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::CorrelationExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

void pairwise_distance_correlation(const raft::handle_t& handle,
Expand All @@ -52,13 +49,10 @@ void pairwise_distance_correlation(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::CorrelationExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_correlation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
18 changes: 7 additions & 11 deletions cpp/src/metrics/pairwise_distance_cosine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_cosine.cuh"
Expand All @@ -33,12 +33,10 @@ void pairwise_distance_cosine(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::CosineExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::
distance<raft::distance::DistanceType::CosineExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

void pairwise_distance_cosine(const raft::handle_t& handle,
Expand All @@ -51,11 +49,9 @@ void pairwise_distance_cosine(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::CosineExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::CosineExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
#pragma once

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
41 changes: 18 additions & 23 deletions cpp/src/metrics/pairwise_distance_euclidean.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_euclidean.cuh"
Expand All @@ -34,29 +34,27 @@ void pairwise_distance_euclidean(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::L2Expanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::
distance<raft::distance::DistanceType::L2Expanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::L2SqrtExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2Unexpanded:
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::L2Unexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2Unexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::L2SqrtUnexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
Expand All @@ -73,29 +71,26 @@ void pairwise_distance_euclidean(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::L2Expanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
raft::distance::distance<raft::distance::DistanceType::L2Expanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::L2SqrtExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2Unexpanded:
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::L2Unexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2Unexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::L2SqrtUnexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
distance<raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/metrics/pairwise_distance_euclidean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/
#pragma once
#include <raft/distance/distance.cuh>
#include <raft/distance/distance.hpp>
#include <raft/handle.hpp>

namespace ML {
Expand Down
Loading