Skip to content

Commit

Permalink
Merge pull request RustPython#3558 from fanninpm/codecs-3.10
Browse files Browse the repository at this point in the history
Update codecs.py to CPython 3.10
  • Loading branch information
youknowone authored Feb 24, 2022
2 parents ef90d09 + ead652b commit ab50217
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 37 deletions.
6 changes: 3 additions & 3 deletions Lib/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
class CodecInfo(tuple):
"""Codec details when looking up the codec registry"""

# Private API to allow Python 3.4 to blacklist the known non-Unicode
# Private API to allow Python 3.4 to denylist the known non-Unicode
# codecs in the standard library. A more general mechanism to
# reliably distinguish test encodings from other codecs will hopefully
# be defined for Python 3.5
Expand Down Expand Up @@ -386,7 +386,7 @@ def writelines(self, list):

def reset(self):

""" Flushes and resets the codec buffers used for keeping state.
""" Resets the codec buffers used for keeping internal state.
Calling this method should ensure that the data on the
output is put into a clean state, that allows appending
Expand Down Expand Up @@ -620,7 +620,7 @@ def readlines(self, sizehint=None, keepends=True):

def reset(self):

""" Resets the codec buffers used for keeping state.
""" Resets the codec buffers used for keeping internal state.
Note that no stream repositioning should take place.
This method is primarily intended to be able to recover
Expand Down
194 changes: 161 additions & 33 deletions Lib/test/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

from test import support
from test.support import os_helper
from test.support import warnings_helper

try:
import _testcapi
except ImportError as exc:
except ImportError:
_testcapi = None

try:
Expand Down Expand Up @@ -113,7 +114,7 @@ def check_partial(self, input, partialresults):
q = Queue(b"")
r = codecs.getreader(self.encoding)(q)
result = ""
for (c, partialresult) in zip(input.encode(self.encoding), partialresults):
for (c, partialresult) in zip(input.encode(self.encoding), partialresults, strict=True):
q.write(bytes([c]))
result += r.read()
self.assertEqual(result, partialresult)
Expand All @@ -124,7 +125,7 @@ def check_partial(self, input, partialresults):
# do the check again, this time using an incremental decoder
d = codecs.getincrementaldecoder(self.encoding)()
result = ""
for (c, partialresult) in zip(input.encode(self.encoding), partialresults):
for (c, partialresult) in zip(input.encode(self.encoding), partialresults, strict=True):
result += d.decode(bytes([c]))
self.assertEqual(result, partialresult)
# check that there's nothing left in the buffers
Expand All @@ -134,7 +135,7 @@ def check_partial(self, input, partialresults):
# Check whether the reset method works properly
d.reset()
result = ""
for (c, partialresult) in zip(input.encode(self.encoding), partialresults):
for (c, partialresult) in zip(input.encode(self.encoding), partialresults, strict=True):
result += d.decode(bytes([c]))
self.assertEqual(result, partialresult)
# check that there's nothing left in the buffers
Expand Down Expand Up @@ -843,7 +844,7 @@ def test_bug691291(self):
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
with open(os_helper.TESTFN, 'wb') as fp:
fp.write(s)
with support.check_warnings(('', DeprecationWarning)):
with warnings_helper.check_warnings(('', DeprecationWarning)):
reader = codecs.open(os_helper.TESTFN, 'U', encoding=self.encoding)
with reader:
self.assertEqual(reader.read(), s1)
Expand Down Expand Up @@ -1814,6 +1815,22 @@ def test_register(self):
self.assertRaises(TypeError, codecs.register)
self.assertRaises(TypeError, codecs.register, 42)

def test_unregister(self):
name = "nonexistent_codec_name"
search_function = mock.Mock()
codecs.register(search_function)
self.assertRaises(TypeError, codecs.lookup, name)
search_function.assert_called_with(name)
search_function.reset_mock()

codecs.unregister(search_function)
self.assertRaises(LookupError, codecs.lookup, name)
search_function.assert_not_called()

# TODO: RUSTPYTHON, AttributeError: module '_winapi' has no attribute 'GetACP'
if sys.platform == "win32":
test_unregister = unittest.expectedFailure(test_unregister)

def test_lookup(self):
self.assertRaises(TypeError, codecs.lookup)
self.assertRaises(LookupError, codecs.lookup, "__spam__")
Expand Down Expand Up @@ -2544,7 +2561,16 @@ def test_unicode_escape(self):
(r"\x5c\x55\x30\x30\x31\x31\x30\x30\x30\x30", 10))


class UnicodeEscapeTest(unittest.TestCase):
class UnicodeEscapeTest(ReadTest, unittest.TestCase):
encoding = "unicode-escape"

test_lone_surrogates = None

# TODO: RUSTPYTHON, TypeError: Expected type 'str', not 'bytes'
@unittest.expectedFailure
def test_incremental_surrogatepass(self): # TODO: RUSTPYTHON, remove when this passes
super().test_incremental_surrogatepass() # TODO: RUSTPYTHON, remove when this passes

def test_empty(self):
self.assertEqual(codecs.unicode_escape_encode(""), (b"", 0))
self.assertEqual(codecs.unicode_escape_decode(b""), ("", 0))
Expand Down Expand Up @@ -2631,8 +2657,57 @@ def test_decode_errors(self):
self.assertEqual(decode(br"\U00110000", "ignore"), ("", 10))
self.assertEqual(decode(br"\U00110000", "replace"), ("\ufffd", 10))

# TODO: RUSTPYTHON, UnicodeDecodeError: ('unicodeescape', b'\\', 0, 1, '\\ at end of string')
@unittest.expectedFailure
def test_partial(self):
self.check_partial(
"\x00\t\n\r\\\xff\uffff\U00010000",
[
'',
'',
'',
'\x00',
'\x00',
'\x00\t',
'\x00\t',
'\x00\t\n',
'\x00\t\n',
'\x00\t\n\r',
'\x00\t\n\r',
'\x00\t\n\r\\',
'\x00\t\n\r\\',
'\x00\t\n\r\\',
'\x00\t\n\r\\',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff\U00010000',
]
)

class RawUnicodeEscapeTest(ReadTest, unittest.TestCase):
encoding = "raw-unicode-escape"

test_lone_surrogates = None

# TODO: RUSTPYTHON, AssertionError: '\\' != ''
@unittest.expectedFailure
def test_incremental_surrogatepass(self): # TODO: RUSTPYTHON, remove when this passes
super().test_incremental_surrogatepass() # TODO: RUSTPYTHON, remove when this passes

class RawUnicodeEscapeTest(unittest.TestCase):
def test_empty(self):
self.assertEqual(codecs.raw_unicode_escape_encode(""), (b"", 0))
self.assertEqual(codecs.raw_unicode_escape_decode(b""), ("", 0))
Expand Down Expand Up @@ -2681,6 +2756,37 @@ def test_decode_errors(self):
self.assertEqual(decode(br"\U00110000", "ignore"), ("", 10))
self.assertEqual(decode(br"\U00110000", "replace"), ("\ufffd", 10))

# TODO: RUSTPYTHON, AssertionError: '\x00\t\n\r\\' != '\x00\t\n\r'
@unittest.expectedFailure
def test_partial(self):
self.check_partial(
"\x00\t\n\r\\\xff\uffff\U00010000",
[
'\x00',
'\x00\t',
'\x00\t\n',
'\x00\t\n\r',
'\x00\t\n\r',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff',
'\x00\t\n\r\\\xff\uffff\U00010000',
]
)


class EscapeEncodeTest(unittest.TestCase):

Expand Down Expand Up @@ -2889,7 +2995,7 @@ def test_buffer_api_usage(self):
view_decoded = codecs.decode(view, encoding)
self.assertEqual(view_decoded, data)

def test_text_to_binary_blacklists_binary_transforms(self):
def test_text_to_binary_denylists_binary_transforms(self):
# Check binary -> binary codecs give a good error for str input
bad_input = "bad input type"
for encoding in bytes_transform_encodings:
Expand All @@ -2901,14 +3007,14 @@ def test_text_to_binary_blacklists_binary_transforms(self):
bad_input.encode(encoding)
self.assertIsNone(failure.exception.__cause__)

def test_text_to_binary_blacklists_text_transforms(self):
def test_text_to_binary_denylists_text_transforms(self):
# Check str.encode gives a good error message for str -> str codecs
msg = (r"^'rot_13' is not a text encoding; "
r"use codecs.encode\(\) to handle arbitrary codecs")
with self.assertRaisesRegex(LookupError, msg):
"just an example message".encode("rot_13")

def test_binary_to_text_blacklists_binary_transforms(self):
def test_binary_to_text_denylists_binary_transforms(self):
# Check bytes.decode and bytearray.decode give a good error
# message for binary -> binary codecs
data = b"encode first to ensure we meet any format restrictions"
Expand All @@ -2923,7 +3029,7 @@ def test_binary_to_text_blacklists_binary_transforms(self):
with self.assertRaisesRegex(LookupError, msg):
bytearray(encoded_data).decode(encoding)

def test_binary_to_text_blacklists_text_transforms(self):
def test_binary_to_text_denylists_text_transforms(self):
# Check str -> str codec gives a good error for binary input
for bad_input in (b"immutable", bytearray(b"mutable")):
with self.subTest(bad_input=bad_input):
Expand Down Expand Up @@ -2991,29 +3097,14 @@ def test_uu_invalid(self):

def _get_test_codec(codec_name):
return _TEST_CODECS.get(codec_name)
codecs.register(_get_test_codec) # Returns None, not usable as a decorator

try:
# Issue #22166: Also need to clear the internal cache in CPython
from _codecs import _forget_codec
except ImportError:
def _forget_codec(codec_name):
pass


class ExceptionChainingTest(unittest.TestCase):

def setUp(self):
# There's no way to unregister a codec search function, so we just
# ensure we render this one fairly harmless after the test
# case finishes by using the test case repr as the codec name
# The codecs module normalizes codec names, although this doesn't
# appear to be formally documented...
# We also make sure we use a truly unique id for the custom codec
# to avoid issues with the codec cache when running these tests
# multiple times (e.g. when hunting for refleaks)
unique_id = repr(self) + str(id(self))
self.codec_name = encodings.normalize_encoding(unique_id).lower()
self.codec_name = 'exception_chaining_test'
codecs.register(_get_test_codec)
self.addCleanup(codecs.unregister, _get_test_codec)

# We store the object to raise on the instance because of a bad
# interaction between the codec caching (which means we can't
Expand All @@ -3028,10 +3119,6 @@ def tearDown(self):
_TEST_CODECS.pop(self.codec_name, None)
# Issue #22166: Also pop from caches to avoid appearance of ref leaks
encodings._cache.pop(self.codec_name, None)
try:
_forget_codec(self.codec_name)
except KeyError:
pass

def set_codec(self, encode, decode):
codec_info = codecs.CodecInfo(encode, decode,
Expand Down Expand Up @@ -3710,5 +3797,46 @@ def test_rot13_func(self):
'To be, or not to be, that is the question')


class CodecNameNormalizationTest(unittest.TestCase):
"""Test codec name normalization"""
# TODO: RUSTPYTHON, AssertionError: Tuples differ: (1, 2, 3, 4) != (None, None, None, None)
@unittest.expectedFailure
def test_codecs_lookup(self):
FOUND = (1, 2, 3, 4)
NOT_FOUND = (None, None, None, None)
def search_function(encoding):
if encoding == "aaa_8":
return FOUND
else:
return NOT_FOUND

codecs.register(search_function)
self.addCleanup(codecs.unregister, search_function)
self.assertEqual(FOUND, codecs.lookup('aaa_8'))
self.assertEqual(FOUND, codecs.lookup('AAA-8'))
self.assertEqual(FOUND, codecs.lookup('AAA---8'))
self.assertEqual(FOUND, codecs.lookup('AAA 8'))
self.assertEqual(FOUND, codecs.lookup('aaa\xe9\u20ac-8'))
self.assertEqual(NOT_FOUND, codecs.lookup('AAA.8'))
self.assertEqual(NOT_FOUND, codecs.lookup('AAA...8'))
self.assertEqual(NOT_FOUND, codecs.lookup('BBB-8'))
self.assertEqual(NOT_FOUND, codecs.lookup('BBB.8'))
self.assertEqual(NOT_FOUND, codecs.lookup('a\xe9\u20ac-8'))

# TODO: RUSTPYTHON, AssertionError
@unittest.expectedFailure
def test_encodings_normalize_encoding(self):
# encodings.normalize_encoding() ignores non-ASCII characters.
normalize = encodings.normalize_encoding
self.assertEqual(normalize('utf_8'), 'utf_8')
self.assertEqual(normalize('utf\xE9\u20AC\U0010ffff-8'), 'utf_8')
self.assertEqual(normalize('utf 8'), 'utf_8')
# encodings.normalize_encoding() doesn't convert
# characters to lower case.
self.assertEqual(normalize('UTF 8'), 'UTF_8')
self.assertEqual(normalize('utf.8'), 'utf.8')
self.assertEqual(normalize('utf...8'), 'utf...8')


if __name__ == "__main__":
unittest.main()
20 changes: 19 additions & 1 deletion vm/src/codecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
builtins::{PyBaseExceptionRef, PyBytesRef, PyStr, PyStrRef, PyTuple, PyTupleRef},
common::{ascii, lock::PyRwLock},
function::IntoPyObject,
PyContext, PyObject, PyObjectRef, PyResult, PyValue, TryFromObject, TypeProtocol,
IdProtocol, PyContext, PyObject, PyObjectRef, PyResult, PyValue, TryFromObject, TypeProtocol,
VirtualMachine,
};
use std::borrow::Cow;
Expand Down Expand Up @@ -195,6 +195,24 @@ impl CodecsRegistry {
Ok(())
}

pub fn unregister(&self, search_function: PyObjectRef) -> PyResult<()> {
let mut inner = self.inner.write();
// Do nothing if search_path is not created yet or was cleared.
if inner.search_path.is_empty() {
return Ok(());
}
for (i, item) in inner.search_path.iter().enumerate() {
if item.get_id() == search_function.get_id() {
if !inner.search_cache.is_empty() {
inner.search_cache.clear();
}
inner.search_path.remove(i);
return Ok(());
}
}
Ok(())
}

pub fn lookup(&self, encoding: &str, vm: &VirtualMachine) -> PyResult<PyCodec> {
let encoding = normalize_encoding_name(encoding);
let inner = self.inner.read();
Expand Down
5 changes: 5 additions & 0 deletions vm/src/stdlib/codecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ mod _codecs {
vm.state.codec_registry.register(search_function, vm)
}

#[pyfunction]
fn unregister(search_function: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
vm.state.codec_registry.unregister(search_function)
}

#[pyfunction]
fn lookup(encoding: PyStrRef, vm: &VirtualMachine) -> PyResult {
vm.state
Expand Down

0 comments on commit ab50217

Please sign in to comment.