diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index ff99b8a4..2f246f21 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -738,6 +738,13 @@ def isin(self, values): else: bad_types = (FrameBase,) if isinstance(values, bad_types): + if ( + isinstance(values, FrameBase) + and values.ndim == 1 + and values.npartitions == 1 + ): + # Can broadcast + return new_collection(expr.Isin(self, values=values)) raise NotImplementedError("Passing a %r to `isin`" % typename(type(values))) # We wrap values in a delayed for two reasons: diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 6a2f1fa3..6a6f7006 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1362,6 +1362,9 @@ class Isin(Elemwise): def _meta(self): return make_meta(meta_nonempty(self.frame._meta).isin([1])) + def _broadcast_dep(self, dep: Expr): + return dep.npartitions == 1 + class Clip(Elemwise): _projection_passthrough = True @@ -3044,12 +3047,12 @@ def are_co_aligned(*exprs): # Scalars are valid ancestors that are always broadcastable, # so don't walk through them continue + elif isinstance(e, (_DelayedExpr, Isin)): + continue elif isinstance(e, (Blockwise, CumulativeAggregations, Reduction)): # TODO: Capture this in inheritance logic dependencies = e.dependencies() stack.extend(dependencies) - elif isinstance(e, _DelayedExpr): - continue else: ancestors.append(e) diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index 11abe032..7c0b0510 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -85,6 +85,8 @@ def _as_dict(key, value): def _adjust_split_out_for_group_keys(npartitions, by): + if len(by) == 1: + return math.ceil(npartitions / 15) return math.ceil(npartitions / (10 / (len(by) - 1))) @@ -222,7 +224,7 @@ def _projection_columns(self): return self.frame.columns def _tune_down(self): - if len(self.by) > 1 and self.operand("split_out") is None: + if self.operand("split_out") is None: return self.substitute_parameters( { "split_out": functools.partial( @@ -674,7 +676,7 @@ class GroupByReduction(Reduction, GroupByBase): _chunk_cls = GroupByChunk def _tune_down(self): - if len(self.by) > 1 and self.operand("split_out") is None: + if self.operand("split_out") is None: return self.substitute_parameters( { "split_out": functools.partial( diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index 083f580a..97936b7f 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -113,7 +113,9 @@ def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): # Move the column projection to come # before the abstract Shuffle - projection = determine_column_projection(self, parent, dependents) + projection = _convert_to_list( + determine_column_projection(self, parent, dependents) + ) partitioning_index = self._partitioning_index target = self.frame diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index d8bf915f..e89a605d 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -821,6 +821,8 @@ def sample_statistics(self, n=3): ixs = [] for i in range(0, nfrags, stepsize): sort_ix = finfo_argsort[i] + # TODO: This is crude but the most conservative estimate + sort_ix = sort_ix if sort_ix < nfrags else 0 ixs.append(sort_ix) finfos_sampled.append(finfos[sort_ix]) frags_samples.append(frags[sort_ix]) @@ -1001,17 +1003,17 @@ def fragments_unsorted(self): @property def _fusion_compression_factor(self): - if self.operand("columns") is None: - return 1 approx_stats = self.approx_statistics() total_uncompressed = 0 after_projection = 0 - col_op = self.operand("columns") + col_op = self.operand("columns") or self.columns for col in approx_stats["columns"]: total_uncompressed += col["total_uncompressed_size"] if col["path_in_schema"] in col_op: after_projection += col["total_uncompressed_size"] + min_size = dask.config.get("dataframe.parquet.minimum-partition-size") + total_uncompressed = max(total_uncompressed, min_size) return max(after_projection / total_uncompressed, 0.001) def _filtered_task(self, index: int): diff --git a/dask_expr/io/tests/test_parquet.py b/dask_expr/io/tests/test_parquet.py index 61f619d0..ca12f464 100644 --- a/dask_expr/io/tests/test_parquet.py +++ b/dask_expr/io/tests/test_parquet.py @@ -1,6 +1,7 @@ import os import pickle +import dask import pandas as pd import pytest from dask.dataframe.utils import assert_eq @@ -165,32 +166,33 @@ def test_pyarrow_filesystem_list_of_files(parquet_file, second_parquet_file): def test_partition_pruning(tmpdir): - filesystem = fs.LocalFileSystem() - df = from_pandas( - pd.DataFrame( - { - "a": [1, 2, 3, 4, 5] * 10, - "b": range(50), - } - ), - npartitions=2, - ) - df.to_parquet(tmpdir, partition_on=["a"]) - ddf = read_parquet(tmpdir, filesystem=filesystem) - ddf_filtered = read_parquet( - tmpdir, filters=[[("a", "==", 1)]], filesystem=filesystem - ) - assert ddf_filtered.npartitions == ddf.npartitions // 5 - - ddf_optimize = read_parquet(tmpdir, filesystem=filesystem) - ddf_optimize = ddf_optimize[ddf_optimize.a == 1].optimize() - assert ddf_optimize.npartitions == ddf.npartitions // 5 - assert_eq( - ddf_filtered, - ddf_optimize, - # FIXME ? - check_names=False, - ) + with dask.config.set({"dataframe.parquet.minimum-partition-size": 1}): + filesystem = fs.LocalFileSystem() + df = from_pandas( + pd.DataFrame( + { + "a": [1, 2, 3, 4, 5] * 10, + "b": range(50), + } + ), + npartitions=2, + ) + df.to_parquet(tmpdir, partition_on=["a"]) + ddf = read_parquet(tmpdir, filesystem=filesystem) + ddf_filtered = read_parquet( + tmpdir, filters=[[("a", "==", 1)]], filesystem=filesystem + ) + assert ddf_filtered.npartitions == ddf.npartitions // 5 + + ddf_optimize = read_parquet(tmpdir, filesystem=filesystem) + ddf_optimize = ddf_optimize[ddf_optimize.a == 1].optimize() + assert ddf_optimize.npartitions == ddf.npartitions // 5 + assert_eq( + ddf_filtered, + ddf_optimize, + # FIXME ? + check_names=False, + ) def test_predicate_pushdown(tmpdir): diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index 23a86022..0b6cc1f5 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -319,12 +319,12 @@ def test_groupby_agg_column_projection(pdf, df): def test_groupby_split_every(pdf): df = from_pandas(pdf, npartitions=16) - query = df.groupby("x").sum() + query = df.groupby("x").sum(split_out=1) tree_reduce_node = list(query.optimize(fuse=False).find_operations(TreeReduce)) assert len(tree_reduce_node) == 1 assert tree_reduce_node[0].split_every == 8 - query = df.groupby("x").aggregate({"y": "sum"}) + query = df.groupby("x").aggregate({"y": "sum"}, split_out=1) tree_reduce_node = list(query.optimize(fuse=False).find_operations(TreeReduce)) assert len(tree_reduce_node) == 1 assert tree_reduce_node[0].split_every == 8 @@ -352,7 +352,7 @@ def test_split_out_automatically(): pdf = pd.DataFrame({"a": [1, 2, 3] * 1_000, "b": 1, "c": 1, "d": 1}) df = from_pandas(pdf, npartitions=500) q = df.groupby("a").sum() - assert q.optimize().npartitions == 1 + assert q.optimize().npartitions == 34 expected = pdf.groupby("a").sum() assert_eq(q, expected)