Skip to content

Commit

Permalink
Add batchnorm op (PaddlePaddle#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored Sep 1, 2020
1 parent 59f2b7a commit e400a0c
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 14 deletions.
5 changes: 4 additions & 1 deletion cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,10 @@ Expr ConvertCasToCinn(Expr expr) {
*expr = init;
} else {
// some case like a^-2
CINN_NOT_IMPLEMENTED
auto new_expr = make_const(a.type(), 1.f) / (ir::Power::Make(a, make_const(b.type(), -b.get_constant())));
Visit(&new_expr);
*expr = new_expr;
return;
}
} else {
CINN_NOT_IMPLEMENTED
Expand Down
57 changes: 57 additions & 0 deletions cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,55 @@ std::vector<Type> InferDtypeForConv2d(const std::vector<Type> &inputs_type, cons
return res;
}

std::shared_ptr<OpStrategy> StrategyForBatchNorm(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const Target &target) {
float epsilon = 0.00001f;
if (attrs.attr_store.find("epsilon") != attrs.attr_store.end()) {
epsilon = std::get<float>(attrs.attr_store.at("epsilon"));
}
framework::CINNCompute batchnorm_compute([=](lang::Args args, lang::RetValue *ret) {
CINNValuePack a = args[0];
ir::Expr A = a[0];
ir::Expr B = a[1];
CHECK(A.as_tensor());
CHECK(B.as_tensor());
auto out = pe::BatchNorm_NCHW(A.as_tensor_ref(), B.as_tensor_ref(), epsilon, UniqName("BatchNorm_output"));
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
});

framework::CINNSchedule batchnorm_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
ir::Expr A [[maybe_unused]] = arg_pack[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
CHECK(out_type.size()) << "Out_type of batchnorm op is empty! Please check.";
if (out_type[0] == Float(32)) {
strategy->AddImpl(batchnorm_compute, batchnorm_schedule, "strategy.batchnorm.x86", 1);
} else {
LOG(INFO) << "BatchNorm op with dtype != float32 is not implemented yet!";
}
return strategy;
}

std::vector<std::vector<int>> InferShapeForBatchNorm(const std::vector<std::vector<int>> &inputs_shape,
const framework::NodeAttr &attrs) {
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again.";
std::vector<std::vector<int>> res{inputs_shape[0]};
return res;
}

std::vector<Type> InferDtypeForBatchNorm(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
std::vector<Type> res{inputs_type[0]};
return res;
}

} // namespace op
} // namespace hlir
} // namespace cinn
Expand Down Expand Up @@ -220,4 +269,12 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForConv2d))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForConv2d))
.set_support_level(4);
CINN_REGISTER_OP(batchnorm)
.describe("Can be used as a normalizer function for convolution or fully_connected operations.")
.set_num_inputs(2) // here we consider batchnorm's 4 attrs(mean, variance, scale, bias) as another input
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForBatchNorm)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForBatchNorm))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForBatchNorm))
.set_support_level(4);
}
27 changes: 27 additions & 0 deletions cinn/hlir/pe/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,33 @@ std::vector<ir::Tensor> Conv2d_NCHW(const ir::Tensor& input,
return {input_pad, weights_dilation, res};
}

/**
* Can be used as a normalizer function for convolution or fully_connected operations.
* Specified for NCHW layout.
* Math: Y = (X - mean) / sqrt(variance + epsilon) * scale + bias
* @param input The input variable.
* @param weights The weights containing mean, variance, scale and bias.
* @param epsilon The param epsilon is added to avoid divide zero.
* @param output_name The name of output tensor.
* @return The calculated output tensor.
*/
ir::Tensor BatchNorm_NCHW(const ir::Tensor& input,
const ir::Tensor& weights,
float epsilon,
const std::string& output_name) {
CHECK_EQ(4, input->shape.size()) << "Input's dimension of BatchNorm op is not 4! Please check.";
CHECK_EQ(2, weights->shape.size()) << "Weight's dimension of BatchNorm op is not 2! Please check.";
auto res = Compute(
input->shape,
[=](Expr n, Expr c, Expr h, Expr w) {
return (((input(n, c, h, w) - weights(Expr(0), c)) / ir::Sqrt(weights(Expr(1), c) + Expr(epsilon))) *
weights(Expr(2), c) +
weights(Expr(3), c));
},
output_name);
return res;
}

} // namespace pe
} // namespace hlir
} // namespace cinn
5 changes: 5 additions & 0 deletions cinn/hlir/pe/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ std::vector<ir::Tensor> Conv2d_NCHW(const ir::Tensor& input,
int groups,
const std::string& output_name);

ir::Tensor BatchNorm_NCHW(const ir::Tensor& input,
const ir::Tensor& weights,
float epsilon,
const std::string& output_name);

} // namespace pe
} // namespace hlir
} // namespace cinn
39 changes: 26 additions & 13 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import unittest
import math
import numpy as np
import cinn
from cinn import frontend
Expand Down Expand Up @@ -54,7 +55,7 @@ def to_test_op(self, input_shapes, output_shape, op_name, attrs):
for i_shape in input_shapes:
expr_shape = []
inputs_data.append(
np.around(np.ones(i_shape).astype("float32"), 3))
np.around(np.random.random(i_shape).astype("float32"), 3))

for dim_shape in i_shape:
expr_shape.append(ir.Expr(dim_shape))
Expand Down Expand Up @@ -83,14 +84,10 @@ def to_test_op(self, input_shapes, output_shape, op_name, attrs):
args.append(runtime.cinn_pod_value_t(out_data))

fn(args)
print("test op output is:")
out_result = out[len(out) - 1]
print(out_result.numpy())
self.assertTrue(
np.allclose(
out_result.numpy(),
self.create_target_data(inputs_data),
atol=1e-4))

out_result = out[len(out) - 1].numpy()
correct_result = self.create_target_data(inputs_data)
self.assertTrue(np.allclose(out_result, correct_result, atol=1e-4))

def __codegen(self, op_name, inputs, attrs):
types = [common.Float(32)]
Expand All @@ -111,7 +108,7 @@ def __gen_var_name(self):

class OpTest_add(SingleOpTester):
def create_target_data(self, inputs_data):
X, Y = inputs_data
[X, Y] = inputs_data
return X + Y

def test_op(self):
Expand All @@ -121,12 +118,12 @@ def test_op(self):

class OpTest_relu(SingleOpTester):
def create_target_data(self, inputs_data):
X = inputs_data
return np.maximum(X, np.zeros(np.array(X).shape).astype("float32"))
[X] = inputs_data
return np.maximum(X, np.zeros(X.shape).astype("float32"))

def test_op(self):
attrs = framework.NodeAttr()
self.to_test_op([[32, 32]], [[32, 32]], "relu", attrs)
self.to_test_op([[32]], [[32]], "relu", attrs)


""" class OpTest_conv2d(SingleOpTester):
Expand All @@ -142,5 +139,21 @@ def test_op(self):
self.to_test_op([[1, 3, 10, 10], [2, 3, 2, 2]],
[[1, 3, 12, 12], [2, 3, 3, 3], [1, 2, 5, 5]], "conv2d", attrs) """


class OpTest_batchnorm(SingleOpTester):
def create_target_data(self, inputs_data):
[X, Y] = inputs_data
c = X.shape[1]
for i in range(0, c):
X[:, i, :, :] = (X[:, i, :, :] - Y[0, i]) / math.sqrt(
Y[1, i] + 0.00001) * Y[2, i] + Y[3, i]
return X

def test_op(self):
attrs = framework.NodeAttr()
self.to_test_op([[1, 3, 2, 2], [4, 3]], [[1, 3, 2, 2]], "batchnorm",
attrs)


if __name__ == "__main__":
unittest.main()

0 comments on commit e400a0c

Please sign in to comment.