Skip to content

Commit

Permalink
[Prim][PIR] add leaky_relu, sigmoid, instance_norm op forward prim (#…
Browse files Browse the repository at this point in the history
…60564)

* hardswish op prim sink

* hardswish op prim

* add composite

* add leaky_relu, sigmoid op forward prim

* remove hardswish op forward

* add instance_norm op forward prim
  • Loading branch information
kevincheng2 authored Jan 8, 2024
1 parent 1646a83 commit 385ec43
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
"dropout",
"full_like",
"gelu",
"instance_norm",
"layer_norm",
"leaky_relu",
"mean",
"pow",
"relu",
"rsqrt",
"sigmoid",
"silu",
"softmax",
"sqrt",
Expand All @@ -44,11 +47,14 @@
"dropout",
"full_like",
"gelu",
"instance_norm",
"layer_norm",
"leaky_relu",
"mean",
"pow",
"relu",
"rsqrt",
"sigmoid",
"silu",
"softmax",
"sqrt",
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@
kernel :
func : hardswish_grad
inplace : (out_grad -> x_grad)
composite : hardswish_grad(x, out_grad, x_grad)

- backward_op : hsigmoid_loss_grad
forward : hsigmoid_loss (Tensor x, Tensor label, Tensor w, Tensor bias, Tensor path, Tensor code, int num_classes, bool is_sparse) -> Tensor(out), Tensor(pre_out), Tensor(w_out)
Expand Down
109 changes: 109 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,115 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) {
}
}

template <typename T>
Tensor sigmoid_decomp(const Tensor& x) {
auto org_dtype = x.dtype();
Tensor x_cast = x;

bool need_cast = is_half_dtype(org_dtype);
if (need_cast) {
x_cast = cast<T>(x, phi::DataType::FLOAT32);
}

// res = 1 / (1 + exp(-x))
auto one = full<T>(common::vectorize(x_cast.dims()), 1, x_cast.dtype());
auto exp_tmp = exp<T>(
full<T>(common::vectorize(x_cast.dims()), -1, x_cast.dtype()) * x_cast);
auto res = one / (one + exp_tmp);
if (need_cast) {
return cast<T>(res, org_dtype);
} else {
return res;
}
}

template <typename T>
Tensor leaky_relu_decomp(const Tensor& x, float negative_slope) {
auto multiply_tmp =
full<T>(phi::vectorize(x.dims()), negative_slope, x.dtype()) * x;
if (negative_slope < 1.0) {
return maximum<T>(x, multiply_tmp);
} else {
return minimum<T>(x, multiply_tmp);
}
}

template <typename T>
std::tuple<Tensor, Tensor, Tensor> instance_norm_decomp(
const Tensor& x,
const paddle::optional<Tensor>& scale,
const paddle::optional<Tensor>& bias,
float epsilon) {
auto org_dtype = x.dtype();
Tensor x_cast = x;

bool need_cast = is_half_dtype(org_dtype);
if (need_cast) {
x_cast = cast<T>(x, phi::DataType::FLOAT32);
}

std::vector<int64_t> axis;
auto x_dim = common::vectorize<int64_t>(x.dims());
for (size_t i = 2; i < x_dim.size(); i++) {
axis.push_back(static_cast<int64_t>(i));
}

// out = (x - mean(x)) / sqrt(var + epsilon))
// var = mean((x-mean(x))^2)
auto mean_ = mean_decomp<T>(x_cast, IntArray(axis), true);
auto difference = x_cast - mean_;
auto var_tmp1 = difference * difference;
auto variance = mean_decomp<T>(var_tmp1, IntArray(axis), true);
auto var_tmp3 = variance + epsilon;
auto rsqrt_var = elementwise_pow<T>(
var_tmp3,
full<T>(common::vectorize(var_tmp3.dims()), 0.5, var_tmp3.dtype()));
auto out = difference / rsqrt_var;

auto scale_ptr = scale.get_ptr();
auto bias_ptr = bias.get_ptr();
std::vector<int64_t> slice_shape(x_dim.size(), 1);
slice_shape[1] = x_dim[1];

Tensor scale_cast;
if (scale_ptr) {
if (slice_shape != scale_ptr->shape()) {
scale_cast = reshape<T>(*scale_ptr, slice_shape);
} else {
scale_cast = *scale_ptr;
}
if (need_cast) {
scale_cast = cast<T>(scale_cast, phi::DataType::FLOAT32);
}
out = out * scale_cast;
}
Tensor bias_cast;
if (bias_ptr) {
if (slice_shape != bias_ptr->shape()) {
bias_cast = reshape<T>(*bias_ptr, slice_shape);
} else {
bias_cast = *bias_ptr;
}
if (need_cast) {
bias_cast = cast<T>(bias_cast, phi::DataType::FLOAT32);
}
out = out + bias_cast;
}

std::vector<int64_t> res_shape(1, -1);
auto mean_out = reshape<T>(mean_, res_shape);
auto variance_out = reshape<T>(1 / rsqrt_var, res_shape);

Tensor res;
if (need_cast) {
res = cast<T>(out, org_dtype);
} else {
res = out;
}

return std::make_tuple(res, mean_out, variance_out);
}

} // namespace details

} // namespace primitive
Expand Down
65 changes: 47 additions & 18 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
Expand All @@ -411,7 +411,7 @@ def init_dtype(self):

def test_check_output(self):
with paddle.static.scope_guard(paddle.static.Scope()):
self.check_output(check_prim=False)
self.check_output(check_prim=False, check_prim_pir=False)

def test_check_grad(self):
self.check_grad(
Expand All @@ -420,6 +420,7 @@ def test_check_grad(self):
max_relative_error=0.006,
check_prim=False,
check_pir=True,
check_prim_pir=False,
)


Expand All @@ -428,7 +429,9 @@ def init_dtype(self):
self.dtype = np.complex128

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=False, check_pir=True)
self.check_grad(
['X'], 'Out', check_prim=False, check_pir=True, check_prim_pir=False
)


class TestSigmoid_ZeroDim(TestSigmoid):
Expand Down Expand Up @@ -469,7 +472,9 @@ def if_enable_cinn(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_prim=True, check_pir=True)
self.check_output_with_place(
place, check_prim=True, check_pir=True, check_prim_pir=True
)

def test_check_grad(self):
place = core.CUDAPlace(0)
Expand Down Expand Up @@ -2555,7 +2560,7 @@ def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_prim=True, check_pir=True)
self.check_output(check_prim=True, check_pir=True, check_prim_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
Expand Down Expand Up @@ -3038,7 +3043,9 @@ def test_check_grad(self):
else False,
only_check_prim=self.if_only_check_prim(),
check_pir=True,
check_prim_pir=True,
check_prim_pir=True
if self.dtype not in [np.complex64, np.complex128]
else False,
)

def test_check_output(self):
Expand Down Expand Up @@ -4832,7 +4839,11 @@ def test_check_grad(self):
)
create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(
TestSigmoid, check_prim=True, enable_cinn=True, check_pir=True
TestSigmoid,
check_prim=True,
enable_cinn=True,
check_pir=True,
check_prim_pir=True,
)
create_test_act_fp16_class(
TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True
Expand Down Expand Up @@ -4929,18 +4940,24 @@ def test_check_grad(self):
create_test_act_fp16_class(TestHardSwish, check_prim=True, check_pir=True)
create_test_act_fp16_class(TestMish, check_pir=True)
create_test_act_fp16_class(
TestLeakyRelu, check_prim=True, enable_cinn=True, check_pir=True
TestLeakyRelu,
check_prim=True,
enable_cinn=True,
check_pir=True,
check_prim_pir=True,
)
create_test_act_fp16_class(
TestLeakyReluAlpha1, check_prim=True, enable_cinn=True, check_prim_pir=True
)
create_test_act_fp16_class(
TestLeakyReluAlpha1, check_prim=True, enable_cinn=True
TestLeakyReluAlpha2, check_prim=True, enable_cinn=True, check_prim_pir=True
)
create_test_act_fp16_class(
TestLeakyReluAlpha2, check_prim=True, enable_cinn=True
TestLeakyReluAlpha3, check_prim=True, enable_cinn=True, check_prim_pir=True
)
create_test_act_fp16_class(
TestLeakyReluAlpha3, check_prim=True, enable_cinn=True
TestLeakyRelu_ZeroDim, check_prim=True, check_prim_pir=True
)
create_test_act_fp16_class(TestLeakyRelu_ZeroDim, check_prim=True)
create_test_act_fp16_class(
TestRsqrt,
check_prim=True,
Expand Down Expand Up @@ -5017,7 +5034,9 @@ def test_check_grad(self):
TestExpFp32_Prim, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(TestExpm1)
create_test_act_bf16_class(TestSigmoid, check_prim=True, check_pir=True)
create_test_act_bf16_class(
TestSigmoid, check_prim=True, check_pir=True, check_prim_pir=True
)
create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True)
create_test_act_bf16_class(TestLogSigmoid)
create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True)
Expand Down Expand Up @@ -5089,11 +5108,21 @@ def test_check_grad(self):
create_test_act_bf16_class(TestSwish)
create_test_act_bf16_class(TestHardSwish, check_prim=True, check_pir=True)
create_test_act_bf16_class(TestMish, check_pir=True)
create_test_act_bf16_class(TestLeakyRelu, check_prim=True, check_pir=True)
create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True)
create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True)
create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True)
create_test_act_bf16_class(TestLeakyRelu_ZeroDim, check_prim=True)
create_test_act_bf16_class(
TestLeakyRelu, check_prim=True, check_pir=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestLeakyReluAlpha1, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestLeakyReluAlpha2, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestLeakyReluAlpha3, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestLeakyRelu_ZeroDim, check_prim=True, check_prim_pir=True
)
create_test_act_bf16_class(
TestRsqrt, check_prim=True, check_pir=True, check_prim_pir=True
)
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/test_instance_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(check_prim=True, check_pir=True)
self.check_output(check_prim=True, check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(
Expand Down
22 changes: 19 additions & 3 deletions test/legacy_test/test_instance_norm_op_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,12 @@ def setUp(self):

def test_check_output(self):
self.check_output(
atol=self.atol, check_prim=self.check_prim, check_pir=True
atol=self.atol,
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)

def test_check_grad(self):
Expand Down Expand Up @@ -275,7 +280,13 @@ def set_err_thre(self):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, atol=self.atol, check_prim=self.check_prim, check_pir=True
place,
atol=self.atol,
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)

def test_check_grad(self):
Expand Down Expand Up @@ -350,7 +361,12 @@ def init_shape(self):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_prim=self.check_prim, check_pir=True
place,
check_prim=self.check_prim,
check_pir=True,
check_prim_pir=False
if os.getenv("FLAGS_enable_pir_in_executor")
else True,
)

def test_check_grad(self):
Expand Down

0 comments on commit 385ec43

Please sign in to comment.