Skip to content

Commit

Permalink
Call dynamic class hook on generic classes (#16052)
Browse files Browse the repository at this point in the history
Fixes: #8359

CC @sobolevn 

`get_dynamic_class_hook()` will now additionally be called for generic
classes with parameters. e.g.

```python
y = SomeGenericClass[type, ...].method()
```
  • Loading branch information
Petter Friberg authored Sep 19, 2023
1 parent 1dcff0d commit ba978f4
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
7 changes: 7 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3205,6 +3205,13 @@ def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None:
if isinstance(callee_expr, RefExpr) and callee_expr.fullname:
method_name = call.callee.name
fname = callee_expr.fullname + "." + method_name
elif (
isinstance(callee_expr, IndexExpr)
and isinstance(callee_expr.base, RefExpr)
and isinstance(callee_expr.analyzed, TypeApplication)
):
method_name = call.callee.name
fname = callee_expr.base.fullname + "." + method_name
elif isinstance(callee_expr, CallExpr):
# check if chain call
call = callee_expr
Expand Down
12 changes: 11 additions & 1 deletion test-data/unit/check-custom-plugin.test
Original file line number Diff line number Diff line change
Expand Up @@ -684,12 +684,16 @@ plugins=<ROOT>/test-data/unit/plugins/dyn_class.py
[case testDynamicClassHookFromClassMethod]
# flags: --config-file tmp/mypy.ini

from mod import QuerySet, Manager
from mod import QuerySet, Manager, GenericQuerySet

MyManager = Manager.from_queryset(QuerySet)
ManagerFromGenericQuerySet = GenericQuerySet[int].as_manager()

reveal_type(MyManager()) # N: Revealed type is "__main__.MyManager"
reveal_type(MyManager().attr) # N: Revealed type is "builtins.str"
reveal_type(ManagerFromGenericQuerySet()) # N: Revealed type is "__main__.ManagerFromGenericQuerySet"
reveal_type(ManagerFromGenericQuerySet().attr) # N: Revealed type is "builtins.int"
queryset: GenericQuerySet[int] = ManagerFromGenericQuerySet()

def func(manager: MyManager) -> None:
reveal_type(manager) # N: Revealed type is "__main__.MyManager"
Expand All @@ -704,6 +708,12 @@ class QuerySet:
class Manager:
@classmethod
def from_queryset(cls, queryset_cls: Type[QuerySet]): ...
T = TypeVar("T")
class GenericQuerySet(Generic[T]):
attr: T

@classmethod
def as_manager(cls): ...

[builtins fixtures/classmethod.pyi]
[file mypy.ini]
Expand Down
40 changes: 39 additions & 1 deletion test-data/unit/plugins/dyn_class_from_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@

from typing import Callable

from mypy.nodes import GDEF, Block, ClassDef, RefExpr, SymbolTable, SymbolTableNode, TypeInfo
from mypy.nodes import (
GDEF,
Block,
ClassDef,
IndexExpr,
MemberExpr,
NameExpr,
RefExpr,
SymbolTable,
SymbolTableNode,
TypeApplication,
TypeInfo,
)
from mypy.plugin import DynamicClassDefContext, Plugin
from mypy.types import Instance

Expand All @@ -13,6 +25,8 @@ def get_dynamic_class_hook(
) -> Callable[[DynamicClassDefContext], None] | None:
if "from_queryset" in fullname:
return add_info_hook
if "as_manager" in fullname:
return as_manager_hook
return None


Expand All @@ -34,5 +48,29 @@ def add_info_hook(ctx: DynamicClassDefContext) -> None:
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))


def as_manager_hook(ctx: DynamicClassDefContext) -> None:
class_def = ClassDef(ctx.name, Block([]))
class_def.fullname = ctx.api.qualified_name(ctx.name)

info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id)
class_def.info = info
assert isinstance(ctx.call.callee, MemberExpr)
assert isinstance(ctx.call.callee.expr, IndexExpr)
assert isinstance(ctx.call.callee.expr.analyzed, TypeApplication)
assert isinstance(ctx.call.callee.expr.analyzed.expr, NameExpr)

queryset_type_fullname = ctx.call.callee.expr.analyzed.expr.fullname
queryset_node = ctx.api.lookup_fully_qualified_or_none(queryset_type_fullname)
assert queryset_node is not None
queryset_info = queryset_node.node
assert isinstance(queryset_info, TypeInfo)
parameter_type = ctx.call.callee.expr.analyzed.types[0]

obj = ctx.api.named_type("builtins.object")
info.mro = [info, queryset_info, obj.type]
info.bases = [Instance(queryset_info, [parameter_type])]
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))


def plugin(version: str) -> type[DynPlugin]:
return DynPlugin

0 comments on commit ba978f4

Please sign in to comment.