Skip to content

Commit

Permalink
[mypyc] Add primitives and specialization for ord() (#18240)
Browse files Browse the repository at this point in the history
This makes a microbenchmark adapted from an internal production codebase
that heavily uses `ord()` over 10x faster.

Work on mypyc/mypyc#644 and mypyc/mypyc#880.
  • Loading branch information
JukkaL authored Dec 4, 2024
1 parent cc45bec commit ee19ea7
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 10 deletions.
6 changes: 6 additions & 0 deletions mypyc/doc/str_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ Methods
* ``s.split(sep: str)``
* ``s.split(sep: str, maxsplit: int)``
* ``s1.startswith(s2: str)``

Functions
---------

* ``len(s: str)``
* ``ord(s: str)``
11 changes: 11 additions & 0 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mypy.nodes import (
ARG_NAMED,
ARG_POS,
BytesExpr,
CallExpr,
DictExpr,
Expression,
Expand Down Expand Up @@ -877,3 +878,13 @@ def translate_float(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Valu
# No-op float conversion.
return builder.accept(arg)
return None


@specialize_function("builtins.ord")
def translate_ord(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:
return None
arg = expr.args[0]
if isinstance(arg, (StrExpr, BytesExpr)) and len(arg.value) == 1:
return Integer(ord(arg.value))
return None
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ bool CPyStr_IsTrue(PyObject *obj);
Py_ssize_t CPyStr_Size_size_t(PyObject *str);
PyObject *CPy_Decode(PyObject *obj, PyObject *encoding, PyObject *errors);
PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors);
CPyTagged CPyStr_Ord(PyObject *obj);


// Bytes operations
Expand All @@ -740,6 +741,7 @@ PyObject *CPyBytes_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index);
PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
CPyTagged CPyBytes_Ord(PyObject *obj);


int CPyBytes_Compare(PyObject *left, PyObject *right);
Expand Down
17 changes: 17 additions & 0 deletions mypyc/lib-rt/bytes_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,20 @@ PyObject *CPyBytes_Build(Py_ssize_t len, ...) {

return (PyObject *)ret;
}


CPyTagged CPyBytes_Ord(PyObject *obj) {
if (PyBytes_Check(obj)) {
Py_ssize_t s = PyBytes_GET_SIZE(obj);
if (s == 1) {
return (unsigned char)(PyBytes_AS_STRING(obj)[0]) << 1;
}
} else if (PyByteArray_Check(obj)) {
Py_ssize_t s = PyByteArray_GET_SIZE(obj);
if (s == 1) {
return (unsigned char)(PyByteArray_AS_STRING(obj)[0]) << 1;
}
}
PyErr_SetString(PyExc_TypeError, "ord() expects a character");
return CPY_INT_TAG;
}
12 changes: 12 additions & 0 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,15 @@ PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors) {
return NULL;
}
}


CPyTagged CPyStr_Ord(PyObject *obj) {
Py_ssize_t s = PyUnicode_GET_LENGTH(obj);
if (s == 1) {
int kind = PyUnicode_KIND(obj);
return PyUnicode_READ(kind, PyUnicode_DATA(obj), 0) << 1;
}
PyErr_Format(
PyExc_TypeError, "ord() expected a character, but a string of length %zd found", s);
return CPY_INT_TAG;
}
8 changes: 8 additions & 0 deletions mypyc/primitives/bytes_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,11 @@
error_kind=ERR_MAGIC,
var_arg_type=bytes_rprimitive,
)

function_op(
name="builtins.ord",
arg_types=[bytes_rprimitive],
return_type=int_rprimitive,
c_function_name="CPyBytes_Ord",
error_kind=ERR_MAGIC,
)
8 changes: 8 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,11 @@
c_function_name="CPy_Encode",
error_kind=ERR_MAGIC,
)

function_op(
name="builtins.ord",
arg_types=[str_rprimitive],
return_type=int_rprimitive,
c_function_name="CPyStr_Ord",
error_kind=ERR_MAGIC,
)
43 changes: 43 additions & 0 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,46 @@ L0:
r37 = 'latin2'
r38 = CPy_Encode(s, r37, 0)
return 1

[case testOrd]
def str_ord(x: str) -> int:
return ord(x)
def str_ord_literal() -> int:
return ord("a")
def bytes_ord(x: bytes) -> int:
return ord(x)
def bytes_ord_literal() -> int:
return ord(b"a")
def any_ord(x) -> int:
return ord(x)
[out]
def str_ord(x):
x :: str
r0 :: int
L0:
r0 = CPyStr_Ord(x)
return r0
def str_ord_literal():
L0:
return 194
def bytes_ord(x):
x :: bytes
r0 :: int
L0:
r0 = CPyBytes_Ord(x)
return r0
def bytes_ord_literal():
L0:
return 194
def any_ord(x):
x, r0 :: object
r1 :: str
r2, r3 :: object
r4 :: int
L0:
r0 = builtins :: module
r1 = 'ord'
r2 = CPyObject_GetAttr(r0, r1)
r3 = PyObject_CallFunctionObjArgs(r2, x, 0)
r4 = unbox(int, r3)
return r4
23 changes: 23 additions & 0 deletions mypyc/test-data/run-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,29 @@ def test_len() -> None:
assert len(b) == 3
assert len(bytes()) == 0

def test_ord() -> None:
assert ord(b'a') == ord('a')
assert ord(b'a' + bytes()) == ord('a')
assert ord(b'\x00') == 0
assert ord(b'\x00' + bytes()) == 0
assert ord(b'\xfe') == 254
assert ord(b'\xfe' + bytes()) == 254

with assertRaises(TypeError):
ord(b'aa')
with assertRaises(TypeError):
ord(b'')

def test_ord_bytesarray() -> None:
assert ord(bytearray(b'a')) == ord('a')
assert ord(bytearray(b'\x00')) == 0
assert ord(bytearray(b'\xfe')) == 254

with assertRaises(TypeError):
ord(bytearray(b'aa'))
with assertRaises(TypeError):
ord(bytearray(b''))

[case testBytesSlicing]
def test_bytes_slicing() -> None:
b = b'abcdefg'
Expand Down
27 changes: 17 additions & 10 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -565,25 +565,32 @@ def test_chr() -> None:
assert try_invalid(1114112)

[case testOrd]
from testutil import assertRaises

def test_ord() -> None:
assert ord(' ') == 32
assert ord(' ' + str()) == 32
assert ord('\x00') == 0
assert ord('\x00' + str()) == 0
assert ord('\ue000') == 57344
s = "a\xac\u1234\u20ac\U00008000"
# ^^^^ two-digit hex escape
# ^^^^^^ four-digit Unicode escape
# ^^^^^^^^^^ eight-digit Unicode escape
assert ord('\ue000' + str()) == 57344
s = "a\xac\u1234\u20ac\U00010000"
# ^^^^ two-digit hex escape
# ^^^^^^ four-digit Unicode escape
# ^^^^^^^^^^ eight-digit Unicode escape
l1 = [ord(c) for c in s]
assert l1 == [97, 172, 4660, 8364, 32768]
assert l1 == [97, 172, 4660, 8364, 65536]
u = 'abcdé'
assert ord(u[-1]) == 233
assert ord(b'a') == 97
assert ord(b'a' + bytes()) == 97
u2 = '\U0010ffff'
u2 = '\U0010ffff' + str()
assert ord(u2) == 1114111
try:
assert ord('\U0010ffff') == 1114111
with assertRaises(TypeError, "ord() expected a character, but a string of length 2 found"):
ord('aa')
assert False
except TypeError:
pass
with assertRaises(TypeError):
ord('')

[case testDecode]
def test_decode() -> None:
Expand Down

0 comments on commit ee19ea7

Please sign in to comment.