Skip to content

Commit

Permalink
follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
QiJune committed Jul 6, 2017
1 parent 0c13b23 commit c7bdbdb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 36 deletions.
39 changes: 14 additions & 25 deletions paddle/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ limitations under the License. */

#pragma once

#include "paddle/framework/enforce.h"
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cuda.h"
#define EIGEN_USE_GPU
#endif

#include "paddle/framework/enforce.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU
#endif
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"

Expand All @@ -34,37 +33,27 @@ class DeviceContext {
virtual ~DeviceContext() {}
};

class CpuDeviceContext : public DeviceContext {
Eigen::DefaultDevice eigen_device() {
if (!eigen_device_) {
eigen_device_ = new Eigen::DefaultDevice();
}
return *eigen_device_;
}

private:
Eigen::DefaultDevice* eigen_device_{nullptr};
};
class CPUDeviceContext : public DeviceContext {};

#ifndef PADDLE_ONLY_CPU
class DeviceGuard {
class GPUPlaceGuard {
public:
explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
if (previous_ != new_place) {
paddle::platform::SetDeviceId(new_place.device);
}
}

~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); }
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }

private:
GPUPlace previous_;
};

class CudaDeviceContext : public DeviceContext {
class CUDADeviceContext : public DeviceContext {
public:
explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
DeviceGuard guard(gpu_place_);
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed");
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
Expand All @@ -82,7 +71,7 @@ class CudaDeviceContext : public DeviceContext {

cublasHandle_t cublas_handle() {
if (!blas_handle_) {
DeviceGuard guard(gpu_place_);
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
CUBLAS_STATUS_SUCCESS,
"cublasCreate failed");
Expand All @@ -95,7 +84,7 @@ class CudaDeviceContext : public DeviceContext {

cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
DeviceGuard guard(gpu_place_);
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
CUDNN_STATUS_SUCCESS,
"cudnnCreate failed");
Expand All @@ -108,7 +97,7 @@ class CudaDeviceContext : public DeviceContext {

curandGenerator_t curand_generator() {
if (!rand_generator_) {
DeviceGuard guard(gpu_place_);
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
CURAND_STATUS_SUCCESS,
Expand All @@ -124,7 +113,7 @@ class CudaDeviceContext : public DeviceContext {
return rand_generator_;
}

~CudaDeviceContext() {
~CUDADeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==
Expand Down
22 changes: 11 additions & 11 deletions paddle/platform/device_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ limitations under the License. */
#include "paddle/platform/device_context.h"
#include "gtest/gtest.h"

TEST(DeviceContext, CudaDevice) {
TEST(CUDADeviceContext, Init) {
int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) {
paddle::platform::CudaDeviceContext* device_context =
new paddle::platform::CudaDeviceContext(i);
__attribute__((unused)) Eigen::GpuDevice gpu_device =
device_context->eigen_device();
__attribute__((unused)) cudnnHandle_t cudnn_handle =
device_context->cudnn_handle();
__attribute__((unused)) cublasHandle_t cublas_handle =
device_context->cublas_handle();
__attribute__((unused)) curandGenerator_t curand_handle =
device_context->curand_generator();
paddle::platform::CUDADeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device.stream());
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
curandGenerator_t curand_handle = device_context->curand_generator();
ASSERT_NE(nullptr, curand_handle);
delete device_context;
}
}

0 comments on commit c7bdbdb

Please sign in to comment.