-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.17】 为 Paddle 新增 pdist API -part #57869
Changes from 20 commits
3e19c90
51b908a
ced239c
9c651d9
13d04ad
8bcbe47
212bdf7
7bbc22e
20f82de
b06dd06
2e259a6
e9dfa5a
cddec3b
23d5f7e
3a9fbfb
e41f9dc
008ae5b
114453a
171339d
8a281b4
a9d053c
b3816df
3f7b607
e3e5abf
8172f28
0146a37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,122 @@ | ||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||||
# | ||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||
# you may not use this file except in compliance with the License. | ||||
# You may obtain a copy of the License at | ||||
# | ||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||
# | ||||
# Unless required by applicable law or agreed to in writing, software | ||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | ||||
# limitations under the License. | ||||
|
||||
import unittest | ||||
|
||||
import numpy as np | ||||
|
||||
import paddle | ||||
|
||||
|
||||
def ref_pdist(x, p=2.0): | ||||
dist = np.linalg.norm(x[..., None, :] - x[None, :, :], ord=p, axis=-1) | ||||
res = [] | ||||
rows, cols = dist.shape | ||||
for i in range(rows): | ||||
for j in range(cols): | ||||
if i >= j: | ||||
continue | ||||
res.append(dist[i][j]) | ||||
return np.array(res) | ||||
|
||||
|
||||
class TestPdistAPI(unittest.TestCase): | ||||
def setUp(self): | ||||
self.x = np.random.rand(10, 20).astype('float32') | ||||
self.p = 2.0 | ||||
self.init_input() | ||||
self.place = ( | ||||
paddle.CUDAPlace(0) | ||||
if paddle.is_compiled_with_cuda() | ||||
else paddle.CPUPlace() | ||||
) | ||||
|
||||
def init_input(self): | ||||
pass | ||||
|
||||
def test_static_api(self): | ||||
paddle.enable_static() | ||||
with paddle.static.program_guard(paddle.static.Program()): | ||||
x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype) | ||||
out = paddle.pdist( | ||||
x, | ||||
self.p, | ||||
) | ||||
exe = paddle.static.Executor(self.place) | ||||
res = exe.run(feed={'x': self.x}, fetch_list=[out]) | ||||
out_ref = ref_pdist(self.x, self.p) | ||||
np.testing.assert_allclose(out_ref, res[0], rtol=1e-5, atol=1e-5) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该不需要使用rtol, atol? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
如果不用rtol atol的话精度过不了。我进行了一些尝试,发现norm看样子像是没和numpy对齐,在 Paddle/test/legacy_test/test_cdist.py Line 56 in 907e425
|
||||
|
||||
def test_dygraph_api(self): | ||||
paddle.disable_static(self.place) | ||||
x = paddle.to_tensor(self.x) | ||||
out = paddle.pdist( | ||||
x, | ||||
self.p, | ||||
) | ||||
out_ref = ref_pdist(self.x, self.p) | ||||
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5) | ||||
paddle.enable_static() | ||||
|
||||
|
||||
class TestPdistAPICase1_param_p1(TestPdistAPI): | ||||
def init_input(self): | ||||
self.p = 0 | ||||
|
||||
|
||||
class TestPdistAPICase2_param_p2(TestPdistAPI): | ||||
def init_input(self): | ||||
self.p = 1.0 | ||||
|
||||
|
||||
class TestPdistAPICase3_param_p3(TestPdistAPI): | ||||
def init_input(self): | ||||
self.p = 3.0 | ||||
|
||||
|
||||
class TestPdistAPICase4_param_p4(TestPdistAPI): | ||||
def init_input(self): | ||||
self.p = 1.5 | ||||
|
||||
|
||||
class TestPdistAPICase5_param_p5(TestPdistAPI): | ||||
def init_input(self): | ||||
self.p = 2.5 | ||||
|
||||
|
||||
class TestPdistAPICase6_param_p6(TestPdistAPI): | ||||
def init_input(self): | ||||
self.p = float('inf') | ||||
|
||||
|
||||
class TestPdistAPICase7_input_x1(TestPdistAPI): | ||||
def init_input(self): | ||||
self.x = np.random.rand(50, 20).astype('float64') | ||||
|
||||
|
||||
class TestPdistShapeError(unittest.TestCase): | ||||
def test_error(self): | ||||
with self.assertRaises(AssertionError): | ||||
self.x = np.random.rand(50, 10, 20).astype('float64') | ||||
self.p = 2.0 | ||||
x = paddle.to_tensor(self.x) | ||||
out0 = paddle.pdist( | ||||
x, | ||||
self.p, | ||||
) | ||||
|
||||
|
||||
if __name__ == '__main__': | ||||
paddle.enable_static() | ||||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which path do we recommend users to use? If
paddle.pdist
is recommended, it cannot be added to this __all__ list. ifpaddle.nn.functional.pdist
, it cannot be added to the __all__ list inpython/paddle/__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tks for you review, I think I prefer
paddle.pdist
path, I removed this line then.