diff --git a/_unittests/ut_df/test_connex_split_cat.py b/_unittests/ut_df/test_connex_split_cat.py index cf72d20..7aa8264 100644 --- a/_unittests/ut_df/test_connex_split_cat.py +++ b/_unittests/ut_df/test_connex_split_cat.py @@ -39,6 +39,37 @@ def test_cat_strat(self): lambda: train_test_apart_stratify(df, group="b", test_size=0.5), ValueError ) + def test_cat_strat_sorted(self): + df = pandas.DataFrame( + [ + dict(a=1, b="e"), + dict(a=2, b="e"), + dict(a=4, b="f"), + dict(a=8, b="f"), + dict(a=32, b="f"), + dict(a=16, b="f"), + ] + ) + + train, test = train_test_apart_stratify( + df, group="a", stratify="b", test_size=0.5, sorted_indices=True + ) + self.assertEqual(train.shape[1], test.shape[1]) + self.assertEqual(train.shape[0] + test.shape[0], df.shape[0]) + c1 = Counter(train["b"]) + c2 = Counter(train["b"]) + self.assertEqual(c1, c2) + + self.assertRaise( + lambda: train_test_apart_stratify( + df, group=None, stratify="b", test_size=0.5, sorted_indices=True + ), + ValueError, + ) + self.assertRaise( + lambda: train_test_apart_stratify(df, group="b", test_size=0.5), ValueError + ) + def test_cat_strat_multi(self): df = pandas.DataFrame( [ diff --git a/pandas_streaming/df/connex_split.py b/pandas_streaming/df/connex_split.py index ce9a3a2..21b82b0 100644 --- a/pandas_streaming/df/connex_split.py +++ b/pandas_streaming/df/connex_split.py @@ -1,5 +1,6 @@ from collections import Counter from logging import getLogger +from typing import Optional, Tuple import pandas import numpy from .dataframe_helpers import dataframe_shuffle @@ -449,14 +450,15 @@ def double_merge(d): def train_test_apart_stratify( - df, + df: pandas.DataFrame, group, - test_size=0.25, - train_size=None, - stratify=None, - force=False, - random_state=None, -): + test_size: Optional[float] = 0.25, + train_size: Optional[float] = None, + stratify: Optional[str] = None, + force: bool = False, + random_state: Optional[int] = None, + sorted_indices: bool = False, +) -> Tuple["StreamingDataFrame", "StreamingDataFrame"]: # noqa: F821 """ This split is for a specific case where data is linked in one way. Let's assume we have two ids as we have @@ -474,6 +476,8 @@ def train_test_apart_stratify( :param force: if True, tries to get at least one example on the test side for each value of the column *stratify* :param random_state: seed for random generators + :param sorted_indices: sort index first, + see issue `41 ` :return: Two see :class:`StreamingDataFrame `, one for train, one for test. @@ -540,10 +544,15 @@ def train_test_apart_stratify( split = {} for _, k in sorted_hist: - not_assigned = [c for c in ids[k] if c not in split] + indices = sorted(ids[k]) if sorted_indices else ids[k] + not_assigned, assigned = [], [] + for c in indices: + if c in split: + assigned.append(c) + else: + not_assigned.append(c) if len(not_assigned) == 0: continue - assigned = [c for c in ids[k] if c in split] nb_test = sum(split[c] for c in assigned) expected = min(len(ids[k]), int(test_size * len(ids[k]) + 0.5)) - nb_test if force and expected == 0 and nb_test == 0: