diff --git a/cpp/bench/prims/gram_matrix.cu b/cpp/bench/prims/gram_matrix.cu index 5ce6a21032..c561a875c2 100644 --- a/cpp/bench/prims/gram_matrix.cu +++ b/cpp/bench/prims/gram_matrix.cu @@ -58,7 +58,7 @@ struct GramMatrix : public Fixture { std::unique_ptr>(KernelFactory::create(p.kernel_params, cublas_handle)); } - ~GramMatrix() { CUBLAS_CHECK(cublasDestroy(cublas_handle)); } + ~GramMatrix() { CUBLAS_CHECK_NO_THROW(cublasDestroy(cublas_handle)); } protected: void allocateBuffers(const ::benchmark::State& state) override diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake index 0d43f702d2..8dbab5fbfc 100644 --- a/cpp/cmake/modules/ConfigureCUDA.cmake +++ b/cpp/cmake/modules/ConfigureCUDA.cmake @@ -25,8 +25,10 @@ endif() list(APPEND CUML_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr) # set warnings as errors -# list(APPEND CUML_CUDA_FLAGS -Werror=cross-execution-space-call) -# list(APPEND CUML_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations) +if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.2.0) + list(APPEND CUML_CUDA_FLAGS -Werror=all-warnings) +endif() +list(APPEND CUML_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations) if(DISABLE_DEPRECATION_WARNING) list(APPEND CUML_CXX_FLAGS -Wno-deprecated-declarations) diff --git a/cpp/include/cuml/ensemble/randomforest.hpp b/cpp/include/cuml/ensemble/randomforest.hpp index 20a95d1c1b..4f82396dfc 100644 --- a/cpp/include/cuml/ensemble/randomforest.hpp +++ b/cpp/include/cuml/ensemble/randomforest.hpp @@ -19,8 +19,13 @@ #include #include #include + #include +namespace raft { +class handle_t; // forward decl +} + namespace ML { enum RF_type { diff --git a/cpp/include/cuml/random_projection/rproj_c.h b/cpp/include/cuml/random_projection/rproj_c.h index 20a8d2cc62..d4f1702b54 100644 --- a/cpp/include/cuml/random_projection/rproj_c.h +++ b/cpp/include/cuml/random_projection/rproj_c.h @@ -17,6 +17,7 @@ #pragma once #include + #include #include @@ -95,4 +96,4 @@ void RPROJtransform(const raft::handle_t& handle, size_t johnson_lindenstrauss_min_dim(size_t n_samples, double eps); -} // namespace ML \ No newline at end of file +} // namespace ML diff --git a/cpp/include/cuml/tree/decisiontree.hpp b/cpp/include/cuml/tree/decisiontree.hpp index 54020c45ec..827859f816 100644 --- a/cpp/include/cuml/tree/decisiontree.hpp +++ b/cpp/include/cuml/tree/decisiontree.hpp @@ -15,13 +15,12 @@ */ #pragma once -#include + #include "algo_helper.h" #include "flatnode.h" -namespace raft { -class handle_t; -} +#include +#include namespace ML { diff --git a/cpp/src/common/allocatorAdapter.hpp b/cpp/src/common/allocatorAdapter.hpp index f0f41d9d28..ca3f5c2cc5 100644 --- a/cpp/src/common/allocatorAdapter.hpp +++ b/cpp/src/common/allocatorAdapter.hpp @@ -16,14 +16,15 @@ #pragma once -#include - -#include - #include #include #include +#include + +#include +#include + namespace ML { template diff --git a/cpp/src/common/cumlHandle.hpp b/cpp/src/common/cumlHandle.hpp index 4120df9e41..4b0a4793fc 100644 --- a/cpp/src/common/cumlHandle.hpp +++ b/cpp/src/common/cumlHandle.hpp @@ -17,6 +17,7 @@ #pragma once #include + #include #include #include diff --git a/cpp/src/common/cuml_api.cpp b/cpp/src/common/cuml_api.cpp index e5fe9a6646..cca2793bca 100644 --- a/cpp/src/common/cuml_api.cpp +++ b/cpp/src/common/cuml_api.cpp @@ -14,17 +14,17 @@ * limitations under the License. */ +#include "cumlHandle.hpp" + #include -#include -#include +#include #include -#include -#include #include #include -#include "cumlHandle.hpp" +#include +#include namespace ML { namespace detail { diff --git a/cpp/src/common/tensor.hpp b/cpp/src/common/tensor.hpp index 8bb4b17221..8578556199 100644 --- a/cpp/src/common/tensor.hpp +++ b/cpp/src/common/tensor.hpp @@ -19,6 +19,7 @@ #include #include #include + #include namespace ML { @@ -171,7 +172,7 @@ class Tensor { std::shared_ptr _hAllocator; /// Raw pointer to where the tensor data begins - DataPtrT _data; + DataPtrT _data{}; /// Array of strides (in sizeof(T) terms) per each dimension IndexT _stride[Dim]; @@ -179,9 +180,9 @@ class Tensor { /// Size per each dimension IndexT _size[Dim]; - AllocState _state; + AllocState _state{}; - cudaStream_t _stream; + cudaStream_t _stream{}; }; }; // end namespace ML diff --git a/cpp/src/dbscan/adjgraph/algo.cuh b/cpp/src/dbscan/adjgraph/algo.cuh index effafb6c7f..13cbf3eae6 100644 --- a/cpp/src/dbscan/adjgraph/algo.cuh +++ b/cpp/src/dbscan/adjgraph/algo.cuh @@ -16,15 +16,17 @@ #pragma once -#include -#include -#include -#include #include "../common.cuh" #include "pack.h" +#include + +#include #include +#include +#include + using namespace thrust; namespace ML { @@ -32,8 +34,6 @@ namespace Dbscan { namespace AdjGraph { namespace Algo { -using namespace MLCommon; - static const int TPB_X = 256; /** @@ -61,4 +61,4 @@ void launcher(const raft::handle_t& handle, } // namespace Algo } // namespace AdjGraph } // namespace Dbscan -} // namespace ML \ No newline at end of file +} // namespace ML diff --git a/cpp/src/dbscan/dbscan.cuh b/cpp/src/dbscan/dbscan.cuh index 5250536aae..467ecb0839 100644 --- a/cpp/src/dbscan/dbscan.cuh +++ b/cpp/src/dbscan/dbscan.cuh @@ -16,13 +16,16 @@ #pragma once +#include "runner.cuh" + #include + #include #include #include -#include "runner.cuh" #include +#include namespace ML { namespace Dbscan { @@ -65,7 +68,7 @@ size_t compute_batch_size(size_t& estimated_memory, // To avoid overflow, we need: batch_size <= MAX_LABEL / n_rows (floor div) Index_ MAX_LABEL = std::numeric_limits::max(); - if (batch_size > MAX_LABEL / n_rows) { + if (batch_size > static_cast(MAX_LABEL / n_rows)) { Index_ new_batch_size = MAX_LABEL / n_rows; CUML_LOG_WARN( "Batch size limited by the chosen integer type (%d bytes). %d -> %d. " @@ -77,7 +80,8 @@ size_t compute_batch_size(size_t& estimated_memory, } // Warn when a smaller index type could be used - if (sizeof(Index_) > sizeof(int) && batch_size < std::numeric_limits::max() / n_rows) { + if ((sizeof(Index_) > sizeof(int)) && + (batch_size < std::numeric_limits::max() / static_cast(n_rows))) { CUML_LOG_WARN( "You are using an index type of size (%d bytes) but a smaller index " "type (%d bytes) would be sufficient. Using the smaller integer type " @@ -110,8 +114,11 @@ void dbscanFitImpl(const raft::handle_t& handle, int algo_adj = 1; int algo_ccl = 2; - int my_rank, n_rank; - Index_ start_row, n_owned_rows; + int my_rank{0}; + int n_rank{1}; + Index_ start_row{0}; + Index_ n_owned_rows{n_rows}; + if (opg) { const auto& comm = handle.get_comms(); my_rank = comm.get_rank(); @@ -122,10 +129,6 @@ void dbscanFitImpl(const raft::handle_t& handle, n_owned_rows = max(Index_(0), end_row - start_row); // Note: it is possible for a node to have no work in theory. It won't // happen in practice (because n_rows is much greater than n_rank) - } else { - my_rank = 0; - n_rank = 1; - n_owned_rows = n_rows; } CUML_LOG_DEBUG("#%d owns %ld rows", (int)my_rank, (unsigned long)n_owned_rows); @@ -200,4 +203,4 @@ void dbscanFitImpl(const raft::handle_t& handle, } } // namespace Dbscan -} // namespace ML \ No newline at end of file +} // namespace ML diff --git a/cpp/src/dbscan/mergelabels/runner.cuh b/cpp/src/dbscan/mergelabels/runner.cuh index e43ba382d1..506239e04a 100644 --- a/cpp/src/dbscan/mergelabels/runner.cuh +++ b/cpp/src/dbscan/mergelabels/runner.cuh @@ -18,6 +18,7 @@ #include