Skip to content

Commit

Permalink
[phi] migrate prelu (#47422)
Browse files Browse the repository at this point in the history
* migrate prelu

* remove cache

* review fixes
  • Loading branch information
sfraczek authored Nov 10, 2022
1 parent 5900129 commit cdd8c8a
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 208 deletions.
208 changes: 0 additions & 208 deletions paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc

This file was deleted.

61 changes: 61 additions & 0 deletions paddle/phi/backends/onednn/onednn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,67 @@ class BroadcastDataOneDNNHandler
}
};

template <typename T>
class PReluOneDNNHandler
: public OneDNNHandlerNoCachingT<T,
dnnl::prelu_forward,
dnnl::prelu_backward> {
public:
PReluOneDNNHandler(const dnnl::engine engine,
Place cpu_place,
const DenseTensor& x,
const DenseTensor& weights,
const std::string& mode,
const std::string& data_format,
const bool is_test)
: OneDNNHandlerNoCachingT<T, dnnl::prelu_forward, dnnl::prelu_backward>(
engine, cpu_place) {
auto weights_dims = phi::vectorize(weights.dims());
// weights must have same size as X only for "element" case
if (weights.dims().size() != x.dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x.dims().size(), 1);
if (mode == "channel") {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
}
weights_dims = std::move(new_weights_dims);
}
auto weights_md = memory::desc(
weights_dims, OneDNNGetDataType<T>(), memory::format_tag::any);

this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training, x.mem_desc(), weights_md);
if (!is_test) {
this->AcquireBackwardPrimitiveDescriptor(
x.mem_desc(), weights_md, x.mem_desc(), weights_md);
}
}

std::shared_ptr<memory> AcquireWeightsMemoryPossiblyWithReorder(
const DenseTensor* weights, const bool is_test) {
const T* weights_data = weights->data<T>();

// if weights are 1D, every format tag is correct, so we accept
// format_tag::any's output and no reorder is needed
if (weights->dims().size() == 1) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(weights_data));
}

return this->AcquireMemoryWithReorder(weights->mem_desc(),
this->fwd_pd_->weights_desc(),
to_void_cast<T>(weights_data),
is_test);
}

std::shared_ptr<memory> AcquireDiffWeightsMemory(DenseTensor* output) {
T* output_data = output->mutable_data<T>(
this->place_, this->bwd_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
output_data);
}
};

template <typename T>
class ReductionOneDNNHandler
: public OneDNNHandlerNoCachingT<T, dnnl::reduction> {
Expand Down
69 changes: 69 additions & 0 deletions paddle/phi/kernels/onednn/prelu_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/prelu_grad_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void PReluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const DenseTensor& out_grad,
const std::string& data_format,
const std::string& mode,
DenseTensor* x_grad,
DenseTensor* alpha_grad) {
bool is_test = dev_ctx.HasDnnAttr("is_test")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test"))
: false;
funcs::PReluOneDNNHandler<T> handler(dev_ctx.GetEngine(),
dev_ctx.GetPlace(),
x,
alpha,
mode,
data_format,
is_test);

auto src_memory_p = handler.AcquireSrcMemory(&x);
auto weights_memory_p =
handler.AcquireWeightsMemoryPossiblyWithReorder(&alpha, is_test);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(x_grad);
auto diff_weights_memory_p = handler.AcquireDiffWeightsMemory(alpha_grad);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(&out_grad);
auto prelu_p = handler.AcquireBackwardPrimitive();

auto& astream = OneDNNContext::tls().get_stream();
prelu_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p},
{DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait();

x_grad->set_mem_desc(diff_src_memory_p->get_desc());
}

} // namespace phi

PD_REGISTER_KERNEL(prelu_grad,
OneDNN,
ONEDNN,
phi::PReluGradKernel,
float,
phi::dtype::bfloat16) {}
Loading

0 comments on commit cdd8c8a

Please sign in to comment.