From 890eaa3a9e53dab5bcb16c5d017ae0470109b8fb Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 6 Jan 2025 13:07:13 -0800 Subject: [PATCH] Allow displaying SVG images securely in `gr.Image` and `gr.Gallery` components (#10269) * changes * changes * add changeset * changes * add changeset * changes * changes * changes * add changeset * add changeset * add changeset * format fe * changes * changes * changes * revert * revert more * revert * add changeset * more changes * add changeset * changes * add changeset * format * add changeset * changes * changes * svg * changes * format * add changeset * fix tests --------- Co-authored-by: gradio-pr-bot --- .changeset/eleven-suits-itch.md | 7 +++++ gradio/components/gallery.py | 30 ++++++++++++++------- gradio/components/image.py | 40 ++++++++-------------------- gradio/data_classes.py | 24 +++++++++++++++++ gradio/image_utils.py | 22 ++++++++++++++- js/gallery/Index.svelte | 33 ++++++++++++++++++----- js/image/shared/ImageUploader.svelte | 16 ++++++++--- test/components/test_gallery.py | 10 ++++--- test/test_files/file_icon.svg | 1 + test/test_image_utils.py | 26 ++++++++++++++++++ 10 files changed, 156 insertions(+), 53 deletions(-) create mode 100644 .changeset/eleven-suits-itch.md create mode 100644 test/test_files/file_icon.svg create mode 100644 test/test_image_utils.py diff --git a/.changeset/eleven-suits-itch.md b/.changeset/eleven-suits-itch.md new file mode 100644 index 0000000000000..e38f4b3fe4b93 --- /dev/null +++ b/.changeset/eleven-suits-itch.md @@ -0,0 +1,7 @@ +--- +"@gradio/gallery": patch +"@gradio/image": patch +"gradio": patch +--- + +fix:Allow displaying SVG images securely in `gr.Image` and `gr.Gallery` components diff --git a/gradio/components/gallery.py b/gradio/components/gallery.py index e7a8759452b42..97a38af1cc9f7 100644 --- a/gradio/components/gallery.py +++ b/gradio/components/gallery.py @@ -12,7 +12,7 @@ Optional, Union, ) -from urllib.parse import urlparse +from urllib.parse import quote, urlparse import numpy as np import PIL.Image @@ -21,9 +21,9 @@ from gradio_client.documentation import document from gradio_client.utils import is_http_url_like -from gradio import processing_utils, utils, wasm_utils +from gradio import image_utils, processing_utils, utils, wasm_utils from gradio.components.base import Component -from gradio.data_classes import FileData, GradioModel, GradioRootModel +from gradio.data_classes import FileData, GradioModel, GradioRootModel, ImageData from gradio.events import Events from gradio.exceptions import Error @@ -35,7 +35,7 @@ class GalleryImage(GradioModel): - image: FileData + image: ImageData caption: Optional[str] = None @@ -188,7 +188,7 @@ def preprocess( if isinstance(gallery_element, GalleryVideo): file_path = gallery_element.video.path else: - file_path = gallery_element.image.path + file_path = gallery_element.image.path or "" if self.file_types and not client_utils.is_valid_file( file_path, self.file_types ): @@ -216,6 +216,10 @@ def postprocess( """ if value is None: return GalleryData(root=[]) + if isinstance(value, str): + raise ValueError( + "The `value` passed into `gr.Gallery` must be a list of images or videos, or list of (media, caption) tuples." + ) output = [] def _save(img): @@ -236,14 +240,20 @@ def _save(img): ) file_path = str(utils.abspath(file)) elif isinstance(img, str): - file_path = img - mime_type = client_utils.get_mimetype(file_path) - if is_http_url_like(img): + mime_type = client_utils.get_mimetype(img) + if img.lower().endswith(".svg"): + svg_content = image_utils.extract_svg_content(img) + orig_name = Path(img).name + url = f"data:image/svg+xml,{quote(svg_content)}" + file_path = None + elif is_http_url_like(img): url = img orig_name = Path(urlparse(img).path).name + file_path = img else: url = None orig_name = Path(img).name + file_path = img elif isinstance(img, Path): file_path = str(img) orig_name = img.name @@ -253,7 +263,7 @@ def _save(img): if mime_type is not None and "video" in mime_type: return GalleryVideo( video=FileData( - path=file_path, + path=file_path, # type: ignore url=url, orig_name=orig_name, mime_type=mime_type, @@ -262,7 +272,7 @@ def _save(img): ) else: return GalleryImage( - image=FileData( + image=ImageData( path=file_path, url=url, orig_name=orig_name, diff --git a/gradio/components/image.py b/gradio/components/image.py index da1a0534c784b..dc7335ee008e7 100644 --- a/gradio/components/image.py +++ b/gradio/components/image.py @@ -5,18 +5,18 @@ import warnings from collections.abc import Callable, Sequence from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast +from urllib.parse import quote import numpy as np import PIL.Image from gradio_client import handle_file from gradio_client.documentation import document from PIL import ImageOps -from pydantic import ConfigDict, Field from gradio import image_utils, utils from gradio.components.base import Component, StreamingInput -from gradio.data_classes import GradioModel +from gradio.data_classes import Base64ImageData, ImageData from gradio.events import Events from gradio.exceptions import Error @@ -26,28 +26,6 @@ PIL.Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 -class ImageData(GradioModel): - path: Optional[str] = Field(default=None, description="Path to a local file") - url: Optional[str] = Field( - default=None, description="Publicly available url or base64 encoded image" - ) - size: Optional[int] = Field(default=None, description="Size of image in bytes") - orig_name: Optional[str] = Field(default=None, description="Original filename") - mime_type: Optional[str] = Field(default=None, description="mime type of image") - is_stream: bool = Field(default=False, description="Can always be set to False") - meta: dict = {"_type": "gradio.FileData"} - - model_config = ConfigDict( - json_schema_extra={ - "description": "For input, either path or url must be provided. For output, path is always provided." - } - ) - - -class Base64ImageData(GradioModel): - url: str = Field(description="base64 encoded image") - - @document() class Image(StreamingInput, Component): """ @@ -112,7 +90,7 @@ def __init__( width: The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed image file or numpy array, but will affect the displayed image. image_mode: The pixel format and color depth that the image should be loaded and preprocessed as. "RGB" will load the image as a color image, or "L" as black-and-white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files. If set to None, the image_mode will be inferred from the image file type (e.g. "RGBA" for a .png image, "RGB" in most other cases). sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"]. - type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to "filepath" or "pil". + type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. To support animated GIFs in input, the `type` should be set to "filepath" or "pil". To support SVGs, the `type` should be set to "filepath". label: the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer. inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change. @@ -198,7 +176,7 @@ def preprocess( Parameters: payload: image data in the form of a FileData object Returns: - Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`. For SVGs, the `type` parameter is ignored and the filepath of the SVG is returned. + Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`. """ if payload is None: return payload @@ -227,7 +205,7 @@ def preprocess( if suffix.lower() == "svg": if self.type == "filepath": return str(file_path) - raise Error("SVG files are not supported as input images.") + raise Error("SVG files are not supported as input images for this app.") im = PIL.Image.open(file_path) if self.type == "filepath" and (self.image_mode in [None, im.mode]): @@ -267,7 +245,11 @@ def postprocess( if value is None: return None if isinstance(value, str) and value.lower().endswith(".svg"): - return ImageData(path=value, orig_name=Path(value).name) + svg_content = image_utils.extract_svg_content(value) + return ImageData( + orig_name=Path(value).name, + url=f"data:image/svg+xml,{quote(svg_content)}", + ) if self.streaming: if isinstance(value, np.ndarray): return Base64ImageData( diff --git a/gradio/data_classes.py b/gradio/data_classes.py index a04ea90c0df2d..c324be300190e 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -24,6 +24,8 @@ from gradio_client.utils import is_file_obj_with_meta, traverse from pydantic import ( BaseModel, + ConfigDict, + Field, GetCoreSchemaHandler, GetJsonSchemaHandler, RootModel, @@ -391,3 +393,25 @@ class MediaStreamChunk(TypedDict): duration: float extension: str id: NotRequired[str] + + +class ImageData(GradioModel): + path: Optional[str] = Field(default=None, description="Path to a local file") + url: Optional[str] = Field( + default=None, description="Publicly available url or base64 encoded image" + ) + size: Optional[int] = Field(default=None, description="Size of image in bytes") + orig_name: Optional[str] = Field(default=None, description="Original filename") + mime_type: Optional[str] = Field(default=None, description="mime type of image") + is_stream: bool = Field(default=False, description="Can always be set to False") + meta: dict = {"_type": "gradio.FileData"} + + model_config = ConfigDict( + json_schema_extra={ + "description": "For input, either path or url must be provided. For output, path is always provided." + } + ) + + +class Base64ImageData(GradioModel): + url: str = Field(description="base64 encoded image") diff --git a/gradio/image_utils.py b/gradio/image_utils.py index a0a40efcfee5f..f71fdf8c3f624 100644 --- a/gradio/image_utils.py +++ b/gradio/image_utils.py @@ -5,9 +5,10 @@ from pathlib import Path from typing import Literal, cast +import httpx import numpy as np import PIL.Image -from gradio_client.utils import get_mimetype +from gradio_client.utils import get_mimetype, is_http_url_like from PIL import ImageOps from gradio import processing_utils @@ -152,3 +153,22 @@ def encode_image_file_to_base64(image_file: str | Path) -> str: bytes_data = f.read() base64_str = str(base64.b64encode(bytes_data), "utf-8") return f"data:{mime_type};base64," + base64_str + + +def extract_svg_content(image_file: str | Path) -> str: + """ + Provided a path or URL to an SVG file, return the SVG content as a string. + Parameters: + image_file: Local file path or URL to an SVG file + Returns: + str: The SVG content as a string + """ + image_file = str(image_file) + if is_http_url_like(image_file): + response = httpx.get(image_file) + response.raise_for_status() # Raise an error for bad status codes + return response.text + else: + with open(image_file) as file: + svg_content = file.read() + return svg_content diff --git a/js/gallery/Index.svelte b/js/gallery/Index.svelte index c23f0a7d23109..623c61f362bc8 100644 --- a/js/gallery/Index.svelte +++ b/js/gallery/Index.svelte @@ -4,6 +4,7 @@ gradio.client.upload(...args)} stream_handler={(...args) => gradio.client.stream(...args)} - on:upload={(e) => { + on:upload={async (e) => { const files = Array.isArray(e.detail) ? e.detail : [e.detail]; - value = files.map((x) => - x.mime_type?.includes("video") - ? { video: x, caption: null } - : { image: x, caption: null } - ); + value = await process_upload_files(files); gradio.dispatch("upload", value); }} on:error={({ detail }) => { diff --git a/js/image/shared/ImageUploader.svelte b/js/image/shared/ImageUploader.svelte index 415df761c898c..b1e6cab946549 100644 --- a/js/image/shared/ImageUploader.svelte +++ b/js/image/shared/ImageUploader.svelte @@ -45,10 +45,20 @@ export let webcam_constraints: { [key: string]: any } | undefined = undefined; - function handle_upload({ detail }: CustomEvent): void { - // only trigger streaming event if streaming + async function handle_upload({ + detail + }: CustomEvent): Promise { if (!streaming) { - value = detail; + if (detail.path?.toLowerCase().endsWith(".svg") && detail.url) { + const response = await fetch(detail.url); + const svgContent = await response.text(); + value = { + ...detail, + url: `data:image/svg+xml,${encodeURIComponent(svgContent)}` + }; + } else { + value = detail; + } dispatch("upload"); } } diff --git a/test/components/test_gallery.py b/test/components/test_gallery.py index eac9ce763a596..a5d79a6a77fe0 100644 --- a/test/components/test_gallery.py +++ b/test/components/test_gallery.py @@ -5,7 +5,7 @@ import gradio as gr from gradio.components.gallery import GalleryImage -from gradio.data_classes import FileData +from gradio.data_classes import ImageData class TestGallery: @@ -96,7 +96,7 @@ def test_gallery_preprocess(self): from gradio.components.gallery import GalleryData, GalleryImage gallery = gr.Gallery() - img = GalleryImage(image=FileData(path="test/test_files/bus.png")) + img = GalleryImage(image=ImageData(path="test/test_files/bus.png")) data = GalleryData(root=[img]) assert (preprocessed := gallery.preprocess(data)) @@ -115,7 +115,7 @@ def test_gallery_preprocess(self): ) img_captions = GalleryImage( - image=FileData(path="test/test_files/bus.png"), caption="bus" + image=ImageData(path="test/test_files/bus.png"), caption="bus" ) data = GalleryData(root=[img_captions]) assert (preprocess := gr.Gallery().preprocess(data)) @@ -127,4 +127,6 @@ def test_gallery_format(self): [np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)] ) if isinstance(output.root[0], GalleryImage): - assert output.root[0].image.path.endswith(".jpeg") + assert output.root[0].image.path and output.root[0].image.path.endswith( + ".jpeg" + ) diff --git a/test/test_files/file_icon.svg b/test/test_files/file_icon.svg new file mode 100644 index 0000000000000..8855359467805 --- /dev/null +++ b/test/test_files/file_icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/test/test_image_utils.py b/test/test_image_utils.py new file mode 100644 index 0000000000000..7666c3792dc9f --- /dev/null +++ b/test/test_image_utils.py @@ -0,0 +1,26 @@ +from gradio.image_utils import extract_svg_content + + +def test_extract_svg_content_local_file(): + svg_path = "test/test_files/file_icon.svg" + svg_content = extract_svg_content(svg_path) + assert ( + svg_content + == '' + ) + + +def test_extract_svg_content_from_url(monkeypatch): + class MockResponse: + def __init__(self): + self.text = "mock svg content" + + def raise_for_status(self): + pass + + def mock_get(*args, **kwargs): + return MockResponse() + + monkeypatch.setattr("httpx.get", mock_get) + svg_content = extract_svg_content("https://example.com/test.svg") + assert svg_content == "mock svg content"