diff --git a/src/common/cuda_utils.cc b/src/common/cuda_utils.cc new file mode 100644 index 000000000000..cbf46508cdbf --- /dev/null +++ b/src/common/cuda_utils.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file cuda_utils.cc + * \brief Common CUDA utilities. + */ + +#include +#include +#include "cuda_utils.h" + +#if MXNET_USE_CUDA + +namespace mxnet { +namespace common { +namespace cuda { + +int get_load_type(size_t N) { + using namespace mshadow; + if (N % 8 == 0) { + return kFloat64; + } else if (N % 4 == 0) { + return kFloat32; + } else if (N % 2 == 0) { + return kFloat16; + } else { + return kInt8; + } +} +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index acc8d5fac6df..f16607d8b716 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -20,7 +20,7 @@ /*! * Copyright (c) 2015 by Contributors * \file cuda_utils.h - * \brief CUDA debugging utilities. + * \brief Common CUDA utilities. */ #ifndef MXNET_COMMON_CUDA_UTILS_H_ #define MXNET_COMMON_CUDA_UTILS_H_ @@ -326,6 +326,15 @@ class DeviceStore { bool restore_; }; +/*! \brief Get the largest datatype suitable to read + * requested number of bytes. + * + * \input Number of bytes to be read + * \return mshadow representation of type that could + * be used for reading + */ +int get_load_type(size_t N); + } // namespace cuda } // namespace common } // namespace mxnet @@ -550,7 +559,7 @@ static inline __device__ void atomicAdd(double *address, double val) { // Overload atomicAdd for half precision // Taken from: // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh -#if defined(__CUDA_ARCH__) +#ifdef __CUDACC__ static inline __device__ void atomicAdd(mshadow::half::half_t *address, mshadow::half::half_t val) { unsigned int *address_as_ui = @@ -615,6 +624,28 @@ __device__ inline DType ldg(const DType* address) { return *address; #endif } -#endif + +template +__device__ inline T warp_reduce(T value, OP redfun) { + value = redfun(value, __shfl_down_sync(0xffffffff, value, 16)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 8)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 4)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 2)); + value = redfun(value, __shfl_down_sync(0xffffffff, value, 1)); + return value; +} + +template +__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { + float v = static_cast(value); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 16)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 8)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 4)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 2)); + v = redfun(v, __shfl_down_sync(0xffffffff, v, 1)); + return mshadow::half::half_t(v); +} + +#endif // __CUDACC__ #endif // MXNET_COMMON_CUDA_UTILS_H_ diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index ad02ec9da35e..0ff42f4d7d63 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -34,6 +34,7 @@ #include "../mxnet_op.h" #include "../operator_common.h" #include "../tensor/broadcast_reduce_op.h" +#include "../../common/cuda_utils.h" namespace mxnet { namespace op { @@ -312,27 +313,6 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, IType *length, const int softmax_threads_per_block = 512; -template -__device__ inline T warp_reduce(T value, OP redfun) { - value = redfun(value, __shfl_down_sync(0xffffffff, value, 16)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 8)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 4)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 2)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 1)); - return value; -} - -template -__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { - float v = static_cast(value); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 16)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 8)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 4)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 2)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 1)); - return mshadow::half::half_t(v); -} - template __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, IType *length, @@ -356,7 +336,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp // the division by zero warning generated for such invalid cases. const int row_length = entries_per_load > 0 ? M / entries_per_load : 0; - const LType * in_aligned = reinterpret_cast(in); + const LType* in_aligned = reinterpret_cast(in); size_t base = my_row * row_length; for (index_t i = my_id; i < row_length; i += threads_per_row) { @@ -420,7 +400,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp } __syncthreads(); - LType * out_aligned = reinterpret_cast(out); + LType* out_aligned = reinterpret_cast(out); for (index_t i = my_id; i < row_length; i += threads_per_row) { out_aligned[base + i] = persistent_storage[my_local_row * row_length + i]; @@ -429,18 +409,6 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp namespace { -int get_load_type(size_t N) { - if (N % 8 == 0) { - return kFloat64; - } else if (N % 4 == 0) { - return kFloat32; - } else if (N % 2 == 0) { - return kFloat16; - } else { - return kInt8; - } -} - int get_rows_per_block(size_t N) { const int warp_size = 32; // How many read instructions should 1 thread at least do @@ -479,9 +447,9 @@ inline void Softmax(Stream *s, DType *in, OType *out, IType *length, // Using 20 kB of shared memory for persistent storage in the optimized case const size_t max_opt_M = 20 * 1024 / DSize; if (stride[axis] == 1 && - M <= max_opt_M && + static_cast(M) <= max_opt_M && std::is_same::value) { - int ltype = get_load_type(M * sizeof(DType)); + int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType)); MSHADOW_TYPE_SWITCH(ltype, LType, { int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType)); int nblocks = (N + rows_per_block - 1) / rows_per_block;