diff --git a/datar_pandas/api/dplyr/sets.py b/datar_pandas/api/dplyr/sets.py index d530043..b6fd52e 100644 --- a/datar_pandas/api/dplyr/sets.py +++ b/datar_pandas/api/dplyr/sets.py @@ -2,6 +2,7 @@ https://github.com/tidyverse/dplyr/blob/master/R/sets.r """ +import numpy as np from datar.apis.dplyr import ( ungroup, bind_rows, @@ -10,11 +11,17 @@ setdiff, union_all, setequal, + symdiff, ) from ... import pandas as pd +from ...utils import meta_kwargs from ...pandas import DataFrame -from ...common import setdiff as _setdiff +from ...common import ( + setdiff as _setdiff, + union as _union, + intersect as _intersect, +) from ...tibble import TibbleGrouped, reconstruct_tibble @@ -52,13 +59,8 @@ def _intersect_df(x: DataFrame, y: DataFrame) -> DataFrame: from .distinct import distinct out = distinct( - pd.merge( - x, - ungroup(y, __ast_fallback="normal", __backend="pandas"), - how="inner", - ), - __ast_fallback="normal", - __backend="pandas", + pd.merge(x, ungroup(y, **meta_kwargs), how="inner"), + **meta_kwargs, ) if isinstance(y, TibbleGrouped): return reconstruct_tibble(out, y) @@ -67,8 +69,8 @@ def _intersect_df(x: DataFrame, y: DataFrame) -> DataFrame: @intersect.register(TibbleGrouped, backend="pandas") def _intersect_grouped(x, y): - newx = ungroup(x, __ast_fallback="normal", __backend="pandas") - newy = ungroup(y, __ast_fallback="normal", __backend="pandas") + newx = ungroup(x, **meta_kwargs) + newy = ungroup(y, **meta_kwargs) out = intersect.dispatch(DataFrame)(newx, newy) return reconstruct_tibble(out, x) @@ -88,13 +90,8 @@ def _union_df(x, y): from .distinct import distinct out = distinct( - pd.merge( - x, - ungroup(y, __ast_fallback="normal", __backend="pandas"), - how="outer", - ), - __ast_fallback="normal", - __backend="pandas", + pd.merge(x, ungroup(y, **meta_kwargs), how="outer"), + **meta_kwargs, ) out.reset_index(drop=True, inplace=True) if isinstance(y, TibbleGrouped): @@ -105,8 +102,8 @@ def _union_df(x, y): @union.register(TibbleGrouped, backend="pandas") def _union_grouped(x, y): out = union.dispatch(DataFrame)( - ungroup(x, __ast_fallback="normal", __backend="pandas"), - ungroup(y, __ast_fallback="normal", __backend="pandas"), + ungroup(x, **meta_kwargs), + ungroup(y, **meta_kwargs), ) return reconstruct_tibble(out, x) @@ -126,7 +123,7 @@ def _setdiff_df(x, y): indicator = "__datar_setdiff__" out = pd.merge( x, - ungroup(y, __ast_fallback="normal", __backend="pandas"), + ungroup(y, **meta_kwargs), how="left", indicator=indicator, ) @@ -137,8 +134,7 @@ def _setdiff_df(x, y): out[out[indicator] == "left_only"] .drop(columns=[indicator]) .reset_index(drop=True), - __ast_fallback="normal", - __backend="pandas", + **meta_kwargs, ) if isinstance(y, TibbleGrouped): return reconstruct_tibble(out, y) @@ -148,12 +144,17 @@ def _setdiff_df(x, y): @setdiff.register(TibbleGrouped, backend="pandas") def _setdiff_grouped(x, y): out = setdiff.dispatch(DataFrame)( - ungroup(x, __ast_fallback="normal", __backend="pandas"), - ungroup(y, __ast_fallback="normal", __backend="pandas"), + ungroup(x, **meta_kwargs), + ungroup(y, **meta_kwargs), ) return reconstruct_tibble(out, x) +@union_all.register(object, backend="pandas") +def _union_all_obj(x, y): + return np.concatenate([x, y]) + + @union_all.register(DataFrame, backend="pandas") def _union_all(x, y): """Union of all rows of two dataframes @@ -166,12 +167,7 @@ def _union_all(x, y): The dataframe of union of all rows of input dataframes """ _check_xy(x, y) - out = bind_rows( - x, - ungroup(y, __ast_fallback="normal", __backend="pandas"), - __ast_fallback="normal", - __backend="pandas", - ) + out = bind_rows(x, ungroup(y, **meta_kwargs), **meta_kwargs) if isinstance(y, TibbleGrouped): return reconstruct_tibble(out, y) return out @@ -180,8 +176,8 @@ def _union_all(x, y): @union_all.register(TibbleGrouped, backend="pandas") def _union_all_grouped(x, y): out = union_all.dispatch(DataFrame)( - ungroup(x, __ast_fallback="normal", __backend="pandas"), - ungroup(y, __ast_fallback="normal", __backend="pandas"), + ungroup(x, **meta_kwargs), + ungroup(y, **meta_kwargs), ) return reconstruct_tibble(out, x) @@ -199,10 +195,31 @@ def _set_equal_df(x, y, equal_na=True): Returns: True if they equal else False """ - x = ungroup(x, __ast_fallback="normal", __backend="pandas") - y = ungroup(y, __ast_fallback="normal", __backend="pandas") + x = ungroup(x, **meta_kwargs) + y = ungroup(y, **meta_kwargs) _check_xy(x, y) x = x.sort_values(by=x.columns.to_list()).reset_index(drop=True) y = y.sort_values(by=y.columns.to_list()).reset_index(drop=True) return x.equals(y) + + +@symdiff.register(object, backend="pandas") +def _symdiff(x, y): + """Symmetric difference of two vectors""" + return _setdiff(_union(x, y), _intersect(x, y)) + + +@symdiff.register(DataFrame, backend="pandas") +def _symdiff_df(x, y): + """Symmetric difference of two dataframes""" + _x = ungroup(x, **meta_kwargs) + _y = ungroup(y, **meta_kwargs) + _check_xy(_x, _y) + + out = setdiff( + union(_x, _y, **meta_kwargs), + intersect(_x, _y, **meta_kwargs), + **meta_kwargs, + ) + return reconstruct_tibble(out, x) diff --git a/datar_pandas/common.py b/datar_pandas/common.py index 0849f28..88081d1 100644 --- a/datar_pandas/common.py +++ b/datar_pandas/common.py @@ -17,30 +17,33 @@ from .pandas import unique # noqa: F401 from .typing import Data, Bool +np_meta_kwargs = {"__backend": "numpy", "__ast_fallback": "normal"} + def is_null(x: Any) -> Data[Bool]: return pd.isnull(x) def is_factor(x: Any) -> bool: - return _is_factor(x, __backend="pandas", __ast_fallback="normal") + from .utils import meta_kwargs + return _is_factor(x, **meta_kwargs) def is_integer(x: Any) -> bool: - return _is_integer(x, __ast_fallback="normal", __backend="numpy") + return _is_integer(x, **np_meta_kwargs) def is_logical(x: Any) -> bool: - return _is_logical(x, __ast_fallback="normal", __backend="numpy") + return _is_logical(x, **np_meta_kwargs) def intersect(x: Any, y: Any) -> Any: - return _intersect(x, y, __ast_fallback="normal", __backend="numpy") + return _intersect(x, y, **np_meta_kwargs) def setdiff(x: Any, y: Any) -> Any: - return _setdiff(x, y, __ast_fallback="normal", __backend="numpy") + return _setdiff(x, y, **np_meta_kwargs) def union(x: Any, y: Any) -> Any: - return _union(x, y, __ast_fallback="normal", __backend="numpy") + return _union(x, y, **np_meta_kwargs) diff --git a/datar_pandas/utils.py b/datar_pandas/utils.py index e376777..18d5e81 100644 --- a/datar_pandas/utils.py +++ b/datar_pandas/utils.py @@ -27,6 +27,8 @@ DEFAULT_COLUMN_PREFIX = "_VAR_" +meta_kwargs = {"__backend": "pandas", "__ast_fallback": "normal"} + class ExpressionWrapper: """A wrapper around an expression to bypass evaluation""" diff --git a/tests/dplyr/test_sets.py b/tests/dplyr/test_sets.py index 5934a76..ad55e65 100644 --- a/tests/dplyr/test_sets.py +++ b/tests/dplyr/test_sets.py @@ -32,11 +32,43 @@ distinct, union_all, filter, + symdiff, ) from datar.tibble import tibble from datar_pandas.pandas import assert_frame_equal -from ..conftest import assert_equal, assert_ +from ..conftest import assert_equal, assert_, assert_iterable_equal + + +def test_also_works_with_vectors(): + assert_iterable_equal(intersect([1, 2, 3], [3, 4]), [3]) + assert_iterable_equal(union([1, 2, 3], [3, 4]), [1, 2, 3, 4]) + assert_iterable_equal(union_all([1, 2, 3], [3, 4]), [1, 2, 3, 3, 4]) + assert_iterable_equal(setdiff([1, 2, 3], [3, 4]), [1, 2]) + assert_iterable_equal(symdiff([1, 2, 3], [3, 4]), [1, 2, 4]) + assert_iterable_equal(symdiff([1, 1, 2], [2, 2, 3]), [1, 3]) + + +def test_x_used_as_basis_of_output(): + df1 = tibble(x=[1, 2, 3, 4], y=1) + df2 = tibble(y=1, x=[4, 2]) + + assert_frame_equal(intersect(df1, df2), tibble(x=[2, 4], y=1)) + assert_frame_equal(union(df1, df2), tibble(x=[1, 2, 3, 4], y=1)) + assert_frame_equal(union_all(df1, df2), tibble(x=[1, 2, 3, 4, 4, 2], y=1)) + assert_frame_equal(setdiff(df1, df2), tibble(x=[1, 3], y=1)) + assert_frame_equal(symdiff(df1, df2), tibble(x=[1, 3], y=1)) + + +def test_set_removes_duplicates_except_union_all(): + df1 = tibble(x=[1, 1, 2]) + df2 = tibble(x=2) + + assert_frame_equal(intersect(df1, df2), tibble(x=2)) + assert_frame_equal(union(df1, df2), tibble(x=[1, 2])) + assert_frame_equal(union_all(df1, df2), tibble(x=[1, 1, 2, 2])) + assert_frame_equal(setdiff(df1, df2), tibble(x=1)) + assert_frame_equal(symdiff(df1, df2), tibble(x=1)) def test_set_uses_coercion_rules(): @@ -46,6 +78,7 @@ def test_set_uses_coercion_rules(): assert_equal(nrow(union(df1, df2)), 3) assert_equal(nrow(intersect(df1, df2)), 1) assert_equal(nrow(setdiff(df1, df2)), 1) + assert_equal(nrow(symdiff(df1, df2)), 2) df1 = tibble(x=factor(letters[:10])) df2 = tibble(x=letters[5:15]) @@ -103,6 +136,7 @@ def test_set_operations_reconstruct_grouping_metadata(): exp = tibble(x=seq(1, 6), g=rep([1, 2, 3], each=2)) >> group_by(f.g) assert out.equals(exp) assert_equal(group_vars(out), group_vars(exp)) + assert_equal(group_vars(symdiff(df1, df2)), ["g"]) out = setdiff(df1, df2) >> group_rows() assert out == [[0, 1]]