Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gram matrix support for sparse input #1296

Merged
merged 34 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
959bb29
gram matrix support for csr
mfoerste4 Dec 8, 2022
36c56b1
Add CSRxDense kernel compute, also add row norm for CSR
mfoerste4 Feb 2, 2023
a99c129
fix RBF for dense with offset
mfoerste4 Feb 3, 2023
60017db
add matrix wrapper to unify kernel API
mfoerste4 Feb 21, 2023
d0a0de3
merge with 23.04
mfoerste4 Feb 21, 2023
9f46742
finalize merge, adjust/add tests
mfoerste4 Feb 21, 2023
c096495
add test and fix rbf
mfoerste4 Feb 22, 2023
c8f3a2d
Merge branch 'branch-23.04' into sparse_kernels
cjnolet Feb 28, 2023
f87e514
Merge branch 'branch-23.04' into sparse_kernels
cjnolet Mar 11, 2023
8174693
review suggestions
mfoerste4 Mar 12, 2023
66e5534
Merge branch 'branch-23.04' into sparse_kernels
cjnolet Mar 13, 2023
3ab7226
Merge branch 'sparse_kernels' of github.com:mfoerste4/raft into spars…
mfoerste4 Mar 13, 2023
5bbcd00
review comments norm
mfoerste4 Mar 14, 2023
86a0314
removed handle member, but re-introduced old API to ensure backwards …
mfoerste4 Mar 14, 2023
fd49e35
Merge branch 'rapidsai:branch-23.04' into sparse_kernels
mfoerste4 Mar 18, 2023
591b77d
changed GramMatrix API to support device_mdspan/device_csr_matrix_vie…
mfoerste4 Mar 20, 2023
0e30226
mere conflict
mfoerste4 Mar 27, 2023
c672226
Merge branch 'rapidsai:branch-23.04' into sparse_kernels
mfoerste4 Mar 29, 2023
2403b2d
utilize public API for spmm, gemm
mfoerste4 Mar 30, 2023
f57be13
refactored rowNormCsr to utilize csr_row_op
mfoerste4 Mar 30, 2023
3f61b64
changed order of arguments according to best practice
mfoerste4 Mar 30, 2023
2b6090a
moved kernel computation to public section
mfoerste4 Mar 30, 2023
c34f242
Merge branch 'branch-23.04' into sparse_kernels
mfoerste4 Apr 5, 2023
563032c
removed outdated docstring
mfoerste4 Apr 5, 2023
23e308d
fix row-major algorithm selection for cusparse spmm
mfoerste4 Apr 5, 2023
a5ee783
fixed doc build
mfoerste4 Apr 5, 2023
d7001ff
Merge branch 'branch-23.04' into sparse_kernels
mfoerste4 Apr 6, 2023
d7be021
Merge branch 'branch-23.06' into sparse_kernels
cjnolet Apr 12, 2023
d7d2f5b
reverted changeset 2b6090a860e6fe36c6c63beb50939bceca13d6f2
mfoerste4 Apr 18, 2023
8d851b3
Merge branch 'branch-23.06' into sparse_kernels
mfoerste4 Apr 18, 2023
089612c
Merge branch 'rapidsai:branch-23.06' into sparse_kernels
mfoerste4 Apr 20, 2023
f2ebd76
merge API conflicts with recent updates to sparse structures
mfoerste4 Apr 20, 2023
b22d7fe
Merge branch 'branch-23.06' into sparse_kernels
mfoerste4 Apr 24, 2023
ae8fbb5
Fixing build
cjnolet Apr 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/bench/prims/distance/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <common/benchmark.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/util/cudart_utils.hpp>
#if defined RAFT_COMPILED
#include <raft/distance/specializations.cuh>
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/detail/nvtx.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
30 changes: 30 additions & 0 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,36 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col
return device_matrix_view<ElementType, IndexType, LayoutPolicy>{ptr, extents};
}

/**
* @brief Create a 2-dim mdspan instance for device pointer with a strided layout
* that is restricted to stride 1 in the trailing dimension. It's
* expected that the given layout policy match the layout of the underlying
* pointer.
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] ptr on device to wrap
* @param[in] n_rows number of rows in pointer
* @param[in] n_cols number of columns in pointer
* @param[in] stride leading dimension / stride of data
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_strided_matrix_view(ElementType* ptr,
IndexType n_rows,
IndexType n_cols,
IndexType stride)
{
constexpr auto is_row_major = std::is_same_v<LayoutPolicy, layout_c_contiguous>;
IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1;
IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows);

assert(is_row_major ? stride0 >= n_cols : stride1 >= n_rows);
matrix_extent<IndexType> extents{n_rows, n_cols};

auto layout = make_strided_layout(extents, std::array<IndexType, 2>{stride0, stride1});
return device_matrix_view<ElementType, IndexType, layout_stride>{ptr, layout};
}

/**
* @brief Create a 1-dim mdspan instance for device pointer.
* @tparam ElementType the data type of the vector elements
Expand Down
Loading