Skip to content

Commit

Permalink
merged upstream develop and resloved conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
lfchener committed Aug 21, 2020
2 parents 013ab2e + 4f25935 commit 3eb7fa1
Show file tree
Hide file tree
Showing 194 changed files with 13,120 additions and 1,468 deletions.
4 changes: 4 additions & 0 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ function(detect_installed_gpus out_variable)
if(NOT CUDA_gpu_detect_output)
message(STATUS "Automatic GPU detection failed. Building for all known architectures.")
set(${out_variable} ${paddle_known_gpu_archs} PARENT_SCOPE)
#Todo: fix Automatic GPU detection failed on windows
if(WIN32)
set(${out_variable} "61 75" PARENT_SCOPE)
endif()
else()
set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE)
endif()
Expand Down
180 changes: 123 additions & 57 deletions paddle/fluid/framework/op_desc.cc

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ struct SimpleOpTypeSetTeller : public Teller {
#if IS_TRT_VERSION_GE(5130)
teller_set.insert("relu6");
teller_set.insert("hard_sigmoid");
int8_teller_set.insert("relu6");
int8_teller_set.insert("hard_sigmoid");
#endif
#if IS_TRT_VERSION_GE(6000)
teller_set.insert("fused_embedding_eltwise_layernorm");
Expand Down Expand Up @@ -53,11 +55,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_add",
"leaky_relu",
"fc",
"relu6",
"concat",
"scale",
"elementwise_mul",
"conv2d_transpose"};
"conv2d_transpose",
"hard_swish"};
std::unordered_set<std::string> teller_set{
"mul",
"conv2d",
Expand Down
38 changes: 30 additions & 8 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,6 @@ The OP square each elements of the inputs.
)DOC";

UNUSED constexpr char SoftplusDoc[] = R"DOC(
Softplus Activation Operator.
$$out = \ln(1 + e^{x})$$
)DOC";

UNUSED constexpr char SoftsignDoc[] = R"DOC(
Softsign Activation Operator.
Expand Down Expand Up @@ -396,6 +389,36 @@ LeakyRelu Activation Operator.
}
};

class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"Input of Softplus operator, an N-D Tensor, with data type "
"float32, float64 or float16.");
AddOutput(
"Out",
"Output of Softplus operator, a Tensor with shape same as input.");
AddAttr<float>("beta", "The value of beta for Softplus.").SetDefault(1.0f);
AddAttr<float>("threshold", "The value of threshold for Softplus.")
.SetDefault(20.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel.")
.SetDefault(false);
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn.")
.SetDefault(false);
AddComment(R"DOC(
:strong:`Softplus Activation Operator`
.. math::
out = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\
\text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold.
)DOC");
}
};

class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -672,7 +695,6 @@ REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

template <ActBwdOpFwdDeps kDepValue>
Expand Down
52 changes: 33 additions & 19 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>();
out.device(d) = x * (temp1 + temp2);
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
out.device(d) = x * (temp1 + temp2 > 0).template cast<T>();
}
};

Expand All @@ -405,9 +405,9 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
dx.device(d) = dout * (temp1 + temp2 > 0).template cast<T>();
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
Expand Down Expand Up @@ -975,32 +975,46 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
// For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
// 1, threshold = 20 by default), otherwise x
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
auto x_beta = static_cast<T>(beta) * x;
out.device(d) = (x_beta > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x_beta.exp()).log() /
static_cast<T>(beta));
}
};

// d(softplus(x))/dx = exp(x) / (1 + exp(x))
// For numerical stability:
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
// exp(x - max(x, 0)))
// For numerical stability, using the following formula instead of
// d(softplus(x))/dx = 1 / (1 + exp(-x))
// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
// = 1, threshold = 20 by default), otherwise x
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}

template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
auto x_beta = static_cast<T>(beta) * x;
dx.device(d) =
dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
(x_beta > static_cast<T>(threshold))
.select(dout, dout / (static_cast<T>(1) + (-x_beta).exp()));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
Expand Down
51 changes: 22 additions & 29 deletions paddle/fluid/operators/arg_max_op.cu
Original file line number Diff line number Diff line change
@@ -1,29 +1,22 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OP_CUDA_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_min_max_op_base.cu.h"

REGISTER_OP_CUDA_KERNEL(
arg_max, paddle::operators::ArgMinMaxOpCUDAKernel<float, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<double, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<int64_t, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<int32_t, cub::ArgMax>,
paddle::operators::ArgMinMaxOpCUDAKernel<int8_t, cub::ArgMax>);
163 changes: 163 additions & 0 deletions paddle/fluid/operators/arg_min_max_op_base.cu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef __NVCC__

#include <cub/cub.cuh>
#include <limits>
#include <string>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace operators {

namespace { // NOLINT
template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>;
using Tensor = framework::Tensor;

} // end namespace

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
case (1 << (log2_block_dim)): { \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \
} break

#define FIXED_BLOCK_DIM_CASE(...) \
FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);

template <typename T, typename IndType, class Reducer, size_t BlockDim>
__global__ void ArgCUDAKernel(const IndType height, // n * h
const IndType width, // c
const IndType post_size, // h
const Reducer reducer, const T init, const T* in,
IndType* out) {
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
KeyValuePair<int, T> kv_pair = {-1, init};
int h = idx / post_size;
int w = idx % post_size;
for (int k = threadIdx.x; k < width; k += blockDim.x) {
kv_pair =
reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
}
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
if (threadIdx.x == 0) {
out[idx] = static_cast<IndType>(kv_pair.key);
}
__syncthreads();
}
}

template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
Tensor* indices, const IndType pre, const IndType post,
const IndType n) {
auto cu_stream = ctx.stream();
auto ComputeBlockSize = [](IndType col) {
if (col > 512)
return 1024;
else if (col > 256)
return 512;
else if (col > 128)
return 256;
else if (col > 64)
return 128;
else if (col > 32)
return 64;
else if (col > 16)
return 32;
else if (col > 8)
return 16;
else
return 8;
};

int max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
int height = pre * post;
int width = n;
int grid_size = height < max_grid_dimx ? height : max_grid_dimx;

const T* in_data = input.data<T>();
IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace());

if (typeid(Reducer) == typeid(cub::ArgMax)) {
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgCUDAKernel<T, IndType, Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, Reducer(), std::numeric_limits<T>::lowest(),
in_data, out_data));
}
} else {
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgCUDAKernel<T, IndType, Reducer,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, Reducer(), std::numeric_limits<T>::max(),
in_data, out_data));
}
}
}

template <typename T, class Reducer>
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int64_t>("axis");
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

int64_t numel = input->numel();
int64_t groups = numel / in_dims[axis];
int64_t pre = 1;
int64_t post = 1;
int64_t n = in_dims[axis];

for (int i = 0; i < axis; i++) {
pre *= in_dims[i];
}

for (int i = axis + 1; i < in_dims.size(); i++) {
post *= in_dims[i];
}

const auto& dev_ctx = ctx.cuda_device_context();
ComputeFullArg<T, int64_t, Reducer>(dev_ctx, *input, output, pre, post, n);
}
};

#endif

} // namespace operators
} // namespace paddle
Loading

0 comments on commit 3eb7fa1

Please sign in to comment.