Skip to content

Commit

Permalink
fix dirichlet sample memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly committed Jul 28, 2022
1 parent 76cedd7 commit 51ced14
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 175 deletions.
25 changes: 13 additions & 12 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,19 @@ void DiagonalInferMeta(const MetaTensor& input,
out->set_dims(phi::make_ddim(out_dims));
}

void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) {
const auto alpha_dim = alpha.dims();
PADDLE_ENFORCE_GE(alpha_dim.size(),
1,
phi::errors::InvalidArgument(
"ShapeError: The number of dimensions of 'Alpha' "
"must be greater than or euqal to 1. "
"But received Alpha's dimensions = %d,",
alpha_dim.size()));
out->set_dims(alpha_dim);
out->set_dtype(alpha.dtype());
}

void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) {
auto x_dims = x.dims();
int rank = x_dims.size();
Expand Down Expand Up @@ -3677,18 +3690,6 @@ void IdentityLossInferMeta(const MetaTensor& x,
}
}

void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) {
const auto alpha_dim = alpha.dims();
PADDLE_ENFORCE_GE(alpha_dim.size(),
1,
phi::errors::InvalidArgument(
"ShapeError: The number of dimensions of 'Alpha' "
"must be greater than or euqal to 1. "
"But received Alpha's dimensions = %d,",
alpha_dim.size()));
out->set_dims(alpha_dim);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ void DiagInferMeta(const MetaTensor& x,
void DiagonalInferMeta(
const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out);

void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out);

void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v);

void EighInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -527,6 +529,4 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
MetaTensor* out);

void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out);

void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out);
} // namespace phi
42 changes: 18 additions & 24 deletions paddle/phi/kernels/cpu/dirichlet_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/dirichlet.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
Expand All @@ -29,20 +28,20 @@ template <typename T, typename UniformSamplerT, typename NormalSamplerT>
struct GammaCPUFunctor {
GammaCPUFunctor(const T* alpha,
T* gamma,
funcs::BaseSampler<T, UniformSamplerT> uniform,
funcs::BaseSampler<T, NormalSamplerT> normal)
BaseSampler<T, UniformSamplerT> uniform,
BaseSampler<T, NormalSamplerT> normal)
: alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {}

HOST void operator()(int64_t index) {
auto sample = funcs::sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
auto sample = sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
alpha_[index], uniform_, normal_);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
}

const T* alpha_;
T* gamma_;
funcs::BaseSampler<T, UniformSamplerT> uniform_;
funcs::BaseSampler<T, NormalSamplerT> normal_;
BaseSampler<T, UniformSamplerT> uniform_;
BaseSampler<T, NormalSamplerT> normal_;
};

template <typename T>
Expand All @@ -56,49 +55,44 @@ struct DirichletSampler<CPUContext, T> {
std::uniform_real_distribution<T> u(0.0, 1.0);
return u(*generator);
};
funcs::BaseSampler<T, decltype(uniform)> standard_uniform(uniform);
BaseSampler<T, decltype(uniform)> standard_uniform(uniform);

auto normal = [&generator]() {
std::normal_distribution<T> n(0.0, 1.0);
return n(*generator);
};
funcs::BaseSampler<T, decltype(normal)> standard_normal(normal);
BaseSampler<T, decltype(normal)> standard_normal(normal);

// sample from K gamma distributions, where K=alpha.numel()
DenseTensor* gamma_samples = new DenseTensor();
gamma_samples->Resize(alpha.dims());
dev_ctx.template Alloc<T>(gamma_samples);
DenseTensor gamma_samples;
gamma_samples.Resize(alpha.dims());
dev_ctx.template Alloc<T>(&gamma_samples);

GammaCPUFunctor<T, decltype(uniform), decltype(normal)> gamma_functor(
alpha.data<T>(),
gamma_samples->data<T>(),
gamma_samples.data<T>(),
standard_uniform,
standard_normal);
funcs::ForRange<CPUContext> for_range(dev_ctx, alpha.numel());
for_range(gamma_functor);

// normalize them into a simplex, along the last axis
DenseTensor* gamma_sum = new DenseTensor();
auto new_shape = gamma_samples->dims();
DenseTensor gamma_sum;
auto new_shape = gamma_samples.dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum->Resize(new_shape);
dev_ctx.template Alloc<T>(gamma_sum);
gamma_sum.Resize(new_shape);
dev_ctx.template Alloc<T>(&gamma_sum);

ReduceKernelImpl<CPUContext, T, T, funcs::SumFunctor>(
dev_ctx,
*gamma_samples,
gamma_sum,
gamma_samples,
&gamma_sum,
{new_shape.size() - 1},
true,
false);

funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>(
dev_ctx,
*gamma_samples,
*gamma_sum,
-1,
funcs::DivideFunctor<T>(),
out);
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out);
}
};

Expand Down
115 changes: 0 additions & 115 deletions paddle/phi/kernels/funcs/dirichlet.h

This file was deleted.

36 changes: 14 additions & 22 deletions paddle/phi/kernels/gpu/dirichlet_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/dirichlet.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
Expand Down Expand Up @@ -57,13 +56,11 @@ struct GammaCUDAFunctor {

// sample
auto uniform_lambda = [&state]() { return COMPAT_RAND_UNIFORM(&state); };
funcs::BaseSampler<T, decltype(uniform_lambda)> standard_uniform(
uniform_lambda);
BaseSampler<T, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
auto normal_lambda = [&state]() { return COMPAT_RAND_NORMAL(&state); };
funcs::BaseSampler<T, decltype(normal_lambda)> standard_normal(
normal_lambda);
BaseSampler<T, decltype(normal_lambda)> standard_normal(normal_lambda);

auto sample = funcs::
auto sample =
sample_gamma<T, T, decltype(uniform_lambda), decltype(normal_lambda)>(
alpha_[index], standard_uniform, standard_normal);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
Expand All @@ -86,36 +83,31 @@ struct DirichletSampler<GPUContext, T> {
auto offset = seed_and_offset.second;

// sample from K gamma distributions, where K=alpha.numel()
DenseTensor* gamma_samples = new DenseTensor();
gamma_samples->Resize(alpha.dims());
dev_ctx.template Alloc<T>(gamma_samples);
DenseTensor gamma_samples;
gamma_samples.Resize(alpha.dims());
dev_ctx.template Alloc<T>(&gamma_samples);

GammaCUDAFunctor<T> gamma_functor(
alpha.data<T>(), gamma_samples->data<T>(), seed, offset);
alpha.data<T>(), gamma_samples.data<T>(), seed, offset);
funcs::ForRange<GPUContext> for_range(dev_ctx, out->numel());
for_range(gamma_functor);

// normalize them into a simplex, along the last axis
DenseTensor* gamma_sum = new DenseTensor();
auto new_shape = gamma_samples->dims();
DenseTensor gamma_sum;
auto new_shape = gamma_samples.dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum->Resize(new_shape);
dev_ctx.template Alloc<T>(gamma_sum);
gamma_sum.Resize(new_shape);
dev_ctx.template Alloc<T>(&gamma_sum);

ReduceKernelImpl<GPUContext, T, T, funcs::SumFunctor>(
dev_ctx,
*gamma_samples,
gamma_sum,
gamma_samples,
&gamma_sum,
{new_shape.size() - 1},
true,
false);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>(
dev_ctx,
*gamma_samples,
*gamma_sum,
-1,
funcs::DivideFunctor<T>(),
out);
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out);
}
};
} // namespace phi
Expand Down
Loading

0 comments on commit 51ced14

Please sign in to comment.