Skip to content

Commit

Permalink
[mypyc] Add a special case for len() of a str value (#10710)
Browse files Browse the repository at this point in the history
Closes mypyc/mypyc#835

* Add a branch for str_rpimitive in builtin_len
* Reduce redundant code
* Faster list/tuple built from str
  • Loading branch information
97littleleaf11 authored Jul 1, 2021
1 parent 1985928 commit 49bb90a
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 91 deletions.
5 changes: 3 additions & 2 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from mypyc.ir.rtypes import (
RType, is_short_int_rprimitive, is_list_rprimitive, is_sequence_rprimitive,
is_tuple_rprimitive, is_dict_rprimitive,
is_tuple_rprimitive, is_dict_rprimitive, is_str_rprimitive,
RTuple, short_int_rprimitive, int_rprimitive
)
from mypyc.primitives.registry import CFunctionDescription
Expand Down Expand Up @@ -164,7 +164,8 @@ def sequence_from_generator_preallocate_helper(
"""
if len(gen.sequences) == 1 and len(gen.indices) == 1 and len(gen.condlists[0]) == 0:
rtype = builder.node_type(gen.sequences[0])
if is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype):
if (is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype)
or is_str_rprimitive(rtype)):
sequence = builder.accept(gen.sequences[0])
length = builder.builder.builtin_len(sequence, gen.line, use_pyssize_t=True)
target_op = empty_op_llbuilder(length, gen.line)
Expand Down
38 changes: 18 additions & 20 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
list_tuple_op, new_tuple_op, new_tuple_with_length_op
)
from mypyc.primitives.dict_ops import (
dict_update_in_display_op, dict_new_op, dict_build_op, dict_size_op
dict_update_in_display_op, dict_new_op, dict_build_op, dict_ssize_t_size_op
)
from mypyc.primitives.generic_ops import (
py_getattr_op, py_call_op, py_call_with_kwargs_op, py_method_call_op,
Expand All @@ -64,7 +64,9 @@
)
from mypyc.primitives.int_ops import int_comparison_op_mapping
from mypyc.primitives.exc_ops import err_occurred_op, keep_propagating_op
from mypyc.primitives.str_ops import unicode_compare, str_check_if_true
from mypyc.primitives.str_ops import (
unicode_compare, str_check_if_true, str_ssize_t_size_op
)
from mypyc.primitives.set_ops import new_set_op
from mypyc.rt_subtype import is_runtime_subtype
from mypyc.subtype import is_subtype
Expand Down Expand Up @@ -1125,32 +1127,28 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val
Return c_pyssize_t if use_pyssize_t is true (unshifted).
"""
typ = val.type
size_value = None
if is_list_rprimitive(typ) or is_tuple_rprimitive(typ):
elem_address = self.add(GetElementPtr(val, PyVarObject, 'ob_size'))
size_value = self.add(LoadMem(c_pyssize_t_rprimitive, elem_address))
self.add(KeepAlive([val]))
if use_pyssize_t:
return size_value
offset = Integer(1, c_pyssize_t_rprimitive, line)
return self.int_op(short_int_rprimitive, size_value, offset,
IntOp.LEFT_SHIFT, line)
elif is_dict_rprimitive(typ):
size_value = self.call_c(dict_size_op, [val], line)
if use_pyssize_t:
return size_value
offset = Integer(1, c_pyssize_t_rprimitive, line)
return self.int_op(short_int_rprimitive, size_value, offset,
IntOp.LEFT_SHIFT, line)
elif is_set_rprimitive(typ):
elem_address = self.add(GetElementPtr(val, PySetObject, 'used'))
size_value = self.add(LoadMem(c_pyssize_t_rprimitive, elem_address))
self.add(KeepAlive([val]))
elif is_dict_rprimitive(typ):
size_value = self.call_c(dict_ssize_t_size_op, [val], line)
elif is_str_rprimitive(typ):
size_value = self.call_c(str_ssize_t_size_op, [val], line)

if size_value is not None:
if use_pyssize_t:
return size_value
offset = Integer(1, c_pyssize_t_rprimitive, line)
return self.int_op(short_int_rprimitive, size_value, offset,
IntOp.LEFT_SHIFT, line)
elif isinstance(typ, RInstance):

if isinstance(typ, RInstance):
# TODO: Support use_pyssize_t
assert not use_pyssize_t
length = self.gen_method_call(val, '__len__', [], int_rprimitive, line)
Expand All @@ -1164,12 +1162,12 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val
self.add(Unreachable())
self.activate_block(ok)
return length

# generic case
if use_pyssize_t:
return self.call_c(generic_ssize_t_len_op, [val], line)
else:
# generic case
if use_pyssize_t:
return self.call_c(generic_ssize_t_len_op, [val], line)
else:
return self.call_c(generic_len_op, [val], line)
return self.call_c(generic_len_op, [val], line)

def new_tuple(self, items: List[Value], line: int) -> Value:
size: Value = Integer(len(items), c_pyssize_t_rprimitive)
Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def dict_methods_fast_path(
def translate_list_from_generator_call(
builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]:
# Special case for simplest list comprehension, for example
# list(f(x) for x in other_list/other_tuple)
# list(f(x) for x in some_list/some_tuple/some_str)
# translate_list_comprehension() would take care of other cases if this fails.
if (len(expr.args) == 1
and expr.arg_kinds[0] == ARG_POS
Expand All @@ -142,7 +142,7 @@ def translate_list_from_generator_call(
def translate_tuple_from_generator_call(
builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]:
# Special case for simplest tuple creation from a generator, for example
# tuple(f(x) for x in other_list/other_tuple)
# tuple(f(x) for x in some_list/some_tuple/some_str)
# translate_safe_generator_call() would take care of other cases if this fails.
if (len(expr.args) == 1
and expr.arg_kinds[0] == ARG_POS
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
bool CPyStr_Startswith(PyObject *self, PyObject *subobj);
bool CPyStr_Endswith(PyObject *self, PyObject *subobj);
bool CPyStr_IsTrue(PyObject *obj);
Py_ssize_t CPyStr_Size_size_t(PyObject *str);


// Set operations
Expand Down
8 changes: 8 additions & 0 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,16 @@ PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
}
return CPyObject_GetSlice(obj, start, end);
}

/* Check if the given string is true (i.e. it's length isn't zero) */
bool CPyStr_IsTrue(PyObject *obj) {
Py_ssize_t length = PyUnicode_GET_LENGTH(obj);
return length != 0;
}

Py_ssize_t CPyStr_Size_size_t(PyObject *str) {
if (PyUnicode_READY(str) != -1) {
return PyUnicode_GET_LENGTH(str);
}
return -1;
}
2 changes: 1 addition & 1 deletion mypyc/primitives/dict_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
c_function_name='CPyDict_CheckSize',
error_kind=ERR_FALSE)

dict_size_op = custom_op(
dict_ssize_t_size_op = custom_op(
arg_types=[dict_rprimitive],
return_type=c_pyssize_t_rprimitive,
c_function_name='PyDict_Size',
Expand Down
19 changes: 13 additions & 6 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
from mypyc.ir.rtypes import (
RType, object_rprimitive, str_rprimitive, int_rprimitive, list_rprimitive,
c_int_rprimitive, pointer_rprimitive, bool_rprimitive, bit_rprimitive
c_int_rprimitive, pointer_rprimitive, bool_rprimitive, bit_rprimitive,
c_pyssize_t_rprimitive
)
from mypyc.primitives.registry import (
method_op, binary_op, function_op,
load_address_op, custom_op
load_address_op, custom_op, ERR_NEG_INT
)


Expand Down Expand Up @@ -89,7 +90,7 @@

# str1 += str2
#
# PyUnicodeAppend makes an effort to reuse the LHS when the refcount
# PyUnicode_Append makes an effort to reuse the LHS when the refcount
# is 1. This is super dodgy but oh well, the interpreter does it.
binary_op(name='+=',
arg_types=[str_rprimitive, str_rprimitive],
Expand All @@ -116,7 +117,7 @@
name='replace',
arg_types=[str_rprimitive, str_rprimitive, str_rprimitive],
return_type=str_rprimitive,
c_function_name="PyUnicode_Replace",
c_function_name='PyUnicode_Replace',
error_kind=ERR_MAGIC,
extra_int_constants=[(-1, c_int_rprimitive)])

Expand All @@ -125,13 +126,19 @@
name='replace',
arg_types=[str_rprimitive, str_rprimitive, str_rprimitive, int_rprimitive],
return_type=str_rprimitive,
c_function_name="CPyStr_Replace",
c_function_name='CPyStr_Replace',
error_kind=ERR_MAGIC)

# check if a string is true (isn't an empty string)
str_check_if_true = custom_op(
arg_types=[str_rprimitive],
return_type=bit_rprimitive,
c_function_name="CPyStr_IsTrue",
c_function_name='CPyStr_IsTrue',
error_kind=ERR_NEVER,
)

str_ssize_t_size_op = custom_op(
arg_types=[str_rprimitive],
return_type=c_pyssize_t_rprimitive,
c_function_name='CPyStr_Size_size_t',
error_kind=ERR_NEG_INT)
1 change: 0 additions & 1 deletion mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,3 @@ L2:
return 0
L3:
unreachable

91 changes: 32 additions & 59 deletions mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,13 @@ L4:
return 1


[case testTupleBuiltFromList2]
[case testTupleBuiltFromStr]
def f2(val: str) -> str:
return val + "f2"

def test() -> None:
source = ["a", "b", "c"]
source = "abc"
a = tuple(f2(x) for x in source)
print(a)
[out]
def f2(val):
val, r0, r1 :: str
Expand All @@ -366,71 +365,45 @@ L0:
r1 = PyUnicode_Concat(val, r0)
return r1
def test():
r0, r1, r2 :: str
r3 :: list
r4, r5, r6, r7 :: ptr
source :: list
r8 :: ptr
r9 :: native_int
r10 :: tuple
r11 :: short_int
r12 :: ptr
r13 :: native_int
r14 :: short_int
r15 :: bit
r16 :: object
r17, x, r18 :: str
r19 :: bit
r20 :: short_int
r0, source :: str
r1 :: native_int
r2 :: bit
r3 :: tuple
r4 :: short_int
r5 :: native_int
r6 :: bit
r7 :: short_int
r8 :: bit
r9, x, r10 :: str
r11 :: bit
r12 :: short_int
a :: tuple
r21 :: object
r22 :: str
r23, r24 :: object
L0:
r0 = 'a'
r1 = 'b'
r2 = 'c'
r3 = PyList_New(3)
r4 = get_element_ptr r3 ob_item :: PyListObject
r5 = load_mem r4 :: ptr*
set_mem r5, r0 :: builtins.object*
r6 = r5 + WORD_SIZE*1
set_mem r6, r1 :: builtins.object*
r7 = r5 + WORD_SIZE*2
set_mem r7, r2 :: builtins.object*
keep_alive r3
source = r3
r8 = get_element_ptr source ob_size :: PyVarObject
r9 = load_mem r8 :: native_int*
keep_alive source
r10 = PyTuple_New(r9)
r11 = 0
r0 = 'abc'
source = r0
r1 = CPyStr_Size_size_t(source)
r2 = r1 >= 0 :: signed
r3 = PyTuple_New(r1)
r4 = 0
L1:
r12 = get_element_ptr source ob_size :: PyVarObject
r13 = load_mem r12 :: native_int*
keep_alive source
r14 = r13 << 1
r15 = r11 < r14 :: signed
if r15 goto L2 else goto L4 :: bool
r5 = CPyStr_Size_size_t(source)
r6 = r5 >= 0 :: signed
r7 = r5 << 1
r8 = r4 < r7 :: signed
if r8 goto L2 else goto L4 :: bool
L2:
r16 = CPyList_GetItemUnsafe(source, r11)
r17 = cast(str, r16)
x = r17
r18 = f2(x)
r19 = CPySequenceTuple_SetItemUnsafe(r10, r11, r18)
r9 = CPyStr_GetItem(source, r4)
x = r9
r10 = f2(x)
r11 = CPySequenceTuple_SetItemUnsafe(r3, r4, r10)
L3:
r20 = r11 + 2
r11 = r20
r12 = r4 + 2
r4 = r12
goto L1
L4:
a = r10
r21 = builtins :: module
r22 = 'print'
r23 = CPyObject_GetAttr(r21, r22)
r24 = PyObject_CallFunctionObjArgs(r23, a, 0)
a = r3
return 1


[case testTupleBuiltFromVariableLengthTuple]
from typing import Tuple

Expand Down
4 changes: 4 additions & 0 deletions mypyc/test-data/run-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,7 @@ def test() -> None:
source_e = [0, 1, 2]
e = list((x ** 2) for x in (y + 2 for y in source_e))
assert e == [4, 9, 16]
source_str = "abcd"
f = list("str:" + x for x in source_str)
assert f == ["str:a", "str:b", "str:c", "str:d"]

4 changes: 4 additions & 0 deletions mypyc/test-data/run-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def test_sequence_generator() -> None:
a = tuple(f8(x) for x in source_fixed_length_tuple)
assert a == (False, True, False, True)

source_str = 'abbc'
b = tuple('s:' + x for x in source_str)
assert b == ('s:a', 's:b', 's:b', 's:c')

TUPLE: Final[Tuple[str, ...]] = ('x', 'y')

def test_final_boxed_tuple() -> None:
Expand Down

0 comments on commit 49bb90a

Please sign in to comment.