Skip to content

Commit

Permalink
migrate dirichlet op kernel to phi
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly committed Jul 28, 2022
1 parent 9a3e1bc commit 76cedd7
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 160 deletions.
104 changes: 11 additions & 93 deletions paddle/fluid/operators/dirichlet_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,83 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/dirichlet_op.h"

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
template <typename T, typename UniformSamplerT, typename NormalSamplerT>
struct GammaCPUFunctor {
GammaCPUFunctor(const T* alpha,
T* gamma,
BaseSampler<T, UniformSamplerT> uniform,
BaseSampler<T, NormalSamplerT> normal)
: alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {}

HOST void operator()(int64_t index) {
auto sample = sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
alpha_[index], uniform_, normal_);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
}

const T* alpha_;
T* gamma_;
BaseSampler<T, UniformSamplerT> uniform_;
BaseSampler<T, NormalSamplerT> normal_;
};

template <typename T>
struct DirichletSampler<phi::CPUContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const Tensor* alpha,
Tensor* out) {
auto& dev_ctx = ctx.device_context<phi::CPUContext>();

auto p_gen = framework::DefaultCPUGenerator();
auto generator = p_gen->GetCPUEngine();

auto uniform = [&generator]() -> T {
std::uniform_real_distribution<T> u(0.0, 1.0);
return u(*generator);
};
BaseSampler<T, decltype(uniform)> standard_uniform(uniform);

auto normal = [&generator]() {
std::normal_distribution<T> n(0.0, 1.0);
return n(*generator);
};
BaseSampler<T, decltype(normal)> standard_normal(normal);

// sample from K gamma distributions, where K=alpha.numel()
framework::Tensor gamma_samples;
gamma_samples.mutable_data<T>(alpha->dims(), dev_ctx.GetPlace());
GammaCPUFunctor<T, decltype(uniform), decltype(normal)> gamma_functor(
alpha->data<T>(),
gamma_samples.data<T>(),
standard_uniform,
standard_normal);
platform::ForRange<phi::CPUContext> for_range(dev_ctx, alpha->numel());
for_range(gamma_functor);

// normalize them into a simplex, along the last axis
framework::Tensor gamma_sum;
auto new_shape = gamma_samples.dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum.mutable_data<T>(new_shape, dev_ctx.GetPlace());

ReduceKernelFunctor<phi::CPUContext, T, SumFunctor>(
&gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx)
.template apply<T>();
ElementwiseComputeEx<DivFunctor<T>, phi::CPUContext, T, T>(
ctx, &gamma_samples, &gamma_sum, -1, DivFunctor<T>(), out);
}
};

class DirichletOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand All @@ -100,29 +31,16 @@ class DirichletOpMaker : public framework::OpProtoAndCheckerMaker {
class DirichletOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "dirichlet");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "dirichlet");
const auto alpha_dim = ctx->GetInputDim("Alpha");
PADDLE_ENFORCE_GE(alpha_dim.size(),
1,
platform::errors::InvalidArgument(
"ShapeError: The number of dimensions of 'Alpha' "
"must be greater than or euqal to 1. "
"But received Alpha's dimensions = %d,",
alpha_dim.size()));
ctx->ShareDim("Alpha", /*->*/ "Out");
}
};

} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(dirichlet,
DirichletInferShapeFunctor,
PD_INFER_META(phi::DirichletInferMeta));

REGISTER_OP_WITHOUT_GRADIENT(dirichlet,
paddle::operators::DirichletOp,
paddle::operators::DirichletOpMaker);
REGISTER_OP_CPU_KERNEL(
dirichlet,
paddle::operators::DirichletKernel<phi::CPUContext, float>,
paddle::operators::DirichletKernel<phi::CPUContext, double>);
paddle::operators::DirichletOpMaker,
DirichletInferShapeFunctor);
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,15 @@
kernel:
func: broadcast_tensors
backward: broadcast_tensors_grad

# dirichlet
- api: dirichlet
args: (Tensor alpha)
output: Tensor
infer_meta:
func: DirichletInferMeta
kernel:
func: dirichlet

# eig
- api: eig
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3677,6 +3677,18 @@ void IdentityLossInferMeta(const MetaTensor& x,
}
}

void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) {
const auto alpha_dim = alpha.dims();
PADDLE_ENFORCE_GE(alpha_dim.size(),
1,
phi::errors::InvalidArgument(
"ShapeError: The number of dimensions of 'Alpha' "
"must be greater than or euqal to 1. "
"But received Alpha's dimensions = %d,",
alpha_dim.size()));
out->set_dims(alpha_dim);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,4 +528,5 @@ void ChannelShuffleInferMeta(const MetaTensor& x,

void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out);

void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out);
} // namespace phi
108 changes: 108 additions & 0 deletions paddle/phi/kernels/cpu/dirichlet_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/dirichlet.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h"

namespace phi {

template <typename T, typename UniformSamplerT, typename NormalSamplerT>
struct GammaCPUFunctor {
GammaCPUFunctor(const T* alpha,
T* gamma,
funcs::BaseSampler<T, UniformSamplerT> uniform,
funcs::BaseSampler<T, NormalSamplerT> normal)
: alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {}

HOST void operator()(int64_t index) {
auto sample = funcs::sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
alpha_[index], uniform_, normal_);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
}

const T* alpha_;
T* gamma_;
funcs::BaseSampler<T, UniformSamplerT> uniform_;
funcs::BaseSampler<T, NormalSamplerT> normal_;
};

template <typename T>
struct DirichletSampler<CPUContext, T> {
void operator()(const CPUContext& dev_ctx,
const DenseTensor& alpha,
DenseTensor* out) {
auto generator = dev_ctx.GetGenerator()->GetCPUEngine();

auto uniform = [&generator]() -> T {
std::uniform_real_distribution<T> u(0.0, 1.0);
return u(*generator);
};
funcs::BaseSampler<T, decltype(uniform)> standard_uniform(uniform);

auto normal = [&generator]() {
std::normal_distribution<T> n(0.0, 1.0);
return n(*generator);
};
funcs::BaseSampler<T, decltype(normal)> standard_normal(normal);

// sample from K gamma distributions, where K=alpha.numel()
DenseTensor* gamma_samples = new DenseTensor();
gamma_samples->Resize(alpha.dims());
dev_ctx.template Alloc<T>(gamma_samples);

GammaCPUFunctor<T, decltype(uniform), decltype(normal)> gamma_functor(
alpha.data<T>(),
gamma_samples->data<T>(),
standard_uniform,
standard_normal);
funcs::ForRange<CPUContext> for_range(dev_ctx, alpha.numel());
for_range(gamma_functor);

// normalize them into a simplex, along the last axis
DenseTensor* gamma_sum = new DenseTensor();
auto new_shape = gamma_samples->dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum->Resize(new_shape);
dev_ctx.template Alloc<T>(gamma_sum);

ReduceKernelImpl<CPUContext, T, T, funcs::SumFunctor>(
dev_ctx,
*gamma_samples,
gamma_sum,
{new_shape.size() - 1},
true,
false);

funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>(
dev_ctx,
*gamma_samples,
*gamma_sum,
-1,
funcs::DivideFunctor<T>(),
out);
}
};

} // namespace phi

PD_REGISTER_KERNEL(
dirichlet, CPU, ALL_LAYOUT, phi::Dirichletkernel, float, double) {}
25 changes: 25 additions & 0 deletions paddle/phi/kernels/dirichlet_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void Dirichletkernel(const Context& dev_ctx,
const DenseTensor& alpha,
DenseTensor* out);
} // namespace phi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,6 @@
#include <cmath>
#include <random>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"

// ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(PADDLE_WITH_CUDA)
#define COMPAT_EXP exp
Expand All @@ -42,10 +39,8 @@
#define COMPAT_LOG1P std::log1p
#endif

namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
struct DirichletSampler;
namespace phi {
namespace funcs {

template <typename ScalarT, typename SamplerT>
struct BaseSampler {
Expand Down Expand Up @@ -116,18 +111,5 @@ sample_gamma(ScalarT alpha,
return static_cast<ScalarT>(scale * d * v);
}
}

template <typename DeviceContext, typename T>
class DirichletKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* alpha = ctx.Input<framework::Tensor>("Alpha");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());

DirichletSampler<DeviceContext, T> sampler;
sampler(ctx, alpha, out);
}
};
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
Loading

0 comments on commit 76cedd7

Please sign in to comment.