diff --git a/dask_expr/_categorical.py b/dask_expr/_categorical.py index 697199242..1d17bc928 100644 --- a/dask_expr/_categorical.py +++ b/dask_expr/_categorical.py @@ -170,10 +170,18 @@ class Categorize(Blockwise): @functools.cached_property def _meta(self): - meta = _categorize_block( + return _categorize_block( self.frame._meta, self.operand("categories"), self.operand("index") ) - return meta + + def _simplify_up(self, parent, dependents): + result = super()._simplify_up(parent, dependents) + if result is None: + return result + # pop potentially dropped columns from categories + cats = self.operand("categories") + cats = {k: v for k, v in cats.items() if k in result.frame.columns} + return Categorize(result.frame, cats, result.operand("index")) class GetCategories(ApplyConcatApply): diff --git a/dask_expr/tests/test_categorical.py b/dask_expr/tests/test_categorical.py index 930df3884..91f586bc0 100644 --- a/dask_expr/tests/test_categorical.py +++ b/dask_expr/tests/test_categorical.py @@ -59,3 +59,12 @@ def test_categorical_set_index(): d1, d2 = b.get_partition(0), b.get_partition(1) assert list(d1.index.compute(fuse=False)) == ["a"] assert list(sorted(d2.index.compute())) == ["b", "b", "c"] + + +def test_categorize_drops_category_columns(): + pdf = pd.DataFrame({"a": [1, 2, 1, 2, 3], "b": 1}) + df = from_pandas(pdf) + df = df.categorize(columns=["a"]) + result = df["b"].to_frame() + expected = pdf["b"].to_frame() + assert_eq(result, expected)