Skip to content

Commit

Permalink
Pulling out argmax for now since the test seems to be failing in centos.
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Sep 30, 2022
1 parent e25caef commit 37ae236
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 165 deletions.
42 changes: 0 additions & 42 deletions cpp/include/raft/matrix/argmax.cuh

This file was deleted.

26 changes: 26 additions & 0 deletions cpp/include/raft/matrix/matrix_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace raft::matrix {

struct print_separators {
char horizontal = ' ';
char vertical = '\n';
};

} // namespace raft::matrix
33 changes: 9 additions & 24 deletions cpp/include/raft/matrix/print.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/matrix/detail/matrix.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/matrix/matrix_types.hpp>

namespace raft::matrix {

Expand All @@ -29,34 +30,18 @@ namespace raft::matrix {
* @tparam idx_t integer type used for indexing
* @param[in] handle: raft handle
* @param[in] in: input matrix
* @param[in] h_separator: horizontal separator character
* @param[in] v_separator: vertical separator character
* @param[in] separators: horizontal and vertical separator characters
*/
template <typename m_t, typename idx_t>
void print(const raft::handle_t& handle,
raft::device_matrix_view<const m_t, idx_t, col_major> in,
char h_separator = ' ',
char v_separator = '\n')
print_separators& separators)
{
detail::print(
in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator, handle.get_stream());
}

/**
* @brief Prints the host data stored in CPU memory
* @tparam m_t type of matrix elements
* @tparam idx_t integer type used for indexing
* @param[in] handle raft handle for managing resources
* @param[in] in input matrix with column-major layout
* @param[in] h_separator: horizontal separator character
* @param[in] v_separator: vertical separator character
*/
template <typename m_t, typename idx_t>
void print(const raft::handle_t& handle,
raft::host_matrix_view<const m_t, idx_t, col_major> in,
char h_separator = ' ',
char v_separator = '\n')
{
detail::printHost(in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator);
detail::print(in.data_handle(),
in.extent(0),
in.extent(1),
separators.horizontal,
separators.vertical,
handle.get_stream());
}
} // namespace raft::matrix
9 changes: 6 additions & 3 deletions cpp/include/raft/matrix/print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@

#include <raft/core/host_mdspan.hpp>
#include <raft/matrix/detail/print.hpp>
#include <raft/matrix/matrix_types.hpp>

namespace raft::matrix {

/**
* @brief Prints the data stored in CPU memory
* @param in: input matrix with column-major layout
* @param[in] in: input matrix with column-major layout
* @param[in] separators: horizontal and vertical separator characters
*/
template <typename m_t, typename idx_t>
void print(raft::host_matrix_view<const m_t, idx_t, col_major> in)
void print(raft::host_matrix_view<const m_t, idx_t, col_major> in, print_separators& separators)
{
detail::printHost(in.data_handle(), in.extent(0), in.extent(1));
detail::printHost(
in.data_handle(), in.extent(0), in.extent(1), separators.horizontal, separators.vertical);
}
} // namespace raft::matrix
1 change: 0 additions & 1 deletion cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ if(BUILD_TESTS)

ConfigureTest(NAME MATRIX_TEST
PATH
test/matrix/argmax.cu
test/matrix/gather.cu
test/matrix/math.cu
test/matrix/matrix.cu
Expand Down
95 changes: 0 additions & 95 deletions cpp/test/matrix/argmax.cu

This file was deleted.

0 comments on commit 37ae236

Please sign in to comment.