Skip to content

Commit

Permalink
[mypyc] Compile away NewType type calls (#14398)
Browse files Browse the repository at this point in the history
For example, here the call to ID is simply a no-op at runtime, returning
1 unchanged.

    ID = NewType('ID', int)
    person = ID(1)

Resolves mypyc/mypyc#958
  • Loading branch information
ichard26 authored Jan 5, 2023
1 parent 2413578 commit ca66805
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
9 changes: 8 additions & 1 deletion mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,20 @@ def transform_super_expr(builder: IRBuilder, o: SuperExpr) -> Value:


def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value:
callee = expr.callee
if isinstance(expr.analyzed, CastExpr):
return translate_cast_expr(builder, expr.analyzed)
elif isinstance(expr.analyzed, AssertTypeExpr):
# Compile to a no-op.
return builder.accept(expr.analyzed.expr)
elif (
isinstance(callee, (NameExpr, MemberExpr))
and isinstance(callee.node, TypeInfo)
and callee.node.is_newtype
):
# A call to a NewType type is a no-op at runtime.
return builder.accept(expr.args[0])

callee = expr.callee
if isinstance(callee, IndexExpr) and isinstance(callee.analyzed, TypeApplication):
callee = callee.analyzed.expr # Unwrap type application

Expand Down
32 changes: 20 additions & 12 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -2338,11 +2338,8 @@ def __top_level__():
r92, r93, r94, r95 :: ptr
r96 :: dict
r97 :: str
r98, r99 :: object
r100 :: dict
r101 :: str
r102 :: int32
r103 :: bit
r98 :: int32
r99 :: bit
L0:
r0 = builtins :: module
r1 = load_address _Py_NoneStruct
Expand Down Expand Up @@ -2454,13 +2451,9 @@ L2:
set_mem r95, r91 :: builtins.object*
keep_alive r88
r96 = __main__.globals :: static
r97 = 'Bar'
r98 = CPyDict_GetItem(r96, r97)
r99 = PyObject_CallFunctionObjArgs(r98, r88, 0)
r100 = __main__.globals :: static
r101 = 'y'
r102 = CPyDict_SetItem(r100, r101, r99)
r103 = r102 >= 0 :: signed
r97 = 'y'
r98 = CPyDict_SetItem(r96, r97, r88)
r99 = r98 >= 0 :: signed
return 1

[case testChainedConditional]
Expand Down Expand Up @@ -3584,3 +3577,18 @@ L0:
r3 = 0.0
i__redef____redef__ = r3
return 1

[case testNewType]
from typing import NewType

class A: pass

N = NewType("N", A)

def f(arg: A) -> N:
return N(arg)
[out]
def f(arg):
arg :: __main__.A
L0:
return arg

0 comments on commit ca66805

Please sign in to comment.