Skip to content

Commit

Permalink
fix resnet bug (PaddlePaddle#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Sep 21, 2020
1 parent c4f02ac commit f79970f
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 12 deletions.
7 changes: 1 addition & 6 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -876,15 +876,10 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_LoweredFunc_ *op) {
/*Parent=*/function,
/*InsertBefore=*/nullptr);

llvm::Value *old_args = GetVar("_args"); // store _args
SetVar("_args", args[0]);
b_->SetInsertPoint(entry);
Visit(&function_body);
if (old_args) {
SetVar("_args", old_args); // restore _args
} else {
symbol_table_->Erase("_args");
}
symbol_table_->Erase("_args");
RetVoid();
return function;
}
Expand Down
4 changes: 1 addition & 3 deletions cinn/hlir/pe/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ ir::Tensor BatchNorm_NCHW(const ir::Tensor &input,
auto res = Compute(
input->shape,
[=](Expr n, Expr c, Expr h, Expr w) {
//! TODO(haozech) Add Sqrt will cause bug
//! return (input(n, c, h, w) - mean(c))* scale(c) / Sqrt(variance(c) + Expr(epsilon)) + bias(c);
return (input(n, c, h, w) - mean(c)) * scale(c) / (variance(c) + Expr(epsilon)) + bias(c);
return (input(n, c, h, w) - mean(c)) * scale(c) / Sqrt(variance(c) + Expr(epsilon)) + bias(c);
},
output_name);
return res;
Expand Down
1 change: 1 addition & 0 deletions python/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

class TestMamul(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.target = Target()
self.target.arch = Target.Arch.X86
self.target.bits = Target.Bit.k32
Expand Down
3 changes: 0 additions & 3 deletions python/tests/test_op_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,7 @@ def create_target_data(self, inputs_data, attrs):
[X, Scale, Bias, Mean, Variance] = inputs_data
c = X.shape[1]
for i in range(0, c):
""" TODO(haozech) This should be the correct compute function(with sqrt)
X[:, i, :, :] = (X[:, i, :, :] - Mean[i]) / math.sqrt(
Variance[i] + 0.00001) * Scale[i] + Bias[i] """
X[:, i, :, :] = (X[:, i, :, :] - Mean[i]) / (
Variance[i] + 0.00001) * Scale[i] + Bias[i]
return X

Expand Down

0 comments on commit f79970f

Please sign in to comment.