Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.160、204、216 】 Migrate NLLLoss/BCEWithLogitsLoss/MarginRankingLoss into pir #58832

Merged
merged 14 commits into from
Nov 27, 2023
10 changes: 5 additions & 5 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
# TODO: define loss functions of neural network
import paddle
from paddle import _C_ops, base, in_dynamic_mode
from paddle.framework import core
from paddle.static.nn.control_flow import Assert
from paddle.utils import deprecated

from ...base.data_feeder import check_variable_and_dtype
from ...base.framework import (
_current_expected_place,
core,
in_dynamic_or_pir_mode,
in_pir_mode,
)
Expand Down Expand Up @@ -800,7 +800,7 @@ def binary_cross_entropy_with_logits(
% reduction
)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
one = _C_ops.full(
[1],
1.0,
Expand Down Expand Up @@ -1197,11 +1197,11 @@ def margin_ranking_loss(
"The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction
)
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.subtract(other, input)
out = _C_ops.multiply(out, label)
if margin != 0.0:
margin = base.dygraph.base.to_variable([margin], dtype=out.dtype)
margin = paddle.to_tensor([margin], dtype=out.dtype)
out = _C_ops.add(out, margin)
out = _C_ops.relu(out)
if reduction == 'sum':
Expand Down Expand Up @@ -1440,7 +1440,7 @@ def nll_loss(

n = input_shape[0]
c = input_shape[1]
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if input_dims != 2 and input_dims != 4:
input = _C_ops.reshape(input, [n, c, 1, -1])
label = _C_ops.reshape(label, [n, 1, -1])
Expand Down
160 changes: 95 additions & 65 deletions test/legacy_test/test_bce_with_logits_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle
from paddle import base
from paddle.pir_utils import test_with_pir_api


def call_bce_layer(
Expand Down Expand Up @@ -49,9 +50,10 @@ def test_static(
functional=False,
):
paddle.enable_static()
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):

with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
logit = paddle.static.data(
name='logit', shape=logit_np.shape, dtype='float64'
)
Expand Down Expand Up @@ -79,7 +81,7 @@ def test_static(
else:
res = call_bce_layer(logit, label, weight, reduction, pos_weight)
exe = paddle.static.Executor(place)
(static_result,) = exe.run(prog, feed=feed_dict, fetch_list=[res])
(static_result,) = exe.run(feed=feed_dict, fetch_list=[res])
return static_result


Expand Down Expand Up @@ -147,41 +149,51 @@ def test_BCEWithLogitsLoss(self):
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
static_result = test_static(
place, logit_np, label_np, reduction=reduction
)
dy_result = test_dygraph(
place, logit_np, label_np, reduction=reduction
)
expected = calc_bce_with_logits_loss(
logit_np, label_np, reduction
)
np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
static_functional = test_static(
place,
logit_np,
label_np,
reduction=reduction,
functional=True,
)
dy_functional = test_dygraph(
place,
logit_np,
label_np,
reduction=reduction,
functional=True,
)

np.testing.assert_allclose(
static_functional, expected, rtol=1e-05
)
np.testing.assert_allclose(
static_functional, dy_functional, rtol=1e-05
expected = calc_bce_with_logits_loss(
logit_np, label_np, reduction
)

np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
np.testing.assert_allclose(dy_functional, expected, rtol=1e-05)

@test_with_pir_api
def test_dynamic_or_pir_mode():
static_result = test_static(
place, logit_np, label_np, reduction=reduction
)
static_functional = test_static(
place,
logit_np,
label_np,
reduction=reduction,
functional=True,
)
np.testing.assert_allclose(
static_result, expected, rtol=1e-05
)
np.testing.assert_allclose(
static_result, dy_result, rtol=1e-05
)

np.testing.assert_allclose(
static_functional, expected, rtol=1e-05
)
np.testing.assert_allclose(
static_functional, dy_functional, rtol=1e-05
)

test_dynamic_or_pir_mode()

def test_BCEWithLogitsLoss_weight(self):
logit_np = np.random.uniform(0.1, 0.8, size=(2, 3, 4, 10)).astype(
np.float64
Expand All @@ -196,13 +208,6 @@ def test_BCEWithLogitsLoss_weight(self):
else base.CPUPlace()
)
for reduction in ['sum', 'mean', 'none']:
static_result = test_static(
place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction,
)
dy_result = test_dygraph(
place,
logit_np,
Expand All @@ -213,17 +218,6 @@ def test_BCEWithLogitsLoss_weight(self):
expected = calc_bce_with_logits_loss(
logit_np, label_np, reduction, weight_np=weight_np
)
np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
static_functional = test_static(
place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction,
functional=True,
)
dy_functional = test_dygraph(
place,
logit_np,
Expand All @@ -232,12 +226,39 @@ def test_BCEWithLogitsLoss_weight(self):
reduction=reduction,
functional=True,
)
np.testing.assert_allclose(static_functional, expected, rtol=1e-05)
np.testing.assert_allclose(
static_functional, dy_functional, rtol=1e-05
)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
np.testing.assert_allclose(dy_functional, expected, rtol=1e-05)

@test_with_pir_api
def test_dynamic_or_pir_mode():
static_result = test_static(
place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction,
)

static_functional = test_static(
place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction,
functional=True,
)
np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)

np.testing.assert_allclose(
static_functional, expected, rtol=1e-05
)
np.testing.assert_allclose(
static_functional, dy_functional, rtol=1e-05
)

test_dynamic_or_pir_mode()

def test_BCEWithLogitsLoss_pos_weight(self):
logit_np = np.random.uniform(0.1, 0.8, size=(2, 3, 4, 10)).astype(
np.float64
Expand All @@ -253,27 +274,13 @@ def test_BCEWithLogitsLoss_pos_weight(self):
else base.CPUPlace()
)
reduction = "mean"
static_result = test_static(
place, logit_np, label_np, weight_np, reduction, pos_weight_np
)

dy_result = test_dygraph(
place, logit_np, label_np, weight_np, reduction, pos_weight_np
)
expected = calc_bce_with_logits_loss(
logit_np, label_np, reduction, weight_np, pos_weight_np
)
np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
static_functional = test_static(
place,
logit_np,
label_np,
weight_np,
reduction,
pos_weight_np,
functional=True,
)
dy_functional = test_dygraph(
place,
logit_np,
Expand All @@ -283,10 +290,33 @@ def test_BCEWithLogitsLoss_pos_weight(self):
pos_weight_np,
functional=True,
)
np.testing.assert_allclose(static_functional, expected, rtol=1e-05)
np.testing.assert_allclose(static_functional, dy_functional, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
np.testing.assert_allclose(dy_functional, expected, rtol=1e-05)

@test_with_pir_api
def test_dynamic_or_pir_mode():
static_result = test_static(
place, logit_np, label_np, weight_np, reduction, pos_weight_np
)
static_functional = test_static(
place,
logit_np,
label_np,
weight_np,
reduction,
pos_weight_np,
functional=True,
)

np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(static_functional, expected, rtol=1e-05)
np.testing.assert_allclose(
static_functional, dy_functional, rtol=1e-05
)

test_dynamic_or_pir_mode()

def test_BCEWithLogitsLoss_error(self):
paddle.disable_static()
self.assertRaises(
Expand Down
Loading