diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index e4a55b3455a..7d71fc982b4 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -320,7 +320,7 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = # Try to load as base64 try: - b64 = base64.b64decode(image, validate=True) + b64 = base64.decodebytes(image.encode()) image = PIL.Image.open(BytesIO(b64)) except Exception as e: raise ValueError( diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index d6bc9a37585..f360c4bb825 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import codecs import os import tempfile import unittest @@ -544,6 +545,23 @@ def test_load_img_base64(self): self.assertEqual(img_arr.shape, (64, 32, 3)) + def test_load_img_base64_encoded_bytes(self): + try: + tmp_file = tempfile.mktemp() + with open(tmp_file, "wb") as f: + http_get( + "https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_2.txt", f + ) + + with codecs.open(tmp_file, encoding="unicode_escape") as b64: + img = load_image(b64.read()) + img_arr = np.array(img) + + finally: + os.remove(tmp_file) + + self.assertEqual(img_arr.shape, (256, 256, 3)) + def test_load_img_rgba(self): # we use revision="refs/pr/1" until the PR is merged # https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1