Skip to content

Commit

Permalink
Add primitive for bytes join() method (#10929)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhance authored Aug 5, 2021
1 parent e734321 commit 97a1b3f
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 1 deletion.
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ Py_ssize_t CPyStr_Size_size_t(PyObject *str);
// Bytes operations


PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);


// Set operations

Expand Down
10 changes: 10 additions & 0 deletions mypyc/lib-rt/bytes_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,13 @@

#include <Python.h>
#include "CPy.h"

// Like _PyBytes_Join but fallback to dynamic call if 'sep' is not bytes
// (mostly commonly, for bytearrays)
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter) {
if (PyBytes_CheckExact(sep)) {
return _PyBytes_Join(sep, iter);
} else {
return PyObject_CallMethod(sep, "join", "(O)", iter);
}
}
11 changes: 10 additions & 1 deletion mypyc/primitives/bytes_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive,
str_rprimitive, RUnion
)
from mypyc.primitives.registry import load_address_op, function_op
from mypyc.primitives.registry import load_address_op, function_op, method_op


# Get the 'bytes' type object.
Expand All @@ -29,3 +29,12 @@
return_type=bytes_rprimitive,
c_function_name='PyByteArray_FromObject',
error_kind=ERR_MAGIC)

# bytes.join(obj)
method_op(
name='join',
arg_types=[bytes_rprimitive, object_rprimitive],
return_type=bytes_rprimitive,
c_function_name='CPyBytes_Join',
error_kind=ERR_MAGIC
)
14 changes: 14 additions & 0 deletions mypyc/test-data/irbuild-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,17 @@ L0:
c = r6
return 1


[case testBytesJoin]
from typing import List

def f(b: List[bytes]) -> bytes:
return b" ".join(b)
[out]
def f(b):
b :: list
r0, r1 :: bytes
L0:
r0 = b' '
r1 = CPyBytes_Join(r0, b)
return r1
25 changes: 25 additions & 0 deletions mypyc/test-data/run-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,28 @@ def test_bytearray_passed_into_bytes() -> None:
assert f(bytearray(3))
brr1: Any = bytearray()
assert f(brr1)

[case testBytesJoin]
from typing import Any
from testutil import assertRaises
from a import bytes_subclass

def test_bytes_join() -> None:
assert b' '.join([b'a', b'b']) == b'a b'
assert b' '.join([]) == b''

x: bytes = bytearray(b' ')
assert x.join([b'a', b'b']) == b'a b'
assert type(x.join([b'a', b'b'])) == bytearray

y: bytes = bytes_subclass()
assert y.join([]) == b'spook'

n: Any = 5
with assertRaises(TypeError, "can only join an iterable"):
assert b' '.join(n)

[file a.py]
class bytes_subclass(bytes):
def join(self, iter):
return b'spook'

0 comments on commit 97a1b3f

Please sign in to comment.