Skip to content

Commit

Permalink
Allow displaying SVG images securely in gr.Image and gr.Gallery c…
Browse files Browse the repository at this point in the history
…omponents (#10269)

* changes

* changes

* add changeset

* changes

* add changeset

* changes

* changes

* changes

* add changeset

* add changeset

* add changeset

* format fe

* changes

* changes

* changes

* revert

* revert more

* revert

* add changeset

* more changes

* add changeset

* changes

* add changeset

* format

* add changeset

* changes

* changes

* svg

* changes

* format

* add changeset

* fix tests

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Jan 6, 2025
1 parent 99123e7 commit 890eaa3
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 53 deletions.
7 changes: 7 additions & 0 deletions .changeset/eleven-suits-itch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/gallery": patch
"@gradio/image": patch
"gradio": patch
---

fix:Allow displaying SVG images securely in `gr.Image` and `gr.Gallery` components
30 changes: 20 additions & 10 deletions gradio/components/gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Optional,
Union,
)
from urllib.parse import urlparse
from urllib.parse import quote, urlparse

import numpy as np
import PIL.Image
Expand All @@ -21,9 +21,9 @@
from gradio_client.documentation import document
from gradio_client.utils import is_http_url_like

from gradio import processing_utils, utils, wasm_utils
from gradio import image_utils, processing_utils, utils, wasm_utils
from gradio.components.base import Component
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.data_classes import FileData, GradioModel, GradioRootModel, ImageData
from gradio.events import Events
from gradio.exceptions import Error

Expand All @@ -35,7 +35,7 @@


class GalleryImage(GradioModel):
image: FileData
image: ImageData
caption: Optional[str] = None


Expand Down Expand Up @@ -188,7 +188,7 @@ def preprocess(
if isinstance(gallery_element, GalleryVideo):
file_path = gallery_element.video.path
else:
file_path = gallery_element.image.path
file_path = gallery_element.image.path or ""
if self.file_types and not client_utils.is_valid_file(
file_path, self.file_types
):
Expand Down Expand Up @@ -216,6 +216,10 @@ def postprocess(
"""
if value is None:
return GalleryData(root=[])
if isinstance(value, str):
raise ValueError(
"The `value` passed into `gr.Gallery` must be a list of images or videos, or list of (media, caption) tuples."
)
output = []

def _save(img):
Expand All @@ -236,14 +240,20 @@ def _save(img):
)
file_path = str(utils.abspath(file))
elif isinstance(img, str):
file_path = img
mime_type = client_utils.get_mimetype(file_path)
if is_http_url_like(img):
mime_type = client_utils.get_mimetype(img)
if img.lower().endswith(".svg"):
svg_content = image_utils.extract_svg_content(img)
orig_name = Path(img).name
url = f"data:image/svg+xml,{quote(svg_content)}"
file_path = None
elif is_http_url_like(img):
url = img
orig_name = Path(urlparse(img).path).name
file_path = img
else:
url = None
orig_name = Path(img).name
file_path = img
elif isinstance(img, Path):
file_path = str(img)
orig_name = img.name
Expand All @@ -253,7 +263,7 @@ def _save(img):
if mime_type is not None and "video" in mime_type:
return GalleryVideo(
video=FileData(
path=file_path,
path=file_path, # type: ignore
url=url,
orig_name=orig_name,
mime_type=mime_type,
Expand All @@ -262,7 +272,7 @@ def _save(img):
)
else:
return GalleryImage(
image=FileData(
image=ImageData(
path=file_path,
url=url,
orig_name=orig_name,
Expand Down
40 changes: 11 additions & 29 deletions gradio/components/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
import warnings
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from urllib.parse import quote

import numpy as np
import PIL.Image
from gradio_client import handle_file
from gradio_client.documentation import document
from PIL import ImageOps
from pydantic import ConfigDict, Field

from gradio import image_utils, utils
from gradio.components.base import Component, StreamingInput
from gradio.data_classes import GradioModel
from gradio.data_classes import Base64ImageData, ImageData
from gradio.events import Events
from gradio.exceptions import Error

Expand All @@ -26,28 +26,6 @@
PIL.Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843


class ImageData(GradioModel):
path: Optional[str] = Field(default=None, description="Path to a local file")
url: Optional[str] = Field(
default=None, description="Publicly available url or base64 encoded image"
)
size: Optional[int] = Field(default=None, description="Size of image in bytes")
orig_name: Optional[str] = Field(default=None, description="Original filename")
mime_type: Optional[str] = Field(default=None, description="mime type of image")
is_stream: bool = Field(default=False, description="Can always be set to False")
meta: dict = {"_type": "gradio.FileData"}

model_config = ConfigDict(
json_schema_extra={
"description": "For input, either path or url must be provided. For output, path is always provided."
}
)


class Base64ImageData(GradioModel):
url: str = Field(description="base64 encoded image")


@document()
class Image(StreamingInput, Component):
"""
Expand Down Expand Up @@ -112,7 +90,7 @@ def __init__(
width: The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed image file or numpy array, but will affect the displayed image.
image_mode: The pixel format and color depth that the image should be loaded and preprocessed as. "RGB" will load the image as a color image, or "L" as black-and-white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files. If set to None, the image_mode will be inferred from the image file type (e.g. "RGBA" for a .png image, "RGB" in most other cases).
sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"].
type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to "filepath" or "pil".
type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. To support animated GIFs in input, the `type` should be set to "filepath" or "pil". To support SVGs, the `type` should be set to "filepath".
label: the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
Expand Down Expand Up @@ -198,7 +176,7 @@ def preprocess(
Parameters:
payload: image data in the form of a FileData object
Returns:
Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`. For SVGs, the `type` parameter is ignored and the filepath of the SVG is returned.
Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`.
"""
if payload is None:
return payload
Expand Down Expand Up @@ -227,7 +205,7 @@ def preprocess(
if suffix.lower() == "svg":
if self.type == "filepath":
return str(file_path)
raise Error("SVG files are not supported as input images.")
raise Error("SVG files are not supported as input images for this app.")

im = PIL.Image.open(file_path)
if self.type == "filepath" and (self.image_mode in [None, im.mode]):
Expand Down Expand Up @@ -267,7 +245,11 @@ def postprocess(
if value is None:
return None
if isinstance(value, str) and value.lower().endswith(".svg"):
return ImageData(path=value, orig_name=Path(value).name)
svg_content = image_utils.extract_svg_content(value)
return ImageData(
orig_name=Path(value).name,
url=f"data:image/svg+xml,{quote(svg_content)}",
)
if self.streaming:
if isinstance(value, np.ndarray):
return Base64ImageData(
Expand Down
24 changes: 24 additions & 0 deletions gradio/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from gradio_client.utils import is_file_obj_with_meta, traverse
from pydantic import (
BaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
RootModel,
Expand Down Expand Up @@ -391,3 +393,25 @@ class MediaStreamChunk(TypedDict):
duration: float
extension: str
id: NotRequired[str]


class ImageData(GradioModel):
path: Optional[str] = Field(default=None, description="Path to a local file")
url: Optional[str] = Field(
default=None, description="Publicly available url or base64 encoded image"
)
size: Optional[int] = Field(default=None, description="Size of image in bytes")
orig_name: Optional[str] = Field(default=None, description="Original filename")
mime_type: Optional[str] = Field(default=None, description="mime type of image")
is_stream: bool = Field(default=False, description="Can always be set to False")
meta: dict = {"_type": "gradio.FileData"}

model_config = ConfigDict(
json_schema_extra={
"description": "For input, either path or url must be provided. For output, path is always provided."
}
)


class Base64ImageData(GradioModel):
url: str = Field(description="base64 encoded image")
22 changes: 21 additions & 1 deletion gradio/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from pathlib import Path
from typing import Literal, cast

import httpx
import numpy as np
import PIL.Image
from gradio_client.utils import get_mimetype
from gradio_client.utils import get_mimetype, is_http_url_like
from PIL import ImageOps

from gradio import processing_utils
Expand Down Expand Up @@ -152,3 +153,22 @@ def encode_image_file_to_base64(image_file: str | Path) -> str:
bytes_data = f.read()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return f"data:{mime_type};base64," + base64_str


def extract_svg_content(image_file: str | Path) -> str:
"""
Provided a path or URL to an SVG file, return the SVG content as a string.
Parameters:
image_file: Local file path or URL to an SVG file
Returns:
str: The SVG content as a string
"""
image_file = str(image_file)
if is_http_url_like(image_file):
response = httpx.get(image_file)
response.raise_for_status() # Raise an error for bad status codes
return response.text
else:
with open(image_file) as file:
svg_content = file.read()
return svg_content
33 changes: 27 additions & 6 deletions js/gallery/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

<script lang="ts">
import type { GalleryImage, GalleryVideo } from "./types";
import type { FileData } from "@gradio/client";
import type { Gradio, ShareData, SelectData } from "@gradio/utils";
import { Block, UploadText } from "@gradio/atoms";
import Gallery from "./shared/Gallery.svelte";
Expand Down Expand Up @@ -52,6 +53,30 @@
$: no_value = value === null ? true : value.length === 0;
$: selected_index, dispatch("prop_change", { selected_index });
async function process_upload_files(
files: FileData[]
): Promise<GalleryData[]> {
const processed_files = await Promise.all(
files.map(async (x) => {
if (x.path?.toLowerCase().endsWith(".svg") && x.url) {
const response = await fetch(x.url);
const svgContent = await response.text();
return {
...x,
url: `data:image/svg+xml,${encodeURIComponent(svgContent)}`
};
}
return x;
})
);
return processed_files.map((x) =>
x.mime_type?.includes("video")
? { video: x, caption: null }
: { image: x, caption: null }
);
}
</script>

<Block
Expand Down Expand Up @@ -83,13 +108,9 @@
i18n={gradio.i18n}
upload={(...args) => gradio.client.upload(...args)}
stream_handler={(...args) => gradio.client.stream(...args)}
on:upload={(e) => {
on:upload={async (e) => {
const files = Array.isArray(e.detail) ? e.detail : [e.detail];
value = files.map((x) =>
x.mime_type?.includes("video")
? { video: x, caption: null }
: { image: x, caption: null }
);
value = await process_upload_files(files);
gradio.dispatch("upload", value);
}}
on:error={({ detail }) => {
Expand Down
16 changes: 13 additions & 3 deletions js/image/shared/ImageUploader.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,20 @@
export let webcam_constraints: { [key: string]: any } | undefined = undefined;
function handle_upload({ detail }: CustomEvent<FileData>): void {
// only trigger streaming event if streaming
async function handle_upload({
detail
}: CustomEvent<FileData>): Promise<void> {
if (!streaming) {
value = detail;
if (detail.path?.toLowerCase().endsWith(".svg") && detail.url) {
const response = await fetch(detail.url);
const svgContent = await response.text();
value = {
...detail,
url: `data:image/svg+xml,${encodeURIComponent(svgContent)}`
};
} else {
value = detail;
}
dispatch("upload");
}
}
Expand Down
10 changes: 6 additions & 4 deletions test/components/test_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import gradio as gr
from gradio.components.gallery import GalleryImage
from gradio.data_classes import FileData
from gradio.data_classes import ImageData


class TestGallery:
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_gallery_preprocess(self):
from gradio.components.gallery import GalleryData, GalleryImage

gallery = gr.Gallery()
img = GalleryImage(image=FileData(path="test/test_files/bus.png"))
img = GalleryImage(image=ImageData(path="test/test_files/bus.png"))
data = GalleryData(root=[img])

assert (preprocessed := gallery.preprocess(data))
Expand All @@ -115,7 +115,7 @@ def test_gallery_preprocess(self):
)

img_captions = GalleryImage(
image=FileData(path="test/test_files/bus.png"), caption="bus"
image=ImageData(path="test/test_files/bus.png"), caption="bus"
)
data = GalleryData(root=[img_captions])
assert (preprocess := gr.Gallery().preprocess(data))
Expand All @@ -127,4 +127,6 @@ def test_gallery_format(self):
[np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)]
)
if isinstance(output.root[0], GalleryImage):
assert output.root[0].image.path.endswith(".jpeg")
assert output.root[0].image.path and output.root[0].image.path.endswith(
".jpeg"
)
1 change: 1 addition & 0 deletions test/test_files/file_icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 890eaa3

Please sign in to comment.