Skip to content

Commit

Permalink
Use sorted indices
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Sep 7, 2024
1 parent 9753f32 commit d0813a3
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
31 changes: 31 additions & 0 deletions _unittests/ut_df/test_connex_split_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,37 @@ def test_cat_strat(self):
lambda: train_test_apart_stratify(df, group="b", test_size=0.5), ValueError
)

def test_cat_strat_sorted(self):
df = pandas.DataFrame(
[
dict(a=1, b="e"),
dict(a=2, b="e"),
dict(a=4, b="f"),
dict(a=8, b="f"),
dict(a=32, b="f"),
dict(a=16, b="f"),
]
)

train, test = train_test_apart_stratify(
df, group="a", stratify="b", test_size=0.5, sorted_indices=True
)
self.assertEqual(train.shape[1], test.shape[1])
self.assertEqual(train.shape[0] + test.shape[0], df.shape[0])
c1 = Counter(train["b"])
c2 = Counter(train["b"])
self.assertEqual(c1, c2)

self.assertRaise(
lambda: train_test_apart_stratify(
df, group=None, stratify="b", test_size=0.5, sorted_indices=True
),
ValueError,
)
self.assertRaise(
lambda: train_test_apart_stratify(df, group="b", test_size=0.5), ValueError
)

def test_cat_strat_multi(self):
df = pandas.DataFrame(
[
Expand Down
27 changes: 18 additions & 9 deletions pandas_streaming/df/connex_split.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import Counter
from logging import getLogger
from typing import Optional, Tuple
import pandas
import numpy
from .dataframe_helpers import dataframe_shuffle
Expand Down Expand Up @@ -449,14 +450,15 @@ def double_merge(d):


def train_test_apart_stratify(
df,
df: pandas.DataFrame,
group,
test_size=0.25,
train_size=None,
stratify=None,
force=False,
random_state=None,
):
test_size: Optional[float] = 0.25,
train_size: Optional[float] = None,
stratify: Optional[str] = None,
force: bool = False,
random_state: Optional[int] = None,
sorted_indices: bool = False,
) -> Tuple["StreamingDataFrame", "StreamingDataFrame"]: # noqa: F821
"""
This split is for a specific case where data is linked
in one way. Let's assume we have two ids as we have
Expand All @@ -474,6 +476,8 @@ def train_test_apart_stratify(
:param force: if True, tries to get at least one example on the test side
for each value of the column *stratify*
:param random_state: seed for random generators
:param sorted_indices: sort index first,
see issue `41 <https://github.com/sdpython/pandas-streaming/issues/41>`
:return: Two see :class:`StreamingDataFrame
<pandas_streaming.df.dataframe.StreamingDataFrame>`, one
for train, one for test.
Expand Down Expand Up @@ -540,10 +544,15 @@ def train_test_apart_stratify(

split = {}
for _, k in sorted_hist:
not_assigned = [c for c in ids[k] if c not in split]
indices = sorted(ids[k]) if sorted_indices else ids[k]
not_assigned, assigned = [], []
for c in indices:
if c in split:
assigned.append(c)
else:
not_assigned.append(c)
if len(not_assigned) == 0:
continue
assigned = [c for c in ids[k] if c in split]
nb_test = sum(split[c] for c in assigned)
expected = min(len(ids[k]), int(test_size * len(ids[k]) + 0.5)) - nb_test
if force and expected == 0 and nb_test == 0:
Expand Down

0 comments on commit d0813a3

Please sign in to comment.