Skip to content

Commit

Permalink
load_image - decode b64encode and encodebytes strings (#30192)
Browse files Browse the repository at this point in the history
* Decode b64encode and encodebytes strings

* Remove conditional encode -- image is always a string
  • Loading branch information
amyeroberts authored Apr 26, 2024
1 parent e7d52a1 commit c793b26
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =

# Try to load as base64
try:
b64 = base64.b64decode(image, validate=True)
b64 = base64.decodebytes(image.encode())
image = PIL.Image.open(BytesIO(b64))
except Exception as e:
raise ValueError(
Expand Down
18 changes: 18 additions & 0 deletions tests/utils/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import codecs
import os
import tempfile
import unittest
Expand Down Expand Up @@ -544,6 +545,23 @@ def test_load_img_base64(self):

self.assertEqual(img_arr.shape, (64, 32, 3))

def test_load_img_base64_encoded_bytes(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get(
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_2.txt", f
)

with codecs.open(tmp_file, encoding="unicode_escape") as b64:
img = load_image(b64.read())
img_arr = np.array(img)

finally:
os.remove(tmp_file)

self.assertEqual(img_arr.shape, (256, 256, 3))

def test_load_img_rgba(self):
# we use revision="refs/pr/1" until the PR is merged
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
Expand Down

0 comments on commit c793b26

Please sign in to comment.