Skip to content

Commit

Permalink
[Phi] Migrate squared_l2_norm_op to phi (PaddlePaddle#44492)
Browse files Browse the repository at this point in the history
  • Loading branch information
affectionlu authored and Aurelius84 committed Jul 29, 2022
1 parent 22ed27f commit 100568d
Show file tree
Hide file tree
Showing 24 changed files with 340 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ USE_OP_ITSELF(sqrt);
USE_OP_ITSELF(elementwise_max);
USE_OP_ITSELF(elementwise_div);
USE_OP_ITSELF(sgd);
USE_OP(squared_l2_norm);
USE_OP_ITSELF(squared_l2_norm);
USE_OP_ITSELF(memcpy_h2d);
USE_OP_ITSELF(memcpy_d2h);
USE_OP_ITSELF(fetch_v2);
Expand Down Expand Up @@ -87,6 +87,7 @@ PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sigmoid, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sigmoid_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(squared_l2_norm, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(reshape_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_grad, GPU, ALL_LAYOUT);
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/inplace_abn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/inplace_abn_op.h"
#include <iostream>
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/phi/kernels/batch_norm_grad_kernel.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/operators/optimizers/lamb_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ limitations under the License. */
#include "paddle/fluid/memory/buffer.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/squared_l2_norm.h"
#include "paddle/fluid/operators/tensor_to_string.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#include "paddle/phi/kernels/funcs/squared_l2_norm.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -756,13 +756,14 @@ class LambOpKernel : public framework::OpKernel<T> {
// TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true.
memory::Buffer buffer(dev_ctx.GetPlace());
math::SquaredL2Norm(dev_ctx,
reinterpret_cast<const MT*>(
IsMultiPrecision ? master_param_ptr : param_ptr),
p_norm_ptr,
numel,
&buffer);
math::SquaredL2Norm(
phi::funcs::SquaredL2Norm(
dev_ctx,
reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr,
numel,
&buffer);
phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);

if (VLOG_IS_ON(1)) {
Expand Down
51 changes: 19 additions & 32 deletions paddle/fluid/operators/squared_l2_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/squared_l2_norm_op.h"

#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -24,13 +25,6 @@ using framework::Tensor;
class SquaredL2NormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SquaredL2NormOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SquaredL2NormOp");

ctx->SetOutputDim("Out", {1});
}
};

template <typename T>
Expand All @@ -54,20 +48,6 @@ class SquaredL2NormGradOpMaker : public framework::SingleGradOpMaker<T> {
class SquaredL2NormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SquaredL2NormGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"SquaredL2NormGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
"X@GRAD",
"SquaredL2NormGradOp");

ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker {
Expand All @@ -90,15 +70,22 @@ Computes the squared L2 norm of a tensor.
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(squared_l2_norm,
SquaredL2NormInferShapeFunctor,
PD_INFER_META(phi::SquaredL2NormInferMeta));

DECLARE_INFER_SHAPE_FUNCTOR(squared_l2_norm_grad,
SquaredL2NormGradInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));

REGISTER_OPERATOR(squared_l2_norm,
ops::SquaredL2NormOp,
ops::SquaredL2NormOpMaker,
ops::SquaredL2NormGradOpMaker<paddle::framework::OpDesc>,
ops::SquaredL2NormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(squared_l2_norm_grad, ops::SquaredL2NormGradOp);
REGISTER_OP_CPU_KERNEL(squared_l2_norm,
ops::SquaredL2NormKernel<phi::CPUContext, float>,
ops::SquaredL2NormKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(squared_l2_norm_grad,
ops::SquaredL2NormGradKernel<phi::CPUContext, float>,
ops::SquaredL2NormGradKernel<phi::CPUContext, double>);
ops::SquaredL2NormGradOpMaker<paddle::imperative::OpBase>,
SquaredL2NormInferShapeFunctor);

REGISTER_OPERATOR(squared_l2_norm_grad,
ops::SquaredL2NormGradOp,
SquaredL2NormGradInferShapeFunctor);
24 changes: 0 additions & 24 deletions paddle/fluid/operators/squared_l2_norm_op.cu

This file was deleted.

71 changes: 0 additions & 71 deletions paddle/fluid/operators/squared_l2_norm_op.h

This file was deleted.

1 change: 0 additions & 1 deletion paddle/fluid/operators/squared_l2_norm_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/squared_l2_norm_op.h"
// #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/squared_l2_norm_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/squared_l2_norm_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"

namespace paddle {
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,15 @@
func : square
backward : square_grad

- api : squared_l2_norm
args : (Tensor x)
output : Tensor
infer_meta :
func : SquaredL2NormInferMeta
kernel :
func : squared_l2_norm
backward : squared_l2_norm_grad

- api : squeeze
args : (Tensor x, int[] axes)
output : Tensor(out), Tensor(xshape)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2009,6 +2009,16 @@
backward : square_double_grad
inplace : (out_grad -> x_grad)

- backward_api : squared_l2_norm_grad
forward : squared_l2_norm(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : squared_l2_norm_grad

- backward_api : squeeze_double_grad
forward : squeeze_grad(Tensor xshape, Tensor grad_out, int[] axes) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int[] axes)
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2489,6 +2489,10 @@ void SplitInferMeta(const MetaTensor& x,
}
}

void SquaredL2NormInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims({1});
}

void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ void SplitInferMeta(const MetaTensor& x_meta,
std::vector<MetaTensor*> out,
MetaConfig config = MetaConfig());

void SquaredL2NormInferMeta(const MetaTensor& x, MetaTensor* out);

void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* out);
Expand Down
26 changes: 26 additions & 0 deletions paddle/phi/kernels/cpu/squared_l2_norm_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h"

PD_REGISTER_KERNEL(squared_l2_norm_grad,
CPU,
ALL_LAYOUT,
phi::SquaredL2NormGradKernel,
float,
double) {}
23 changes: 23 additions & 0 deletions paddle/phi/kernels/cpu/squared_l2_norm_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/squared_l2_norm_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h"

PD_REGISTER_KERNEL(
squared_l2_norm, CPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) {
}
Loading

0 comments on commit 100568d

Please sign in to comment.