From 6afe310319c37f966fc29e8eaa774cfb7076710e Mon Sep 17 00:00:00 2001 From: Andrew Grigorev Date: Tue, 21 May 2019 21:41:59 +0300 Subject: [PATCH 1/2] remote: implement Google Drive --- MANIFEST.in | 1 + dvc/config.py | 11 + dvc/data_cloud.py | 2 + dvc/remote/__init__.py | 2 + dvc/remote/gdrive/__init__.py | 202 +++++++++++++ dvc/remote/gdrive/client.py | 284 ++++++++++++++++++ dvc/remote/gdrive/exceptions.py | 17 ++ dvc/remote/gdrive/google-dvc-client-id.json | 1 + dvc/remote/gdrive/oauth2.py | 112 +++++++ dvc/remote/gdrive/utils.py | 78 +++++ dvc/remote/gdrive/waitable_lock.py | 29 ++ dvc/scheme.py | 1 + scripts/ci/decrypt_gdrive_oauth2.py | 39 +++ .../068b8e92002dd24414a9995a80726a14.enc | Bin 0 -> 496 bytes .../589e2f63a0de57566be6c247074399db.enc | Bin 0 -> 496 bytes scripts/ci/install.sh | 2 + setup.py | 3 + tests/conftest.py | 5 + tests/func/test_data_cloud.py | 44 +++ tests/func/test_gdrive.py | 79 +++++ tests/unit/remote/gdrive/__init__.py | 0 tests/unit/remote/gdrive/conftest.py | 140 +++++++++ tests/unit/remote/gdrive/test_client.py | 165 ++++++++++ tests/unit/remote/gdrive/test_gdrive.py | 108 +++++++ tests/unit/remote/gdrive/test_oauth2.py | 8 + tests/unit/remote/gdrive/test_utils.py | 20 ++ 26 files changed, 1353 insertions(+) create mode 100644 MANIFEST.in create mode 100644 dvc/remote/gdrive/__init__.py create mode 100644 dvc/remote/gdrive/client.py create mode 100644 dvc/remote/gdrive/exceptions.py create mode 100644 dvc/remote/gdrive/google-dvc-client-id.json create mode 100644 dvc/remote/gdrive/oauth2.py create mode 100644 dvc/remote/gdrive/utils.py create mode 100644 dvc/remote/gdrive/waitable_lock.py create mode 100644 scripts/ci/decrypt_gdrive_oauth2.py create mode 100644 scripts/ci/gdrive-oauth2/068b8e92002dd24414a9995a80726a14.enc create mode 100644 scripts/ci/gdrive-oauth2/589e2f63a0de57566be6c247074399db.enc create mode 100644 tests/func/test_gdrive.py create mode 100644 tests/unit/remote/gdrive/__init__.py create mode 100644 tests/unit/remote/gdrive/conftest.py create mode 100644 tests/unit/remote/gdrive/test_client.py create mode 100644 tests/unit/remote/gdrive/test_gdrive.py create mode 100644 tests/unit/remote/gdrive/test_oauth2.py create mode 100644 tests/unit/remote/gdrive/test_utils.py diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..0f674bfb5e --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include dvc/remote/gdrive/google-dvc-client-id.json diff --git a/dvc/config.py b/dvc/config.py index 10b228b566..6059dea899 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -154,6 +154,8 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_CORE_ANALYTICS_SCHEMA = BOOL_SCHEMA SECTION_CORE_CHECKSUM_JOBS = "checksum_jobs" SECTION_CORE_CHECKSUM_JOBS_SCHEMA = And(Use(int), lambda x: x > 0) + SECTION_CORE_OAUTH2_FLOW_RUNNER = "oauth2_flow_runner" + SECTION_CORE_OAUTH2_FLOW_RUNNER_SCHEMA = Choices("console", "local") SECTION_CACHE = "cache" SECTION_CACHE_DIR = "dir" @@ -195,6 +197,9 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional( SECTION_CORE_CHECKSUM_JOBS, default=None ): SECTION_CORE_CHECKSUM_JOBS_SCHEMA, + Optional( + SECTION_CORE_OAUTH2_FLOW_RUNNER, default="console" + ): SECTION_CORE_OAUTH2_FLOW_RUNNER_SCHEMA, } # backward compatibility @@ -228,6 +233,10 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional(SECTION_GCP_PROJECTNAME): str, } + SECTION_GDRIVE_SCOPES = "scopes" + SECTION_GDRIVE_CREDENTIALPATH = SECTION_AWS_CREDENTIALPATH + SECTION_GDRIVE_OAUTH_ID = "oauth_id" + # backward compatibility SECTION_LOCAL = "local" SECTION_LOCAL_STORAGEPATH = SECTION_AWS_STORAGEPATH @@ -259,6 +268,8 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA, Optional(SECTION_AWS_SSE): str, Optional(SECTION_GCP_PROJECTNAME): str, + Optional(SECTION_GDRIVE_SCOPES): str, + Optional(SECTION_GDRIVE_OAUTH_ID, default="default"): str, Optional(SECTION_CACHE_TYPE): SECTION_CACHE_TYPE_SCHEMA, Optional(SECTION_CACHE_PROTECTED, default=False): BOOL_SCHEMA, Optional(SECTION_REMOTE_USER): str, diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 3568ae4cbb..3c34223417 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -8,6 +8,7 @@ from dvc.remote import Remote from dvc.remote.s3 import RemoteS3 from dvc.remote.gs import RemoteGS +from dvc.remote.gdrive import RemoteGDrive from dvc.remote.azure import RemoteAZURE from dvc.remote.oss import RemoteOSS from dvc.remote.ssh import RemoteSSH @@ -33,6 +34,7 @@ class DataCloud(object): CLOUD_MAP = { "aws": RemoteS3, "gcp": RemoteGS, + "gdrive": RemoteGDrive, "azure": RemoteAZURE, "oss": RemoteOSS, "ssh": RemoteSSH, diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index e8ffe81f45..f14ed9f5d4 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from dvc.remote.azure import RemoteAZURE +from dvc.remote.gdrive import RemoteGDrive from dvc.remote.gs import RemoteGS from dvc.remote.hdfs import RemoteHDFS from dvc.remote.local import RemoteLOCAL @@ -15,6 +16,7 @@ REMOTES = [ RemoteAZURE, + RemoteGDrive, RemoteGS, RemoteHDFS, RemoteHTTP, diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py new file mode 100644 index 0000000000..91de3f9828 --- /dev/null +++ b/dvc/remote/gdrive/__init__.py @@ -0,0 +1,202 @@ +from __future__ import unicode_literals + +import os +import logging + +try: + import google_auth_oauthlib + from dvc.remote.gdrive.client import GDriveClient +except ImportError: + google_auth_oauthlib = None + +from dvc.scheme import Schemes +from dvc.path_info import CloudURLInfo +from dvc.remote.base import RemoteBASE +from dvc.config import Config +from dvc.remote.gdrive.utils import ( + TrackFileReadProgress, + only_once, + metadata_isdir, + shared_token_warning, +) +from dvc.remote.gdrive.exceptions import GDriveError, GDriveResourceNotFound + + +logger = logging.getLogger(__name__) + + +class GDriveURLInfo(CloudURLInfo): + @property + def netloc(self): + return self.parsed.netloc + + +class RemoteGDrive(RemoteBASE): + """Google Drive remote implementation + + ## Some notes on Google Drive design + + Google Drive differs from S3 and GS remotes - it identifies the resources + by IDs instead of paths. + + Folders are regular resources with an `application/vnd.google-apps.folder` + MIME type. Resource can have multiple parent folders, and also there could + be multiple resources with the same name linked to a single folder, so + files could be duplicated. + + There are multiple root folders accessible from a single user account: + - `root` (special ID) - alias for the "My Drive" folder + - `appDataFolder` (special ID) - alias for the hidden application + space root folder + - shared drives root folders + + ## Example URLs + + - Datasets/my-dataset inside "My Drive" folder: + + gdrive://root/Datasets/my-dataset + + - Folder by ID (recommended): + + gdrive://1r3UbnmS5B4-7YZPZmyqJuCxLVps1mASC + + (get it https://drive.google.com/drive/folders/{here}) + + - Dataset named "my-dataset" in the hidden application folder: + + gdrive://appDataFolder/my-dataset + + (this one wouldn't be visible through Google Drive web UI and + couldn't be shared) + """ + + scheme = Schemes.GDRIVE + path_cls = GDriveURLInfo + REGEX = r"^gdrive://.*$" + REQUIRES = {"google-auth-oauthlib": google_auth_oauthlib} + PARAM_CHECKSUM = "md5Checksum" + SPACE_DRIVE = "drive" + SCOPE_DRIVE = "https://www.googleapis.com/auth/drive" + SPACE_APPDATA = "appDataFolder" + SCOPE_APPDATA = "https://www.googleapis.com/auth/drive.appdata" + DEFAULT_OAUTH_ID = "default" + + # Default credential is needed to show the string of "Data Version + # Control" in OAuth dialog application name and icon in authorized + # applications list in Google account security settings. Also, the + # quota usage is limited by the application defined by client_id. + # The good practice would be to suggest the user to create their + # own application credentials. + DEFAULT_CREDENTIALPATH = os.path.join( + os.path.dirname(__file__), "google-dvc-client-id.json" + ) + + def __init__(self, repo, config): + super(RemoteGDrive, self).__init__(repo, config) + self.path_info = self.path_cls(config[Config.SECTION_REMOTE_URL]) + self.root = self.path_info.netloc.lower() + if self.root == self.SPACE_APPDATA.lower(): + default_scopes = self.SCOPE_APPDATA + space = self.SPACE_APPDATA + else: + default_scopes = self.SCOPE_DRIVE + space = self.SPACE_DRIVE + if Config.SECTION_GDRIVE_CREDENTIALPATH not in config: + shared_token_warning() + credentialpath = config.get( + Config.SECTION_GDRIVE_CREDENTIALPATH, + self.DEFAULT_CREDENTIALPATH, + ) + scopes = config.get(Config.SECTION_GDRIVE_SCOPES, default_scopes) + # scopes should be a list and it is space-delimited in all + # configs, and `.split()` also works for a single-element list + scopes = scopes.split() + + core_config = self.repo.config.config[Config.SECTION_CORE] + oauth2_flow_runner = core_config.get( + Config.SECTION_CORE_OAUTH2_FLOW_RUNNER, "console" + ) + + self.client = GDriveClient( + space, + config.get(Config.SECTION_GDRIVE_OAUTH_ID, self.DEFAULT_OAUTH_ID), + credentialpath, + scopes, + oauth2_flow_runner, + ) + + def get_file_checksum(self, path_info): + metadata = self.client.get_metadata(path_info, fields=["md5Checksum"]) + return metadata["md5Checksum"] + + def exists(self, path_info): + return self.client.exists(path_info) + + def batch_exists(self, path_infos, callback): + results = [] + for path_info in path_infos: + results.append(self.exists(path_info)) + callback.update(str(path_info)) + return results + + def list_cache_paths(self): + try: + root = self.client.get_metadata(self.path_info) + except GDriveResourceNotFound as e: + logger.debug("list_cache_paths: {}".format(e)) + else: + prefix = self.path_info.path + for i in self.client.list_children(root["id"]): + yield prefix + "/" + i + + @only_once + def mkdir(self, parent, name): + return self.client.mkdir(parent, name) + + def makedirs(self, path_info): + parent = path_info.netloc + parts = iter(path_info.path.split("/")) + current_path = ["gdrive://" + path_info.netloc] + for part in parts: + try: + metadata = self.client.get_metadata( + self.path_cls.from_parts( + self.scheme, parent, path="/" + part + ) + ) + except GDriveResourceNotFound: + break + else: + current_path.append(part) + if not metadata_isdir(metadata): + raise GDriveError( + "{} is not a folder".format("/".join(current_path)) + ) + parent = metadata["id"] + to_create = [part] + list(parts) + for part in to_create: + parent = self.mkdir(parent, part)["id"] + return parent + + def _upload(self, from_file, to_info, name, no_progress_bar): + + dirname = to_info.parent.path + if dirname: + try: + parent = self.client.get_metadata(to_info.parent) + except GDriveResourceNotFound: + parent = self.makedirs(to_info.parent) + else: + parent = to_info.netloc + + from_file = open(from_file, "rb") + if not no_progress_bar: + from_file = TrackFileReadProgress(name, from_file) + + try: + self.client.upload(parent, to_info, from_file) + finally: + from_file.close() + + def _download(self, from_info, to_file, name, no_progress_bar): + self.client.download(from_info, to_file, name, no_progress_bar) diff --git a/dvc/remote/gdrive/client.py b/dvc/remote/gdrive/client.py new file mode 100644 index 0000000000..ca88744f8b --- /dev/null +++ b/dvc/remote/gdrive/client.py @@ -0,0 +1,284 @@ +from time import sleep +import logging +import posixpath +import os + +from funcy import cached_property + +from requests import ConnectionError + +from dvc.progress import progress +from dvc.remote.gdrive.utils import ( + metadata_isdir, + response_is_ratelimit, + MIME_GOOGLE_APPS_FOLDER, + response_error_message, +) +from dvc.remote.gdrive.exceptions import ( + GDriveError, + GDriveHTTPError, + GDriveResourceNotFound, +) +from dvc.remote.gdrive.oauth2 import OAuth2 + +logger = logging.getLogger(__name__) + + +class GDriveClient: + + GOOGLEAPIS_BASE_URL = "https://www.googleapis.com/" + TIMEOUT = (5, 60) + + def __init__( + self, + space, + oauth_id, + credentialpath, + scopes, + oauth2_flow_runner, + max_retries=10, + ): + self.space = space + self.oauth_id = oauth_id + self.credentialpath = credentialpath + self.scopes = scopes + self.oauth2_flow_runner = oauth2_flow_runner + self.max_retries = max_retries + self.oauth2 = OAuth2( + oauth_id, credentialpath, scopes, oauth2_flow_runner + ) + + @cached_property + def session(self): + """AuthorizedSession to communicate with https://googleapis.com + + Security notice: + + It always adds the Authorization header to the requests, not paying + attention is request is for googleapis.com or not. It is just how + AuthorizedSession from google-auth implements adding its headers. Don't + use RemoteGDrive.session() to send requests to domains other than + googleapis.com. + """ + return self.oauth2.get_session() + + def request(self, method, path, *args, **kwargs): + # Google Drive has tight rate limits, which strikes the + # performance and gives the 403 and 429 errors. + # See https://developers.google.com/drive/api/v3/handle-errors + retries = 0 + exponential_backoff = 1 + if "timeout" not in kwargs: + kwargs["timeout"] = self.TIMEOUT + while retries < self.max_retries: + retries += 1 + response = self.session.request( + method, self.GOOGLEAPIS_BASE_URL + path, *args, **kwargs + ) + if response_is_ratelimit(response) or response.status_code >= 500: + logger.debug( + "got {} response, will retry in {} sec".format( + response.status_code, exponential_backoff + ) + ) + sleep(exponential_backoff) + exponential_backoff *= 2 + else: + break + if response.status_code >= 400: + raise GDriveHTTPError(response) + return response + + def search(self, parent=None, name=None, add_params={}): + query = [] + if parent is not None: + query.append("'{}' in parents".format(parent)) + if name is not None: + query.append("name = '{}'".format(name)) + params = {"q": " and ".join(query), "spaces": self.space} + params.update(add_params) + while True: + data = self.request("GET", "drive/v3/files", params=params).json() + for i in data["files"]: + yield i + if not data.get("nextPageToken"): + break + params["pageToken"] = data["nextPageToken"] + + def get_metadata(self, path_info, fields=None): + parent = self.request( + "GET", "drive/v3/files/" + path_info.netloc + ).json() + current_path = ["gdrive://" + path_info.netloc] + parts = path_info.path.split("/") + kwargs = [{} for i in parts] + if fields is not None: + # only specify fields for the last search query + kwargs[-1]["add_params"] = { + "fields": "files({})".format(",".join(fields)) + } + for part, kwargs in zip(parts, kwargs): + if not metadata_isdir(parent): + raise GDriveError( + "{} is not a folder".format("/".join(current_path)) + ) + current_path.append(part) + files = list(self.search(parent["id"], part, **kwargs)) + if len(files) > 1: + raise GDriveError( + "path {} is duplicated".format("/".join(current_path)) + ) + elif len(files) == 0: + raise GDriveResourceNotFound("/".join(current_path)) + parent = files[0] + return parent + + def exists(self, path_info): + try: + self.get_metadata(path_info, fields=["id"]) + return True + except GDriveResourceNotFound: + return False + + def list_children(self, folder_id): + for i in self.search(parent=folder_id): + if metadata_isdir(i): + for j in self.list_children(i["id"]): + yield i["name"] + "/" + j + else: + yield i["name"] + + def mkdir(self, parent, name): + data = { + "name": name, + "mimeType": MIME_GOOGLE_APPS_FOLDER, + "parents": [parent], + "spaces": self.space, + } + return self.request("POST", "drive/v3/files", json=data).json() + + def _resumable_upload_initiate(self, parent, filename): + response = self.request( + "POST", + "upload/drive/v3/files", + params={"uploadType": "resumable"}, + json={"name": filename, "space": self.space, "parents": [parent]}, + ) + return response.headers["Location"] + + def _resumable_upload_first_request( + self, resumable_upload_url, from_file, to_info, file_size + ): + try: + # outside of self.request() because this process + # doesn't need it to handle errors and retries, + # they are handled in the next "while" loop + response = self.session.put( + resumable_upload_url, + data=from_file, + headers={"Content-Length": str(file_size)}, + timeout=self.TIMEOUT, + ) + return response.status_code in (200, 201) + except ConnectionError: + return False + + def _resumable_upload_resume( + self, resumable_upload_url, from_file, to_info, file_size + ): + # determine the offset + response = self.session.put( + resumable_upload_url, + headers={ + "Content-Length": str(0), + "Content-Range": "bytes */{}".format(file_size), + }, + timeout=self.TIMEOUT, + ) + if response.status_code in (200, 201): + # file has been already uploaded + return True + elif response.status_code == 404: + # restarting upload from the beginning wouldn't make a + # profit, so it is better to notify the user + raise GDriveError("upload failed, try again") + elif response.status_code != 308: + logger.error( + "upload resume failure: {}".format( + response_error_message(response) + ) + ) + return False + # ^^ response.status_code is 308 (Resume Incomplete) - continue + # the upload + + if "Range" in response.headers: + # if Range header contains a string "bytes 0-9/20" + # then the server has received the bytes from 0 to 9 + # (including the ends), so upload should be resumed from + # byte 10 + offset = int(response.headers["Range"].split("-")[-1]) + 1 + else: + # there could be no Range header in the server response, + # then upload should be resumed from start + offset = 0 + logger.debug( + "resuming {} upload from offset {}".format(to_info, offset) + ) + + # resume the upload + from_file.seek(offset, 0) + response = self.session.put( + resumable_upload_url, + data=from_file, + headers={ + "Content-Length": str(file_size - offset), + "Content-Range": "bytes {}-{}/{}".format( + offset, file_size - 1, file_size + ), + }, + timeout=self.TIMEOUT, + ) + return response.status_code in (200, 201) + + def upload(self, parent, to_info, from_file): + # Resumable upload protocol implementation + # https://developers.google.com/drive/api/v3/manage-uploads#resumable + resumable_upload_url = self._resumable_upload_initiate( + parent, posixpath.basename(to_info.path) + ) + file_size = os.fstat(from_file.fileno()).st_size + success = self._resumable_upload_first_request( + resumable_upload_url, from_file, to_info, file_size + ) + errors_count = 0 + while not success: + try: + success = self._resumable_upload_resume( + resumable_upload_url, from_file, to_info, file_size + ) + except ConnectionError: + errors_count += 1 + if errors_count >= 10: + raise + sleep(1.0) + + def download(self, from_info, to_file, name, no_progress_bar): + metadata = self.get_metadata( + from_info, fields=["id", "mimeType", "size"] + ) + response = self.request( + "GET", + "drive/v3/files/" + metadata["id"], + params={"alt": "media"}, + stream=True, + ) + current = 0 + if response.status_code != 200: + raise GDriveHTTPError(response) + with open(to_file, "wb") as f: + for chunk in response.iter_content(4096): + f.write(chunk) + if not no_progress_bar: + current += len(chunk) + progress.update_target(name, current, metadata["size"]) diff --git a/dvc/remote/gdrive/exceptions.py b/dvc/remote/gdrive/exceptions.py new file mode 100644 index 0000000000..f7721d4276 --- /dev/null +++ b/dvc/remote/gdrive/exceptions.py @@ -0,0 +1,17 @@ +from dvc.remote.gdrive.utils import response_error_message + + +class GDriveError(Exception): + pass + + +class GDriveHTTPError(GDriveError): + def __init__(self, response): + super(GDriveHTTPError, self).__init__(response_error_message(response)) + + +class GDriveResourceNotFound(GDriveError): + def __init__(self, path): + super(GDriveResourceNotFound, self).__init__( + "'{}' resource not found".format(path) + ) diff --git a/dvc/remote/gdrive/google-dvc-client-id.json b/dvc/remote/gdrive/google-dvc-client-id.json new file mode 100644 index 0000000000..8c566e466e --- /dev/null +++ b/dvc/remote/gdrive/google-dvc-client-id.json @@ -0,0 +1 @@ +{"installed":{"client_id":"719861249063-v4an78j9grdtuuuqg3lnm0sugna6v3lh.apps.googleusercontent.com","project_id":"data-version-control","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"2fy_HyzSwkxkGzEken7hThXb","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}} diff --git a/dvc/remote/gdrive/oauth2.py b/dvc/remote/gdrive/oauth2.py new file mode 100644 index 0000000000..f6fa29c1ce --- /dev/null +++ b/dvc/remote/gdrive/oauth2.py @@ -0,0 +1,112 @@ +from hashlib import md5 +import datetime +import json +import os + +from google_auth_oauthlib.flow import InstalledAppFlow +from google.auth.transport.requests import AuthorizedSession +import google.oauth2.credentials + +from dvc.config import Config +from dvc.remote.gdrive.waitable_lock import WaitableLock + + +class OAuth2(object): + + DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S" + + def __init__(self, oauth_id, credentialpath, scopes, flow_runner): + self.oauth_id = oauth_id + self.credentialpath = credentialpath + self.scopes = scopes + self.flow_runner = flow_runner + + def get_session(self): + creds_storage, creds_storage_lock = self._get_storage_lock() + with creds_storage_lock: + if os.path.exists(creds_storage): + creds = self._load_credentials(creds_storage) + else: + creds = self._acquire_credentials() + self._save_credentials(creds_storage, creds) + return AuthorizedSession(creds) + + def _get_creds_id(self, client_id): + plain_text = "|".join([self.oauth_id, client_id] + self.scopes) + hashed = md5(plain_text.encode("ascii")).hexdigest() + return hashed + + def _acquire_credentials(self): + # Create the flow using the client secrets file from the + # Google API Console. + flow = InstalledAppFlow.from_client_secrets_file( + self.credentialpath, scopes=self.scopes + ) + if self.flow_runner == "local": + creds = flow.run_local_server() + elif self.flow_runner == "console": + creds = flow.run_console() + else: + raise ValueError( + "oauth2 flow runner should be 'local' or 'console'" + ) + return creds + + def _load_credentials(self, creds_storage): + """Load credentials from json file and refresh them if needed + + Should be called under lock. + """ + info = json.load(open(creds_storage)) + creds = google.oauth2.credentials.Credentials( + token=info["token"], + refresh_token=info["refresh_token"], + token_uri=info["token_uri"], + client_id=info["client_id"], + client_secret=info["client_secret"], + scopes=self.scopes, + ) + creds.expiry = datetime.datetime.strptime( + info["expiry"], self.DATETIME_FORMAT + ) + if creds.expired: + creds.refresh(google.auth.transport.requests.Request()) + self._save_credentials(creds_storage, creds) + return creds + + def _save_credentials(self, creds_storage, creds): + """Save credentials to the json file + + Should be called under lock. + """ + with open(creds_storage, "w") as f: + info = { + "token": creds.token, + "refresh_token": creds.refresh_token, + "token_uri": creds.token_uri, + "client_id": creds.client_id, + "client_secret": creds.client_secret, + "scopes": creds.scopes, + "expiry": creds.expiry.strftime(self.DATETIME_FORMAT), + } + json.dump(info, f) + + def get_storage_filename(self): + creds_storage_dir = os.path.join( + Config.get_global_config_dir(), "gdrive-oauth2" + ) + if not os.path.exists(creds_storage_dir): + os.makedirs(creds_storage_dir) + info = json.load(open(self.credentialpath)) + creds_id = self._get_creds_id(info["installed"]["client_id"]) + return os.path.join(creds_storage_dir, creds_id) + + def _get_storage_lock(self): + creds_storage = self.get_storage_filename() + # 5 minutes timeout is needed to allow the user to get the + # token when accessing the remote first time + timeout = 5 * 60 + return ( + creds_storage, + WaitableLock(creds_storage + ".lock", timeout=timeout), + ) diff --git a/dvc/remote/gdrive/utils.py b/dvc/remote/gdrive/utils.py new file mode 100644 index 0000000000..516f59f60a --- /dev/null +++ b/dvc/remote/gdrive/utils.py @@ -0,0 +1,78 @@ +import functools +import os +import threading +import logging + +from dvc.progress import progress + + +logger = logging.getLogger(__name__) + + +MIME_GOOGLE_APPS_FOLDER = "application/vnd.google-apps.folder" + + +class TrackFileReadProgress(object): + def __init__(self, progress_name, fobj): + self.progress_name = progress_name + self.fobj = fobj + self.file_size = os.fstat(fobj.fileno()).st_size + + def read(self, size): + progress.update_target( + self.progress_name, self.fobj.tell(), self.file_size + ) + return self.fobj.read(size) + + def __getattr__(self, attr): + return getattr(self.fobj, attr) + + +def only_once(func): + lock = threading.Lock() + locks = {} + results = {} + + @functools.wraps(func) + def wrapped(*args, **kwargs): + key = (args, tuple(kwargs.items())) + # could do with just setdefault, but it would require + # create/delete a "default" Lock() object for each call, so it + # is better to lock a single one for a short time + with lock: + if key not in locks: + locks[key] = threading.Lock() + with locks[key]: + if key not in results: + results[key] = func(*args, **kwargs) + return results[key] + + return wrapped + + +def response_error_message(response): + try: + message = response.json()["error"]["message"] + except (TypeError, KeyError): + message = response.text + return "HTTP {}: {}".format(response.status_code, message) + + +def response_is_ratelimit(response): + if response.status_code not in (403, 429): + return False + errors = response.json()["error"]["errors"] + domains = [i["domain"] for i in errors] + return "usageLimits" in domains + + +def metadata_isdir(metadata): + return metadata["mimeType"] == MIME_GOOGLE_APPS_FOLDER + + +@only_once +def shared_token_warning(): + logger.warning( + "Warning: a shared GoogleAPI token is in use. " + "Please create your own token." + ) diff --git a/dvc/remote/gdrive/waitable_lock.py b/dvc/remote/gdrive/waitable_lock.py new file mode 100644 index 0000000000..6c2efbf709 --- /dev/null +++ b/dvc/remote/gdrive/waitable_lock.py @@ -0,0 +1,29 @@ +from time import time, sleep +from threading import Lock + +import zc.lockfile + + +class WaitableLock(object): + def __init__(self, lock_file, timeout=5): + self.lock_file = lock_file + self.timeout = timeout + self._thread_lock = Lock() + self._lock = None + + def __enter__(self): + t0 = time() + self._thread_lock.acquire() + while time() - t0 < self.timeout: + try: + self._lock = zc.lockfile.LockFile(self.lock_file) + except zc.lockfile.LockError: + sleep(0.5) + else: + return + self._lock = zc.lockfile.LockFile(self.lock_file) + + def __exit__(self, typ, value, tbck): + self._lock.close() + self._lock = None + self._thread_lock.release() diff --git a/dvc/scheme.py b/dvc/scheme.py index e12b768f58..5f7a8d1a28 100644 --- a/dvc/scheme.py +++ b/dvc/scheme.py @@ -9,5 +9,6 @@ class Schemes: HTTP = "http" HTTPS = "https" GS = "gs" + GDRIVE = "gdrive" LOCAL = "local" OSS = "oss" diff --git a/scripts/ci/decrypt_gdrive_oauth2.py b/scripts/ci/decrypt_gdrive_oauth2.py new file mode 100644 index 0000000000..7bd9150e76 --- /dev/null +++ b/scripts/ci/decrypt_gdrive_oauth2.py @@ -0,0 +1,39 @@ +from subprocess import check_call +import os +import sys + +from dvc.config import Config + +OAUTH2_TOKEN_FILE_KEY = os.getenv("OAUTH2_TOKEN_FILE_KEY") +OAUTH2_TOKEN_FILE_IV = os.getenv("OAUTH2_TOKEN_FILE_IV") +if OAUTH2_TOKEN_FILE_KEY is None or OAUTH2_TOKEN_FILE_IV is None: + print("{}:".format(sys.argv[0])) + print("OAUTH2_TOKEN_FILE_KEY or OAUTH2_TOKEN_FILE_IV are not defined.") + print("Skipping decrypt.") + sys.exit(0) + +src = os.path.join("scripts", "ci", "gdrive-oauth2") +dest = os.path.join(Config.get_global_config_dir(), "gdrive-oauth2") +if not os.path.exists(dest): + os.makedirs(dest) + +for enc_filename in os.listdir(src): + filename, ext = os.path.splitext(enc_filename) + if ext != ".enc": + print("Skipping {}".format(enc_filename)) + continue + print("Decrypting {}".format(enc_filename)) + cmd = [ + "openssl", + "aes-256-cbc", + "-d", + "-K", + OAUTH2_TOKEN_FILE_KEY, + "-iv", + OAUTH2_TOKEN_FILE_IV, + "-in", + os.path.join(src, enc_filename), + "-out", + os.path.join(dest, filename), + ] + check_call(cmd) diff --git a/scripts/ci/gdrive-oauth2/068b8e92002dd24414a9995a80726a14.enc b/scripts/ci/gdrive-oauth2/068b8e92002dd24414a9995a80726a14.enc new file mode 100644 index 0000000000000000000000000000000000000000..3a47a817242c55fc62c155bd465182fabf1f2b6a GIT binary patch literal 496 zcmVwBUS14r_Jbh#r1M8Yfl zJR0t-`c|~d>*|XEWVhe#czo-+u{0F&!c}~E!xAuuaTc@!E`+t}w?c{HMxPIJ<-Xg{ zbm6M`BwJ~A>WkIO>Q93LpuB^0nYO<7V$Y&Nw6FQ+ph>gw7U1q6(D7>lssE_L#~s77 zitG=j4{2MU7vS{BB2xEy9Bw62*vGv;qyuGkuAo#fuphY z9sdP>Zb?1RONN!Yh&)C4uGX*s(U1%+2VstyDGAeL?Ah}cO=~vJf`G8HJnJ;x#YceW z!T|j`vkBZ3#NykOHO?LghtR-hzWS_|{RfzAhe13!=jY~$x%_2P&(uzSxT5??sIQjD mV^mcdlLM9gYJQ}cUq~1lwa1kx6sD2ja^fXKgRWeEqP+-3WCL*k literal 0 HcmV?d00001 diff --git a/scripts/ci/gdrive-oauth2/589e2f63a0de57566be6c247074399db.enc b/scripts/ci/gdrive-oauth2/589e2f63a0de57566be6c247074399db.enc new file mode 100644 index 0000000000000000000000000000000000000000..87cf3f17245f37025cd67b408969d4483fb8112f GIT binary patch literal 496 zcmVOU!sb z)36R=yVQ?a?_2bM5FtVmLy)C6>|{lhAz`+JSiHi}&PBufkQ%nykNQ`&&yZHoc>_iy9m3~Y z8QD)xKLgRSmv~ny*{5~E5(O&s>u67<HDpwK*KR@5k{x=}iK$~`YcUYa3r!+Z z$>k6~5&4%+Zs~h9v4|Vi7nQC%1_5MpX&(3i=w<%*HupR zISF`}WfZ(u8}Y=bZTYscvFuii5CDj9%-9(Vp;K#pQ=`HVB`~%V5#}*p+^_F8y;d|% z4g9AHTyWJs)Mx7TC!`M{OczF*ZBTn;=7Y~kH4wvn0CGSzKJ%7EAzlw-aT>%nAf=$D zC5n;`!!ir|s^=ok#6DG5A-5{O4~X@>;yp2X6{o73a+u#O3t(2)Nr=4OSqB_uJt-~W m2`YK@k5$LYTew=>I=L+#w<^o-SXQU|{a^-9oN~q{4bTWSQ0wae literal 0 HcmV?d00001 diff --git a/scripts/ci/install.sh b/scripts/ci/install.sh index cefee81e3a..ef0d6b20de 100644 --- a/scripts/ci/install.sh +++ b/scripts/ci/install.sh @@ -28,4 +28,6 @@ if [[ "$TRAVIS_PULL_REQUEST" == "false" && \ aws configure set region us-east-2 openssl enc -d -aes-256-cbc -md md5 -k $GCP_CREDS -in scripts/ci/gcp-creds.json.enc -out scripts/ci/gcp-creds.json + + python scripts/ci/decrypt_gdrive_oauth2.py fi diff --git a/setup.py b/setup.py index 0e69701bfb..55a104a07c 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ def run(self): # Extra dependencies for remote integrations gs = ["google-cloud-storage==1.13.0"] +gdrive = ["google-auth-oauthlib==0.3.0"] s3 = ["boto3==1.9.115"] azure = ["azure-storage-blob==2.0.1"] oss = ["oss2==2.6.1"] @@ -98,6 +99,7 @@ def run(self): "xmltodict>=0.11.0", "awscli>=1.16.125", "google-compute-engine", + "google-auth-oauthlib", "pywin32; sys_platform == 'win32'", "Pygments", # required by collective.checkdocs, "collective.checkdocs", @@ -125,6 +127,7 @@ def run(self): extras_require={ "all": all_remotes, "gs": gs, + "gdrive": gdrive, "s3": s3, "azure": azure, "oss": oss, diff --git a/tests/conftest.py b/tests/conftest.py index 8010ff18f0..0b57277d9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from git.exc import GitCommandNotFound from dvc.remote.config import RemoteConfig +from dvc.remote.gdrive import RemoteGDrive from dvc.utils.compat import cast_bytes_py2 from dvc.remote.ssh.connection import SSHConnection from dvc.repo import Repo as DvcRepo @@ -16,6 +17,10 @@ os.environ[cast_bytes_py2("DVC_TEST")] = cast_bytes_py2("true") +# Make DVC tests use separate OAuth token to access Google Drive +RemoteGDrive.DEFAULT_OAUTH_ID = "test" + + @pytest.fixture(autouse=True) def reset_loglevel(request, caplog): """ diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index a6092c61f8..9a4141c934 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -19,6 +19,7 @@ DataCloud, RemoteS3, RemoteGS, + RemoteGDrive, RemoteAZURE, RemoteOSS, RemoteLOCAL, @@ -69,6 +70,20 @@ def _should_test_aws(): return False +def _should_test_gdrive(): + if os.getenv("DVC_TEST_GDRIVE") == "true": + return True + elif os.getenv("DVC_TEST_GDRIVE") == "false": + return False + oauth_storage = os.path.join( + Config.get_global_config_dir(), + "gdrive-oauth2", + "068b8e92002dd24414a9995a80726a14", + ) + if os.path.exists(oauth_storage): + return True + + def _should_test_gcp(): if os.getenv("DVC_TEST_GCP") == "true": return True @@ -201,6 +216,10 @@ def get_aws_url(): return "s3://" + get_aws_storagepath() +def get_gdrive_url(): + return "gdrive://root/" + str(uuid.uuid4()) + + def get_gcp_storagepath(): return TEST_GCP_REPO_BUCKET + "/" + str(uuid.uuid4()) @@ -234,6 +253,7 @@ def test(self): clist = [ ("s3://mybucket/", RemoteS3), + ("gdrive://root/", RemoteGDrive), ("gs://mybucket/", RemoteGS), ("ssh://user@localhost:/", RemoteSSH), ("http://localhost:8000/", RemoteHTTP), @@ -375,6 +395,17 @@ def _get_cloud_class(self): return RemoteS3 +class TestRemoteGDrive(TestDataCloudBase): + def _should_test(self): + return _should_test_gdrive() + + def _get_url(self): + return get_gdrive_url() + + def _get_cloud_class(self): + return RemoteGDrive + + class TestRemoteGS(TestDataCloudBase): def _should_test(self): return _should_test_gcp() @@ -613,6 +644,19 @@ def _test(self): self._test_cloud(TEST_REMOTE) +class TestRemoteGDriveCLI(TestDataCloudCLIBase): + def _should_test(self): + return _should_test_gdrive() + + def _test(self): + url = get_gdrive_url() + + self.main(["remote", "add", TEST_REMOTE, url]) + self.main(["remote", "modify", TEST_REMOTE, "oauth_id", "test"]) + + self._test_cloud(TEST_REMOTE) + + class TestRemoteGSCLI(TestDataCloudCLIBase): def _should_test(self): return _should_test_gcp() diff --git a/tests/func/test_gdrive.py b/tests/func/test_gdrive.py new file mode 100644 index 0000000000..d12449cc0c --- /dev/null +++ b/tests/func/test_gdrive.py @@ -0,0 +1,79 @@ +from subprocess import check_call +import shutil +import os +import tempfile + +import pytest + +from dvc.main import main +from dvc.remote.gdrive import RemoteGDrive +from dvc.remote.gdrive.client import GDriveClient + + +if os.getenv("DVC_TEST_GDRIVE") != "true": + pytest.skip("Skipping long GDrive tests") + + +client = GDriveClient( + "drive", + "test", + RemoteGDrive.DEFAULT_CREDENTIALPATH, + [RemoteGDrive.SCOPE_DRIVE], + "console", +) +root_id = client.request("GET", "drive/v3/files/root").json()["id"] + + +@pytest.mark.parametrize( + "base_url", + ["gdrive://root/", "gdrive://" + root_id + "/", "gdrive://appDataFolder/"], +) +def test_gdrive_push_pull(repo_dir, dvc_repo, base_url): + + dirname = tempfile.mktemp("", "dvc_test_", "") + url = base_url + dirname + files = [repo_dir.FOO, repo_dir.DATA_SUB.split(os.path.sep)[0]] + + gdrive = RemoteGDrive(dvc_repo, {"url": url}) + + # push files + check_call(["dvc", "add"] + files) + check_call(["dvc", "remote", "add", "gdrive", url]) + check_call(["dvc", "remote", "modify", "gdrive", "oauth_id", "test"]) + assert main(["push", "-r", "gdrive"]) == 0 + + paths = dvc_repo.cache.local.list_cache_paths() + paths = [i.parts[-2:] for i in paths] + + # check that files are correctly uploaded + testdir_meta = gdrive.client.get_metadata(gdrive.path_info) + q = "'{}' in parents".format(testdir_meta["id"]) + found = list(gdrive.client.search(add_params={"q": q})) + assert set(i["name"] for i in found) == set([i[0] for i in paths]) + q = " or ".join("'{}' in parents".format(i["id"]) for i in found) + found = list(gdrive.client.search(add_params={"q": q})) + assert set(i["name"] for i in found) == set(i[1] for i in paths) + + # remove cache and files + shutil.rmtree(".dvc/cache") + for i in files: + if os.path.isdir(i): + shutil.rmtree(i) + else: + os.remove(i) + + # check that they are in list_cache_paths + assert set(gdrive.list_cache_paths()) == { + "/".join([dirname] + list(i)) for i in paths + } + + # pull them back from remote + assert main(["pull", "-r", "gdrive"]) == 0 + + assert set(files) < set(os.listdir(".")) + + # remove the temporary directory on Google Drive + resp = gdrive.client.request( + "DELETE", "drive/v3/files/" + testdir_meta["id"] + ) + print("Delete temp dir: HTTP {}".format(resp.status_code)) diff --git a/tests/unit/remote/gdrive/__init__.py b/tests/unit/remote/gdrive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/remote/gdrive/conftest.py b/tests/unit/remote/gdrive/conftest.py new file mode 100644 index 0000000000..d83723c18b --- /dev/null +++ b/tests/unit/remote/gdrive/conftest.py @@ -0,0 +1,140 @@ +from datetime import datetime, timedelta +import json +import mock + +from google_auth_oauthlib.flow import InstalledAppFlow +import google.oauth2.credentials + +import pytest + +from dvc.repo import Repo +from dvc.remote.gdrive import RemoteGDrive +from dvc.remote.gdrive.client import GDriveClient +from dvc.remote.gdrive.utils import MIME_GOOGLE_APPS_FOLDER +from dvc.remote.gdrive.oauth2 import OAuth2 + + +AUTHORIZATION = {"authorization": "Bearer MOCK_token"} +FOLDER = {"mimeType": MIME_GOOGLE_APPS_FOLDER} +FILE = {"mimeType": "not-a-folder"} + +COMMON_KWARGS = { + "data": None, + "headers": AUTHORIZATION, + "timeout": GDriveClient.TIMEOUT, +} + + +class Response: + def __init__(self, data, status_code=200): + self._data = data + self.text = json.dumps(data) if isinstance(data, dict) else data + self.status_code = status_code + + def json(self): + return self._data + + +@pytest.fixture() +def repo(): + return Repo(".") + + +@pytest.fixture +def gdrive(repo, client): + ret = RemoteGDrive(repo, {"url": "gdrive://root/data"}) + ret.client = client + return ret + + +@pytest.fixture +def client(): + return GDriveClient( + RemoteGDrive.SPACE_DRIVE, + "test", + RemoteGDrive.DEFAULT_CREDENTIALPATH, + RemoteGDrive.SCOPE_DRIVE, + "console", + ) + + +@pytest.fixture(autouse=True) +def no_requests(monkeypatch): + mocked = mock.Mock(return_value=Response("test")) + monkeypatch.setattr("requests.sessions.Session.request", mocked) + return mocked + + +@pytest.fixture() +def mocked_get_metadata(client, monkeypatch): + mocked = mock.Mock( + client.get_metadata, + return_value=dict(id="root", name="root", **FOLDER), + ) + monkeypatch.setattr(client, "get_metadata", mocked) + return mocked + + +@pytest.fixture() +def mocked_search(client, monkeypatch): + mocked = mock.Mock(client.search) + monkeypatch.setattr(client, "search", mocked) + return mocked + + +def _url(url): + return GDriveClient.GOOGLEAPIS_BASE_URL + url + + +def _p(root, path): + return RemoteGDrive.path_cls.from_parts( + "gdrive", netloc=root, path="/" + path + ) + + +@pytest.fixture(autouse=True) +def fake_creds(monkeypatch): + + creds = google.oauth2.credentials.Credentials( + token="MOCK_token", + refresh_token="MOCK_refresh_token", + token_uri="MOCK_token_uri", + client_id="MOCK_client_id", + client_secret="MOCK_client_secret", + scopes=["MOCK_scopes"], + ) + creds.expiry = datetime.now() + timedelta(days=1) + + mocked_flow = mock.Mock() + mocked_flow.run_console.return_value = creds + mocked_flow.run_local_server.return_value = creds + + monkeypatch.setattr( + InstalledAppFlow, + "from_client_secrets_file", + classmethod(lambda *args, **kwargs: mocked_flow), + ) + + monkeypatch.setattr( + OAuth2, "_get_creds_id", mock.Mock(return_value="test") + ) + + +@pytest.fixture(autouse=True) +def no_refresh(monkeypatch): + expired_mock = mock.PropertyMock(return_value=False) + monkeypatch.setattr( + "google.oauth2.credentials.Credentials.expired", expired_mock + ) + refresh_mock = mock.Mock() + monkeypatch.setattr( + "google.oauth2.credentials.Credentials.refresh", refresh_mock + ) + return refresh_mock, expired_mock + + +@pytest.fixture() +def makedirs(gdrive, monkeypatch): + mocked = mock.Mock(gdrive.makedirs, return_value="FOLDER_ID") + monkeypatch.setattr(gdrive, "makedirs", mocked) + return mocked diff --git a/tests/unit/remote/gdrive/test_client.py b/tests/unit/remote/gdrive/test_client.py new file mode 100644 index 0000000000..dd6d5350e8 --- /dev/null +++ b/tests/unit/remote/gdrive/test_client.py @@ -0,0 +1,165 @@ +import mock + +import requests + +import pytest + +from dvc.remote.gdrive.exceptions import GDriveError, GDriveResourceNotFound + +from tests.unit.remote.gdrive.conftest import ( + COMMON_KWARGS, + FOLDER, + FILE, + Response, + _url, + _p, +) + + +def test_request(client, no_requests): + assert client.request("GET", "test").text == "test" + no_requests.assert_called_once_with("GET", _url("test"), **COMMON_KWARGS) + + +def test_request_refresh(client, no_requests, no_refresh): + refresh_mock, _ = no_refresh + no_requests.side_effect = [ + Response("error", 401), + Response("after_refresh", 200), + ] + assert client.request("GET", "test").text == "after_refresh" + refresh_mock.assert_called_once() + assert no_requests.mock_calls == [ + mock.call("GET", _url("test"), **COMMON_KWARGS), + mock.call("GET", _url("test"), **COMMON_KWARGS), + ] + + +def test_request_expired(client, no_requests, no_refresh): + refresh_mock, expired_mock = no_refresh + expired_mock.side_effect = [True, False] + no_requests.side_effect = [Response("test", 200)] + assert client.request("GET", "test").text == "test" + expired_mock.assert_called() + refresh_mock.assert_called_once() + assert no_requests.mock_calls == [ + mock.call("GET", _url("test"), **COMMON_KWARGS) + ] + + +def test_request_retry_and_backoff(client, no_requests, monkeypatch): + no_requests.side_effect = [ + Response("error", 500), + Response("error", 500), + Response("retry", 200), + ] + sleep_mock = mock.Mock() + monkeypatch.setattr("dvc.remote.gdrive.client.sleep", sleep_mock) + assert client.request("GET", "test").text == "retry" + assert no_requests.mock_calls == [ + mock.call("GET", _url("test"), **COMMON_KWARGS), + mock.call("GET", _url("test"), **COMMON_KWARGS), + mock.call("GET", _url("test"), **COMMON_KWARGS), + ] + assert sleep_mock.mock_calls == [mock.call(1), mock.call(2)] + + +def test_request_4xx(client, no_requests): + no_requests.return_value = Response("error", 400) + with pytest.raises(GDriveError): + client.request("GET", "test") + + +def test_search(client, no_requests): + no_requests.side_effect = [ + Response({"files": ["test1"], "nextPageToken": "TEST_nextPageToken"}), + Response({"files": ["test2"]}), + ] + assert list(client.search("test", "root")) == ["test1", "test2"] + + +def test_get_metadata(client, no_requests): + no_requests.side_effect = [ + Response(dict(id="root", name="root", **FOLDER)), + Response({"files": [dict(id="id1", name="path1", **FOLDER)]}), + Response({"files": [dict(id="id2", name="path2", **FOLDER)]}), + ] + client.get_metadata(_p("root", "path1/path2"), ["field1", "field2"]) + assert no_requests.mock_calls == [ + mock.call("GET", _url("drive/v3/files/root"), **COMMON_KWARGS), + mock.call( + "GET", + _url("drive/v3/files"), + params={ + "q": "'root' in parents and name = 'path1'", + "spaces": "drive", + }, + **COMMON_KWARGS + ), + mock.call( + "GET", + _url("drive/v3/files"), + params={ + "q": "'id1' in parents and name = 'path2'", + "spaces": "drive", + "fields": "files(field1,field2)", + }, + **COMMON_KWARGS + ), + ] + + +def test_get_metadata_not_a_folder(client, no_requests, mocked_search): + no_requests.return_value = Response(dict(id="id1", name="root", **FOLDER)) + mocked_search.return_value = [dict(id="id2", name="path1", **FILE)] + with pytest.raises(GDriveError): + client.get_metadata(_p("root", "path1/path2"), ["field1", "field2"]) + client.get_metadata(_p("root", "path1"), ["field1", "field2"]) + + +def test_get_metadata_duplicate(client, no_requests, mocked_search): + no_requests.return_value = Response(dict(id="id1", name="root", **FOLDER)) + mocked_search.return_value = [ + dict(id="id2", name="path1", **FOLDER), + dict(id="id3", name="path1", **FOLDER), + ] + with pytest.raises(GDriveError): + client.get_metadata(_p("root", "path1/path2"), ["field1", "field2"]) + + +def test_get_metadata_not_found(client, no_requests, mocked_search): + no_requests.return_value = Response(dict(id="root", name="root", **FOLDER)) + mocked_search.return_value = [] + with pytest.raises(GDriveResourceNotFound): + client.get_metadata(_p("root", "path1/path2"), ["field1", "field2"]) + + +def test_resumable_upload_first_request(client, no_requests): + resp = Response("", 201) + no_requests.return_value = resp + from_file = mock.Mock() + to_info = mock.Mock() + assert ( + client._resumable_upload_first_request("url", from_file, to_info, 100) + is True + ) + + +def test_resumable_upload_first_request_connection_error(client, no_requests): + no_requests.side_effect = requests.ConnectionError + from_file = mock.Mock() + to_info = mock.Mock() + assert ( + client._resumable_upload_first_request("url", from_file, to_info, 100) + is False + ) + + +def test_resumable_upload_first_request_failure(client, no_requests): + no_requests.return_value = Response("", 400) + from_file = mock.Mock() + to_info = mock.Mock() + assert ( + client._resumable_upload_first_request("url", from_file, to_info, 100) + is False + ) diff --git a/tests/unit/remote/gdrive/test_gdrive.py b/tests/unit/remote/gdrive/test_gdrive.py new file mode 100644 index 0000000000..cc51070821 --- /dev/null +++ b/tests/unit/remote/gdrive/test_gdrive.py @@ -0,0 +1,108 @@ +import mock + +import pytest + +from dvc.remote.gdrive import RemoteGDrive, GDriveError, GDriveResourceNotFound + +from tests.unit.remote.gdrive.conftest import ( + Response, + FOLDER, + FILE, + COMMON_KWARGS, + _p, + _url, +) + + +def test_init_drive(repo): + url = "gdrive://root/data" + gdrive = RemoteGDrive(repo, {"url": url}) + assert gdrive.root == "root" + assert str(gdrive.path_info) == url + assert gdrive.client.scopes == ["https://www.googleapis.com/auth/drive"] + assert gdrive.client.space == RemoteGDrive.SPACE_DRIVE + + +def test_init_appfolder(repo): + url = "gdrive://appdatafolder/data" + gdrive = RemoteGDrive(repo, {"url": url}) + assert gdrive.root == "appdatafolder" + assert str(gdrive.path_info) == url + assert gdrive.client.scopes == [ + "https://www.googleapis.com/auth/drive.appdata" + ] + assert gdrive.client.space == RemoteGDrive.SPACE_APPDATA + + +def test_init_folder_id(repo): + url = "gdrive://folder_id/data" + gdrive = RemoteGDrive(repo, {"url": url}) + assert gdrive.root == "folder_id" + assert str(gdrive.path_info) == url + assert gdrive.client.scopes == ["https://www.googleapis.com/auth/drive"] + assert gdrive.client.space == "drive" + + +def test_get_file_checksum(gdrive, mocked_get_metadata): + mocked_get_metadata.return_value = dict( + id="id1", name="path1", md5Checksum="checksum" + ) + checksum = gdrive.get_file_checksum(_p(gdrive.root, "path1")) + assert checksum == "checksum" + mocked_get_metadata.assert_called_once_with( + _p(gdrive.root, "path1"), fields=["md5Checksum"] + ) + + +def test_list_cache_paths(gdrive, mocked_get_metadata, mocked_search): + mocked_get_metadata.return_value = dict(id="root", name="root", **FOLDER) + mocked_search.side_effect = [ + [dict(id="f1", name="f1", **FOLDER), dict(id="f2", name="f2", **FILE)], + [dict(id="f3", name="f3", **FILE)], + ] + assert list(gdrive.list_cache_paths()) == ["data/f1/f3", "data/f2"] + mocked_get_metadata.assert_called_once_with(_p("root", "data")) + + +def test_list_cache_path_not_found(gdrive, mocked_get_metadata): + mocked_get_metadata.side_effect = GDriveResourceNotFound("test") + assert list(gdrive.list_cache_paths()) == [] + mocked_get_metadata.assert_called_once_with(_p("root", "data")) + + +def test_mkdir(gdrive, no_requests): + no_requests.return_value = Response("test") + assert gdrive.mkdir("root", "test") == "test" + no_requests.assert_called_once_with( + "POST", + _url("drive/v3/files"), + json={ + "name": "test", + "mimeType": FOLDER["mimeType"], + "parents": ["root"], + "spaces": "drive", + }, + **COMMON_KWARGS + ) + + +def test_makedirs(gdrive, monkeypatch, mocked_get_metadata): + mocked_get_metadata.side_effect = [ + dict(id="id1", name="test1", **FOLDER), + GDriveResourceNotFound("test1/test2"), + ] + monkeypatch.setattr( + gdrive, "mkdir", mock.Mock(side_effect=[{"id": "id2"}]) + ) + assert gdrive.makedirs(_p(gdrive.root, "test1/test2")) == "id2" + assert mocked_get_metadata.mock_calls == [ + mock.call(_p(gdrive.root, "test1")), + mock.call(_p("id1", "test2")), + ] + assert gdrive.mkdir.mock_calls == [mock.call("id1", "test2")] + + +def test_makedirs_error(gdrive, mocked_get_metadata): + mocked_get_metadata.side_effect = [dict(id="id1", name="test1", **FILE)] + with pytest.raises(GDriveError): + gdrive.makedirs(_p(gdrive.root, "test1/test2")) diff --git a/tests/unit/remote/gdrive/test_oauth2.py b/tests/unit/remote/gdrive/test_oauth2.py new file mode 100644 index 0000000000..5381ba0b16 --- /dev/null +++ b/tests/unit/remote/gdrive/test_oauth2.py @@ -0,0 +1,8 @@ +from tests.unit.remote.gdrive.conftest import AUTHORIZATION + + +def test_get_session(gdrive, no_requests): + session = gdrive.client.oauth2.get_session() + session.get("https://googleapis.com") + args, kwargs = no_requests.call_args + assert kwargs["headers"]["authorization"] == AUTHORIZATION["authorization"] diff --git a/tests/unit/remote/gdrive/test_utils.py b/tests/unit/remote/gdrive/test_utils.py new file mode 100644 index 0000000000..3457c03de0 --- /dev/null +++ b/tests/unit/remote/gdrive/test_utils.py @@ -0,0 +1,20 @@ +from dvc.remote.gdrive.utils import ( + response_error_message, + response_is_ratelimit, +) + +from tests.unit.remote.gdrive.conftest import Response + + +def test_response_is_ratelimit(gdrive): + assert response_is_ratelimit( + Response({"error": {"errors": [{"domain": "usageLimits"}]}}, 403) + ) + assert not response_is_ratelimit(Response("")) + + +def test_response_error_message(gdrive): + r = Response({"error": {"message": "test"}}) + assert response_error_message(r) == "HTTP 200: test" + r = Response("test") + assert response_error_message(r) == "HTTP 200: test" From 0c455eb7276cc51fdc36e37cc10f0ab31b16b4aa Mon Sep 17 00:00:00 2001 From: Andrew Grigorev Date: Mon, 29 Jul 2019 23:47:56 +0300 Subject: [PATCH 2/2] tests: fix gdrive module level pytest.skip --- tests/func/test_gdrive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/func/test_gdrive.py b/tests/func/test_gdrive.py index d12449cc0c..7fff7d31f8 100644 --- a/tests/func/test_gdrive.py +++ b/tests/func/test_gdrive.py @@ -11,7 +11,7 @@ if os.getenv("DVC_TEST_GDRIVE") != "true": - pytest.skip("Skipping long GDrive tests") + pytest.skip("Skipping long GDrive tests", allow_module_level=True) client = GDriveClient(