diff --git a/xarray/tutorial.py b/xarray/tutorial.py index b4e606c0ee3..81a78d3b8ad 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -9,6 +9,8 @@ from __future__ import division from __future__ import print_function +import hashlib + import os as _os import shutil as _shutil @@ -19,9 +21,17 @@ _default_cache_dir = _os.sep.join(('~', '.xarray_tutorial_data')) +def _md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + hash_md5.update(f.read()) + return hash_md5.hexdigest() + + # idea borrowed from Seaborn def load_dataset(name, cache=True, cache_dir=_default_cache_dir, - github_url='https://github.com/pydata/xarray-data', **kws): + github_url='https://github.com/pydata/xarray-data', + branch='master', **kws): """ Load a dataset from the online repository (requires internet). @@ -38,15 +48,17 @@ def load_dataset(name, cache=True, cache_dir=_default_cache_dir, If True, then cache data locally for use on subsequent calls github_url : string Github repository where the data is stored + branch : string + The git branch to download from kws : dict, optional Passed to xarray.open_dataset """ longdir = _os.path.expanduser(cache_dir) - tmpdir = _os.sep.join((longdir, '.tmp')) fullname = name + '.nc' localfile = _os.sep.join((longdir, fullname)) - tmpfile = _os.sep.join((longdir, fullname)) + md5name = name + '.md5' + md5file = _os.sep.join((longdir, md5name)) if not _os.path.exists(localfile): @@ -54,16 +66,21 @@ def load_dataset(name, cache=True, cache_dir=_default_cache_dir, # May want to add an option to remove it. if not _os.path.isdir(longdir): _os.mkdir(longdir) - if not _os.path.isdir(tmpdir): - _os.mkdir(tmpdir) - - url = '/'.join((github_url, 'raw', 'master', fullname)) - _urlretrieve(url, tmpfile) - - if not _os.path.exists(tmpfile): - raise ValueError('File could not be downloaded, please try again') - _shutil.move(tmpfile, localfile) + url = '/'.join((github_url, 'raw', branch, fullname)) + _urlretrieve(url, localfile) + url = '/'.join((github_url, 'raw', branch, md5name)) + _urlretrieve(url, md5file) + + localmd5 = _md5(localfile) + with open(md5file, 'r') as f: + remotemd5 = f.read() + if localmd5 != remotemd5: + _os.remove(localfile) + msg = """ + MD5 checksum does not match, try downloading dataset again. + """ + raise IOError(msg) ds = _open_dataset(localfile, **kws).load()