Skip to content

Commit

Permalink
Improve url downloads for file objects (#8978)
Browse files Browse the repository at this point in the history
* changes

* changes

* add changeset

* add changeset

* Ci security tweaks (#9010)

* asd

* asd

* asd

* asd

* asd

* asd

* asd

* asd

* asd

* asd

* asd

* asd

* asd

* asd

* change

* changes

* changes

* changes

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <[email protected]>
Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: pngwn <[email protected]>
  • Loading branch information
4 people authored Aug 6, 2024
1 parent 8805fc4 commit fe9d1cb
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .changeset/cyan-spies-check.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Improve url downloads for file objects
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ gradio/templates/frontend/cdn
*.db
*.sqlite3
gradio/launches.json
gradio/hash_seed.txt

# Tests
.coverage
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ js/gradio-preview/test/*
*.db
*.sqlite3
gradio/launches.json
gradio/hash_seed.txt
flagged/
gradio_cached_examples/
tmp.zip
Expand Down
52 changes: 49 additions & 3 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

import base64
import hashlib
import ipaddress
import json
import logging
import os
import shutil
import socket
import subprocess
import tempfile
import warnings
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

import aiofiles
import httpx
Expand All @@ -22,7 +25,7 @@
from gradio import utils, wasm_utils
from gradio.data_classes import FileData, GradioModel, GradioRootModel, JsonData
from gradio.exceptions import Error
from gradio.utils import abspath, get_upload_folder, is_in_or_equal
from gradio.utils import abspath, get_hash_seed, get_upload_folder, is_in_or_equal

with warnings.catch_warnings():
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
Expand Down Expand Up @@ -177,8 +180,12 @@ def encode_pil_to_bytes(pil_image, format="png"):
return output_bytes.getvalue()


hash_seed = get_hash_seed().encode("utf-8")


def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
sha1.update(hash_seed)
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
sha1.update(chunk)
Expand All @@ -187,18 +194,21 @@ def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str:

def hash_url(url: str) -> str:
sha1 = hashlib.sha1()
sha1.update(hash_seed)
sha1.update(url.encode("utf-8"))
return sha1.hexdigest()


def hash_bytes(bytes: bytes):
sha1 = hashlib.sha1()
sha1.update(hash_seed)
sha1.update(bytes)
return sha1.hexdigest()


def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
sha1.update(hash_seed)
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
sha1.update(data.encode("utf-8"))
Expand Down Expand Up @@ -260,20 +270,51 @@ def save_file_to_cache(file_path: str | Path, cache_dir: str) -> str:
return full_temp_file_path


def check_public_url(url: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise httpx.RequestError(f"Invalid URL: {url}")
hostname = parsed_url.hostname
if not hostname:
raise httpx.RequestError(f"Invalid URL: {url}")

try:
addrinfo = socket.getaddrinfo(hostname, None)
except socket.gaierror:
raise httpx.RequestError(f"Cannot resolve hostname: {hostname}") from None

for family, _, _, _, sockaddr in addrinfo:
ip = sockaddr[0]
if family == socket.AF_INET6:
ip = ip.split("%")[0] # Remove scope ID if present

if not ipaddress.ip_address(ip).is_global:
raise httpx.RequestError(
f"Non-public IP address found: {ip} for URL: {url}"
)

return True


def save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
check_public_url(url)

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))

if not Path(full_temp_file_path).exists():
with sync_client.stream("GET", url, follow_redirects=True) as r, open(
with sync_client.stream("GET", url, follow_redirects=True) as response, open(
full_temp_file_path, "wb"
) as f:
for chunk in r.iter_raw():
for redirect in response.history:
check_public_url(str(redirect.url))

for chunk in response.iter_raw():
f.write(chunk)

return full_temp_file_path
Expand All @@ -282,6 +323,8 @@ def save_url_to_cache(url: str, cache_dir: str) -> str:
async def async_save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file. Uses async httpx."""
check_public_url(url)

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
Expand All @@ -290,6 +333,9 @@ async def async_save_url_to_cache(url: str, cache_dir: str) -> str:

if not Path(full_temp_file_path).exists():
async with async_client.stream("GET", url, follow_redirects=True) as response:
for redirect in response.history:
check_public_url(str(redirect.url))

async with aiofiles.open(full_temp_file_path, "wb") as f:
async for chunk in response.aiter_raw():
await f.write(chunk)
Expand Down
1 change: 1 addition & 0 deletions gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def __init__(
) -> None:
super().__init__(file, size=size, filename=filename, headers=headers)
self.sha = hashlib.sha1()
self.sha.update(processing_utils.hash_seed)


@python_dataclass(frozen=True)
Expand Down
19 changes: 19 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import traceback
import typing
import urllib.parse
import uuid
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
Expand Down Expand Up @@ -51,6 +52,7 @@
from gradio_client.documentation import document
from typing_extensions import ParamSpec

import gradio
from gradio.context import get_blocks_context
from gradio.data_classes import BlocksConfigDict, FileData
from gradio.exceptions import Error
Expand Down Expand Up @@ -435,6 +437,23 @@ def download_if_url(article: str) -> str:
return article


HASH_SEED_PATH = os.path.join(os.path.dirname(gradio.__file__), "hash_seed.txt")


def get_hash_seed() -> str:
try:
if os.path.exists(HASH_SEED_PATH):
with open(HASH_SEED_PATH) as j:
return j.read().strip()
else:
with open(HASH_SEED_PATH, "w") as j:
seed = uuid.uuid4().hex
j.write(seed)
return seed
except Exception:
return uuid.uuid4().hex


def get_default_args(func: Callable) -> list[Any]:
signature = inspect.signature(func)
return [
Expand Down
29 changes: 29 additions & 0 deletions test/test_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest.mock import patch

import ffmpy
import httpx
import numpy as np
import pytest
from gradio_client import media_data
Expand Down Expand Up @@ -404,3 +405,31 @@ async def test_json_data_not_moved_to_cache():
)
== data
)


@pytest.mark.parametrize(
"url",
[
"https://localhost",
"http://127.0.0.1/file/a/b/c",
"http://[::1]",
"https://192.168.0.1",
"http://10.0.0.1?q=a",
"http://192.168.1.250.nip.io",
],
)
def test_local_urls_fail(url):
with pytest.raises(httpx.RequestError, match="Non-public IP address found"):
processing_utils.check_public_url(url)


@pytest.mark.parametrize(
"url",
[
"https://google.com",
"https://8.8.8.8/",
"http://93.184.215.14.nip.io/",
],
)
def test_public_urls_pass(url):
assert processing_utils.check_public_url(url)

0 comments on commit fe9d1cb

Please sign in to comment.