From c7bdbdb708dd4543c9319091c436aa7972c2bcc3 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 6 Jul 2017 04:17:13 +0000 Subject: [PATCH] follow comments --- paddle/platform/device_context.h | 39 +++++++++----------------- paddle/platform/device_context_test.cc | 22 +++++++-------- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 94ec3b751576e..fcef0a5e3058f 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -14,15 +14,14 @@ limitations under the License. */ #pragma once +#include "paddle/framework/enforce.h" #ifndef PADDLE_ONLY_CPU #include "paddle/platform/cuda.h" -#define EIGEN_USE_GPU -#endif - -#include "paddle/framework/enforce.h" #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" +#define EIGEN_USE_GPU +#endif #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -34,37 +33,27 @@ class DeviceContext { virtual ~DeviceContext() {} }; -class CpuDeviceContext : public DeviceContext { - Eigen::DefaultDevice eigen_device() { - if (!eigen_device_) { - eigen_device_ = new Eigen::DefaultDevice(); - } - return *eigen_device_; - } - - private: - Eigen::DefaultDevice* eigen_device_{nullptr}; -}; +class CPUDeviceContext : public DeviceContext {}; #ifndef PADDLE_ONLY_CPU -class DeviceGuard { +class GPUPlaceGuard { public: - explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { + explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { if (previous_ != new_place) { paddle::platform::SetDeviceId(new_place.device); } } - ~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); } + ~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } private: GPUPlace previous_; }; -class CudaDeviceContext : public DeviceContext { +class CUDADeviceContext : public DeviceContext { public: - explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { - DeviceGuard guard(gpu_place_); + explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); @@ -82,7 +71,7 @@ class CudaDeviceContext : public DeviceContext { cublasHandle_t cublas_handle() { if (!blas_handle_) { - DeviceGuard guard(gpu_place_); + GPUPlaceGuard guard(gpu_place_); PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS, "cublasCreate failed"); @@ -95,7 +84,7 @@ class CudaDeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() { if (!dnn_handle_) { - DeviceGuard guard(gpu_place_); + GPUPlaceGuard guard(gpu_place_); PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS, "cudnnCreate failed"); @@ -108,7 +97,7 @@ class CudaDeviceContext : public DeviceContext { curandGenerator_t curand_generator() { if (!rand_generator_) { - DeviceGuard guard(gpu_place_); + GPUPlaceGuard guard(gpu_place_); PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == CURAND_STATUS_SUCCESS, @@ -124,7 +113,7 @@ class CudaDeviceContext : public DeviceContext { return rand_generator_; } - ~CudaDeviceContext() { + ~CUDADeviceContext() { Wait(); if (blas_handle_) { PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 9074206cee9e4..61be4a307dbf0 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -15,19 +15,19 @@ limitations under the License. */ #include "paddle/platform/device_context.h" #include "gtest/gtest.h" -TEST(DeviceContext, CudaDevice) { +TEST(CUDADeviceContext, Init) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { - paddle::platform::CudaDeviceContext* device_context = - new paddle::platform::CudaDeviceContext(i); - __attribute__((unused)) Eigen::GpuDevice gpu_device = - device_context->eigen_device(); - __attribute__((unused)) cudnnHandle_t cudnn_handle = - device_context->cudnn_handle(); - __attribute__((unused)) cublasHandle_t cublas_handle = - device_context->cublas_handle(); - __attribute__((unused)) curandGenerator_t curand_handle = - device_context->curand_generator(); + paddle::platform::CUDADeviceContext* device_context = + new paddle::platform::CUDADeviceContext(i); + Eigen::GpuDevice gpu_device = device_context->eigen_device(); + ASSERT_NE(nullptr, gpu_device.stream()); + cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device_context->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); + curandGenerator_t curand_handle = device_context->curand_generator(); + ASSERT_NE(nullptr, curand_handle); delete device_context; } }