Skip to content

Commit

Permalink
Merge pull request #10963 from alexhartl:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715745355
  • Loading branch information
The TensorFlow Datasets Authors committed Jan 15, 2025
2 parents 1322866 + ff89242 commit 855b1cd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 11 deletions.
68 changes: 57 additions & 11 deletions tensorflow_datasets/core/download/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import urllib

from etils import epath
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core import units
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.download import checksums as checksums_lib
Expand Down Expand Up @@ -130,6 +131,43 @@ def _get_filename(response: Response) -> str:
return _basename_from_url(response.url)


def _process_gdrive_confirmation(original_url: str, contents: str) -> str:
"""Process Google Drive confirmation page.
Extracts the download link from a Google Drive confirmation page.
Args:
original_url: The URL the confirmation page was originally retrieved from.
contents: The confirmation page's HTML.
Returns:
download_url: The URL for downloading the file.
"""
bs4 = lazy_imports_lib.lazy_imports.bs4
soup = bs4.BeautifulSoup(contents, 'html.parser')
form = soup.find('form')
if not form:
raise ValueError(
f'Failed to obtain confirmation link for GDrive URL {original_url}.'
)
action = form.get('action', '')
if not action:
raise ValueError(
f'Failed to obtain confirmation link for GDrive URL {original_url}.'
)
# Find the <input>s named 'uuid', 'export', 'id' and 'confirm'
input_names = ['uuid', 'export', 'id', 'confirm']
params = {}
for name in input_names:
input_tag = form.find('input', {'name': name})
if input_tag:
params[name] = input_tag.get('value', '')
query_string = urllib.parse.urlencode(params)
download_url = f'{action}?{query_string}' if query_string else action
download_url = urllib.parse.urljoin(original_url, download_url)
return download_url


class _Downloader:
"""Class providing async download API with checksum validation.
Expand Down Expand Up @@ -318,11 +356,26 @@ def _open_with_requests(
session.mount(
'https://', requests.adapters.HTTPAdapter(max_retries=retries)
)
if _DRIVE_URL.match(url):
url = _normalize_drive_url(url)
with session.get(url, stream=True, **kwargs) as response:
_assert_status(response)
yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE))
if (
_DRIVE_URL.match(url)
and 'Content-Disposition' not in response.headers
):
download_url = _process_gdrive_confirmation(url, response.text)
with session.get(
download_url, stream=True, **kwargs
) as download_response:
_assert_status(download_response)
yield (
download_response,
download_response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE),
)
else:
_assert_status(response)
yield (
response,
response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE),
)


@contextlib.contextmanager
Expand All @@ -338,13 +391,6 @@ def _open_with_urllib(
)


def _normalize_drive_url(url: str) -> str:
"""Returns Google Drive url with confirmation token."""
# This bypasses the "Google Drive can't scan this file for viruses" warning
# when dowloading large files.
return url + '&confirm=t'


def _assert_status(response: requests.Response) -> None:
"""Ensure the URL response is 200."""
if response.status_code != 200:
Expand Down
10 changes: 10 additions & 0 deletions tensorflow_datasets/core/download/downloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional
from unittest import mock

import bs4
from etils import epath
import pytest
from tensorflow_datasets import testing
Expand All @@ -36,6 +37,7 @@ def __init__(self, url, content, cookies=None, headers=None, status_code=200):
self.status_code = status_code
# For urllib codepath
self.read = self.raw.read
self.text = ''

def __enter__(self):
return self
Expand Down Expand Up @@ -78,6 +80,14 @@ def setUp(self):
lambda *a, **kw: _FakeResponse(self.url, self.response, self.cookies),
).start()

bs_mock = mock.MagicMock(spec=bs4.BeautifulSoup)
form_mock = mock.MagicMock()
form_mock.get.return_value = 'x'
bs_mock.find.return_value = form_mock
mock.patch.object(
bs4, 'BeautifulSoup', autospec=True, return_value=bs_mock
).start()

def test_ok(self):
promise = self.downloader.download(self.url, self.tmp_dir)
future = promise.get()
Expand Down

0 comments on commit 855b1cd

Please sign in to comment.