Skip to content

Commit

Permalink
[mypyc] Add 'range' primitive type (#10307)
Browse files Browse the repository at this point in the history
  • Loading branch information
97littleleaf11 authored Jun 8, 2021
1 parent 0af616e commit 0a6a48c
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 10 deletions.
12 changes: 7 additions & 5 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive,
is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive,
int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive,
is_int64_rprimitive, is_bit_rprimitive
is_int64_rprimitive, is_bit_rprimitive, is_range_rprimitive
)
from mypyc.ir.func_ir import FuncDecl
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
Expand Down Expand Up @@ -410,8 +410,8 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False,

# TODO: Verify refcount handling.
if (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ)
or is_float_rprimitive(typ) or is_str_rprimitive(typ) or is_int_rprimitive(typ)
or is_bool_rprimitive(typ)):
or is_str_rprimitive(typ) or is_range_rprimitive(typ) or is_float_rprimitive(typ)
or is_int_rprimitive(typ) or is_bool_rprimitive(typ)):
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
if is_list_rprimitive(typ):
Expand All @@ -420,10 +420,12 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
prefix = 'PyDict'
elif is_set_rprimitive(typ):
prefix = 'PySet'
elif is_float_rprimitive(typ):
prefix = 'CPyFloat'
elif is_str_rprimitive(typ):
prefix = 'PyUnicode'
elif is_range_rprimitive(typ):
prefix = 'PyRange'
elif is_float_rprimitive(typ):
prefix = 'CPyFloat'
elif is_int_rprimitive(typ):
prefix = 'PyLong'
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
Expand Down
10 changes: 9 additions & 1 deletion mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ def __hash__(self) -> int:
tuple_rprimitive = RPrimitive('builtins.tuple', is_unboxed=False,
is_refcounted=True) # type: Final

# Python range object.
range_rprimitive = RPrimitive('builtins.range', is_unboxed=False,
is_refcounted=True) # type: Final


def is_tagged(rtype: RType) -> bool:
return rtype is int_rprimitive or rtype is short_int_rprimitive
Expand Down Expand Up @@ -405,6 +409,10 @@ def is_tuple_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.tuple'


def is_range_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.range'


def is_sequence_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and (
is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) or is_str_rprimitive(rtype)
Expand Down Expand Up @@ -805,5 +813,5 @@ def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'RArray':
PyListObject = RStruct(
name='PyListObject',
names=['ob_base', 'ob_item', 'allocated'],
types=[PyObject, pointer_rprimitive, c_pyssize_t_rprimitive]
types=[PyVarObject, pointer_rprimitive, c_pyssize_t_rprimitive]
)
4 changes: 3 additions & 1 deletion mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mypyc.ir.rtypes import (
RType, RUnion, RTuple, RInstance, object_rprimitive, dict_rprimitive, tuple_rprimitive,
none_rprimitive, int_rprimitive, float_rprimitive, str_rprimitive, bool_rprimitive,
list_rprimitive, set_rprimitive
list_rprimitive, set_rprimitive, range_rprimitive
)
from mypyc.ir.func_ir import FuncSignature, FuncDecl, RuntimeArg
from mypyc.ir.class_ir import ClassIR
Expand Down Expand Up @@ -58,6 +58,8 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType:
return set_rprimitive
elif typ.type.fullname == 'builtins.tuple':
return tuple_rprimitive # Varying-length tuple
elif typ.type.fullname == 'builtins.range':
return range_rprimitive
elif typ.type in self.type_to_ir:
inst = RInstance(self.type_to_ir[typ.type])
# Treat protocols as Union[protocol, object], so that we can do fast
Expand Down
6 changes: 6 additions & 0 deletions mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
type=object_rprimitive,
src='PyBool_Type')

# Get the 'range' type object.
load_address_op(
name='builtins.range',
type=object_rprimitive,
src='PyRange_Type')

# Get the boxed Python 'None' object
none_object_op = load_address_op(
name='Py_None',
Expand Down
7 changes: 6 additions & 1 deletion mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ def __or__(self, s: Set[S]) -> Set[Union[T, S]]: ...

class slice: pass

class range(Iterable[int]):
def __init__(self, x: int, y: int = ..., z: int = ...) -> None: pass
def __iter__(self) -> Iterator[int]: pass
def __len__(self) -> int: pass
def __next__(self) -> int: pass

class property:
def __init__(self, fget: Optional[Callable[[Any], Any]] = ...,
fset: Optional[Callable[[Any, Any], None]] = ...,
Expand Down Expand Up @@ -245,7 +251,6 @@ def id(o: object) -> int: pass
# This type is obviously wrong but the test stubs don't have Sized anymore
def len(o: object) -> int: pass
def print(*object) -> None: pass
def range(x: int, y: int = ..., z: int = ...) -> Iterator[int]: pass
def isinstance(x: object, t: object) -> bool: pass
def iter(i: Iterable[T]) -> Iterator[T]: pass
@overload
Expand Down
68 changes: 68 additions & 0 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3649,3 +3649,71 @@ L0:
r2 = r1 >= 0 :: signed
r3 = truncate r1: int32 to builtins.bool
return r3

[case testRangeObject]
def range_object() -> None:
r = range(4, 12, 2)
sum = 0
for i in r:
sum += i

def range_in_loop() -> None:
sum = 0
for i in range(4, 12, 2):
sum += i
[out]
def range_object():
r0, r1, r2, r3, r4 :: object
r5, r :: range
sum :: int
r6, r7 :: object
r8, i, r9 :: int
r10 :: bit
L0:
r0 = load_address PyRange_Type
r1 = box(short_int, 8)
r2 = box(short_int, 24)
r3 = box(short_int, 4)
r4 = PyObject_CallFunctionObjArgs(r0, r1, r2, r3, 0)
r5 = cast(range, r4)
r = r5
sum = 0
r6 = PyObject_GetIter(r)
L1:
r7 = PyIter_Next(r6)
if is_error(r7) goto L4 else goto L2
L2:
r8 = unbox(int, r7)
i = r8
r9 = CPyTagged_Add(sum, i)
sum = r9
L3:
goto L1
L4:
r10 = CPy_NoErrOccured()
L5:
return 1
def range_in_loop():
sum :: int
r0 :: short_int
i :: int
r1 :: bit
r2 :: int
r3 :: short_int
L0:
sum = 0
r0 = 8
i = r0
L1:
r1 = r0 < 24 :: signed
if r1 goto L2 else goto L4 :: bool
L2:
r2 = CPyTagged_Add(sum, i)
sum = r2
L3:
r3 = r0 + 4
r0 = r3
i = r3
goto L1
L4:
return 1
2 changes: 1 addition & 1 deletion mypyc/test-data/run-dicts.test
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Dict test cases (compile and run)
# Test cases for dicts (compile and run)

[case testDictStuff]
from typing import Dict, Any, List, Set, Tuple
Expand Down
2 changes: 2 additions & 0 deletions mypyc/test-data/run-floats.test
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Test cases for floats (compile and run)

[case testStrToFloat]
def str_to_float(x: str) -> float:
return float(x)
Expand Down
33 changes: 32 additions & 1 deletion mypyc/test-data/run-loops.test
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Test cases for "for" and "while" loops (compile and run)
# Test cases for "range" objects, "for" and "while" loops (compile and run)

[case testFor]
from typing import List, Tuple
Expand Down Expand Up @@ -452,3 +452,34 @@ def bar(x: Optional[str]) -> None:
[file driver.py]
from native import bar
bar(None)

[case testRangeObject]
from typing import Any

def f(x: range) -> int:
sum = 0
for i in x:
sum += i
return sum

def test_range_object() -> None:
r1 = range(4, 12, 2)
tmp_list = [x for x in r1]
assert tmp_list == [4, 6, 8, 10]
assert f(r1) == 28
r2: Any = range(10)
assert f(r2) == 45
r3: Any = 'x'
try:
f(r3)
except TypeError as e:
assert "range object expected; got str" in str(e)
try:
ff: Any = f
ff(r3)
except TypeError as e:
assert "range object expected; got str" in str(e)
try:
r4 = range(4, 12, 0)
except ValueError as e:
assert "range() arg 3 must not be zero" in str(e)

0 comments on commit 0a6a48c

Please sign in to comment.