Skip to content

Commit

Permalink
[arm-fp16] add hardswish fp16 implementation (#6837)
Browse files Browse the repository at this point in the history
* add hardswish fp16 implementation. test=develop
  • Loading branch information
chenjiaoAngel authored Sep 6, 2021
1 parent 5c1114c commit f28e992
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 9 deletions.
67 changes: 67 additions & 0 deletions lite/backends/arm/math/fp16/activation_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,73 @@ void act_hard_sigmoid<float16_t>(const float16_t* din,
++dout;
}
}

template <>
void act_hard_swish<float16_t>(const float16_t* din,
float16_t* dout,
const int size,
const float threshold,
const float scale,
const float offset,
int threads) {
int cnt = size >> 5;
int remain = size & 31;
float scale_r = 1. / scale;

int cnt_8 = remain >> 3;
int remain_8 = remain & 7;

float16x8_t vzero_8 = vdupq_n_f16(float16_t(0));
float16x8_t vthreshold_8 = vdupq_n_f16(float16_t(threshold));
float16x8_t vscale_8 = vdupq_n_f16(float16_t(scale_r));
float16x8_t voffset_8 = vdupq_n_f16(float16_t(offset));

for (int i = 0; i < cnt; i++) {
float16x8_t vdin0 = vld1q_f16(din);
float16x8_t vdin1 = vld1q_f16(din + 8);
float16x8_t vdin2 = vld1q_f16(din + 16);
float16x8_t vdin3 = vld1q_f16(din + 24);
float16x8_t vtmp0 = vminq_f16(
vthreshold_8, vmaxq_f16(vzero_8, vaddq_f16(vdin0, voffset_8)));
float16x8_t vsum0 = vmulq_f16(vscale_8, vdin0);
float16x8_t vtmp1 = vminq_f16(
vthreshold_8, vmaxq_f16(vzero_8, vaddq_f16(vdin1, voffset_8)));
float16x8_t vsum1 = vmulq_f16(vscale_8, vdin1);
float16x8_t vtmp2 = vminq_f16(
vthreshold_8, vmaxq_f16(vzero_8, vaddq_f16(vdin2, voffset_8)));
float16x8_t vsum2 = vmulq_f16(vscale_8, vdin2);
float16x8_t vtmp3 = vminq_f16(
vthreshold_8, vmaxq_f16(vzero_8, vaddq_f16(vdin3, voffset_8)));
float16x8_t vsum3 = vmulq_f16(vscale_8, vdin3);
float16x8_t vres0 = vmulq_f16(vsum0, vtmp0);
float16x8_t vres1 = vmulq_f16(vsum1, vtmp1);
float16x8_t vres2 = vmulq_f16(vsum2, vtmp2);
float16x8_t vres3 = vmulq_f16(vsum3, vtmp3);
vst1q_f16(dout, vres0);
vst1q_f16(dout + 8, vres1);
vst1q_f16(dout + 16, vres2);
vst1q_f16(dout + 24, vres3);
din += 32;
dout += 32;
}
for (int i = 0; i < cnt_8; i++) {
float16x8_t vdin0 = vld1q_f16(din);
din += 8;
float16x8_t vtmp0 = vminq_f16(
vthreshold_8, vmaxq_f16(vzero_8, vaddq_f16(vdin0, voffset_8)));
float16x8_t vsum0 = vmulq_f16(vscale_8, vdin0);
float16x8_t vres0 = vmulq_f16(vsum0, vtmp0);
vst1q_f16(dout, vres0);
dout += 8;
}
for (int i = 0; i < remain_8; i++) {
dout[0] =
std::min(std::max(0.f, din[0] + offset), threshold) * din[0] * scale_r;
din++;
dout++;
}
}

template <>
void act_prelu<float16_t>(const float16_t* din,
float16_t* dout,
Expand Down
8 changes: 8 additions & 0 deletions lite/backends/arm/math/fp16/activation_fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ void act_hard_sigmoid(const T* din,
int threads);

template <typename T>
void act_hard_swish(const T* din,
T* dout,
const int size,
const float threshold,
const float scale,
const float offset,
int threads);
template <typename T>
void act_prelu(const T* din,
T* dout,
int outer_size,
Expand Down
46 changes: 39 additions & 7 deletions lite/kernels/arm/activation_extra_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ void HardSigmoidCompute<PRECISION(kFP16)>::Run() {
lite::arm::math::fp16::act_hard_sigmoid<float16_t>(
x_data, output_data, x_dims.production(), slope, offset, ctx.threads());
}

template <>
void HardSwishCompute<PRECISION(kFP16)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float16_t>();
auto output_data = param.Out->mutable_data<float16_t>();
float threshold = param.hard_swish_threshold;
float scale = param.hard_swish_scale;
float offset = param.hard_swish_offset;
lite::arm::math::fp16::act_hard_swish<float16_t>(x_data,
output_data,
x_dims.production(),
threshold,
scale,
offset,
ctx.threads());
}
#endif

void SqrtCompute::Run() {
Expand Down Expand Up @@ -133,7 +152,8 @@ void SquareCompute::Run() {
x_data, output_data, x_dims.production(), ctx.threads());
}

void HardSwishCompute::Run() {
template <>
void HardSwishCompute<PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
Expand Down Expand Up @@ -248,6 +268,17 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.Finalize();

REGISTER_LITE_KERNEL(
hard_swish,
kARM,
kFP16,
kNCHW,
paddle::lite::kernels::arm::HardSwishCompute<PRECISION(kFP16)>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.Finalize();
#endif // ENABLE_ARM_FP16

REGISTER_LITE_KERNEL(relu_clipped,
Expand Down Expand Up @@ -304,12 +335,13 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(hard_swish,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::HardSwishCompute,
def)
REGISTER_LITE_KERNEL(
hard_swish,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::HardSwishCompute<PRECISION(kFloat)>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
Expand Down
3 changes: 2 additions & 1 deletion lite/kernels/arm/activation_extra_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class SquareCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~SquareCompute() = default;
};

class HardSwishCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <PrecisionType PType>
class HardSwishCompute : public KernelLite<TARGET(kARM), PType> {
public:
using param_t = operators::ActivationParam;

Expand Down
2 changes: 1 addition & 1 deletion lite/tests/benchmark/src/get_activation_latency.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ int main(int argc, char** argv) {
t0.Stop();
}
} else if (act_type == 10) {
paddle::lite::kernels::arm::HardSwishCompute act_compute;
paddle::lite::kernels::arm::HardSwishCompute<PRECISION(kFloat)> act_compute;
act_compute.SetParam(act_param);
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
Expand Down
55 changes: 55 additions & 0 deletions lite/tests/kernels/activation_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,61 @@ TEST(Activation_prelu_fp16, performance) {
}
}

TEST(Activation_hard_swish_fp16, precision) {
Place place;
float abs_error = 2e-3;
#ifdef LITE_WITH_ARM
place = Place(TARGET(kARM), PRECISION(kFP16));
#else
return;
#endif
for (auto dims : std::vector<std::vector<int64_t>>{{1, 3, 32, 32},
{1, 2, 3, 4},
{1, 3, 2, 4},
{2, 3, 4},
{5, 4},
{8}}) {
TestAct<float16_t>(place,
"def",
0.01,
6.,
"all",
0.,
1.0,
DDim(dims),
"hard_swish",
HARD_SWISH,
abs_error);
}
}

TEST(Activation_hard_swish_fp16, performance) {
Place place;
float abs_error = 2e-3;
#ifdef LITE_WITH_ARM
place = Place(TARGET(kARM), PRECISION(kFP16));
#else
return;
#endif
for (auto dims : std::vector<std::vector<int64_t>>{{1, 3, 32, 32},
{1, 2, 3, 4},
{1, 3, 2, 4},
{2, 3, 4},
{5, 4},
{8}}) {
TestActPerformance<float16_t>(place,
"def",
0.01,
6.,
"all",
0.,
1.0,
DDim(dims),
"hard_swish",
HARD_SWISH,
abs_error);
}
}
#endif
} // namespace lite
} // namespace paddle

0 comments on commit f28e992

Please sign in to comment.