diff --git a/sopel/irc/utils.py b/sopel/irc/utils.py index 12ba09461..3d3bf450f 100644 --- a/sopel/irc/utils.py +++ b/sopel/irc/utils.py @@ -25,10 +25,11 @@ def safe(string): :rtype: str :raises TypeError: when ``string`` is ``None`` - This function removes newlines from a string and always returns a unicode - string (``str``), but doesn't strip or alter it in any other way:: + This function removes newlines and null-bytes from a string. It will always + return a Unicode ``str``, even if given non-Unicode input, but doesn't strip + or alter the string in any other way:: - >>> safe('some text\\r\\n') + >>> safe('some \x00text\\r\\n') 'some text' This is useful to ensure a string can be used in a IRC message. @@ -45,6 +46,7 @@ def safe(string): string = string.decode("utf8") string = string.replace('\n', '') string = string.replace('\r', '') + string = string.replace('\x00', '') return string diff --git a/test/irc/test_irc_utils.py b/test/irc/test_irc_utils.py index 94eab6298..3204444e2 100644 --- a/test/irc/test_irc_utils.py +++ b/test/irc/test_irc_utils.py @@ -1,6 +1,8 @@ """Tests for core ``sopel.irc.utils``""" from __future__ import annotations +from itertools import permutations + import pytest from sopel.irc import utils @@ -8,15 +10,19 @@ def test_safe(): text = 'some text' - assert utils.safe(text + '\r\n') == text - assert utils.safe(text + '\n') == text - assert utils.safe(text + '\r') == text - assert utils.safe('\r\n' + text) == text - assert utils.safe('\n' + text) == text - assert utils.safe('\r' + text) == text - assert utils.safe('some \r\ntext') == text - assert utils.safe('some \ntext') == text - assert utils.safe('some \rtext') == text + variants = permutations(('\n', '\r', '\x00')) + for variant in variants: + seq = ''.join(variant) + assert utils.safe(text + seq) == text + assert utils.safe(seq + text) == text + assert utils.safe('some ' + seq + 'text') == text + assert utils.safe( + variant[0] + + 'some ' + + variant[1] + + 'text' + + variant[2] + ) == text def test_safe_empty(): @@ -24,20 +30,23 @@ def test_safe_empty(): assert utils.safe(text) == text -def test_safe_null(): +def test_safe_none(): with pytest.raises(TypeError): utils.safe(None) def test_safe_bytes(): text = b'some text' - assert utils.safe(text) == text.decode('utf-8') - assert utils.safe(text + b'\r\n') == text.decode('utf-8') - assert utils.safe(text + b'\n') == text.decode('utf-8') - assert utils.safe(text + b'\r') == text.decode('utf-8') - assert utils.safe(b'\r\n' + text) == text.decode('utf-8') - assert utils.safe(b'\n' + text) == text.decode('utf-8') - assert utils.safe(b'\r' + text) == text.decode('utf-8') - assert utils.safe(b'some \r\ntext') == text.decode('utf-8') - assert utils.safe(b'some \ntext') == text.decode('utf-8') - assert utils.safe(b'some \rtext') == text.decode('utf-8') + variants = permutations((b'\n', b'\r', b'\x00')) + for variant in variants: + seq = b''.join(variant) + assert utils.safe(text + seq) == text.decode('utf-8') + assert utils.safe(seq + text) == text.decode('utf-8') + assert utils.safe(b'some ' + seq + b'text') == text.decode('utf-8') + assert utils.safe( + variant[0] + + b'some ' + + variant[1] + + b'text' + + variant[2] + ) == text.decode('utf-8')