From 620bb4e6859a2839a32f7dd59284c34a152d7174 Mon Sep 17 00:00:00 2001 From: Andrew Grigorev Date: Tue, 21 May 2019 21:41:59 +0300 Subject: [PATCH] remote: implement Google Drive --- .appveyor.yml | 4 + MANIFEST.in | 1 + dvc/config.py | 9 + dvc/data_cloud.py | 2 + dvc/path/gdrive.py | 10 + dvc/remote/__init__.py | 10 +- dvc/remote/base.py | 2 +- dvc/remote/gdrive/__init__.py | 552 ++++++++++++++++++ dvc/remote/gdrive/google-dvc-client-id.json | 1 + dvc/remote/gdrive/oauth2.py | 105 ++++ dvc/scheme.py | 1 + scripts/ci/before_install.sh | 10 + .../160f0daf067a7511bd8b8e3b35e3a8bd.enc | Bin 0 -> 496 bytes .../81562e04895c64d8c65a45710a0a8a6b.enc | Bin 0 -> 496 bytes setup.py | 3 + tests/func/test_gdrive.py | 128 ++++ tests/unit/remote/test_gdrive.py | 327 +++++++++++ 17 files changed, 1160 insertions(+), 5 deletions(-) create mode 100644 MANIFEST.in create mode 100644 dvc/path/gdrive.py create mode 100644 dvc/remote/gdrive/__init__.py create mode 100644 dvc/remote/gdrive/google-dvc-client-id.json create mode 100644 dvc/remote/gdrive/oauth2.py create mode 100644 scripts/ci/gdrive-oauth2/160f0daf067a7511bd8b8e3b35e3a8bd.enc create mode 100644 scripts/ci/gdrive-oauth2/81562e04895c64d8c65a45710a0a8a6b.enc create mode 100644 tests/func/test_gdrive.py create mode 100644 tests/unit/remote/test_gdrive.py diff --git a/.appveyor.yml b/.appveyor.yml index c7c530d3f1..7368998767 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -18,6 +18,10 @@ environment: secure: 96fJ3r2i2GohbXHwnSs5N4EplQ7q8YmLpPWM0AC+f4s= CODECOV_TOKEN: secure: XN4jRtmGE5Bqg8pPZkwNs7kn3UEI73Rkldqc0MGsQISZBm5TNJZOPofDMc1QnUsf + OAUTH2_TOKEN_FILE_KEY: + secure: cL2KgINnnWhfNVy+lEI/QmA31cKwYojGhFRFs1x0PAkA5QRvZxYTlBX+XpjQYF8k2A6tqYMUWztZt0dXlMKqLVf5aCtg7z2I5AWW5dhGeTY= + OAUTH2_TOKEN_FILE_IV: + secure: 6lKZF5KNCpP80lFZ7a8yFkvcwllt2eoEC7LpSS5Mlwg4ilFsjFVSm+zw2GqA7NKw AZURE_STORAGE_CONTAINER_NAME: appveyor-tests AZURE_STORAGE_CONNECTION_STRING: DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1; OSS_ENDPOINT: localhost:50004 diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..0f674bfb5e --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include dvc/remote/gdrive/google-dvc-client-id.json diff --git a/dvc/config.py b/dvc/config.py index 40e705e27e..0be31b1856 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -221,6 +221,9 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional(SECTION_GCP_PROJECTNAME): str, } + SECTION_GDRIVE_SCOPES = "gdrive_scopes" + SECTION_GDRIVE_CREDENTIALPATH = SECTION_AWS_CREDENTIALPATH + # backward compatibility SECTION_LOCAL = "local" SECTION_LOCAL_STORAGEPATH = SECTION_AWS_STORAGEPATH @@ -250,6 +253,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional(SECTION_AWS_LIST_OBJECTS, default=False): BOOL_SCHEMA, Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA, Optional(SECTION_GCP_PROJECTNAME): str, + Optional(SECTION_GDRIVE_SCOPES): str, Optional(SECTION_CACHE_TYPE): SECTION_CACHE_TYPE_SCHEMA, Optional(SECTION_CACHE_PROTECTED, default=False): BOOL_SCHEMA, Optional(SECTION_REMOTE_USER): str, @@ -273,11 +277,16 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional(SECTION_STATE_ROW_CLEANUP_QUOTA): And(Use(int), is_percent), } + SECTION_OAUTH2 = "oauth2" + SECTION_OAUTH2_FLOW_RUNNER = "flow_runner" + SECTION_OAUTH2_SCHEMA = {Optional(SECTION_OAUTH2_FLOW_RUNNER): str} + SCHEMA = { Optional(SECTION_CORE, default={}): SECTION_CORE_SCHEMA, Optional(Regex(SECTION_REMOTE_REGEX)): SECTION_REMOTE_SCHEMA, Optional(SECTION_CACHE, default={}): SECTION_CACHE_SCHEMA, Optional(SECTION_STATE, default={}): SECTION_STATE_SCHEMA, + Optional(SECTION_OAUTH2, default={}): SECTION_OAUTH2_SCHEMA, # backward compatibility Optional(SECTION_AWS, default={}): SECTION_AWS_SCHEMA, Optional(SECTION_GCP, default={}): SECTION_GCP_SCHEMA, diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 5802826af3..c518fc6ee7 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -8,6 +8,7 @@ from dvc.remote import Remote from dvc.remote.s3 import RemoteS3 from dvc.remote.gs import RemoteGS +from dvc.remote.gdrive import RemoteGDrive from dvc.remote.azure import RemoteAZURE from dvc.remote.oss import RemoteOSS from dvc.remote.ssh import RemoteSSH @@ -34,6 +35,7 @@ class DataCloud(object): CLOUD_MAP = { "aws": RemoteS3, "gcp": RemoteGS, + "gdrive": RemoteGDrive, "azure": RemoteAZURE, "oss": RemoteOSS, "ssh": RemoteSSH, diff --git a/dvc/path/gdrive.py b/dvc/path/gdrive.py new file mode 100644 index 0000000000..2e89d36419 --- /dev/null +++ b/dvc/path/gdrive.py @@ -0,0 +1,10 @@ +from dvc.scheme import Schemes +from .base import PathBASE + + +class PathGDrive(PathBASE): + scheme = Schemes.GDRIVE + + def __init__(self, root, url=None, path=None): + super(PathGDrive, self).__init__(url, path) + self.root = root diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index 4eb124199a..f6ebc2ae65 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -1,25 +1,27 @@ from __future__ import unicode_literals from dvc.remote.azure import RemoteAZURE +from dvc.remote.gdrive import RemoteGDrive from dvc.remote.gs import RemoteGS from dvc.remote.hdfs import RemoteHDFS -from dvc.remote.local import RemoteLOCAL -from dvc.remote.s3 import RemoteS3 -from dvc.remote.ssh import RemoteSSH from dvc.remote.http import RemoteHTTP from dvc.remote.https import RemoteHTTPS +from dvc.remote.local import RemoteLOCAL from dvc.remote.oss import RemoteOSS +from dvc.remote.s3 import RemoteS3 +from dvc.remote.ssh import RemoteSSH REMOTES = [ RemoteAZURE, + RemoteGDrive, RemoteGS, RemoteHDFS, RemoteHTTP, RemoteHTTPS, + RemoteOSS, RemoteS3, RemoteSSH, - RemoteOSS, # NOTE: RemoteLOCAL is the default ] diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 30f13a8f47..36ca2b61f2 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -415,7 +415,7 @@ def download( from_infos, to_infos, no_progress_bar=False, - name=None, + names=None, resume=False, ): raise RemoteActionNotImplemented("download", self.scheme) diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py new file mode 100644 index 0000000000..fd877d4feb --- /dev/null +++ b/dvc/remote/gdrive/__init__.py @@ -0,0 +1,552 @@ +from __future__ import unicode_literals + +from time import sleep +import threading +import functools +import posixpath +import os +import logging + +try: + import google_auth_oauthlib + from .oauth2 import OAuth2 +except ImportError: + google_auth_oauthlib = None + +from requests import ConnectionError + +from dvc.scheme import Schemes +from dvc.path.gdrive import PathGDrive +from dvc.utils import tmp_fname, move +from dvc.utils.compat import urlparse, makedirs +from dvc.remote.base import RemoteBASE +from dvc.config import Config +from dvc.progress import progress + + +logger = logging.getLogger(__name__) + + +class GDriveError(Exception): + pass + + +class GDriveResourceNotFound(GDriveError): + def __init__(self, path): + super(GDriveResourceNotFound, self).__init__( + "'{}' resource not found".format(path) + ) + + +class track_progress(object): + def __init__(self, progress_name, fobj): + self.progress_name = progress_name + self.fobj = fobj + self.file_size = os.fstat(fobj.fileno()).st_size + + def read(self, size): + progress.update_target( + self.progress_name, self.fobj.tell(), self.file_size + ) + return self.fobj.read(size) + + def __getattr__(self, attr): + return getattr(self.fobj, attr) + + +def only_once(func): + lock = threading.Lock() + locks = {} + results = {} + + @functools.wraps(func) + def wrapped(*args, **kwargs): + key = (args, tuple(kwargs.items())) + # could do with just setdefault, but it would require + # create/delete a "default" Lock() object for each call, so it + # is better to lock a single one for a short time + with lock: + if key not in locks: + locks[key] = threading.Lock() + with locks[key]: + if key not in results: + results[key] = func(*args, **kwargs) + return results[key] + + return wrapped + + +class RemoteGDrive(RemoteBASE): + """Google Drive remote implementation + + Example URLs: + + Datasets/my-dataset inside "My Drive" folder: + + gdrive://root/Datasets/my-dataset + + Folder by ID (recommended): + + gdrive://1r3UbnmS5B4-7YZPZmyqJuCxLVps1mASC + + (get it https://drive.google.com/drive/folders/{here}) + + Dataset named "my-dataset" in the hidden application folder: + + gdrive://appDataFolder/my-dataset + + (this one wouldn't be visible through Google Drive web UI and + couldn't be shared) + """ + + scheme = Schemes.GDRIVE + REGEX = r"^gdrive://.*$" + REQUIRES = {"google-auth-oauthlib": google_auth_oauthlib} + PARAM_CHECKSUM = "md5Checksum" + GOOGLEAPIS_BASE_URL = "https://www.googleapis.com/" + MIME_GOOGLE_APPS_FOLDER = "application/vnd.google-apps.folder" + SPACE_DRIVE = "drive" + SCOPE_DRIVE = "https://www.googleapis.com/auth/drive" + SPACE_APPDATA = "appDataFolder" + SCOPE_APPDATA = "https://www.googleapis.com/auth/drive.appdata" + TIMEOUT = (5, 60) + + # Default credential is needed to show the string of "Data Version + # Control" in OAuth dialog application name and icon in authorized + # applications list in Google account security settings. Also, the + # quota usage is limited by the application defined by client_id. + # The good practice would be to suggest the user to create their + # own application credentials. + DEFAULT_CREDENTIALPATH = os.path.join( + os.path.dirname(__file__), "google-dvc-client-id.json" + ) + + def __init__(self, repo, config): + + super(RemoteGDrive, self).__init__(repo, config) + + self.url = config[Config.SECTION_REMOTE_URL].rstrip("/") + + parsed = urlparse(self.url) + + self.root = parsed.netloc + + if self.root == self.SPACE_APPDATA: + default_scopes = self.SCOPE_APPDATA + self.space = self.SPACE_APPDATA + else: + default_scopes = self.SCOPE_DRIVE + self.space = self.SPACE_DRIVE + + credentialpath = config.get( + Config.SECTION_GDRIVE_CREDENTIALPATH, self.DEFAULT_CREDENTIALPATH + ) + scopes = config.get(Config.SECTION_GDRIVE_SCOPES, default_scopes) + # scopes should be a list and it is space-delimited in all + # configs, and `.split()` also works for a single-element list + scopes = scopes.split() + + self.oauth2 = OAuth2( + credentialpath, + scopes, + self.repo.config.config[Config.SECTION_OAUTH2], + ) + + self.prefix = parsed.path.strip("/") + + self.max_retries = 10 + + @property + def path_info(self): + return PathGDrive(root=self.root) + + @property + def session(self): + if not hasattr(self, "_session"): + self._session = self.oauth2.get_session() + return self._session + + def response_is_ratelimit(self, response): + if response.status_code not in (403, 429): + return False + errors = response.json()["error"]["errors"] + domains = [i["domain"] for i in errors] + return "usageLimits" in domains + + def response_error_message(self, response): + try: + message = response.json()["error"]["message"] + except Exception: + message = response.text + return "HTTP {}: {}".format(response.status_code, message) + + def request(self, method, path, *args, **kwargs): + # Google Drive has tight rate limits, which strikes the + # performance and gives the 403 and 429 errors. + # See https://developers.google.com/drive/api/v3/handle-errors + retries = 0 + exponential_backoff = 1 + if "timeout" not in kwargs: + kwargs["timeout"] = self.TIMEOUT + while retries < self.max_retries: + retries += 1 + response = self.session.request( + method, self.GOOGLEAPIS_BASE_URL + path, *args, **kwargs + ) + if response.status_code == 401: + # try to renew the access token + self.session = self.get_session() + elif ( + self.response_is_ratelimit(response) + or response.status_code >= 500 + ): + logger.debug( + "got {} response, will retry in {} sec".format( + response.status_code, exponential_backoff + ) + ) + sleep(exponential_backoff) + exponential_backoff *= 2 + else: + break + if response.status_code >= 400: + raise GDriveError(self.response_error_message(response)) + return response + + def get_metadata_by_id(self, file_id, **kwargs): + return self.request( + "GET", "drive/v3/files/" + file_id, **kwargs + ).json() + + def search(self, parent=None, name=None, add_params={}): + query = [] + if parent is not None: + query.append("'{}' in parents".format(parent)) + if name is not None: + query.append("name = '{}'".format(name)) + params = {"q": " and ".join(query), "spaces": self.space} + params.update(add_params) + while True: + data = self.request("GET", "drive/v3/files", params=params).json() + for i in data["files"]: + yield i + if not data.get("nextPageToken"): + break + params["pageToken"] = data["nextPageToken"] + + def get_metadata_by_path(self, root, path, fields=[]): + parent = self.get_metadata_by_id(root) + current_path = ["gdrive://" + parent["id"]] + parts = path.split("/") + # only specify fields for the last search query + kwargs = [{}] * (len(parts) - 1) + [ + {"add_params": {"fields": "files({})".format(",".join(fields))}} + if fields + else {} + ] + for part, kwargs in zip(parts, kwargs): + if not self.metadata_isdir(parent): + raise GDriveError( + "{} is not a folder".format("/".join(current_path)) + ) + current_path.append(part) + files = list(self.search(parent["id"], part, **kwargs)) + if len(files) > 1: + raise GDriveError( + "path {} is duplicated".format("/".join(current_path)) + ) + elif len(files) == 0: + raise GDriveResourceNotFound("/".join(current_path)) + parent = files[0] + return parent + + def metadata_isdir(self, metadata): + return metadata["mimeType"] == self.MIME_GOOGLE_APPS_FOLDER + + def get_file_checksum(self, path_info): + metadata = self.get_metadata_by_path( + path_info.root, path_info.path, params={"fields": "md5Checksum"} + ) + return metadata["md5Checksum"] + + def _list_files(self, folder_id): + for i in self.search(parent=folder_id): + if self.metadata_isdir(i): + for j in self._list_files(i["id"]): + yield i["name"] + "/" + j + else: + yield i["name"] + + def list_cache_paths(self): + try: + root = self.get_metadata_by_path(self.root, self.prefix) + except GDriveResourceNotFound as e: + logger.debug("list_cache_paths: {}".format(e)) + else: + for i in self._list_files(root["id"]): + yield self.prefix + "/" + i + + @only_once + def mkdir(self, parent, name): + data = { + "name": name, + "mimeType": self.MIME_GOOGLE_APPS_FOLDER, + "parents": [parent], + "spaces": self.space, + } + return self.request("POST", "drive/v3/files", json=data).json() + + @only_once + def makedirs(self, parent, path): + current_path = [] + for part in path.split("/"): + current_path.append(part) + try: + metadata = self.get_metadata_by_path(parent, part) + if not self.metadata_isdir(metadata): + raise GDriveError( + "{} is not a folder".format("/".join(current_path)) + ) + except GDriveResourceNotFound: + metadata = self.mkdir(parent, part) + parent = metadata["id"] + return parent + + def _resumable_upload_initiate(self, parent, filename): + response = self.request( + "POST", + "upload/drive/v3/files", + params={"uploadType": "resumable"}, + json={"name": filename, "space": self.space, "parents": [parent]}, + ) + return response.headers["Location"] + + def _resumable_upload_first_request( + self, resumable_upload_url, from_file, to_info, file_size + ): + try: + # outside of self.request() because this process + # doesn't need it to handle errors and retries, + # they are handled in the next "while" loop + response = self.session.put( + resumable_upload_url, + data=from_file, + headers={"Content-Length": str(file_size)}, + timeout=self.TIMEOUT, + ) + return response.status_code in (200, 201) + # XXX: which exceptions should be handled here? + except ConnectionError: + logger.info( + "got connection error while uploading '{}/{}', " + "will resume".format(self.url, to_info.path), + exc_info=True, + ) + return False + + def _resumable_upload_resume( + self, resumable_upload_url, from_file, to_info, file_size + ): + try: + # determine the offset + response = self.session.put( + resumable_upload_url, + headers={ + "Content-Length": str(0), + "Content-Range": "bytes */{}".format(file_size), + }, + timeout=self.TIMEOUT, + ) + if response.status_code in (200, 201): + # file has been already uploaded + return True + elif response.status_code == 404: + # restarting upload from the beginning wouldn't make a + # profit, so it is better to notify the user + raise GDriveError("resumable upload URL has been expired") + elif response.status_code != 308: + logger.error( + "upload resume failure: {}".format( + self.response_error_message(response) + ) + ) + return False + # response.status_code is 308 (Resume Incomplete) - continue + if "Range" in response.headers: + # if Range header contains a string "bytes 0-9/20" + # then the server has received the bytes from 0 to 9 + # (including the ends), so upload should be resumed from + # byte 10 + offset = int(response.headers["Range"].split("-")[-1]) + 1 + else: + # there could be no Range header in the server response, + # then upload should be resumed from start + offset = 0 + logger.debug( + "resuming {} upload from offset {}".format( + to_info.path, offset + ) + ) + # resume the upload + from_file.seek(offset, 0) + response = self.session.put( + resumable_upload_url, + data=from_file, + headers={ + "Content-Length": str(file_size - offset), + "Content-Range": "bytes {}-{}/{}".format( + offset, file_size - 1, file_size + ), + }, + timeout=self.TIMEOUT, + ) + return response.status_code in (200, 201) + except ConnectionError: + # don't overload the CPU on consistent network failure + sleep(1.0) + # XXX: should we add some break condition and raise exception? + return False + + def upload_file(self, from_info, to_info, progress_name): + """Implements resumable upload protocol + + https://developers.google.com/drive/api/v3/manage-uploads#resumable + """ + + dirname = posixpath.dirname(to_info.path).strip("/") + if dirname: + parent = self.makedirs(to_info.root, dirname) + else: + parent = to_info.root + + # initiate resumable upload + resumable_upload_url = self._resumable_upload_initiate( + parent, posixpath.basename(to_info.path) + ) + + from_file = open(from_info.path, "rb") + if progress_name is not None: + from_file = track_progress(progress_name, from_file) + + file_size = os.fstat(from_file.fileno()).st_size + + success = self._resumable_upload_first_request( + resumable_upload_url, from_file, to_info, file_size + ) + while not success: + success = self._resumable_upload_resume( + resumable_upload_url, from_file, to_info, file_size + ) + + def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): + + names = self._verify_path_args(to_infos, from_infos, names) + + for from_info, to_info, name in zip(from_infos, to_infos, names): + + if from_info.scheme != Schemes.LOCAL: + raise NotImplementedError + + if to_info.scheme != self.scheme: + raise NotImplementedError + + logger.debug( + "Uploading '{}' to '{}/{}'".format( + from_info.path, self.url, to_info.path + ) + ) + + if not name: + name = os.path.basename(from_info.path) + + if not no_progress_bar: + progress.update_target(name, 0, None) + + try: + self.upload_file( + from_info, + to_info, + progress_name=name if no_progress_bar is False else None, + ) + except Exception: + msg = "failed to upload '{}' to '{}/{}'" + logger.exception( + msg.format(from_info.path, self.url, to_info.path) + ) + continue + + progress.finish_target(name) + + def download_file(self, from_info, to_info, progress_name=None): + metadata = self.get_metadata_by_path( + from_info.root, from_info.path, fields=["id", "mimeType", "size"] + ) + response = self.request( + "GET", + "drive/v3/files/" + metadata["id"], + params={"alt": "media"}, + stream=True, + ) + current = 0 + if response.status_code != 200: + try: + message = response.json()["error"]["message"] + except Exception: + message = response.text + raise GDriveError( + "HTTP {}: {}".format(response.status_code, message) + ) + makedirs(os.path.dirname(to_info.path), exist_ok=True) + tmp_file = tmp_fname(to_info.path) + with open(tmp_file, "wb") as f: + block_size = os.fstat(f.fileno()).st_blksize + for chunk in response.iter_content(block_size): + f.write(chunk) + if progress_name is not None: + current += len(chunk) + progress.update_target( + progress_name, current, metadata["size"] + ) + move(tmp_file, to_info.path) + + def download( + self, + from_infos, + to_infos, + no_progress_bar=False, + names=None, + resume=False, + ): + + names = self._verify_path_args(from_infos, to_infos, names) + + for to_info, from_info, name in zip(to_infos, from_infos, names): + + if from_info.scheme != self.scheme: + raise NotImplementedError + + if to_info.scheme != Schemes.LOCAL: + raise NotImplementedError + + msg = "Downloading '{}/{}' to '{}'".format( + from_info.root, from_info.path, to_info.path + ) + logger.debug(msg) + + if not name: + name = os.path.basename(to_info.path) + + if not no_progress_bar: + progress.update_target(name, 0, None) + + try: + self.download_file(from_info, to_info, progress_name=name) + except Exception: + msg = "failed to download '{}/{}' to '{}'" + logger.exception( + msg.format(from_info.root, from_info.path, to_info.path) + ) + continue + + if not no_progress_bar: + progress.finish_target(name) diff --git a/dvc/remote/gdrive/google-dvc-client-id.json b/dvc/remote/gdrive/google-dvc-client-id.json new file mode 100644 index 0000000000..8c566e466e --- /dev/null +++ b/dvc/remote/gdrive/google-dvc-client-id.json @@ -0,0 +1 @@ +{"installed":{"client_id":"719861249063-v4an78j9grdtuuuqg3lnm0sugna6v3lh.apps.googleusercontent.com","project_id":"data-version-control","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"2fy_HyzSwkxkGzEken7hThXb","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}} diff --git a/dvc/remote/gdrive/oauth2.py b/dvc/remote/gdrive/oauth2.py new file mode 100644 index 0000000000..22c5cfcbd7 --- /dev/null +++ b/dvc/remote/gdrive/oauth2.py @@ -0,0 +1,105 @@ +from hashlib import md5 +import datetime +import json +import os + +from google_auth_oauthlib.flow import InstalledAppFlow +from google.auth.transport.requests import AuthorizedSession +import google.oauth2.credentials + +from dvc.lock import Lock +from dvc.config import Config + + +class OAuth2(object): + + DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S" + + def __init__(self, credentialpath, scopes, config): + self.credentialpath = credentialpath + self.scopes = scopes + self.config = config + + def get_session(self): + creds_storage, creds_storage_lock = self._get_storage_lock() + with creds_storage_lock: + if os.path.exists(creds_storage): + creds = self._load_credentials(creds_storage) + else: + creds = self._acquire_credentials() + self._save_credentials(creds_storage, creds) + return AuthorizedSession(creds) + + def _get_creds_id(self, client_id): + plain_text = "|".join([client_id] + self.scopes) + hashed = md5(plain_text.encode("ascii")).hexdigest() + return hashed + + def _acquire_credentials(self): + # Create the flow using the client secrets file from the + # Google API Console. + flow = InstalledAppFlow.from_client_secrets_file( + self.credentialpath, scopes=self.scopes + ) + flow_runner = self.config.get( + Config.SECTION_OAUTH2_FLOW_RUNNER, "console" + ) + if flow_runner == "local": + creds = flow.run_local_server() + elif flow_runner == "console": + creds = flow.run_console() + else: + raise Exception( + "oauth2_flow_runner should be 'local' or 'console'" + ) + return creds + + def _load_credentials(self, creds_storage): + """Load credentials from json file and refresh them if needed + + Should be called under lock. + """ + info = json.load(open(creds_storage)) + creds = google.oauth2.credentials.Credentials( + token=info["token"], + refresh_token=info["refresh_token"], + token_uri=info["token_uri"], + client_id=info["client_id"], + client_secret=info["client_secret"], + scopes=self.scopes, + ) + creds.expiry = datetime.datetime.strptime( + info["expiry"], self.DATETIME_FORMAT + ) + if creds.expired: + creds.refresh(google.auth.transport.requests.Request()) + self._save_credentials(creds_storage, creds) + return creds + + def _save_credentials(self, creds_storage, creds): + """Save credentials to the json file + + Should be called under lock. + """ + with open(creds_storage, "w") as f: + info = { + "token": creds.token, + "refresh_token": creds.refresh_token, + "token_uri": creds.token_uri, + "client_id": creds.client_id, + "client_secret": creds.client_secret, + "scopes": creds.scopes, + "expiry": creds.expiry.strftime(self.DATETIME_FORMAT), + } + json.dump(info, f) + + def _get_storage_lock(self): + creds_storage_dir = os.path.join( + Config.get_global_config_dir(), "oauth2" + ) + if not os.path.exists(creds_storage_dir): + os.makedirs(creds_storage_dir) + info = json.load(open(self.credentialpath)) + creds_id = self._get_creds_id(info["installed"]["client_id"]) + creds_storage = os.path.join(creds_storage_dir, creds_id) + return creds_storage, Lock(creds_storage_dir, creds_storage + ".lock") diff --git a/dvc/scheme.py b/dvc/scheme.py index e12b768f58..5f7a8d1a28 100644 --- a/dvc/scheme.py +++ b/dvc/scheme.py @@ -9,5 +9,6 @@ class Schemes: HTTP = "http" HTTPS = "https" GS = "gs" + GDRIVE = "gdrive" LOCAL = "local" OSS = "oss" diff --git a/scripts/ci/before_install.sh b/scripts/ci/before_install.sh index db200d8780..fd0fde3869 100644 --- a/scripts/ci/before_install.sh +++ b/scripts/ci/before_install.sh @@ -64,3 +64,13 @@ if [[ -n "$PYTHON_VER" ]]; then echo 'eval "$(pyenv init -)"' >> env.sh echo "pyenv shell $PYTHON_VER" >> env.sh fi + +if [[ -n "$OAUTH2_TOKEN_FILE_KEY" && -n "$OAUTH2_TOKEN_FILE_IV" ]]; then +SRC=scripts/ci/gdrive-oauth2 +DST=~/.config/dvc/oauth2 +[ -d "$DST" ] || mkdir -p $DST +openssl aes-256-cbc -K $OAUTH2_TOKEN_FILE_KEY -iv $OAUTH2_TOKEN_FILE_IV \ + -in $SRC/81562e04895c64d8c65a45710a0a8a6b.enc -out $DST/81562e04895c64d8c65a45710a0a8a6b -d +openssl aes-256-cbc -K $OAUTH2_TOKEN_FILE_KEY -iv $OAUTH2_TOKEN_FILE_IV \ + -in $SRC/160f0daf067a7511bd8b8e3b35e3a8bd.enc -out $DST/160f0daf067a7511bd8b8e3b35e3a8bd -d +fi diff --git a/scripts/ci/gdrive-oauth2/160f0daf067a7511bd8b8e3b35e3a8bd.enc b/scripts/ci/gdrive-oauth2/160f0daf067a7511bd8b8e3b35e3a8bd.enc new file mode 100644 index 0000000000000000000000000000000000000000..87cf3f17245f37025cd67b408969d4483fb8112f GIT binary patch literal 496 zcmVOU!sb z)36R=yVQ?a?_2bM5FtVmLy)C6>|{lhAz`+JSiHi}&PBufkQ%nykNQ`&&yZHoc>_iy9m3~Y z8QD)xKLgRSmv~ny*{5~E5(O&s>u67<HDpwK*KR@5k{x=}iK$~`YcUYa3r!+Z z$>k6~5&4%+Zs~h9v4|Vi7nQC%1_5MpX&(3i=w<%*HupR zISF`}WfZ(u8}Y=bZTYscvFuii5CDj9%-9(Vp;K#pQ=`HVB`~%V5#}*p+^_F8y;d|% z4g9AHTyWJs)Mx7TC!`M{OczF*ZBTn;=7Y~kH4wvn0CGSzKJ%7EAzlw-aT>%nAf=$D zC5n;`!!ir|s^=ok#6DG5A-5{O4~X@>;yp2X6{o73a+u#O3t(2)Nr=4OSqB_uJt-~W m2`YK@k5$LYTew=>I=L+#w<^o-SXQU|{a^-9oN~q{4bTWSQ0wae literal 0 HcmV?d00001 diff --git a/scripts/ci/gdrive-oauth2/81562e04895c64d8c65a45710a0a8a6b.enc b/scripts/ci/gdrive-oauth2/81562e04895c64d8c65a45710a0a8a6b.enc new file mode 100644 index 0000000000000000000000000000000000000000..3a47a817242c55fc62c155bd465182fabf1f2b6a GIT binary patch literal 496 zcmVwBUS14r_Jbh#r1M8Yfl zJR0t-`c|~d>*|XEWVhe#czo-+u{0F&!c}~E!xAuuaTc@!E`+t}w?c{HMxPIJ<-Xg{ zbm6M`BwJ~A>WkIO>Q93LpuB^0nYO<7V$Y&Nw6FQ+ph>gw7U1q6(D7>lssE_L#~s77 zitG=j4{2MU7vS{BB2xEy9Bw62*vGv;qyuGkuAo#fuphY z9sdP>Zb?1RONN!Yh&)C4uGX*s(U1%+2VstyDGAeL?Ah}cO=~vJf`G8HJnJ;x#YceW z!T|j`vkBZ3#NykOHO?LghtR-hzWS_|{RfzAhe13!=jY~$x%_2P&(uzSxT5??sIQjD mV^mcdlLM9gYJQ}cUq~1lwa1kx6sD2ja^fXKgRWeEqP+-3WCL*k literal 0 HcmV?d00001 diff --git a/setup.py b/setup.py index 16011e881b..f455e2c4f4 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ def run(self): # Extra dependencies for remote integrations gs = ["google-cloud-storage==1.13.0"] +gdrive = ["google-auth-oauthlib==0.3.0"] s3 = ["boto3==1.9.115"] azure = ["azure-storage-blob==1.3.0"] oss = ["oss2==2.6.1"] @@ -87,6 +88,7 @@ def run(self): "xmltodict>=0.11.0", "awscli>=1.16.125", "google-compute-engine", + "google-auth-oauthlib", "pywin32; sys_platform == 'win32'", "Pygments", # required by collective.checkdocs, "collective.checkdocs", @@ -112,6 +114,7 @@ def run(self): extras_require={ "all": all_remotes, "gs": gs, + "gdrive": gdrive, "s3": s3, "azure": azure, "oss": oss, diff --git a/tests/func/test_gdrive.py b/tests/func/test_gdrive.py new file mode 100644 index 0000000000..db21a58abb --- /dev/null +++ b/tests/func/test_gdrive.py @@ -0,0 +1,128 @@ +from subprocess import check_call +import shutil +import os +import tempfile + +import pytest + +from dvc.main import main +from dvc.config import Config +from dvc.remote.gdrive import RemoteGDrive + + +oauth2_storage = os.path.join( + Config.get_global_config_dir(), + "oauth2", + "81562e04895c64d8c65a45710a0a8a6b", +) +if not os.path.exists(oauth2_storage): + pytest.skip( + "skipping GDrive tests: could decrypt access token only in Travis", + allow_module_level=True, + ) + + +def test_gdrive_push_pull(repo_dir, dvc_repo): + dirname = tempfile.mktemp("", "dvc_test_", "") + url = "gdrive://root/" + dirname + files = [repo_dir.FOO, repo_dir.DATA_SUB.split(os.path.sep)[0]] + + gdrive = RemoteGDrive(dvc_repo, {"url": url}) + + # push files + check_call(["dvc", "add"] + files) + check_call(["dvc", "remote", "add", "gdrive", url]) + assert main(["push", "-j1", "-r", "gdrive"]) == 0 + + # check that files are correctly uploaded + testdir = gdrive.get_metadata_by_path(gdrive.root, gdrive.prefix) + subdirs = ["ac", "2b", "1a", "bf"] + q = "'{}' in parents".format(testdir["id"]) + found = list(gdrive.search(add_params={"q": q})) + assert set(i["name"] for i in found) == set(subdirs) + q = " or ".join("'{}' in parents".format(i["id"]) for i in found) + found = list(gdrive.search(add_params={"q": q})) + assert set(i["name"] for i in found) == { + "bd18db4cc2f85cedef654fccc4a4d8", + "7235bae9a59ef5602ad01d5719aabc", + "f5a787d16460f32e9f2e62b183b1cc.dir", + "d1c255771ec00b7cee20a136250065", + } + + # remove cache and files + shutil.rmtree(".dvc/cache") + for i in files: + if os.path.isdir(i): + shutil.rmtree(i) + else: + os.remove(i) + + # check that they are in list_cache_paths + assert set(gdrive.list_cache_paths()) == { + dirname + "/1a/d1c255771ec00b7cee20a136250065", + dirname + "/2b/7235bae9a59ef5602ad01d5719aabc", + dirname + "/ac/bd18db4cc2f85cedef654fccc4a4d8", + dirname + "/bf/f5a787d16460f32e9f2e62b183b1cc.dir", + } + + # pull them back from remote + assert main(["pull", "-j1", "-r", "gdrive"]) == 0 + + assert set(files) < set(os.listdir(".")) + + # remove the temporary directory on Google Drive + resp = gdrive.request("DELETE", "drive/v3/files/" + testdir["id"]) + print("Delete temp dir: HTTP {}".format(resp.status_code)) + + +def test_gdrive_push_pull_appfolder(repo_dir, dvc_repo): + dirname = tempfile.mktemp("", "dvc_test_", "") + url = "gdrive://appDataFolder/" + dirname + files = [repo_dir.FOO, repo_dir.DATA_SUB.split(os.path.sep)[0]] + + gdrive = RemoteGDrive(dvc_repo, {"url": url}) + + # push files + check_call(["dvc", "add"] + files) + check_call(["dvc", "remote", "add", "gdrive", url]) + assert main(["push", "-j1", "-r", "gdrive"]) == 0 + + # check that files are correctly uploaded + testdir = gdrive.get_metadata_by_path(gdrive.root, gdrive.prefix) + subdirs = ["ac", "2b", "1a", "bf"] + q = "'{}' in parents".format(testdir["id"]) + found = list(gdrive.search(add_params={"q": q})) + assert set(i["name"] for i in found) == set(subdirs) + q = " or ".join("'{}' in parents".format(i["id"]) for i in found) + found = list(gdrive.search(add_params={"q": q})) + assert set(i["name"] for i in found) == { + "bd18db4cc2f85cedef654fccc4a4d8", + "7235bae9a59ef5602ad01d5719aabc", + "f5a787d16460f32e9f2e62b183b1cc.dir", + "d1c255771ec00b7cee20a136250065", + } + + # remove cache and files + shutil.rmtree(".dvc/cache") + for i in files: + if os.path.isdir(i): + shutil.rmtree(i) + else: + os.remove(i) + + # check that they are in list_cache_paths + assert set(gdrive.list_cache_paths()) == { + dirname + "/1a/d1c255771ec00b7cee20a136250065", + dirname + "/2b/7235bae9a59ef5602ad01d5719aabc", + dirname + "/ac/bd18db4cc2f85cedef654fccc4a4d8", + dirname + "/bf/f5a787d16460f32e9f2e62b183b1cc.dir", + } + + # pull them back from remote + assert main(["pull", "-j1", "-r", "gdrive"]) == 0 + + assert set(files) < set(os.listdir(".")) + + # remove the temporary directory on Google Drive + resp = gdrive.request("DELETE", "drive/v3/files/" + testdir["id"]) + print("Delete temp dir: HTTP {}".format(resp.status_code)) diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py new file mode 100644 index 0000000000..5d8c37c4b0 --- /dev/null +++ b/tests/unit/remote/test_gdrive.py @@ -0,0 +1,327 @@ +import mock +from datetime import datetime, timedelta + +import pytest + +import google.oauth2.credentials +from google_auth_oauthlib.flow import InstalledAppFlow + +from dvc.remote.gdrive import RemoteGDrive, GDriveError, GDriveResourceNotFound +from dvc.remote.gdrive.oauth2 import OAuth2 +from dvc.path.gdrive import PathGDrive +from dvc.repo import Repo + + +GDRIVE_URL = "gdrive://root/data" +GDRIVE_APPFOLDER_URL = "gdrive://appDataFolder/data" +AUTHORIZATION = {"authorization": "Bearer MOCK_token"} +FOLDER = {"mimeType": RemoteGDrive.MIME_GOOGLE_APPS_FOLDER} +FILE = {"mimeType": "not-a-folder"} + +COMMON_KWARGS = { + "data": None, + "headers": AUTHORIZATION, + "timeout": RemoteGDrive.TIMEOUT, +} + + +class Response: + def __init__(self, data, status_code=200): + self._data = data + self.status_code = status_code + + def json(self): + return self._data + + +@pytest.fixture() +def repo(): + return Repo(".") + + +@pytest.fixture +def gdrive(repo): + gdrive = RemoteGDrive(repo, {"url": GDRIVE_URL}) + return gdrive + + +@pytest.fixture +def gdrive_appfolder(repo): + gdrive = RemoteGDrive(repo, {"url": GDRIVE_APPFOLDER_URL}) + return gdrive + + +@pytest.fixture(autouse=True) +def no_requests(monkeypatch): + req_mock = mock.Mock(return_value=Response("")) + monkeypatch.setattr("requests.sessions.Session.request", req_mock) + return req_mock + + +def _url(url): + return RemoteGDrive.GOOGLEAPIS_BASE_URL + url + + +@pytest.fixture(autouse=True) +def fake_creds(monkeypatch): + + creds = google.oauth2.credentials.Credentials( + token="MOCK_token", + refresh_token="MOCK_refresh_token", + token_uri="MOCK_token_uri", + client_id="MOCK_client_id", + client_secret="MOCK_client_secret", + scopes=["MOCK_scopes"], + ) + creds.expiry = datetime.now() + timedelta(days=1) + + mocked_flow = mock.Mock() + mocked_flow.run_console.return_value = creds + mocked_flow.run_local_server.return_value = creds + + monkeypatch.setattr( + InstalledAppFlow, + "from_client_secrets_file", + classmethod(lambda *args, **kwargs: mocked_flow), + ) + + monkeypatch.setattr( + OAuth2, "_get_creds_id", mock.Mock(return_value="test") + ) + + +@pytest.fixture(autouse=True) +def no_refresh(monkeypatch): + monkeypatch.setattr( + "google.oauth2.credentials.Credentials.refresh", mock.Mock() + ) + + +@pytest.fixture() +def makedirs(monkeypatch): + mocked = mock.Mock(return_value="FOLDER_ID") + monkeypatch.setattr(RemoteGDrive, "makedirs", mocked) + return mocked + + +def test_init_drive(gdrive): + assert gdrive.root == "root" + assert gdrive.url == GDRIVE_URL + assert gdrive.oauth2.scopes == ["https://www.googleapis.com/auth/drive"] + assert gdrive.space == RemoteGDrive.SPACE_DRIVE + + +def test_init_appfolder(gdrive_appfolder): + assert gdrive_appfolder.root == RemoteGDrive.SPACE_APPDATA + assert gdrive_appfolder.url == GDRIVE_APPFOLDER_URL + assert gdrive_appfolder.oauth2.scopes == [ + "https://www.googleapis.com/auth/drive.appdata" + ] + assert gdrive_appfolder.space == RemoteGDrive.SPACE_APPDATA + + +def test_init_folder_id(repo): + url = "gdrive://FOLDER_ID/data" + remote = RemoteGDrive(repo, {"url": url}) + assert remote.root == "FOLDER_ID" + assert remote.url == url + assert remote.oauth2.scopes == ["https://www.googleapis.com/auth/drive"] + assert remote.space == "drive" + + +def test_path_info(repo): + remote = RemoteGDrive(repo, {"url": "gdrive://root"}) + assert remote.path_info.root == "root" + + +def test_get_session(gdrive, no_requests): + session = gdrive.oauth2.get_session() + # XXX(ei-grad): it actually shouldn't work with domains different + # to googleapis.com, but the world is not perfect... + # maybe the AuthorizedSession should be replaced with the + # requests-idiomatic Adapter + session.get("http://httpbin.org/get") + args, kwargs = no_requests.call_args + assert kwargs["headers"]["authorization"] == AUTHORIZATION["authorization"] + + +def test_request(gdrive, no_requests): + gdrive.request("GET", "test") + no_requests.assert_called_once_with("GET", _url("test"), **COMMON_KWARGS) + + +def test_get_metadata_by_id(gdrive, no_requests): + gdrive.get_metadata_by_id("test") + no_requests.assert_called_once_with( + "GET", _url("drive/v3/files/test"), **COMMON_KWARGS + ) + + +def test_search(gdrive, no_requests): + no_requests.side_effect = [ + Response({"files": ["test1"], "nextPageToken": "TEST_nextPageToken"}), + Response({"files": ["test2"]}), + ] + assert list(gdrive.search("test", "root")) == ["test1", "test2"] + + +def test_metadata_by_path(gdrive, no_requests, monkeypatch): + no_requests.side_effect = [ + Response(dict(id="root", name="root", **FOLDER)), + Response({"files": [dict(id="id1", name="path1", **FOLDER)]}), + Response({"files": [dict(id="id2", name="path2", **FOLDER)]}), + ] + gdrive.get_metadata_by_path("root", "path1/path2", ["field1", "field2"]) + assert no_requests.mock_calls == [ + mock.call("GET", _url("drive/v3/files/root"), **COMMON_KWARGS), + mock.call( + "GET", + _url("drive/v3/files"), + params={ + "q": "'root' in parents and name = 'path1'", + "spaces": "drive", + }, + **COMMON_KWARGS + ), + mock.call( + "GET", + _url("drive/v3/files"), + params={ + "q": "'id1' in parents and name = 'path2'", + "spaces": "drive", + "fields": "files(field1,field2)", + }, + **COMMON_KWARGS + ), + ] + + +def test_metadata_by_path_not_a_folder(gdrive, monkeypatch): + monkeypatch.setattr( + gdrive, + "get_metadata_by_id", + mock.Mock(return_value=dict(id="id1", name="root", **FOLDER)), + ) + monkeypatch.setattr( + gdrive, + "search", + mock.Mock(return_value=[dict(id="id2", name="path1", **FILE)]), + ) + with pytest.raises(GDriveError): + gdrive.get_metadata_by_path( + "root", "path1/path2", ["field1", "field2"] + ) + gdrive.get_metadata_by_path("root", "path1", ["field1", "field2"]) + + +def test_metadata_by_path_duplicate(gdrive, monkeypatch): + monkeypatch.setattr( + gdrive, + "get_metadata_by_id", + mock.Mock(return_value=dict(id="id1", name="root", **FOLDER)), + ) + monkeypatch.setattr( + gdrive, + "search", + mock.Mock( + return_value=[ + dict(id="id2", name="path1", **FOLDER), + dict(id="id3", name="path1", **FOLDER), + ] + ), + ) + with pytest.raises(GDriveError): + gdrive.get_metadata_by_path( + "root", "path1/path2", ["field1", "field2"] + ) + + +def test_metadata_by_path_not_found(gdrive, monkeypatch): + monkeypatch.setattr( + gdrive, + "get_metadata_by_id", + mock.Mock(return_value=dict(id="root", name="root", **FOLDER)), + ) + monkeypatch.setattr(gdrive, "search", mock.Mock(return_value=[])) + with pytest.raises(GDriveResourceNotFound): + gdrive.get_metadata_by_path( + "root", "path1/path2", ["field1", "field2"] + ) + + +def test_get_file_checksum(gdrive, monkeypatch): + monkeypatch.setattr( + gdrive, + "get_metadata_by_path", + mock.Mock( + return_value=dict(id="id1", name="path1", md5Checksum="checksum") + ), + ) + checksum = gdrive.get_file_checksum(PathGDrive(gdrive.root, path="path1")) + assert checksum == "checksum" + gdrive.get_metadata_by_path.assert_called_once_with( + gdrive.root, "path1", params={"fields": "md5Checksum"} + ) + + +def test_list_cache_paths(gdrive, monkeypatch): + monkeypatch.setattr( + gdrive, + "get_metadata_by_path", + mock.Mock(return_value=dict(id="root", name="root", **FOLDER)), + ) + files_lists = [ + [dict(id="f1", name="f1", **FOLDER), dict(id="f2", name="f2", **FILE)], + [dict(id="f3", name="f3", **FILE)], + ] + monkeypatch.setattr(gdrive, "search", mock.Mock(side_effect=files_lists)) + assert list(gdrive.list_cache_paths()) == ["data/f1/f3", "data/f2"] + gdrive.get_metadata_by_path.assert_called_once_with("root", "data") + + +def test_mkdir(gdrive, no_requests): + no_requests.return_value = Response("test") + assert gdrive.mkdir("root", "test") == "test" + no_requests.assert_called_once_with( + "POST", + _url("drive/v3/files"), + json={ + "name": "test", + "mimeType": FOLDER["mimeType"], + "parents": ["root"], + "spaces": "drive", + }, + **COMMON_KWARGS + ) + + +def test_makedirs(gdrive, monkeypatch): + monkeypatch.setattr( + gdrive, + "get_metadata_by_path", + mock.Mock( + side_effect=[ + dict(id="id1", name="test1", **FOLDER), + GDriveResourceNotFound("test1/test2"), + ] + ), + ) + monkeypatch.setattr( + gdrive, "mkdir", mock.Mock(side_effect=[{"id": "id2"}]) + ) + assert gdrive.makedirs(gdrive.root, "test1/test2") == "id2" + assert gdrive.get_metadata_by_path.mock_calls == [ + mock.call(gdrive.root, "test1"), + mock.call("id1", "test2"), + ] + assert gdrive.mkdir.mock_calls == [mock.call("id1", "test2")] + + +def test_makedirs_error(gdrive, monkeypatch): + monkeypatch.setattr( + gdrive, + "get_metadata_by_path", + mock.Mock(side_effect=[dict(id="id1", name="test1", **FILE)]), + ) + with pytest.raises(GDriveError): + gdrive.makedirs(gdrive.root, "test1/test2")