Skip to content

Commit

Permalink
Add optional cache name override
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoRoth committed May 3, 2017
1 parent 58bb901 commit a2bd81c
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions oggm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,16 @@ def _cached_download_helper(cache_obj_name, dl_func):
return cache_path


def _urlretrieve(url, *args, **kwargs):
def _urlretrieve(url, cache_obj_name=None, *args, **kwargs):
"""Wrapper around urlretrieve, to implement our caching logic.
Instead of accepting a destination path, it decided where to store the file
and returns the local path.
"""

cache_obj_name = urlparse(url)
cache_obj_name = cache_obj_name.netloc + cache_obj_name.path
if cache_obj_name is None:
cache_obj_name = urlparse(url)
cache_obj_name = cache_obj_name.netloc + cache_obj_name.path

def _dlf(cache_path):
logger.info("Downloading %s to %s..." % (url, cache_path))
Expand All @@ -146,7 +147,7 @@ def _dlf(cache_path):
return _cached_download_helper(cache_obj_name, _dlf)


def _progress_urlretrieve(url):
def _progress_urlretrieve(url, cache_name=None):
"""Downloads a file, returns its local path, and shows a progressbar."""

try:
Expand All @@ -162,22 +163,22 @@ def _upd(count, size, total):
pbar[0].start(UnknownLength)
pbar[0].update(min(count * size, total))
sys.stdout.flush()
res = _urlretrieve(url, reporthook=_upd)
res = _urlretrieve(url, cache_obj_name=cache_name, reporthook=_upd)
try:
pbar[0].finish()
except:
pass
return res
except ImportError:
return _urlretrieve(url)
return _urlretrieve(url, cache_obj_name=cache_name)


def aws_file_download(aws_path):
with _get_download_lock():
return _aws_file_download_unlocked(aws_path)


def _aws_file_download_unlocked(aws_path):
def _aws_file_download_unlocked(aws_path, cache_name=None):
"""Download a file from the AWS drive s3://astgtmv2/
**Note:** you need AWS credentials for this to work.
Expand All @@ -190,7 +191,10 @@ def _aws_file_download_unlocked(aws_path):
while aws_path.startswith('/'):
aws_path = aws_path[1:]

cache_obj_name = 'astgtmv2/' + aws_path
if cache_name is not None:
cache_obj_name = cache_name
else:
cache_obj_name = 'astgtmv2/' + aws_path

def _dlf(cache_path):
import boto3
Expand All @@ -209,7 +213,7 @@ def _dlf(cache_path):
return _cached_download_helper(cache_obj_name, _dlf)


def file_downloader(www_path, retry_max=5):
def file_downloader(www_path, retry_max=5, cache_name=None):
"""A slightly better downloader: it tries more than once."""

local_path = None
Expand All @@ -218,7 +222,7 @@ def file_downloader(www_path, retry_max=5):
# Try to download
try:
retry_counter += 1
local_path = _progress_urlretrieve(www_path)
local_path = _progress_urlretrieve(www_path, cache_name=cache_name)
# if no error, exit
break
except HTTPError as err:
Expand Down

0 comments on commit a2bd81c

Please sign in to comment.