diff --git a/oggm/utils.py b/oggm/utils.py index 72fcaf4ff..351f989f4 100644 --- a/oggm/utils.py +++ b/oggm/utils.py @@ -128,15 +128,16 @@ def _cached_download_helper(cache_obj_name, dl_func): return cache_path -def _urlretrieve(url, *args, **kwargs): +def _urlretrieve(url, cache_obj_name=None, *args, **kwargs): """Wrapper around urlretrieve, to implement our caching logic. Instead of accepting a destination path, it decided where to store the file and returns the local path. """ - cache_obj_name = urlparse(url) - cache_obj_name = cache_obj_name.netloc + cache_obj_name.path + if cache_obj_name is None: + cache_obj_name = urlparse(url) + cache_obj_name = cache_obj_name.netloc + cache_obj_name.path def _dlf(cache_path): logger.info("Downloading %s to %s..." % (url, cache_path)) @@ -146,7 +147,7 @@ def _dlf(cache_path): return _cached_download_helper(cache_obj_name, _dlf) -def _progress_urlretrieve(url): +def _progress_urlretrieve(url, cache_name=None): """Downloads a file, returns its local path, and shows a progressbar.""" try: @@ -162,14 +163,14 @@ def _upd(count, size, total): pbar[0].start(UnknownLength) pbar[0].update(min(count * size, total)) sys.stdout.flush() - res = _urlretrieve(url, reporthook=_upd) + res = _urlretrieve(url, cache_obj_name=cache_name, reporthook=_upd) try: pbar[0].finish() except: pass return res except ImportError: - return _urlretrieve(url) + return _urlretrieve(url, cache_obj_name=cache_name) def aws_file_download(aws_path): @@ -177,7 +178,7 @@ def aws_file_download(aws_path): return _aws_file_download_unlocked(aws_path) -def _aws_file_download_unlocked(aws_path): +def _aws_file_download_unlocked(aws_path, cache_name=None): """Download a file from the AWS drive s3://astgtmv2/ **Note:** you need AWS credentials for this to work. @@ -190,7 +191,10 @@ def _aws_file_download_unlocked(aws_path): while aws_path.startswith('/'): aws_path = aws_path[1:] - cache_obj_name = 'astgtmv2/' + aws_path + if cache_name is not None: + cache_obj_name = cache_name + else: + cache_obj_name = 'astgtmv2/' + aws_path def _dlf(cache_path): import boto3 @@ -209,7 +213,7 @@ def _dlf(cache_path): return _cached_download_helper(cache_obj_name, _dlf) -def file_downloader(www_path, retry_max=5): +def file_downloader(www_path, retry_max=5, cache_name=None): """A slightly better downloader: it tries more than once.""" local_path = None @@ -218,7 +222,7 @@ def file_downloader(www_path, retry_max=5): # Try to download try: retry_counter += 1 - local_path = _progress_urlretrieve(www_path) + local_path = _progress_urlretrieve(www_path, cache_name=cache_name) # if no error, exit break except HTTPError as err: