Skip to content

Commit

Permalink
✨ Add symdiff (pwwang/datar#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
pwwang committed Oct 6, 2023
1 parent e99d474 commit 87698b3
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 41 deletions.
85 changes: 51 additions & 34 deletions datar_pandas/api/dplyr/sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
https://github.com/tidyverse/dplyr/blob/master/R/sets.r
"""
import numpy as np
from datar.apis.dplyr import (
ungroup,
bind_rows,
Expand All @@ -10,11 +11,17 @@
setdiff,
union_all,
setequal,
symdiff,
)

from ... import pandas as pd
from ...utils import meta_kwargs
from ...pandas import DataFrame
from ...common import setdiff as _setdiff
from ...common import (
setdiff as _setdiff,
union as _union,
intersect as _intersect,
)
from ...tibble import TibbleGrouped, reconstruct_tibble


Expand Down Expand Up @@ -52,13 +59,8 @@ def _intersect_df(x: DataFrame, y: DataFrame) -> DataFrame:
from .distinct import distinct

out = distinct(
pd.merge(
x,
ungroup(y, __ast_fallback="normal", __backend="pandas"),
how="inner",
),
__ast_fallback="normal",
__backend="pandas",
pd.merge(x, ungroup(y, **meta_kwargs), how="inner"),
**meta_kwargs,
)
if isinstance(y, TibbleGrouped):
return reconstruct_tibble(out, y)
Expand All @@ -67,8 +69,8 @@ def _intersect_df(x: DataFrame, y: DataFrame) -> DataFrame:

@intersect.register(TibbleGrouped, backend="pandas")
def _intersect_grouped(x, y):
newx = ungroup(x, __ast_fallback="normal", __backend="pandas")
newy = ungroup(y, __ast_fallback="normal", __backend="pandas")
newx = ungroup(x, **meta_kwargs)
newy = ungroup(y, **meta_kwargs)
out = intersect.dispatch(DataFrame)(newx, newy)
return reconstruct_tibble(out, x)

Expand All @@ -88,13 +90,8 @@ def _union_df(x, y):
from .distinct import distinct

out = distinct(
pd.merge(
x,
ungroup(y, __ast_fallback="normal", __backend="pandas"),
how="outer",
),
__ast_fallback="normal",
__backend="pandas",
pd.merge(x, ungroup(y, **meta_kwargs), how="outer"),
**meta_kwargs,
)
out.reset_index(drop=True, inplace=True)
if isinstance(y, TibbleGrouped):
Expand All @@ -105,8 +102,8 @@ def _union_df(x, y):
@union.register(TibbleGrouped, backend="pandas")
def _union_grouped(x, y):
out = union.dispatch(DataFrame)(
ungroup(x, __ast_fallback="normal", __backend="pandas"),
ungroup(y, __ast_fallback="normal", __backend="pandas"),
ungroup(x, **meta_kwargs),
ungroup(y, **meta_kwargs),
)
return reconstruct_tibble(out, x)

Expand All @@ -126,7 +123,7 @@ def _setdiff_df(x, y):
indicator = "__datar_setdiff__"
out = pd.merge(
x,
ungroup(y, __ast_fallback="normal", __backend="pandas"),
ungroup(y, **meta_kwargs),
how="left",
indicator=indicator,
)
Expand All @@ -137,8 +134,7 @@ def _setdiff_df(x, y):
out[out[indicator] == "left_only"]
.drop(columns=[indicator])
.reset_index(drop=True),
__ast_fallback="normal",
__backend="pandas",
**meta_kwargs,
)
if isinstance(y, TibbleGrouped):
return reconstruct_tibble(out, y)
Expand All @@ -148,12 +144,17 @@ def _setdiff_df(x, y):
@setdiff.register(TibbleGrouped, backend="pandas")
def _setdiff_grouped(x, y):
out = setdiff.dispatch(DataFrame)(
ungroup(x, __ast_fallback="normal", __backend="pandas"),
ungroup(y, __ast_fallback="normal", __backend="pandas"),
ungroup(x, **meta_kwargs),
ungroup(y, **meta_kwargs),
)
return reconstruct_tibble(out, x)


@union_all.register(object, backend="pandas")
def _union_all_obj(x, y):
return np.concatenate([x, y])


@union_all.register(DataFrame, backend="pandas")
def _union_all(x, y):
"""Union of all rows of two dataframes
Expand All @@ -166,12 +167,7 @@ def _union_all(x, y):
The dataframe of union of all rows of input dataframes
"""
_check_xy(x, y)
out = bind_rows(
x,
ungroup(y, __ast_fallback="normal", __backend="pandas"),
__ast_fallback="normal",
__backend="pandas",
)
out = bind_rows(x, ungroup(y, **meta_kwargs), **meta_kwargs)
if isinstance(y, TibbleGrouped):
return reconstruct_tibble(out, y)
return out
Expand All @@ -180,8 +176,8 @@ def _union_all(x, y):
@union_all.register(TibbleGrouped, backend="pandas")
def _union_all_grouped(x, y):
out = union_all.dispatch(DataFrame)(
ungroup(x, __ast_fallback="normal", __backend="pandas"),
ungroup(y, __ast_fallback="normal", __backend="pandas"),
ungroup(x, **meta_kwargs),
ungroup(y, **meta_kwargs),
)
return reconstruct_tibble(out, x)

Expand All @@ -199,10 +195,31 @@ def _set_equal_df(x, y, equal_na=True):
Returns:
True if they equal else False
"""
x = ungroup(x, __ast_fallback="normal", __backend="pandas")
y = ungroup(y, __ast_fallback="normal", __backend="pandas")
x = ungroup(x, **meta_kwargs)
y = ungroup(y, **meta_kwargs)
_check_xy(x, y)

x = x.sort_values(by=x.columns.to_list()).reset_index(drop=True)
y = y.sort_values(by=y.columns.to_list()).reset_index(drop=True)
return x.equals(y)


@symdiff.register(object, backend="pandas")
def _symdiff(x, y):
"""Symmetric difference of two vectors"""
return _setdiff(_union(x, y), _intersect(x, y))


@symdiff.register(DataFrame, backend="pandas")
def _symdiff_df(x, y):
"""Symmetric difference of two dataframes"""
_x = ungroup(x, **meta_kwargs)
_y = ungroup(y, **meta_kwargs)
_check_xy(_x, _y)

out = setdiff(
union(_x, _y, **meta_kwargs),
intersect(_x, _y, **meta_kwargs),
**meta_kwargs,
)
return reconstruct_tibble(out, x)
15 changes: 9 additions & 6 deletions datar_pandas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,33 @@
from .pandas import unique # noqa: F401
from .typing import Data, Bool

np_meta_kwargs = {"__backend": "numpy", "__ast_fallback": "normal"}


def is_null(x: Any) -> Data[Bool]:
return pd.isnull(x)


def is_factor(x: Any) -> bool:
return _is_factor(x, __backend="pandas", __ast_fallback="normal")
from .utils import meta_kwargs
return _is_factor(x, **meta_kwargs)


def is_integer(x: Any) -> bool:
return _is_integer(x, __ast_fallback="normal", __backend="numpy")
return _is_integer(x, **np_meta_kwargs)


def is_logical(x: Any) -> bool:
return _is_logical(x, __ast_fallback="normal", __backend="numpy")
return _is_logical(x, **np_meta_kwargs)


def intersect(x: Any, y: Any) -> Any:
return _intersect(x, y, __ast_fallback="normal", __backend="numpy")
return _intersect(x, y, **np_meta_kwargs)


def setdiff(x: Any, y: Any) -> Any:
return _setdiff(x, y, __ast_fallback="normal", __backend="numpy")
return _setdiff(x, y, **np_meta_kwargs)


def union(x: Any, y: Any) -> Any:
return _union(x, y, __ast_fallback="normal", __backend="numpy")
return _union(x, y, **np_meta_kwargs)
2 changes: 2 additions & 0 deletions datar_pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

DEFAULT_COLUMN_PREFIX = "_VAR_"

meta_kwargs = {"__backend": "pandas", "__ast_fallback": "normal"}


class ExpressionWrapper:
"""A wrapper around an expression to bypass evaluation"""
Expand Down
36 changes: 35 additions & 1 deletion tests/dplyr/test_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,43 @@
distinct,
union_all,
filter,
symdiff,
)
from datar.tibble import tibble
from datar_pandas.pandas import assert_frame_equal

from ..conftest import assert_equal, assert_
from ..conftest import assert_equal, assert_, assert_iterable_equal


def test_also_works_with_vectors():
assert_iterable_equal(intersect([1, 2, 3], [3, 4]), [3])
assert_iterable_equal(union([1, 2, 3], [3, 4]), [1, 2, 3, 4])
assert_iterable_equal(union_all([1, 2, 3], [3, 4]), [1, 2, 3, 3, 4])
assert_iterable_equal(setdiff([1, 2, 3], [3, 4]), [1, 2])
assert_iterable_equal(symdiff([1, 2, 3], [3, 4]), [1, 2, 4])
assert_iterable_equal(symdiff([1, 1, 2], [2, 2, 3]), [1, 3])


def test_x_used_as_basis_of_output():
df1 = tibble(x=[1, 2, 3, 4], y=1)
df2 = tibble(y=1, x=[4, 2])

assert_frame_equal(intersect(df1, df2), tibble(x=[2, 4], y=1))
assert_frame_equal(union(df1, df2), tibble(x=[1, 2, 3, 4], y=1))
assert_frame_equal(union_all(df1, df2), tibble(x=[1, 2, 3, 4, 4, 2], y=1))
assert_frame_equal(setdiff(df1, df2), tibble(x=[1, 3], y=1))
assert_frame_equal(symdiff(df1, df2), tibble(x=[1, 3], y=1))


def test_set_removes_duplicates_except_union_all():
df1 = tibble(x=[1, 1, 2])
df2 = tibble(x=2)

assert_frame_equal(intersect(df1, df2), tibble(x=2))
assert_frame_equal(union(df1, df2), tibble(x=[1, 2]))
assert_frame_equal(union_all(df1, df2), tibble(x=[1, 1, 2, 2]))
assert_frame_equal(setdiff(df1, df2), tibble(x=1))
assert_frame_equal(symdiff(df1, df2), tibble(x=1))


def test_set_uses_coercion_rules():
Expand All @@ -46,6 +78,7 @@ def test_set_uses_coercion_rules():
assert_equal(nrow(union(df1, df2)), 3)
assert_equal(nrow(intersect(df1, df2)), 1)
assert_equal(nrow(setdiff(df1, df2)), 1)
assert_equal(nrow(symdiff(df1, df2)), 2)

df1 = tibble(x=factor(letters[:10]))
df2 = tibble(x=letters[5:15])
Expand Down Expand Up @@ -103,6 +136,7 @@ def test_set_operations_reconstruct_grouping_metadata():
exp = tibble(x=seq(1, 6), g=rep([1, 2, 3], each=2)) >> group_by(f.g)
assert out.equals(exp)
assert_equal(group_vars(out), group_vars(exp))
assert_equal(group_vars(symdiff(df1, df2)), ["g"])

out = setdiff(df1, df2) >> group_rows()
assert out == [[0, 1]]
Expand Down

0 comments on commit 87698b3

Please sign in to comment.