From 74b41906d247887f5116cc48398ce62e62cdf80e Mon Sep 17 00:00:00 2001
From: Guido Imperiale <guido.imperiale@amphorainc.com>
Date: Fri, 2 Aug 2019 15:47:28 +0100
Subject: [PATCH 1/8] Annotations for Dataset.drop et al

---
 xarray/core/dataset.py         | 83 ++++++++++++++++++++++++----------
 xarray/tests/test_dataarray.py |  4 +-
 2 files changed, 61 insertions(+), 26 deletions(-)

diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index b00dad965ed..6f242da394c 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -315,10 +315,10 @@ class _LocIndexer:
     def __init__(self, dataset: 'Dataset'):
         self.dataset = dataset
 
-    def __getitem__(self, key: Mapping[str, Any]) -> 'Dataset':
+    def __getitem__(self, key: Mapping[Hashable, Any]) -> 'Dataset':
         if not utils.is_dict_like(key):
             raise TypeError('can only lookup dictionaries from Dataset.loc')
-        return self.dataset.sel(**key)
+        return self.dataset.sel(key)
 
 
 class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
@@ -3261,7 +3261,8 @@ def merge(
         return self._replace_vars_and_dims(variables, coord_names, dims,
                                            inplace=inplace)
 
-    def _assert_all_in_dataset(self, names, virtual_okay=False):
+    def _assert_all_in_dataset(self, names: Iterable[Hashable],
+                               virtual_okay: bool = False) -> None:
         bad_names = set(names) - set(self._variables)
         if virtual_okay:
             bad_names -= self.virtual_variables
@@ -3269,14 +3270,20 @@ def _assert_all_in_dataset(self, names, virtual_okay=False):
             raise ValueError('One or more of the specified variables '
                              'cannot be found in this dataset')
 
-    def drop(self, labels, dim=None, *, errors='raise'):
+    def drop(
+        self,
+        labels: Union[Hashable, Iterable[Hashable]],
+        dim: Hashable = None,
+        *,
+        errors: str = 'raise'
+    ) -> 'Dataset':
         """Drop variables or index labels from this dataset.
 
         Parameters
         ----------
-        labels : scalar or list of scalars
-            Name(s) of variables or index labels to drop.
-        dim : None or str, optional
+        labels : hashable or iterable of hashables
+            Name(s) of variables or index labels to drop
+        dim : None or hashable, optional
             Dimension along which to drop index labels. By default (if
             ``dim is None``), drops variables rather than index labels.
         errors: {'raise', 'ignore'}, optional
@@ -3291,11 +3298,18 @@ def drop(self, labels, dim=None, *, errors='raise'):
         """
         if errors not in ['raise', 'ignore']:
             raise ValueError('errors must be either "raise" or "ignore"')
-        if utils.is_scalar(labels):
-            labels = [labels]
+
         if dim is None:
+            if isinstance(labels, str) or not isinstance(labels, Iterable):
+                labels = {labels}
+            else:
+                labels = set(labels)
+
             return self._drop_vars(labels, errors=errors)
         else:
+            if utils.is_scalar(labels):
+                labels = [labels]
+
             try:
                 index = self.indexes[dim]
             except KeyError:
@@ -3304,25 +3318,38 @@ def drop(self, labels, dim=None, *, errors='raise'):
             new_index = index.drop(labels, errors=errors)
             return self.loc[{dim: new_index}]
 
-    def _drop_vars(self, names, errors='raise'):
+    def _drop_vars(
+        self,
+        names: set,
+        errors: str = 'raise'
+    ) -> 'Dataset':
         if errors == 'raise':
             self._assert_all_in_dataset(names)
-        drop = set(names)
+
         variables = OrderedDict((k, v) for k, v in self._variables.items()
-                                if k not in drop)
+                                if k not in names)
         coord_names = set(k for k in self._coord_names if k in variables)
         indexes = OrderedDict((k, v) for k, v in self.indexes.items()
-                              if k not in drop)
+                              if k not in names)
         return self._replace_with_new_dims(
             variables, coord_names=coord_names, indexes=indexes)
 
-    def drop_dims(self, drop_dims, *, errors='raise'):
+    def drop_dims(
+        self,
+        drop_dims: Union[Hashable, Iterable[Hashable]],
+        *,
+        errors: str = 'raise'
+    ) -> 'Dataset':
         """Drop dimensions and associated variables from this dataset.
 
         Parameters
         ----------
         drop_dims : str or list
             Dimension or dimensions to drop.
+        errors: {'raise', 'ignore'}, optional
+            If 'raise' (default), raises a ValueError error if any of the
+            dimensions passed are not in the dataset. If 'ignore', any given
+            labels that are in the dataset are dropped and no error is raised.
 
         Returns
         -------
@@ -3338,8 +3365,10 @@ def drop_dims(self, drop_dims, *, errors='raise'):
         if errors not in ['raise', 'ignore']:
             raise ValueError('errors must be either "raise" or "ignore"')
 
-        if utils.is_scalar(drop_dims):
+        if isinstance(drop_dims, str) or not isinstance(drop_dims, Iterable):
             drop_dims = [drop_dims]
+        else:
+            drop_dims = list(drop_dims)
 
         if errors == 'raise':
             missing_dimensions = [d for d in drop_dims if d not in self.dims]
@@ -3351,7 +3380,7 @@ def drop_dims(self, drop_dims, *, errors='raise'):
                         for d in v.dims if d in drop_dims)
         return self._drop_vars(drop_vars)
 
-    def transpose(self, *dims):
+    def transpose(self, *dims: Hashable) -> 'Dataset':
         """Return a new Dataset object with all array dimensions transposed.
 
         Although the order of dimensions on each array will change, the dataset
@@ -3359,7 +3388,7 @@ def transpose(self, *dims):
 
         Parameters
         ----------
-        *dims : str, optional
+        *dims : Hashable, optional
             By default, reverse the dimensions on each array. Otherwise,
             reorder the dimensions to this order.
 
@@ -3391,13 +3420,19 @@ def transpose(self, *dims):
             ds._variables[name] = var.transpose(*var_dims)
         return ds
 
-    def dropna(self, dim, how='any', thresh=None, subset=None):
+    def dropna(
+        self,
+        dim: Hashable,
+        how: str = 'any',
+        thresh: int = None,
+        subset: Iterable[Hashable] = None
+    ):
         """Returns a new dataset with dropped labels for missing values along
         the provided dimension.
 
         Parameters
         ----------
-        dim : str
+        dim : Hashable
             Dimension along which to drop missing values. Dropping along
             multiple dimensions simultaneously is not yet supported.
         how : {'any', 'all'}, optional
@@ -3405,8 +3440,8 @@ def dropna(self, dim, how='any', thresh=None, subset=None):
             * all : if all values are NA, drop that label
         thresh : int, default None
             If supplied, require this many non-NA values.
-        subset : sequence, optional
-            Subset of variables to check for missing values. By default, all
+        subset : iterable of hashable, optional
+            Which variables to check for missing values. By default, all
             variables in the dataset are checked.
 
         Returns
@@ -3421,7 +3456,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None):
             raise ValueError('%s must be a single dataset dimension' % dim)
 
         if subset is None:
-            subset = list(self.data_vars)
+            subset = iter(self.data_vars)
 
         count = np.zeros(self.dims[dim], dtype=np.int64)
         size = 0
@@ -3430,7 +3465,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None):
             array = self._variables[k]
             if dim in array.dims:
                 dims = [d for d in array.dims if d != dim]
-                count += np.asarray(array.count(dims))
+                count += np.asarray(array.count(dims))  # type: ignore
                 size += np.prod([self.dims[d] for d in dims])
 
         if thresh is not None:
@@ -3446,7 +3481,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None):
 
         return self.isel({dim: mask})
 
-    def fillna(self, value):
+    def fillna(self, value: Any) -> 'Dataset':
         """Fill missing values in this object.
 
         This operation follows the normal broadcasting and alignment rules that
diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py
index 5697704bdbc..000469f24bf 100644
--- a/xarray/tests/test_dataarray.py
+++ b/xarray/tests/test_dataarray.py
@@ -1904,9 +1904,9 @@ def test_drop_coordinates(self):
         assert_identical(actual, expected)
 
         with raises_regex(ValueError, 'cannot be found'):
-            arr.drop(None)
+            arr.drop('w')
 
-        actual = expected.drop(None, errors='ignore')
+        actual = expected.drop('w', errors='ignore')
         assert_identical(actual, expected)
 
         renamed = arr.rename('foo')

From 3eaa6b2d74a3aa7968e73b783827361add3c44b8 Mon Sep 17 00:00:00 2001
From: Guido Imperiale <guido.imperiale@amphorainc.com>
Date: Fri, 2 Aug 2019 16:49:38 +0100
Subject: [PATCH 2/8] Annotations for Dataset.interpolate et al

---
 xarray/core/common.py  |  9 +++--
 xarray/core/dataset.py | 92 +++++++++++++++++++++++++++---------------
 2 files changed, 66 insertions(+), 35 deletions(-)

diff --git a/xarray/core/common.py b/xarray/core/common.py
index bae3b6cd73d..93a5bb71b07 100644
--- a/xarray/core/common.py
+++ b/xarray/core/common.py
@@ -20,6 +20,7 @@
 ALL_DIMS = ReprObject('<all-dims>')
 
 
+C = TypeVar('C')
 T = TypeVar('T')
 
 
@@ -297,9 +298,11 @@ def get_index(self, key: Hashable) -> pd.Index:
             # need to ensure dtype=int64 in case range is empty on Python 2
             return pd.Index(range(self.sizes[key]), name=key, dtype=np.int64)
 
-    def _calc_assign_results(self, kwargs: Mapping[str, T]
-                             ) -> MutableMapping[str, T]:
-        results = SortedKeysDict()  # type: SortedKeysDict[str, T]
+    def _calc_assign_results(
+            self: C,
+            kwargs: Mapping[Hashable, Union[T, Callable[[C], T]]]
+    ) -> MutableMapping[Hashable, T]:
+        results = SortedKeysDict()  # type: SortedKeysDict[Hashable, T]
         for k, v in kwargs.items():
             if callable(v):
                 results[k] = v(self)
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 6f242da394c..8b34be6b527 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -6,9 +6,9 @@
 from distutils.version import LooseVersion
 from numbers import Number
 from pathlib import Path
-from typing import (Any, DefaultDict, Dict, Hashable, Iterable, Iterator, List,
-                    Mapping, MutableMapping, Optional, Sequence, Set, Tuple,
-                    Union, cast, TYPE_CHECKING)
+from typing import (Any, Callable, DefaultDict, Dict, Hashable, Iterable,
+                    Iterator, List, Mapping, MutableMapping, Optional,
+                    Sequence, Set, Tuple, Union, cast, TYPE_CHECKING)
 
 import numpy as np
 import pandas as pd
@@ -792,7 +792,7 @@ def _replace_with_new_dims(  # type: ignore
         self,
         variables: 'OrderedDict[Any, Variable]',
         coord_names: set = None,
-        attrs: 'OrderedDict' = __default,
+        attrs: Optional['OrderedDict'] = __default,
         indexes: 'OrderedDict[Any, pd.Index]' = __default,
         inplace: bool = False,
     ) -> 'Dataset':
@@ -3510,14 +3510,19 @@ def fillna(self, value: Any) -> 'Dataset':
         out = ops.fillna(self, value)
         return out
 
-    def interpolate_na(self, dim=None, method='linear', limit=None,
-                       use_coordinate=True,
-                       **kwargs):
+    def interpolate_na(
+        self,
+        dim: Hashable = None,
+        method: str = 'linear',
+        limit: int = None,
+        use_coordinate: Union[bool, Hashable] = True,
+        **kwargs: Any
+    ) -> 'Dataset':
         """Interpolate values according to different methods.
 
         Parameters
         ----------
-        dim : str
+        dim : Hashable
             Specifies the dimension along which to interpolate.
         method : {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic',
                   'polynomial', 'barycentric', 'krog', 'pchip',
@@ -3541,6 +3546,8 @@ def interpolate_na(self, dim=None, method='linear', limit=None,
         limit : int, default None
             Maximum number of consecutive NaNs to fill. Must be greater than 0
             or None for no limit.
+        kwargs : any
+            parameters passed verbatim to the underlying interplation function
 
         Returns
         -------
@@ -3559,14 +3566,14 @@ def interpolate_na(self, dim=None, method='linear', limit=None,
                                         **kwargs)
         return new
 
-    def ffill(self, dim, limit=None):
-        '''Fill NaN values by propogating values forward
+    def ffill(self, dim: Hashable, limit: int = None) -> 'Dataset':
+        """Fill NaN values by propogating values forward
 
         *Requires bottleneck.*
 
         Parameters
         ----------
-        dim : str
+        dim : Hashable
             Specifies the dimension along which to propagate values when
             filling.
         limit : int, default None
@@ -3578,14 +3585,14 @@ def ffill(self, dim, limit=None):
         Returns
         -------
         Dataset
-        '''
+        """
         from .missing import ffill, _apply_over_vars_with_dim
 
         new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit)
         return new
 
-    def bfill(self, dim, limit=None):
-        '''Fill NaN values by propogating values backward
+    def bfill(self, dim: Hashable, limit: int = None) -> 'Dataset':
+        """Fill NaN values by propogating values backward
 
         *Requires bottleneck.*
 
@@ -3603,13 +3610,13 @@ def bfill(self, dim, limit=None):
         Returns
         -------
         Dataset
-        '''
+        """
         from .missing import bfill, _apply_over_vars_with_dim
 
         new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit)
         return new
 
-    def combine_first(self, other):
+    def combine_first(self, other: 'Dataset') -> 'Dataset':
         """Combine two Datasets, default to data_vars of self.
 
         The new coordinates follow the normal broadcasting and alignment rules
@@ -3618,7 +3625,7 @@ def combine_first(self, other):
 
         Parameters
         ----------
-        other : DataArray
+        other : Dataset
             Used to fill all matching missing values in this array.
 
         Returns
@@ -3628,13 +3635,21 @@ def combine_first(self, other):
         out = ops.fillna(self, other, join="outer", dataset_join="outer")
         return out
 
-    def reduce(self, func, dim=None, keep_attrs=None, keepdims=False,
-               numeric_only=False, allow_lazy=False, **kwargs):
+    def reduce(
+        self,
+        func: Callable,
+        dim: Union[Hashable, Iterable[Hashable]] = None,
+        keep_attrs: bool = None,
+        keepdims: bool = False,
+        numeric_only: bool = False,
+        allow_lazy: bool = False,
+        **kwargs: Any
+    ) -> 'Dataset':
         """Reduce this dataset by applying `func` along some dimension(s).
 
         Parameters
         ----------
-        func : function
+        func : callable
             Function which can be called in the form
             `f(x, axis=axis, **kwargs)` to return the result of reducing an
             np.ndarray over an integer valued axis.
@@ -3651,7 +3666,7 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False,
             are removed.
         numeric_only : bool, optional
             If True, only apply ``func`` to variables with a numeric dtype.
-        **kwargs : dict
+        **kwargs : Any
             Additional keyword arguments passed on to ``func``.
 
         Returns
@@ -3662,10 +3677,10 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False,
         """
         if dim is ALL_DIMS:
             dim = None
-        if isinstance(dim, str):
-            dims = set([dim])
-        elif dim is None:
+        if dim is None:
             dims = set(self.dims)
+        elif isinstance(dim, str) or not isinstance(dim, Iterable):
+            dims = {dim}
         else:
             dims = set(dim)
 
@@ -3677,9 +3692,12 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False,
         if keep_attrs is None:
             keep_attrs = _get_keep_attrs(default=False)
 
-        variables = OrderedDict()
+        variables = OrderedDict()  # type: OrderedDict[Hashable, Variable]
         for name, var in self._variables.items():
-            reduce_dims = [d for d in var.dims if d in dims]
+            reduce_dims = [
+                d for d in var.dims
+                if d in dims
+            ]
             if name in self.coords:
                 if not reduce_dims:
                     variables[name] = var
@@ -3695,7 +3713,7 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False,
                         # prefer to aggregate over axis=None rather than
                         # axis=(0, 1) if they will be equivalent, because
                         # the former is often more efficient
-                        reduce_dims = None
+                        reduce_dims = None  # type: ignore
                     variables[name] = var.reduce(func, dim=reduce_dims,
                                                  keep_attrs=keep_attrs,
                                                  keepdims=keepdims,
@@ -3709,12 +3727,18 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False,
         return self._replace_with_new_dims(
             variables, coord_names=coord_names, attrs=attrs, indexes=indexes)
 
-    def apply(self, func, keep_attrs=None, args=(), **kwargs):
+    def apply(
+        self,
+        func: Callable,
+        keep_attrs: bool = None,
+        args: Iterable[Any] = (),
+        **kwargs: Any
+    ) -> 'Dataset':
         """Apply a function over the data variables in this dataset.
 
         Parameters
         ----------
-        func : function
+        func : callable
             Function which can be called in the form `func(x, *args, **kwargs)`
             to transform each DataArray `x` in this dataset into another
             DataArray.
@@ -3724,7 +3748,7 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs):
             be returned without attributes.
         args : tuple, optional
             Positional arguments passed on to `func`.
-        **kwargs : dict
+        **kwargs : Any
             Keyword arguments passed on to `func`.
 
         Returns
@@ -3759,7 +3783,11 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs):
         attrs = self.attrs if keep_attrs else None
         return type(self)(variables, attrs=attrs)
 
-    def assign(self, variables=None, **variables_kwargs):
+    def assign(
+        self,
+        variables: Mapping[Hashable, Any] = None,
+        **variables_kwargs: Hashable
+    ) -> 'Dataset':
         """Assign new data variables to a Dataset, returning a new object
         with all the original variables in addition to the new ones.
 
@@ -3772,7 +3800,7 @@ def assign(self, variables=None, **variables_kwargs):
             scalar, or array), they are simply assigned.
         **variables_kwargs:
             The keyword arguments form of ``variables``.
-            One of variables or variables_kwarg must be provided.
+            One of variables or variables_kwargs must be provided.
 
         Returns
         -------

From d1fc70c34eeb625d15f13bfc535abfb4bc433bb8 Mon Sep 17 00:00:00 2001
From: Guido Imperiale <guido.imperiale@amphorainc.com>
Date: Fri, 2 Aug 2019 17:30:08 +0100
Subject: [PATCH 3/8] Dataset.drop(DataArray)

---
 xarray/core/dataset.py       | 12 ++++++++++--
 xarray/tests/test_dataset.py |  6 ++++++
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 8b34be6b527..f91a84b6578 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -3272,7 +3272,7 @@ def _assert_all_in_dataset(self, names: Iterable[Hashable],
 
     def drop(
         self,
-        labels: Union[Hashable, Iterable[Hashable]],
+        labels: Union[Hashable, Iterable[Hashable], DataArray],
         dim: Hashable = None,
         *,
         errors: str = 'raise'
@@ -3282,7 +3282,9 @@ def drop(
         Parameters
         ----------
         labels : hashable or iterable of hashables
-            Name(s) of variables or index labels to drop
+            Name(s) of variables or index labels to drop.
+            If dim is not None, it can be anything that can be cast to
+            a numpy array (e.g. a DataArray).
         dim : None or hashable, optional
             Dimension along which to drop index labels. By default (if
             ``dim is None``), drops variables rather than index labels.
@@ -3302,13 +3304,19 @@ def drop(
         if dim is None:
             if isinstance(labels, str) or not isinstance(labels, Iterable):
                 labels = {labels}
+            elif isinstance(labels, DataArray):
+                raise ValueError(
+                    "DataArray labels are only supported when dropping indices")
             else:
                 labels = set(labels)
 
             return self._drop_vars(labels, errors=errors)
         else:
+            # Don't cast to set, as it would harm performance when labels
+            # is a large numpy array
             if utils.is_scalar(labels):
                 labels = [labels]
+            labels = np.asarray(labels)
 
             try:
                 index = self.indexes[dim]
diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py
index fc6f7f36938..fc15393f269 100644
--- a/xarray/tests/test_dataset.py
+++ b/xarray/tests/test_dataset.py
@@ -2000,6 +2000,12 @@ def test_drop_index_labels(self):
         expected = data.isel(x=slice(0, 0))
         assert_identical(expected, actual)
 
+        # DataArrays as labels are a nasty corner case as they are not
+        # Iterable[Hashable] - DataArray.__iter__ yields scalar DataArrays.
+        actual = data.drop(DataArray(['a', 'b', 'c']), 'x', errors='ignore')
+        expected = data.isel(x=slice(0, 0))
+        assert_identical(expected, actual)
+
         with raises_regex(
                 ValueError, 'does not have coordinate labels'):
             data.drop(1, 'y')

From 66cf5921e4766a5fa671229a4d50f65865ea72a5 Mon Sep 17 00:00:00 2001
From: Guido Imperiale <guido.imperiale@amphorainc.com>
Date: Fri, 2 Aug 2019 17:31:16 +0100
Subject: [PATCH 4/8] flake8

---
 xarray/core/dataset.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index f91a84b6578..22b7500e580 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -3305,8 +3305,8 @@ def drop(
             if isinstance(labels, str) or not isinstance(labels, Iterable):
                 labels = {labels}
             elif isinstance(labels, DataArray):
-                raise ValueError(
-                    "DataArray labels are only supported when dropping indices")
+                raise ValueError("DataArray labels are only supported when "
+                                 "dropping indices")
             else:
                 labels = set(labels)
 

From 197ddc2062418f47ce4983c831cb5124b49bd7be Mon Sep 17 00:00:00 2001
From: Guido Imperiale <guido.imperiale@amphorainc.com>
Date: Fri, 2 Aug 2019 17:38:35 +0100
Subject: [PATCH 5/8] trivial

---
 xarray/core/dataset.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 22b7500e580..82bfd4a7764 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -3272,7 +3272,7 @@ def _assert_all_in_dataset(self, names: Iterable[Hashable],
 
     def drop(
         self,
-        labels: Union[Hashable, Iterable[Hashable], DataArray],
+        labels: Union[Hashable, Iterable[Hashable], 'DataArray'],
         dim: Hashable = None,
         *,
         errors: str = 'raise'

From 355c3901cad6d5964b7a7684041d44890c4c5ff4 Mon Sep 17 00:00:00 2001
From: Guido Imperiale <guido.imperiale@amphorainc.com>
Date: Fri, 2 Aug 2019 17:52:18 +0100
Subject: [PATCH 6/8] @overload Dataset.drop

---
 xarray/core/dataarray.py | 30 ++++++++++++++++++++++++------
 xarray/core/dataset.py   | 25 +++++++++++++++++++------
 2 files changed, 43 insertions(+), 12 deletions(-)

diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index 0e28613323e..c96eaef51ee 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -4,7 +4,8 @@
 from collections import OrderedDict
 from numbers import Number
 from typing import (Any, Callable, Dict, Hashable, Iterable, List, Mapping,
-                    Optional, Sequence, Tuple, Union, cast, TYPE_CHECKING)
+                    Optional, Sequence, Tuple, Union, cast, overload,
+                    TYPE_CHECKING)
 
 import numpy as np
 import pandas as pd
@@ -1752,11 +1753,28 @@ def transpose(self,
     def T(self) -> 'DataArray':
         return self.transpose()
 
-    def drop(self,
-             labels: Union[Hashable, Sequence[Hashable]],
-             dim: Hashable = None,
-             *,
-             errors: str = 'raise') -> 'DataArray':
+    # Drop coords
+    @overload
+    def drop(
+        self,
+        labels: Union[Hashable, Iterable[Hashable]],
+        *,
+        errors: str = 'raise'
+    ) -> 'DataArray':
+        ...
+
+    # Drop index labels along dimension
+    @overload  # noqa: F811
+    def drop(
+        self,
+        labels: Any,  # array-like
+        dim: Hashable,
+        *,
+        errors: str = 'raise'
+    ) -> 'DataArray':
+        ...
+
+    def drop(self, labels, dim=None, *, errors='raise'):  # noqa: F811
         """Drop coordinates or index labels from this DataArray.
 
         Parameters
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 82bfd4a7764..f90f9946cce 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -8,7 +8,8 @@
 from pathlib import Path
 from typing import (Any, Callable, DefaultDict, Dict, Hashable, Iterable,
                     Iterator, List, Mapping, MutableMapping, Optional,
-                    Sequence, Set, Tuple, Union, cast, TYPE_CHECKING)
+                    Sequence, Set, Tuple, Union, cast, overload,
+                    TYPE_CHECKING)
 
 import numpy as np
 import pandas as pd
@@ -3270,13 +3271,28 @@ def _assert_all_in_dataset(self, names: Iterable[Hashable],
             raise ValueError('One or more of the specified variables '
                              'cannot be found in this dataset')
 
+    # Drop variables
+    @overload
     def drop(
         self,
-        labels: Union[Hashable, Iterable[Hashable], 'DataArray'],
-        dim: Hashable = None,
+        labels: Union[Hashable, Iterable[Hashable]],
         *,
         errors: str = 'raise'
     ) -> 'Dataset':
+        ...
+
+    # Drop index labels along dimension
+    @overload  # noqa: F811
+    def drop(
+        self,
+        labels: Any,  # array-like
+        dim: Hashable,
+        *,
+        errors: str = 'raise'
+    ) -> 'Dataset':
+        ...
+
+    def drop(self, labels, dim=None, *, errors='raise'):  # noqa: F811
         """Drop variables or index labels from this dataset.
 
         Parameters
@@ -3304,9 +3320,6 @@ def drop(
         if dim is None:
             if isinstance(labels, str) or not isinstance(labels, Iterable):
                 labels = {labels}
-            elif isinstance(labels, DataArray):
-                raise ValueError("DataArray labels are only supported when "
-                                 "dropping indices")
             else:
                 labels = set(labels)
 

From fdc205bd4d629409e3b32b8f1ccf2553f73d7dfd Mon Sep 17 00:00:00 2001
From: Guido Imperiale <guido.imperiale@amphorainc.com>
Date: Fri, 2 Aug 2019 17:55:20 +0100
Subject: [PATCH 7/8] docstring tweaks

---
 xarray/core/dataarray.py | 3 ++-
 xarray/core/dataset.py   | 3 +--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index c96eaef51ee..3b44e51d0f5 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -1780,7 +1780,8 @@ def drop(self, labels, dim=None, *, errors='raise'):  # noqa: F811
         Parameters
         ----------
         labels : hashable or sequence of hashables
-            Name(s) of coordinate variables or index labels to drop.
+            Name(s) of coordinates or index labels to drop.
+            If dim is not None, labels can be any array-like.
         dim : hashable, optional
             Dimension along which to drop index labels. By default (if
             ``dim is None``), drops coordinates rather than index labels.
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index f90f9946cce..5d3ca932ccc 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -3299,8 +3299,7 @@ def drop(self, labels, dim=None, *, errors='raise'):  # noqa: F811
         ----------
         labels : hashable or iterable of hashables
             Name(s) of variables or index labels to drop.
-            If dim is not None, it can be anything that can be cast to
-            a numpy array (e.g. a DataArray).
+            If dim is not None, labels can be any array-like.
         dim : None or hashable, optional
             Dimension along which to drop index labels. By default (if
             ``dim is None``), drops variables rather than index labels.

From 816e13cb273d35739d85e14a6708e5248a7c064b Mon Sep 17 00:00:00 2001
From: Guido Imperiale <crusaderky@gmail.com>
Date: Mon, 5 Aug 2019 23:11:25 +0100
Subject: [PATCH 8/8] Clean up redundant code

---
 xarray/core/dataarray.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index 3b44e51d0f5..40966f684a2 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -1794,8 +1794,6 @@ def drop(self, labels, dim=None, *, errors='raise'):  # noqa: F811
         -------
         dropped : DataArray
         """
-        if utils.is_scalar(labels):
-            labels = [labels]
         ds = self._to_temp_dataset().drop(labels, dim, errors=errors)
         return self._from_temp_dataset(ds)