diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 739456a1..fdb0bd8f 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -3193,15 +3193,14 @@ def _fusion_pass(expr): dependents[next._name] = set() expr_mapping[next._name] = next - for operand in next.operands: - if isinstance(operand, Expr): - stack.append(operand) - if is_valid_blockwise_op(operand): - if next._name in dependencies: - dependencies[next._name].add(operand._name) - dependents[operand._name].add(next._name) - expr_mapping[operand._name] = operand - expr_mapping[next._name] = next + for operand in next.dependencies(): + stack.append(operand) + if is_valid_blockwise_op(operand): + if next._name in dependencies: + dependencies[next._name].add(operand._name) + dependents[operand._name].add(next._name) + expr_mapping[operand._name] = operand + expr_mapping[next._name] = next # Traverse each "root" until we find a fusable sub-group. # Here we use root to refer to a Blockwise Expr node that diff --git a/dask_expr/_indexing.py b/dask_expr/_indexing.py index 2bb12a23..f1abfb9e 100644 --- a/dask_expr/_indexing.py +++ b/dask_expr/_indexing.py @@ -98,6 +98,8 @@ def _loc(self, iindexer, cindexer): elif is_series_like(iindexer) and not is_bool_dtype(iindexer.dtype): return new_collection(LocList(self.obj, iindexer.values, cindexer)) elif isinstance(iindexer, list) or is_arraylike(iindexer): + if len(iindexer) == 0: + return new_collection(LocEmpty(self.obj._meta, cindexer)) return new_collection(LocList(self.obj, iindexer, cindexer)) else: # element should raise KeyError @@ -250,6 +252,26 @@ def _layer(self) -> dict: return self._layer_information[0] +class LocEmpty(LocList): + _parameters = ["meta", "cindexer"] + + def _lower(self): + return None + + @functools.cached_property + def _meta(self): + if self.cindexer is None: + return self.operand("meta") + else: + return self.operand("meta").loc[:, self.cindexer] + + @functools.cached_property + def _layer_information(self): + divisions = [None, None] + dsk = {(self._name, 0): DataNode((self._name, 0), self._meta)} + return dsk, divisions + + class LocSlice(LocBase): @functools.cached_property def start(self): diff --git a/dask_expr/tests/test_indexing.py b/dask_expr/tests/test_indexing.py index b152fdf6..bf378945 100644 --- a/dask_expr/tests/test_indexing.py +++ b/dask_expr/tests/test_indexing.py @@ -35,15 +35,6 @@ def test_iloc_errors(df): df.iloc[(1, 2, 3)] -def test_loc(df, pdf): - assert_eq(df.loc[:, "x"], pdf.loc[:, "x"]) - assert_eq(df.loc[:, ["x"]], pdf.loc[:, ["x"]]) - assert_eq(df.loc[:, []], pdf.loc[:, []]) - - assert_eq(df.loc[df.y == 20, "x"], pdf.loc[pdf.y == 20, "x"]) - assert_eq(df.loc[df.y == 20, ["x"]], pdf.loc[pdf.y == 20, ["x"]]) - - def test_loc_slice(pdf, df): pdf.columns = [10, 20] df.columns = [10, 20] @@ -86,6 +77,12 @@ def test_columns_dtype_on_empty_slice(df, pdf, loc, update): def test_loc(df, pdf): + assert_eq(df.loc[:, "x"], pdf.loc[:, "x"]) + assert_eq(df.loc[:, ["x"]], pdf.loc[:, ["x"]]) + assert_eq(df.loc[:, []], pdf.loc[:, []]) + + assert_eq(df.loc[df.y == 20, "x"], pdf.loc[pdf.y == 20, "x"]) + assert_eq(df.loc[df.y == 20, ["x"]], pdf.loc[pdf.y == 20, ["x"]]) assert df.loc[3:8].divisions[0] == 3 assert df.loc[3:8].divisions[-1] == 8