Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing] 生成 tensor.pyi 时添加 overload 的方法 #68598

Merged
merged 3 commits into from
Oct 10, 2024

Conversation

megemini
Copy link
Contributor

@megemini megemini commented Oct 9, 2024

PR Category

User Experience

PR Types

Bug fixes

Description

生成 tensor.pyi 时添加 overload 的方法

涉及如下方法:

  • atleast_1d
  • atleast_2d
  • atleast_3d
  • lu
  • median
  • nanmedian
  • nonzero
  • qr
  • to
  • unique

修改相应的文件从 typing_extensions 导入 overload

自测在本地可以正常生成,vscode 提示正确:

image

生成的 pyi 文件只是有个小瑕疵:

  • 不带有 overload 的方法跟带有 overload 的方法不在一个地方 ~

比如:

class AbstractTensor:

    @overload
    def atleast_1d(self, *inputs: 'Tensor', name: 'str | None' = ...) -> 'list[Tensor]':
        ...

    @overload
    def atleast_1d(self, name: 'str | None' = ...) -> 'Tensor':
        ...

    @overload
    def atleast_2d(self, *inputs: 'Tensor', name: 'str | None' = ...) -> 'list[Tensor]':
        ...

    @overload
    def atleast_2d(self, name: 'str | None' = ...) -> 'Tensor':
        ...

    ...

    def atleast_1d(self, *inputs, name=None):
        r"""
        Convert inputs to tensors and return the view with at least 1-dimension. Scalar inputs are converted,
        one or high-dimensional inputs are preserved.

    ...

原因是,模板最后统一根据名称排序后插入,@overload 的方法会统一放到文件前面,没有 @overload 的方法排在后面 ~

没啥影响,只是看上去可能不太规整 ~

另外,docstring 统一只插入在不带 overload 的方法里面 ~

@SigureMo 请评审 ~

Copy link

paddle-bot bot commented Oct 9, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -511,6 +513,37 @@ def get_tensor_members(module: str = 'paddle.Tensor') -> dict[int, Member]:
or inspect.ismethod(member)
or inspect.ismethoddescriptor(member)
):
# try to get overloads
overload_signatures = []
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 try 具体是为了捕获什么错误?以及具体是哪一行的错误呢?不然对于后续重构来说就很难轻易修改

建议将这种错误捕获限定到更小的区域,比如

def try_get_overloads(fn: types.FunctionType) -> list[types.FunctionType] | ConcreteException:
    try:
        get_overloads(fn)
    except ConcreteException as e:
        return e

def main_logic():
    ...
    maybe_overloads = try_get_overloads(fn)
    if isinstance(maybe_overloads, ConcreteException):
        warning(f"Cannot get overloads from `{fn}`, reason: `{e}`")
        return
    ...

这样语义就清晰很多了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来逻辑中 method descriptor 不能使用 get_overloads 函数 ~

Traceback (most recent call last):
  File "gen_tensor_stub.py", line 629, in <module>
    main()
  File "gen_tensor_stub.py", line 625, in main
    generate_stub_file(args.input_file, args.output_file)
  File "gen_tensor_stub.py", line 596, in generate_stub_file
    tensor_members = get_tensor_members(module)
  File "gen_tensor_stub.py", line 519, in get_tensor_members
    _overloads = get_overloads(member)
  File "/home/shun/venv38dev/lib/python3.8/site-packages/typing_extensions.py", line 390, in get_overloads
    if f.__module__ not in _overload_registry:
AttributeError: 'method_descriptor' object has no attribute '__module__'

昨天有点晚了,没仔细定位,结果就被抓了 🫠🫠🫠

现在把两个逻辑分开了已经 ~

@SigureMo
Copy link
Member

SigureMo commented Oct 9, 2024

cc @HydrogenSulfate

@paddle-bot paddle-bot bot added the contributor External developers label Oct 9, 2024
@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Oct 10, 2024
Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

[
id(f),
["overload"],
f"{name}{_sig}".replace("Ellipsis", "..."),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

风险:方法参数里本来存在 Ellipsis 比如 XxxEllipsis,但目前可能性不大,暂时这样没啥问题

@SigureMo SigureMo merged commit 3865ce6 into PaddlePaddle:develop Oct 10, 2024
26 of 27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants