diff --git a/tests/test_parser.py b/tests/test_parser.py index bc609e1..453b7ab 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -180,6 +180,16 @@ def test_patchset_from_bytes_string(self): self.assertEqual(ps1, ps2) + def test_patchset_string_input(self): + with codecs.open(self.sample_file, 'r', encoding='utf-8') as diff_file: + diff_data = diff_file.read() + ps1 = PatchSet(diff_data) + + with codecs.open(self.sample_file, 'r', encoding='utf-8') as diff_file: + ps2 = PatchSet(diff_file) + + self.assertEqual(ps1, ps2) + def test_parse_malformed_diff(self): """Parse malformed file.""" with open(self.sample_bad_file) as diff_file: diff --git a/unidiff/patch.py b/unidiff/patch.py index 1aa09ad..70f0284 100644 --- a/unidiff/patch.py +++ b/unidiff/patch.py @@ -62,7 +62,8 @@ def implements_to_string(cls): make_str = str implements_to_string = lambda x: x unicode = str - + basestring = str + @implements_to_string class Line(object): @@ -294,6 +295,11 @@ class PatchSet(list): def __init__(self, f, encoding=None): super(PatchSet, self).__init__() + + # convert string inputs to StringIO objects + if isinstance(f, basestring): + f = self._convert_string(f, encoding) + # make sure we pass an iterator object to parse data = iter(f) # if encoding is None, assume we are reading unicode data @@ -355,13 +361,17 @@ def from_filename(cls, filename, encoding=DEFAULT_ENCODING, errors=None): instance = cls(f) return instance - @classmethod - def from_string(cls, data, encoding=None, errors='strict'): + @staticmethod + def _convert_string(data, encoding=None, errors='strict'): """Return a PatchSet instance given a diff string.""" if encoding is not None: # if encoding is given, assume bytes and decode data = unicode(data, encoding=encoding, errors=errors) - return cls(StringIO(data)) + return StringIO(data) + + @classmethod + def from_string(cls, data, encoding=None, errors='strict'): + return cls(cls._convert_string(data, encoding, errors)) @property def added_files(self):