diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh deleted file mode 100644 index b7423b9ea4..0000000000 --- a/cpp/include/raft/matrix/argmax.cuh +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace raft::matrix { - -/** - * @brief Argmax: find the row idx with maximum value for each column - * @tparam math_t matrix element type - * @tparam idx_t integer type for matrix and vector indexing - * @param[in] handle: raft handle - * @param[in] in: input matrix of size (n_rows, n_cols) - * @param[out] out: output vector of size n_cols - */ -template -void argmax(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_vector_view out) -{ - RAFT_EXPECTS(out.extent(0) == in.extent(0), - "Size of output vector must equal number of rows in input matrix."); - detail::argmax( - in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); -} -} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/matrix_types.hpp b/cpp/include/raft/matrix/matrix_types.hpp new file mode 100644 index 0000000000..1f22154627 --- /dev/null +++ b/cpp/include/raft/matrix/matrix_types.hpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::matrix { + +struct print_separators { + char horizontal = ' '; + char vertical = '\n'; +}; + +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh index def9fc9182..4d3a8ca938 100644 --- a/cpp/include/raft/matrix/print.cuh +++ b/cpp/include/raft/matrix/print.cuh @@ -20,6 +20,7 @@ #include #include #include +#include namespace raft::matrix { @@ -29,34 +30,18 @@ namespace raft::matrix { * @tparam idx_t integer type used for indexing * @param[in] handle: raft handle * @param[in] in: input matrix - * @param[in] h_separator: horizontal separator character - * @param[in] v_separator: vertical separator character + * @param[in] separators: horizontal and vertical separator characters */ template void print(const raft::handle_t& handle, raft::device_matrix_view in, - char h_separator = ' ', - char v_separator = '\n') + print_separators& separators) { - detail::print( - in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator, handle.get_stream()); -} - -/** - * @brief Prints the host data stored in CPU memory - * @tparam m_t type of matrix elements - * @tparam idx_t integer type used for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in input matrix with column-major layout - * @param[in] h_separator: horizontal separator character - * @param[in] v_separator: vertical separator character - */ -template -void print(const raft::handle_t& handle, - raft::host_matrix_view in, - char h_separator = ' ', - char v_separator = '\n') -{ - detail::printHost(in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator); + detail::print(in.data_handle(), + in.extent(0), + in.extent(1), + separators.horizontal, + separators.vertical, + handle.get_stream()); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/print.hpp b/cpp/include/raft/matrix/print.hpp index 66e939be0f..86c314ed44 100644 --- a/cpp/include/raft/matrix/print.hpp +++ b/cpp/include/raft/matrix/print.hpp @@ -18,16 +18,19 @@ #include #include +#include namespace raft::matrix { /** * @brief Prints the data stored in CPU memory - * @param in: input matrix with column-major layout + * @param[in] in: input matrix with column-major layout + * @param[in] separators: horizontal and vertical separator characters */ template -void print(raft::host_matrix_view in) +void print(raft::host_matrix_view in, print_separators& separators) { - detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); + detail::printHost( + in.data_handle(), in.extent(0), in.extent(1), separators.horizontal, separators.vertical); } } // namespace raft::matrix diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index fe2504606b..a18a750e4b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -157,7 +157,6 @@ if(BUILD_TESTS) ConfigureTest(NAME MATRIX_TEST PATH - test/matrix/argmax.cu test/matrix/gather.cu test/matrix/math.cu test/matrix/matrix.cu diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu deleted file mode 100644 index 70884af4de..0000000000 --- a/cpp/test/matrix/argmax.cu +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.h" -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace matrix { - -template -struct ArgMaxInputs { - std::vector input_matrix; - std::vector output_matrix; - std::size_t n_cols; - std::size_t n_rows; -}; - -template -::std::ostream& operator<<(::std::ostream& os, const ArgMaxInputs& dims) -{ - return os; -} - -template -class ArgMaxTest : public ::testing::TestWithParam> { - public: - ArgMaxTest() : params(::testing::TestWithParam>::GetParam()) {} - - void test() - { - auto input = raft::make_device_matrix(handle, params.n_rows, params.n_cols); - auto output = raft::make_device_vector(handle, params.n_rows); - auto expected = raft::make_device_vector(handle, params.n_rows); - - raft::update_device(input.data_handle(), - params.input_matrix.data(), - params.n_rows * params.n_cols, - handle.get_stream()); - raft::update_device( - expected.data_handle(), params.output_matrix.data(), params.n_rows, handle.get_stream()); - - auto input_view = raft::make_device_matrix_view( - input.data_handle(), params.n_rows, params.n_cols); - - raft::matrix::argmax(handle, input_view, output.view()); - - handle.sync_stream(); - - ASSERT_TRUE(devArrMatch(output.data_handle(), - expected.data_handle(), - params.n_rows, - Compare(), - handle.get_stream())); - } - - protected: - raft::handle_t handle; - ArgMaxInputs params; -}; - -const std::vector> inputsf = { - {{0.1f, 0.2f, 0.3f, 0.4f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.5f, 0.0f}, {3, 0, 2}, 3, 4}}; - -const std::vector> inputsd = { - {{0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1, 0.2, 0.3, 0.5, 0.0}, {3, 0, 2}, 3, 4}}; - -typedef ArgMaxTest ArgMaxTestF; -TEST_P(ArgMaxTestF, Result) { test(); } - -typedef ArgMaxTest ArgMaxTestD; -TEST_P(ArgMaxTestD, Result) { test(); } - -INSTANTIATE_TEST_CASE_P(ArgMaxTest, ArgMaxTestF, ::testing::ValuesIn(inputsf)); -INSTANTIATE_TEST_CASE_P(ArgMaxTest, ArgMaxTestD, ::testing::ValuesIn(inputsd)); - -} // namespace matrix -} // namespace raft