Skip to content

Commit

Permalink
【PIR API adaptor No.51, 60】Migrate some ops into pir (PaddlePaddle#58684
Browse files Browse the repository at this point in the history
)
  • Loading branch information
longranger2 authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent d738599 commit 53368f2
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/paddle/vision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ def distribute_fpn_proposals(
num_lvl < 100
), "Only support max to 100 levels, (max_level - min_level + 1 < 100)"

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
assert (
rois_num is not None
), "rois_num should not be None in dygraph mode."
Expand Down
24 changes: 19 additions & 5 deletions test/legacy_test/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle.base import core
from paddle.base.dygraph import base as imperative_base
from paddle.base.framework import Program, program_guard
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand Down Expand Up @@ -66,9 +67,9 @@ def get_static_graph_result(
self, feed, fetch_list, with_lod=False, force_to_use_cpu=False
):
exe = base.Executor(self._get_place(force_to_use_cpu))
exe.run(base.default_startup_program())
exe.run(paddle.static.default_startup_program())
return exe.run(
base.default_main_program(),
paddle.static.default_main_program(),
feed=feed,
fetch_list=fetch_list,
return_numpy=(not with_lod),
Expand Down Expand Up @@ -183,9 +184,7 @@ def test_multiclass_nms2(self):


class TestDistributeFpnProposals(LayerTest):
def test_distribute_fpn_proposals(self):
rois_np = np.random.rand(10, 4).astype('float32')
rois_num_np = np.array([4, 6]).astype('int32')
def static_distribute_fpn_proposals(self, rois_np, rois_num_np):
with self.static_graph():
rois = paddle.static.data(
name='rois', shape=[10, 4], dtype='float32'
Expand Down Expand Up @@ -216,7 +215,9 @@ def test_distribute_fpn_proposals(self):
output_np = np.array(output)
if len(output_np) > 0:
output_stat_np.append(output_np)
return output_stat_np

def dynamic_distribute_fpn_proposals(self, rois_np, rois_num_np):
with self.dynamic_graph():
rois_dy = imperative_base.to_variable(rois_np)
rois_num_dy = imperative_base.to_variable(rois_num_np)
Expand All @@ -239,6 +240,19 @@ def test_distribute_fpn_proposals(self):
output_np = output.numpy()
if len(output_np) > 0:
output_dy_np.append(output_np)
return output_dy_np

@test_with_pir_api
def test_distribute_fpn_proposals(self):
rois_np = np.random.rand(10, 4).astype('float32')
rois_num_np = np.array([4, 6]).astype('int32')

output_stat_np = self.static_distribute_fpn_proposals(
rois_np, rois_num_np
)
output_dy_np = self.dynamic_distribute_fpn_proposals(
rois_np, rois_num_np
)

for res_stat, res_dy in zip(output_stat_np, output_dy_np):
np.testing.assert_array_equal(res_stat, res_dy)
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/test_distribute_fpn_proposals_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from op_test import OpTest

import paddle
from paddle.pir_utils import test_with_pir_api


def distribute_fpn_proposals_wrapper(
Expand Down Expand Up @@ -142,7 +143,7 @@ def setUp(self):
self.set_data()

def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output(check_dygraph=False, check_pir=False)


class TestDistributeFPNProposalsOpWithRoisNum(TestDistributeFPNProposalsOp):
Expand Down Expand Up @@ -200,6 +201,7 @@ def setUp(self):
self.rois_np = np.random.rand(10, 4).astype('float32')
self.rois_num_np = np.array([4, 6]).astype('int32')

@test_with_pir_api
def test_dygraph_with_static(self):
paddle.enable_static()
rois = paddle.static.data(name='rois', shape=[10, 4], dtype='float32')
Expand Down

0 comments on commit 53368f2

Please sign in to comment.