diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index a5603765727c..6398dbc682cd 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -400,6 +400,8 @@ Py_ssize_t CPyStr_Size_size_t(PyObject *str); // Bytes operations +PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); + // Set operations diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 99771bdf926e..4bc014b5fd45 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -4,3 +4,13 @@ #include #include "CPy.h" + +// Like _PyBytes_Join but fallback to dynamic call if 'sep' is not bytes +// (mostly commonly, for bytearrays) +PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter) { + if (PyBytes_CheckExact(sep)) { + return _PyBytes_Join(sep, iter); + } else { + return PyObject_CallMethod(sep, "join", "(O)", iter); + } +} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index a5963d2ad8fa..0ddc2e550bdb 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -5,7 +5,7 @@ object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive, str_rprimitive, RUnion ) -from mypyc.primitives.registry import load_address_op, function_op +from mypyc.primitives.registry import load_address_op, function_op, method_op # Get the 'bytes' type object. @@ -29,3 +29,12 @@ return_type=bytes_rprimitive, c_function_name='PyByteArray_FromObject', error_kind=ERR_MAGIC) + +# bytes.join(obj) +method_op( + name='join', + arg_types=[bytes_rprimitive, object_rprimitive], + return_type=bytes_rprimitive, + c_function_name='CPyBytes_Join', + error_kind=ERR_MAGIC +) diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 479f97872f5e..a0c84014edc0 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -62,3 +62,17 @@ L0: c = r6 return 1 + +[case testBytesJoin] +from typing import List + +def f(b: List[bytes]) -> bytes: + return b" ".join(b) +[out] +def f(b): + b :: list + r0, r1 :: bytes +L0: + r0 = b' ' + r1 = CPyBytes_Join(r0, b) + return r1 diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index 17ebe6085fec..fe248834497d 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -97,3 +97,28 @@ def test_bytearray_passed_into_bytes() -> None: assert f(bytearray(3)) brr1: Any = bytearray() assert f(brr1) + +[case testBytesJoin] +from typing import Any +from testutil import assertRaises +from a import bytes_subclass + +def test_bytes_join() -> None: + assert b' '.join([b'a', b'b']) == b'a b' + assert b' '.join([]) == b'' + + x: bytes = bytearray(b' ') + assert x.join([b'a', b'b']) == b'a b' + assert type(x.join([b'a', b'b'])) == bytearray + + y: bytes = bytes_subclass() + assert y.join([]) == b'spook' + + n: Any = 5 + with assertRaises(TypeError, "can only join an iterable"): + assert b' '.join(n) + +[file a.py] +class bytes_subclass(bytes): + def join(self, iter): + return b'spook'