Skip to content

Commit

Permalink
Revert "accelerate dropout (#9902)" (#10082)
Browse files Browse the repository at this point in the history
* Revert "accelerate dropout (#9902)"

This reverts commit 2e331c6.

* Correct discard
  • Loading branch information
reyoung authored and dzhwinter committed Apr 20, 2018
1 parent ad91bfe commit f2e400d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 44 deletions.
50 changes: 29 additions & 21 deletions paddle/fluid/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,34 @@ namespace paddle {
namespace operators {

template <typename T>
__global__ void RandomGenerator(const size_t n, const T* src,
const T* cpu_mask_data, T* mask_data, T* dst) {
__global__ void RandomGenerator(const size_t n, const int seed,
const float dropout_prob, const T* src,
T* mask_data, T* dst) {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<float> dist(0, 1);

int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;

T mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
mask_data[idx] = cpu_mask_data[idx];
dst[idx] = mask_data[idx] * src[idx];
T s = src[idx];
if (step_size == 0) {
rng.discard(idx);
step_size = blockDim.x * gridDim.x;
} else {
rng.discard(step_size);
}
if (dist(rng) < dropout_prob) {
mask = static_cast<T>(0);
} else {
mask = static_cast<T>(1);
}
dest = s * mask;
mask_data[idx] = mask;
dst[idx] = dest;
}
}

Expand All @@ -56,27 +78,15 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
std::random_device rnd;
int seed =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1);
framework::Vector<T> cpu_mask(size);
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
cpu_mask[i] = static_cast<T>(0);
} else {
cpu_mask[i] = static_cast<T>(1);
}
}

int threads = 512;
int grid = (x->numel() + threads - 1) / threads;
RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, x_data, cpu_mask.CUDAData(context.GetPlace()), mask_data,
y_data);
size, seed, dropout_prob, x_data, mask_data, y_data);
} else {
auto X = EigenVector<T>::Flatten(*x);
auto Y = EigenVector<T>::Flatten(*y);
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
Expand All @@ -89,8 +99,6 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, double>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(dropout_grad,
ops::DropoutGradKernel<plat::CUDADeviceContext, double>,
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
12 changes: 6 additions & 6 deletions paddle/fluid/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> {
Expand Down Expand Up @@ -60,8 +60,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
}
}
} else {
auto X = EigenVector<T>::Flatten(*x);
auto Y = EigenVector<T>::Flatten(*y);
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * (1.0f - dropout_prob);
Expand All @@ -81,9 +81,9 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());

auto M = EigenVector<T>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(*grad_y);
auto M = EigenMatrix<T>::Reshape(*mask, 1);
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);

auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Expand Down
29 changes: 12 additions & 17 deletions paddle/fluid/operators/dropout_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <unistd.h>
#include <iostream>

#include <string>
#include <thread> // NOLINT
Expand All @@ -33,16 +32,14 @@ namespace m = paddle::operators::math;

USE_OP(dropout);

static paddle::framework::DDim dims = {10, 10};

void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto var = scope->Var("X");
auto tensor = var->GetMutable<f::LoDTensor>();
tensor->Resize(dims);
tensor->Resize({10, 10});

std::vector<float> init;
for (int64_t i = 0; i < f::product(dims); ++i) {
for (int64_t i = 0; i < 10 * 10; ++i) {
init.push_back(1.0);
}

Expand All @@ -51,19 +48,18 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
auto place = ctx.GetPlace();
auto out_var = scope->Var("Out");
auto out_tensor = out_var->GetMutable<f::LoDTensor>();
out_tensor->Resize(dims);
out_tensor->Resize({10, 10});
out_tensor->mutable_data<float>(place); // allocate

auto mask_var = scope->Var("Mask");
auto mask_tensor = mask_var->GetMutable<f::LoDTensor>();
mask_tensor->Resize(dims);
mask_tensor->Resize({10, 10});
mask_tensor->mutable_data<float>(place); // allocate

// run
f::AttributeMap attrs;
float dropout_prob = 0.5;
attrs.insert({"is_test", false});
attrs.insert({"fix_seed", true});
attrs.insert({"fix_seed", 1});
attrs.insert({"seed", 3});
attrs.insert({"dropout_prob", dropout_prob});
auto dropout_op = f::OpRegistry::CreateOp(
Expand All @@ -73,7 +69,6 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {

std::vector<float> out_vec;
TensorToVector(*out_tensor, ctx, &out_vec);
ctx.Wait();

std::vector<float> std_out = {
0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1,
Expand All @@ -88,22 +83,22 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
}
}

// TODO(wyi): Due to
// https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily
// disable this test to remove the prevention of the merge of
// unrelated PRs.
/*
TEST(Dropout, CPUDense) {
f::Scope scope;
p::CPUPlace place;
p::CPUDeviceContext ctx(place);
Compare(&scope, ctx);
Compare(scope, ctx);
}
// TODO(wyi, dzhwinter): Due to
// https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily
// disable this test to remove the prevention of the merge of
// unrelated PRs.
/*
TEST(Dropout, GPUDense) {
f::Scope scope;
p::CUDAPlace place;
p::CUDADeviceContext ctx(place);
Compare(&scope, ctx);
Compare(scope, ctx);
}
*/

0 comments on commit f2e400d

Please sign in to comment.