Skip to content
forked from pydata/xarray

Commit

Permalink
pyupgrade one-off run (pydata#3190)
Browse files Browse the repository at this point in the history
* pyupgrade (manually vetted and tweaked)

* pyupgrade

* Tweaks to Dataset.drop_dims()

* mypy

* More concise code
  • Loading branch information
crusaderky authored and max-sixty committed Aug 7, 2019
1 parent 04597a8 commit 8a9c471
Show file tree
Hide file tree
Showing 33 changed files with 202 additions and 156 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-seaborn.*]
ignore_missing_imports = True
[mypy-sparse.*]
ignore_missing_imports = True
[mypy-toolz.*]
ignore_missing_imports = True
[mypy-zarr.*]
Expand Down
23 changes: 11 additions & 12 deletions versioneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
stderr=(subprocess.PIPE if hide_stderr
else None))
break
except EnvironmentError:
except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
Expand All @@ -421,7 +421,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
return stdout, p.returncode


LONG_VERSION_PY['git'] = '''
LONG_VERSION_PY['git'] = r'''
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
Expand Down Expand Up @@ -968,7 +968,7 @@ def git_get_keywords(versionfile_abs):
if mo:
keywords["date"] = mo.group(1)
f.close()
except EnvironmentError:
except OSError:
pass
return keywords

Expand All @@ -992,11 +992,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")])
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
Expand All @@ -1005,7 +1005,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)])
tags = {r for r in refs if re.search(r'\d', r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
Expand Down Expand Up @@ -1148,7 +1148,7 @@ def do_vcs_install(manifest_in, versionfile_source, ipy):
if "export-subst" in line.strip().split()[1:]:
present = True
f.close()
except EnvironmentError:
except OSError:
pass
if not present:
f = open(".gitattributes", "a+")
Expand Down Expand Up @@ -1206,7 +1206,7 @@ def versions_from_file(filename):
try:
with open(filename) as f:
contents = f.read()
except EnvironmentError:
except OSError:
raise NotThisMethod("unable to read _version.py")
mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON",
contents, re.M | re.S)
Expand Down Expand Up @@ -1702,8 +1702,7 @@ def do_setup():
root = get_root()
try:
cfg = get_config_from_root(root)
except (EnvironmentError, configparser.NoSectionError,
configparser.NoOptionError) as e:
except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e:
if isinstance(e, (EnvironmentError, configparser.NoSectionError)):
print("Adding sample versioneer config to setup.cfg",
file=sys.stderr)
Expand All @@ -1728,7 +1727,7 @@ def do_setup():
try:
with open(ipy, "r") as f:
old = f.read()
except EnvironmentError:
except OSError:
old = ""
if INIT_PY_SNIPPET not in old:
print(" appending to %s" % ipy)
Expand All @@ -1752,7 +1751,7 @@ def do_setup():
if line.startswith("include "):
for include in line.split()[1:]:
simple_includes.add(include)
except EnvironmentError:
except OSError:
pass
# That doesn't cover everything MANIFEST.in can do
# (http://docs.python.org/2/distutils/sourcedist.html#commands), so
Expand Down
10 changes: 5 additions & 5 deletions xarray/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
stderr=(subprocess.PIPE if hide_stderr
else None))
break
except EnvironmentError:
except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
Expand Down Expand Up @@ -153,7 +153,7 @@ def git_get_keywords(versionfile_abs):
if mo:
keywords["date"] = mo.group(1)
f.close()
except EnvironmentError:
except OSError:
pass
return keywords

Expand All @@ -177,11 +177,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")])
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
Expand All @@ -190,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)])
tags = {r for r in refs if re.search(r'\d', r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def open_mfdataset(paths, chunks=None, concat_dim='_not_supplied',
paths = [str(p) if isinstance(p, Path) else p for p in paths]

if not paths:
raise IOError('no files to open')
raise OSError('no files to open')

# If combine='by_coords' then this is unnecessary, but quick.
# If combine='nested' then this creates a flat list which is easier to
Expand Down Expand Up @@ -1051,7 +1051,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None,
if groups is None:
groups = [None] * len(datasets)

if len(set([len(datasets), len(paths), len(groups)])) > 1:
if len({len(datasets), len(paths), len(groups)}) > 1:
raise ValueError('must supply lists of the same length for the '
'datasets, paths and groups arguments to '
'save_mfdataset')
Expand Down
12 changes: 7 additions & 5 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _netcdf4_create_group(dataset, name):


def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group):
if group in set([None, '', '/']):
if group in {None, '', '/'}:
# use the root group
return ds
else:
Expand All @@ -155,7 +155,7 @@ def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group):
ds = create_group(ds, key)
else:
# wrap error to provide slightly more helpful message
raise IOError('group not found: %s' % key, e)
raise OSError('group not found: %s' % key, e)
return ds


Expand Down Expand Up @@ -195,9 +195,11 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False,

encoding = variable.encoding.copy()

safe_to_drop = set(['source', 'original_shape'])
valid_encodings = set(['zlib', 'complevel', 'fletcher32', 'contiguous',
'chunksizes', 'shuffle', '_FillValue', 'dtype'])
safe_to_drop = {'source', 'original_shape'}
valid_encodings = {
'zlib', 'complevel', 'fletcher32', 'contiguous',
'chunksizes', 'shuffle', '_FillValue', 'dtype'
}
if lsd_okay:
valid_encodings.add('least_significant_digit')
if h5py_okay:
Expand Down
9 changes: 5 additions & 4 deletions xarray/backends/netcdf3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

# The following are reserved names in CDL and may not be used as names of
# variables, dimension, attributes
_reserved_names = set(['byte', 'char', 'short', 'ushort', 'int', 'uint',
'int64', 'uint64', 'float' 'real', 'double', 'bool',
'string'])
_reserved_names = {
'byte', 'char', 'short', 'ushort', 'int', 'uint', 'int64', 'uint64',
'float' 'real', 'double', 'bool', 'string'
}

# These data-types aren't supported by netCDF3, so they are automatically
# coerced instead as indicated by the "coerce_nc3_dtype" function
Expand Down Expand Up @@ -108,4 +109,4 @@ def is_valid_nc3_name(s):
('/' not in s) and
(s[-1] != ' ') and
(_isalnumMUTF8(s[0]) or (s[0] == '_')) and
all((_isalnumMUTF8(c) or c in _specialchars for c in s)))
all(_isalnumMUTF8(c) or c in _specialchars for c in s))
14 changes: 7 additions & 7 deletions xarray/backends/pseudonetcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,18 @@ def get_variables(self):
for k, v in self.ds.variables.items())

def get_attrs(self):
return Frozen(dict([(k, getattr(self.ds, k))
for k in self.ds.ncattrs()]))
return Frozen({k: getattr(self.ds, k) for k in self.ds.ncattrs()})

def get_dimensions(self):
return Frozen(self.ds.dimensions)

def get_encoding(self):
encoding = {}
encoding['unlimited_dims'] = set(
[k for k in self.ds.dimensions
if self.ds.dimensions[k].isunlimited()])
return encoding
return {
'unlimited_dims': {
k for k in self.ds.dimensions
if self.ds.dimensions[k].isunlimited()
}
}

def close(self):
self._manager.close()
10 changes: 6 additions & 4 deletions xarray/backends/pynio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ def get_dimensions(self):
return Frozen(self.ds.dimensions)

def get_encoding(self):
encoding = {}
encoding['unlimited_dims'] = set(
[k for k in self.ds.dimensions if self.ds.unlimited(k)])
return encoding
return {
'unlimited_dims': {
k for k in self.ds.dimensions
if self.ds.unlimited(k)
}
}

def close(self):
self._manager.close()
9 changes: 5 additions & 4 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key):
def _extract_zarr_variable_encoding(variable, raise_on_invalid=False):
encoding = variable.encoding.copy()

valid_encodings = set(['chunks', 'compressor', 'filters',
'cache_metadata'])
valid_encodings = {'chunks', 'compressor', 'filters', 'cache_metadata'}

if raise_on_invalid:
invalid = [k for k in encoding if k not in valid_encodings]
Expand Down Expand Up @@ -340,8 +339,10 @@ def store(self, variables, attributes, check_encoding_set=frozenset(),
only needed in append mode
"""

existing_variables = set([vn for vn in variables
if _encode_variable_name(vn) in self.ds])
existing_variables = {
vn for vn in variables
if _encode_variable_name(vn) in self.ds
}
new_variables = set(variables) - existing_variables
variables_without_encoding = OrderedDict([(vn, variables[vn])
for vn in new_variables])
Expand Down
2 changes: 1 addition & 1 deletion xarray/coding/cftime_offsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def __apply__(self, other):


_FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys())
_PATTERN = r'^((?P<multiple>\d+)|())(?P<freq>({0}))$'.format(
_PATTERN = r'^((?P<multiple>\d+)|())(?P<freq>({}))$'.format(
_FREQUENCY_CONDITION)


Expand Down
2 changes: 1 addition & 1 deletion xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


# standard calendars recognized by cftime
_STANDARD_CALENDARS = set(['standard', 'gregorian', 'proleptic_gregorian'])
_STANDARD_CALENDARS = {'standard', 'gregorian', 'proleptic_gregorian'}

_NS_PER_TIME_DELTA = {'us': int(1e3),
'ms': int(1e6),
Expand Down
2 changes: 1 addition & 1 deletion xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def ensure_dtype_not_object(var, name=None):
if strings.is_bytes_dtype(inferred_dtype):
fill_value = b''
elif strings.is_unicode_dtype(inferred_dtype):
fill_value = u''
fill_value = ''
else:
# insist on using float for numeric values
if not np.issubdtype(inferred_dtype, np.floating):
Expand Down
6 changes: 3 additions & 3 deletions xarray/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def encode(var):
def _filter_attrs(attrs, ignored_attrs):
""" Return attrs that are not in ignored_attrs
"""
return dict((k, v) for k, v in attrs.items() if k not in ignored_attrs)
return {k: v for k, v in attrs.items() if k not in ignored_attrs}


def from_cdms2(variable):
Expand Down Expand Up @@ -119,7 +119,7 @@ def set_cdms2_attrs(var, attrs):
def _pick_attrs(attrs, keys):
""" Return attrs with keys in keys list
"""
return dict((k, v) for k, v in attrs.items() if k in keys)
return {k: v for k, v in attrs.items() if k in keys}


def _get_iris_args(attrs):
Expand Down Expand Up @@ -188,7 +188,7 @@ def _iris_obj_to_attrs(obj):
if obj.units.origin != '1' and not obj.units.is_unknown():
attrs['units'] = obj.units.origin
attrs.update(obj.attributes)
return dict((k, v) for k, v in attrs.items() if v is not None)
return {k: v for k, v in attrs.items() if v is not None}


def _iris_cell_methods_to_str(cell_methods_obj):
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,10 @@ def reindex_variables(
for dim, indexer in indexers.items():
if isinstance(indexer, DataArray) and indexer.dims != (dim,):
warnings.warn(
"Indexer has dimensions {0:s} that are different "
"from that to be indexed along {1:s}. "
"This will behave differently in the future.".format(
str(indexer.dims), dim),
"Indexer has dimensions {:s} that are different "
"from that to be indexed along {:s}. "
"This will behave differently in the future."
.format(str(indexer.dims), dim),
FutureWarning, stacklevel=3)

target = new_indexes[dim] = utils.safe_cast_to_index(indexers[dim])
Expand Down
13 changes: 7 additions & 6 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def _infer_tile_ids_from_nested_list(entry, current_pos):

if isinstance(entry, list):
for i, item in enumerate(entry):
for result in _infer_tile_ids_from_nested_list(item,
current_pos + (i,)):
yield result
yield from _infer_tile_ids_from_nested_list(
item, current_pos + (i,))
else:
yield current_pos, entry

Expand Down Expand Up @@ -735,10 +734,12 @@ def _auto_concat(datasets, dim=None, data_vars='all', coords='different',
concat_dims = set(ds0.dims)
if ds0.dims != ds1.dims:
dim_tuples = set(ds0.dims.items()) - set(ds1.dims.items())
concat_dims = set(i for i, _ in dim_tuples)
concat_dims = {i for i, _ in dim_tuples}
if len(concat_dims) > 1:
concat_dims = set(d for d in concat_dims
if not ds0[d].equals(ds1[d]))
concat_dims = {
d for d in concat_dims
if not ds0[d].equals(ds1[d])
}
if len(concat_dims) > 1:
raise ValueError('too many different dimensions to '
'concatenate: %s' % concat_dims)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ def expand_dims(self, dim: Union[None, Hashable, Sequence[Hashable],
elif isinstance(dim, Sequence) and not isinstance(dim, str):
if len(dim) != len(set(dim)):
raise ValueError('dims should not contain duplicate values.')
dim = OrderedDict(((d, 1) for d in dim))
dim = OrderedDict((d, 1) for d in dim)
elif dim is not None and not isinstance(dim, Mapping):
dim = OrderedDict(((cast(Hashable, dim), 1),))

Expand Down
Loading

0 comments on commit 8a9c471

Please sign in to comment.