Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cuda generator #26786

Merged
merged 17 commits into from
Sep 4, 2020
57 changes: 10 additions & 47 deletions paddle/fluid/framework/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,46 +30,28 @@ limitations under the License. */
namespace paddle {
namespace framework {

const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(int64_t device_id) {
#ifdef PADDLE_WITH_CUDA
static int64_t num_cuda_devices = -1;
static std::once_flag num_devices_init_flag;
static std::deque<std::once_flag> cuda_device_flags;

static std::vector<std::shared_ptr<Generator>> default_cuda_generators;
#endif
static int64_t num_cuda_devices = -1;
static std::once_flag num_devices_init_flag;
static std::deque<std::once_flag> cuda_device_flags;
static std::vector<std::shared_ptr<Generator>> default_cuda_generators;

static void InitCUDAGenerators() {
#ifdef PADDLE_WITH_CUDA
num_cuda_devices = paddle::platform::GetCUDADeviceCount();
cuda_device_flags.resize(num_cuda_devices);
default_cuda_generators.resize(num_cuda_devices);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"InitCUDAGenerators only support in CUDA place"));
#endif
}

const std::shared_ptr<Generator>& getDefaultCUDAGenerator(int64_t device_id) {
#ifdef PADDLE_WITH_CUDA
std::call_once(num_devices_init_flag, InitCUDAGenerators);
std::call_once(num_devices_init_flag, []() {
num_cuda_devices = paddle::platform::GetCUDADeviceCount();
cuda_device_flags.resize(num_cuda_devices);
default_cuda_generators.resize(num_cuda_devices);
});
platform::Place place;
if (device_id == -1)
device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case, the originally given device_id is -1, what the number after this?

std::cout << "start device id: " << device_id << std::endl;
std::call_once(cuda_device_flags[device_id], [device_id]() {
default_cuda_generators[device_id] =
std::make_shared<Generator>(GetRandomSeed(), device_id);
VLOG(4) << "initial seed: "
<< default_cuda_generators[device_id]->GetCurrentSeed();
std::cout << "initial seed: "
<< default_cuda_generators[device_id]->GetCurrentSeed()
<< "device id : " << device_id << " ||| "
<< default_cuda_generators[device_id]->get_device_id()
<< std::endl;
});
// std::call_once(cuda_device_flags[device_id], initGlobalCUDAGeneratorState,
// device_id);
std::cout << "return device id: " << device_id << std::endl;
return default_cuda_generators[device_id];
#else
PADDLE_THROW(platform::errors::PermissionDenied(
Expand Down Expand Up @@ -176,25 +158,6 @@ uint64_t Generator::Random64() {
return (*engine)();
}

std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
uint64_t total_numel, uint64_t grid_size, uint64_t block_size,
uint64_t engine_calls_num) {
uint64_t cur_offset = this->state_.thread_offset;
#ifdef PADDLE_WITH_CUDA
std::lock_guard<std::mutex> lock(this->mu_);
uint64_t numel_per_thread =
(total_numel - 1) / (block_size * grid_size * 4) + 1;
uint64_t increment = numel_per_thread * engine_calls_num;

this->state_.thread_offset += increment;

#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Increment Offset only support in CUDA place"));
#endif
return std::make_pair(this->state_.current_seed, cur_offset);
}

std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
uint64_t increament_offset) {
uint64_t cur_offset = this->state_.thread_offset;
Expand Down
12 changes: 4 additions & 8 deletions paddle/fluid/framework/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static uint64_t GetRandomSeed() {
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};

Expand Down Expand Up @@ -67,7 +67,7 @@ struct Generator {
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
}
explicit Generator(uint64_t seed, uint64_t device_id) {
Generator(uint64_t seed, uint64_t device_id) {
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
Expand All @@ -77,7 +77,7 @@ struct Generator {
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
this->is_init_py_ = false; // TODO(zhiqiu): remove it in future
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the default is false?

}

Generator(const Generator& other) = delete;
Expand All @@ -99,10 +99,6 @@ struct Generator {

uint64_t Random64();

std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t total_numel,
uint64_t grid_size,
uint64_t block_size,
uint64_t engine_calls_num);
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t increament_offset);

void SetIsInitPy(bool);
Expand All @@ -128,7 +124,7 @@ std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();

std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);

const std::shared_ptr<Generator>& getDefaultCUDAGenerator(
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);

} // namespace framework
Expand Down
21 changes: 0 additions & 21 deletions paddle/fluid/operators/bernoulli_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/bernoulli_op.h"
Expand Down Expand Up @@ -46,7 +45,6 @@ class BernoulliOpKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::random_device rd;
auto seed = rd();
bool seed_flag = false;
const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto* in_data = x->data<T>();
Expand All @@ -57,25 +55,6 @@ class BernoulliOpKernel<platform::CUDADeviceContext, T>
platform::Transform<platform::CUDADeviceContext> trans;
auto* context =
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());

int64_t device_id = -1;
auto gen_cuda = framework::getDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
std::cout << ">>>>>>>>CUDA bernoulli GENERATOR" << std::endl;
// auto seed_offset = gen_cuda->IncrementOffset(1);
auto seed_gen = static_cast<int>(gen_cuda->GetCurrentSeed());
// int offset_step = 0;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
// int gen_offset = offset_step * seed_offset.second;
trans(*context, index_sequence_begin, index_sequence_begin + size,
in_data, out_data, BernoulliCudaFunctor<T>(seed_gen));
} else {
trans(*context, index_sequence_begin, index_sequence_begin + size,
in_data, out_data, BernoulliCudaFunctor<T>(seed));
}

trans(*context, index_sequence_begin, index_sequence_begin + size, in_data,
out_data, BernoulliCudaFunctor<T>(seed));
}
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
}

int64_t device_id = -1;
auto gen_cuda = framework::getDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && (context.Attr<bool>("fix_seed"))) {
std::cout << ">>>>>>>>CUDA DROPOUT GENERATOR" << std::endl;
auto gen_cuda = framework::GetDefaultCUDAGenerator(-1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the device_id is -1 here? I think it should be got by the context.

if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
auto seed_offset = gen_cuda->IncrementOffset(1);
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
Expand Down
65 changes: 22 additions & 43 deletions paddle/fluid/operators/gaussian_random_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
// #include "paddle/fluid/platform/place.h"

namespace paddle {
namespace operators {
Expand All @@ -26,34 +25,20 @@ template <typename T>
struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
unsigned int offset_ = 0;

__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
return dist(rng);
}
};

template <typename T>
struct GaussianGeneratorOffset {
T mean_, std_;
unsigned int seed_;
int offset_;

__host__ __device__ GaussianGeneratorOffset(T mean, T std, int seed,
int offset)
__host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset)
: mean_(mean), std_(std), seed_(seed), offset_(offset) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
unsigned int new_n = n + offset_;
rng.discard(new_n);
return dist(rng);
}
};
Expand All @@ -64,11 +49,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = true;
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = false;
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
Expand All @@ -79,23 +64,20 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
T* data = tensor->mutable_data<T>(context.GetPlace());

int64_t size = tensor->numel();
auto gen_cuda = framework::getDefaultCUDAGenerator(-1);
auto gen_cuda = framework::GetDefaultCUDAGenerator(-1);

if (gen_cuda->GetIsInitPy() && seed_flag) {
std::cout << ">>>>>>>>CUDA GAUSSIAN GENERATOR" << std::endl;
// auto seed_offset = gen_cuda->IncrementOffset(1);
auto seed_gen = static_cast<unsigned int>(gen_cuda->GetCurrentSeed());
// int offset_step = 0;
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
// int gen_offset = offset_step * seed_offset.second;
// std::cout << ">>>>>offset: " << gen_offset << std::endl;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGeneratorOffset<T>(mean, std, seed_gen, 0));
int gen_offset = offset_step * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset));
} else {
std::cout << "COUNT ORIGIN" << std::endl;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
Expand All @@ -110,33 +92,30 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = true;
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = false;
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
auto gen_cuda = framework::getDefaultCUDAGenerator(-1);
auto gen_cuda = framework::GetDefaultCUDAGenerator(-1);

if (gen_cuda->GetIsInitPy() && seed_flag) {
std::cout << ">>>>>>>>CUDA GAUSSIAN GENERATOR" << std::endl;
// auto seed_offset = gen_cuda->IncrementOffset(1);
auto seed_gen = static_cast<unsigned int>(gen_cuda->GetCurrentSeed());
// int offset_step = 0;
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
// int gen_offset = offset_step * seed_offset.second;
// std::cout << ">>>>>offset: " << gen_offset << std::endl;
int gen_offset = offset_step * seed_offset.second;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGeneratorOffset<T>(mean, std, seed_gen, 0));
GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second));
} else {
std::cout << "COUNT ORIGIN" << std::endl;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/math/sample_prob.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ void GPUSampleWithProb<T>::operator()(
Tensor s;
int64_t* s_data = s.mutable_data<int64_t>(s_dim, platform::CPUPlace());

std::cout << "####sample_prob" << std::endl;
math::LogUniformSampler sampler(dict_size, seed);

int range = dict_size;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/randint_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class GPURandintKernel : public framework::OpKernel<T> {

int64_t size = out->numel();
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));

/*
std::minstd_rand engine;
if (seed == 0) {
Expand All @@ -58,14 +59,14 @@ class GPURandintKernel : public framework::OpKernel<T> {
}
engine.seed(seed);
*/

std::uniform_int_distribution<> dist(context.Attr<int>("low"),
context.Attr<int>("high") - 1);
auto engine = framework::GetCPURandomEngine(seed);

for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
// for (int64_t i = 0; i < size; ++i) data[i] = dist(engine);

if (platform::is_gpu_place(context.GetPlace())) {
// Copy tensor to out
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/sample_logits_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
probabilities->mutable_data<T>(samples_dim, context.GetPlace());
// UNDERSTAND: sampling
const auto seed = context.Attr<int>("seed");
std::cout << "####SAMPLING" << std::endl;
auto sampler_with_prob = math::GPUSampleWithProb<T>();
sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq,
num_samples, labels, samples, probabilities);
Expand Down
16 changes: 7 additions & 9 deletions paddle/fluid/operators/truncated_gaussian_random_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,30 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
T* data = tensor->mutable_data<T>(context.GetPlace());

unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = true;
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = false;
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();

auto gen_cuda = framework::getDefaultCUDAGenerator(-1);
auto gen_cuda = framework::GetDefaultCUDAGenerator(-1);
if (gen_cuda->GetIsInitPy() && seed_flag) {
std::cout << ">>>>>>>>CUDA TRUNCATED NORMAL GENERATOR" << std::endl;
// auto seed_offset = gen_cuda->IncrementOffset(1);
auto seed_gen = static_cast<int>(gen_cuda->GetCurrentSeed());
// int offset_step = 0;
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
// int gen_offset = offset_step * seed_offset.second;
int gen_offset = offset_step * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
seed_gen, 0));
seed_offset.first, seed_offset.second));
}

thrust::transform(
Expand Down
Loading