From 124ca7f495b23d3f1d6037c8e868100b71058271 Mon Sep 17 00:00:00 2001 From: coco <1228759711@qq.com> Date: Sun, 12 Nov 2023 17:42:08 +0000 Subject: [PATCH 1/5] pir fit for matrix_nms --- python/paddle/vision/ops.py | 17 +++++++++-------- test/legacy_test/test_matrix_nms_op.py | 6 ++++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 5a8b433cea52e..56193290ff1ef 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -20,7 +20,7 @@ from ..base import core from ..base.data_feeder import check_type, check_variable_and_dtype -from ..base.framework import Variable, in_dygraph_mode +from ..base.framework import Variable, in_dygraph_mode, in_dynamic_or_pir_mode from ..base.layer_helper import LayerHelper from ..framework import _current_expected_place from ..nn import BatchNorm2D, Conv2D, Layer, ReLU, Sequential @@ -2284,7 +2284,14 @@ def matrix_nms( ... score_threshold=0.5, post_threshold=0.1, ... nms_top_k=400, keep_top_k=200, normalized=False) """ - if in_dygraph_mode(): + check_variable_and_dtype( + bboxes, 'BBoxes', ['float32', 'float64'], 'matrix_nms' + ) + check_variable_and_dtype( + scores, 'Scores', ['float32', 'float64'], 'matrix_nms' + ) + + if in_dynamic_or_pir_mode(): out, index, rois_num = _C_ops.matrix_nms( bboxes, scores, @@ -2303,12 +2310,6 @@ def matrix_nms( rois_num = None return out, rois_num, index else: - check_variable_and_dtype( - bboxes, 'BBoxes', ['float32', 'float64'], 'matrix_nms' - ) - check_variable_and_dtype( - scores, 'Scores', ['float32', 'float64'], 'matrix_nms' - ) check_type(score_threshold, 'score_threshold', float, 'matrix_nms') check_type(post_threshold, 'post_threshold', float, 'matrix_nms') check_type(nms_top_k, 'nums_top_k', int, 'matrix_nms') diff --git a/test/legacy_test/test_matrix_nms_op.py b/test/legacy_test/test_matrix_nms_op.py index 08d1ec20653c8..bb984610d51c9 100644 --- a/test/legacy_test/test_matrix_nms_op.py +++ b/test/legacy_test/test_matrix_nms_op.py @@ -20,6 +20,7 @@ import paddle from paddle.base import Program, program_guard +from paddle.pir_utils import test_with_pir_api def python_matrix_nms( @@ -296,7 +297,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestMatrixNMSOpNoOutput(TestMatrixNMSOp): @@ -311,6 +312,7 @@ def set_argument(self): class TestMatrixNMSError(unittest.TestCase): + @test_with_pir_api def test_errors(self): M = 1200 N = 7 @@ -327,7 +329,7 @@ def test_errors(self): scores = np.reshape(scores, (N, M, C)) scores_np = np.transpose(scores, (0, 2, 1)) - with program_guard(Program(), Program()): + with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): boxes_data = paddle.static.data( name='bboxes', shape=[M, C, BOX_SIZE], dtype='float32' ) From b7293fbf53268005bf148b3f33e8c584ae7bd0e7 Mon Sep 17 00:00:00 2001 From: coco <1228759711@qq.com> Date: Sun, 12 Nov 2023 17:47:17 +0000 Subject: [PATCH 2/5] fix bug --- test/legacy_test/test_matrix_nms_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_matrix_nms_op.py b/test/legacy_test/test_matrix_nms_op.py index bb984610d51c9..dfb2d90efacf1 100644 --- a/test/legacy_test/test_matrix_nms_op.py +++ b/test/legacy_test/test_matrix_nms_op.py @@ -331,7 +331,7 @@ def test_errors(self): with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): boxes_data = paddle.static.data( - name='bboxes', shape=[M, C, BOX_SIZE], dtype='float32' + name='bboxes', shape=[M, M, BOX_SIZE], dtype='float32' ) scores_data = paddle.static.data( name='scores', shape=[N, C, M], dtype='float32' From f941e13d3af80b876a411b270378ff8fb6271b98 Mon Sep 17 00:00:00 2001 From: coco <1228759711@qq.com> Date: Mon, 13 Nov 2023 02:02:40 +0800 Subject: [PATCH 3/5] codestyle --- test/legacy_test/test_matrix_nms_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_matrix_nms_op.py b/test/legacy_test/test_matrix_nms_op.py index dfb2d90efacf1..feb475784f7eb 100644 --- a/test/legacy_test/test_matrix_nms_op.py +++ b/test/legacy_test/test_matrix_nms_op.py @@ -19,7 +19,6 @@ from op_test import OpTest import paddle -from paddle.base import Program, program_guard from paddle.pir_utils import test_with_pir_api @@ -329,7 +328,9 @@ def test_errors(self): scores = np.reshape(scores, (N, M, C)) scores_np = np.transpose(scores, (0, 2, 1)) - with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): boxes_data = paddle.static.data( name='bboxes', shape=[M, M, BOX_SIZE], dtype='float32' ) From 794d1e23a192259b9425097db25e2a3e423aa057 Mon Sep 17 00:00:00 2001 From: coco <1228759711@qq.com> Date: Mon, 13 Nov 2023 17:46:17 +0000 Subject: [PATCH 4/5] reset check dtype --- python/paddle/vision/ops.py | 13 ++++++------- test/legacy_test/test_matrix_nms_op.py | 7 +++---- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 56193290ff1ef..2a95521a3b63e 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -2284,13 +2284,6 @@ def matrix_nms( ... score_threshold=0.5, post_threshold=0.1, ... nms_top_k=400, keep_top_k=200, normalized=False) """ - check_variable_and_dtype( - bboxes, 'BBoxes', ['float32', 'float64'], 'matrix_nms' - ) - check_variable_and_dtype( - scores, 'Scores', ['float32', 'float64'], 'matrix_nms' - ) - if in_dynamic_or_pir_mode(): out, index, rois_num = _C_ops.matrix_nms( bboxes, @@ -2310,6 +2303,12 @@ def matrix_nms( rois_num = None return out, rois_num, index else: + check_variable_and_dtype( + bboxes, 'BBoxes', ['float32', 'float64'], 'matrix_nms' + ) + check_variable_and_dtype( + scores, 'Scores', ['float32', 'float64'], 'matrix_nms' + ) check_type(score_threshold, 'score_threshold', float, 'matrix_nms') check_type(post_threshold, 'post_threshold', float, 'matrix_nms') check_type(nms_top_k, 'nums_top_k', int, 'matrix_nms') diff --git a/test/legacy_test/test_matrix_nms_op.py b/test/legacy_test/test_matrix_nms_op.py index dfb2d90efacf1..8324465d4fb29 100644 --- a/test/legacy_test/test_matrix_nms_op.py +++ b/test/legacy_test/test_matrix_nms_op.py @@ -19,8 +19,6 @@ from op_test import OpTest import paddle -from paddle.base import Program, program_guard -from paddle.pir_utils import test_with_pir_api def python_matrix_nms( @@ -312,7 +310,6 @@ def set_argument(self): class TestMatrixNMSError(unittest.TestCase): - @test_with_pir_api def test_errors(self): M = 1200 N = 7 @@ -329,7 +326,9 @@ def test_errors(self): scores = np.reshape(scores, (N, M, C)) scores_np = np.transpose(scores, (0, 2, 1)) - with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): boxes_data = paddle.static.data( name='bboxes', shape=[M, M, BOX_SIZE], dtype='float32' ) From 80409a53ea065f2bcb034cf2f210186f6b5d01db Mon Sep 17 00:00:00 2001 From: coco <69197635+cocoshe@users.noreply.github.com> Date: Tue, 14 Nov 2023 19:27:50 +0800 Subject: [PATCH 5/5] Update test_matrix_nms_op.py --- test/legacy_test/test_matrix_nms_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_matrix_nms_op.py b/test/legacy_test/test_matrix_nms_op.py index 8324465d4fb29..6ed82b336d109 100644 --- a/test/legacy_test/test_matrix_nms_op.py +++ b/test/legacy_test/test_matrix_nms_op.py @@ -330,7 +330,7 @@ def test_errors(self): paddle.static.Program(), paddle.static.Program() ): boxes_data = paddle.static.data( - name='bboxes', shape=[M, M, BOX_SIZE], dtype='float32' + name='bboxes', shape=[M, C, BOX_SIZE], dtype='float32' ) scores_data = paddle.static.data( name='scores', shape=[N, C, M], dtype='float32'