From e4b4959d4ea505f96c38c4a83ec274e025f07852 Mon Sep 17 00:00:00 2001 From: Ashley Whetter Date: Fri, 14 Aug 2020 09:46:42 -0700 Subject: [PATCH] Fixed stubgen parsing generics from C extensions (#8939) pybind11 is capable of producing type signatures that use generics (for example https://github.com/pybind/pybind11/blob/4e3d9fea74ed50a042d98f68fa35a3133482289b/include/pybind11/stl.h#L140). A user may also opt to write a signature in the docstring that uses generics. Currently when stubgen parses one of these generics, it attempts to import a part of it. For example if a docstring had my_func(str, int) -> List[mypackage.module_being_parsed.MyClass], the resulting stub file tries to import List[mypackage.module_being_parsed. This change fixes this behaviour by breaking the found type down into the multiple types around [], characters, adding any imports from those types that are needed, and then stripping out the name of the module being parsed. --- mypy/stubgenc.py | 11 +++++- mypy/test/teststubgen.py | 75 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/mypy/stubgenc.py b/mypy/stubgenc.py index 72477a2ce300..905be239fc13 100755 --- a/mypy/stubgenc.py +++ b/mypy/stubgenc.py @@ -214,7 +214,16 @@ def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str: imports: list of import statements (may be modified during the call) """ stripped_type = typ - if module and typ.startswith(module.__name__ + '.'): + if any(c in typ for c in '[,'): + for subtyp in re.split(r'[\[,\]]', typ): + strip_or_import(subtyp.strip(), module, imports) + if module: + stripped_type = re.sub( + r'(^|[\[, ]+)' + re.escape(module.__name__ + '.'), + r'\1', + typ, + ) + elif module and typ.startswith(module.__name__ + '.'): stripped_type = typ[len(module.__name__) + 1:] elif '.' in typ: arg_module = typ[:typ.rindex('.')] diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index 3566f03fb9a1..5cc9428a47e0 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -794,6 +794,81 @@ def get_attribute(self) -> None: generate_c_property_stub('attribute', TestClass.attribute, output, readonly=True) assert_equal(output, ['@property', 'def attribute(self) -> str: ...']) + def test_generate_c_type_with_single_arg_generic(self) -> None: + class TestClass: + def test(self, arg0: str) -> None: + """ + test(self: TestClass, arg0: List[int]) + """ + pass + output = [] # type: List[str] + imports = [] # type: List[str] + mod = ModuleType(TestClass.__module__, '') + generate_c_function_stub(mod, 'test', TestClass.test, output, imports, + self_var='self', class_name='TestClass') + assert_equal(output, ['def test(self, arg0: List[int]) -> Any: ...']) + assert_equal(imports, []) + + def test_generate_c_type_with_double_arg_generic(self) -> None: + class TestClass: + def test(self, arg0: str) -> None: + """ + test(self: TestClass, arg0: Dict[str, int]) + """ + pass + output = [] # type: List[str] + imports = [] # type: List[str] + mod = ModuleType(TestClass.__module__, '') + generate_c_function_stub(mod, 'test', TestClass.test, output, imports, + self_var='self', class_name='TestClass') + assert_equal(output, ['def test(self, arg0: Dict[str,int]) -> Any: ...']) + assert_equal(imports, []) + + def test_generate_c_type_with_nested_generic(self) -> None: + class TestClass: + def test(self, arg0: str) -> None: + """ + test(self: TestClass, arg0: Dict[str, List[int]]) + """ + pass + output = [] # type: List[str] + imports = [] # type: List[str] + mod = ModuleType(TestClass.__module__, '') + generate_c_function_stub(mod, 'test', TestClass.test, output, imports, + self_var='self', class_name='TestClass') + assert_equal(output, ['def test(self, arg0: Dict[str,List[int]]) -> Any: ...']) + assert_equal(imports, []) + + def test_generate_c_type_with_generic_using_other_module_first(self) -> None: + class TestClass: + def test(self, arg0: str) -> None: + """ + test(self: TestClass, arg0: Dict[argparse.Action, int]) + """ + pass + output = [] # type: List[str] + imports = [] # type: List[str] + mod = ModuleType(TestClass.__module__, '') + generate_c_function_stub(mod, 'test', TestClass.test, output, imports, + self_var='self', class_name='TestClass') + assert_equal(output, ['def test(self, arg0: Dict[argparse.Action,int]) -> Any: ...']) + assert_equal(imports, ['import argparse']) + + def test_generate_c_type_with_generic_using_other_module_last(self) -> None: + class TestClass: + def test(self, arg0: str) -> None: + """ + test(self: TestClass, arg0: Dict[str, argparse.Action]) + """ + pass + output = [] # type: List[str] + imports = [] # type: List[str] + mod = ModuleType(TestClass.__module__, '') + generate_c_function_stub(mod, 'test', TestClass.test, output, imports, + self_var='self', class_name='TestClass') + assert_equal(output, ['def test(self, arg0: Dict[str,argparse.Action]) -> Any: ...']) + assert_equal(imports, ['import argparse']) + def test_generate_c_type_with_overload_pybind11(self) -> None: class TestClass: def __init__(self, arg0: str) -> None: