Skip to content

Commit

Permalink
PERF-modin-project#6749: Preserve partial dtype for the result of 're…
Browse files Browse the repository at this point in the history
…set_index()'

Signed-off-by: Dmitry Chigarev <[email protected]>
  • Loading branch information
dchigarev committed Nov 17, 2023
1 parent bee2c28 commit ae14ce3
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 13 deletions.
12 changes: 7 additions & 5 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,11 +1323,13 @@ def from_labels(self) -> "PandasDataframe":
if "index" not in self.columns
else "level_{}".format(0)
]
new_dtypes = None
if self.has_materialized_dtypes:
names = tuple(level_names) if len(level_names) > 1 else level_names[0]
new_dtypes = self.index.to_frame(name=names).dtypes
new_dtypes = pandas.concat([new_dtypes, self.dtypes])
names = tuple(level_names) if len(level_names) > 1 else level_names[0]
new_dtypes = self.index.to_frame(name=names).dtypes
try:
new_dtypes = ModinDtypes.concat([new_dtypes, self._dtypes])
except NotImplementedError:
# can raise on duplicated labels
new_dtypes = None

# We will also use the `new_column_names` in the calculation of the internal metadata, so this is a
# lightweight way of ensuring the metadata matches.
Expand Down
21 changes: 15 additions & 6 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
SeriesGroupByDefault,
)
from modin.core.dataframe.base.dataframe.utils import join_columns
from modin.core.dataframe.pandas.metadata import ModinDtypes
from modin.core.storage_formats.base.query_compiler import BaseQueryCompiler
from modin.error_message import ErrorMessage
from modin.utils import (
Expand Down Expand Up @@ -719,20 +720,28 @@ def _reset(df, *axis_lengths, partition_idx): # pragma: no cover
df.index = pandas.RangeIndex(start, stop)
return df

if self._modin_frame.has_columns_cache and kwargs["drop"]:
new_columns = self._modin_frame.copy_columns_cache(copy_lengths=True)
new_columns = None
if kwargs["drop"]:
dtypes = self._modin_frame.copy_dtypes_cache()
if self._modin_frame.has_columns_cache:
new_columns = self._modin_frame.copy_columns_cache(
copy_lengths=True
)
else:
new_columns = None
# concat index dtypes (None, since they're unknown) with column dtypes
try:
dtypes = ModinDtypes.concat([None, self._modin_frame._dtypes])
except NotImplementedError:
# may raise on duplicated names in materialized 'self.dtypes'
dtypes = None

return self.__constructor__(
self._modin_frame.apply_full_axis(
axis=1,
func=_reset,
enumerate_partitions=True,
new_columns=new_columns,
dtypes=(
self._modin_frame._dtypes if kwargs.get("drop", False) else None
),
dtypes=dtypes,
sync_labels=False,
pass_axis_lengths_to_partitions=True,
)
Expand Down
91 changes: 89 additions & 2 deletions modin/test/storage_formats/pandas/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import modin.pandas as pd
from modin.config import Engine, ExperimentalGroupbyImpl, MinPartitionSize, NPartitions
from modin.core.dataframe.pandas.dataframe.dataframe import PandasDataframe
from modin.core.dataframe.pandas.dataframe.utils import ColumnInfo, ShuffleSortFunctions
from modin.core.dataframe.pandas.metadata import (
DtypesDescriptor,
Expand Down Expand Up @@ -1947,8 +1948,6 @@ class TestZeroComputationDtypes:
"""

def test_get_dummies_case(self):
from modin.core.dataframe.pandas.dataframe.dataframe import PandasDataframe

with mock.patch.object(PandasDataframe, "_compute_dtypes") as patch:
df = pd.DataFrame(
{"items": [1, 2, 3, 4], "b": [3, 3, 4, 4], "c": [1, 0, 0, 1]}
Expand All @@ -1960,3 +1959,91 @@ def test_get_dummies_case(self):
assert res._query_compiler._modin_frame.has_materialized_dtypes

patch.assert_not_called()

@pytest.mark.parametrize("has_materialized_index", [True, False])
@pytest.mark.parametrize("drop", [True, False])
def test_preserve_dtypes_reset_index(self, drop, has_materialized_index):
with mock.patch.object(PandasDataframe, "_compute_dtypes") as patch:
# case 1: 'df' has complete dtype by default
df = pd.DataFrame({"a": [1, 2, 3]})
if has_materialized_index:
assert df._query_compiler._modin_frame.has_materialized_index
else:
df._query_compiler._modin_frame.set_index_cache(None)
assert not df._query_compiler._modin_frame.has_materialized_index
assert df._query_compiler._modin_frame.has_materialized_dtypes

res = df.reset_index(drop=drop)
if drop:
# we droped the index, so columns and dtypes shouldn't change
assert res._query_compiler._modin_frame.has_materialized_dtypes
assert res.dtypes.equals(df.dtypes)
else:
if has_materialized_index:
# we should have inserted index dtype into the descriptor,
# and since both of them are materialized, the result should be
# materialized too
assert res._query_compiler._modin_frame.has_materialized_dtypes
assert res.dtypes.equals(
pandas.Series(
[np.dtype(int), np.dtype(int)], index=["index", "a"]
)
)
else:
# we now know that there are cols with unknown name and dtype in our dataframe,
# so the resulting dtypes should contain information only about original column
expected_dtypes = DtypesDescriptor(
{"a": np.dtype(int)},
know_all_names=False,
)
assert res._query_compiler._modin_frame._dtypes._value.equals(
expected_dtypes
)

# case 2: 'df' has partial dtype by default
df = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
df._query_compiler._modin_frame.set_dtypes_cache(
ModinDtypes(
DtypesDescriptor(
{"a": np.dtype(int)}, cols_with_unknown_dtypes=["b"]
)
)
)
if has_materialized_index:
assert df._query_compiler._modin_frame.has_materialized_index
else:
df._query_compiler._modin_frame.set_index_cache(None)
assert not df._query_compiler._modin_frame.has_materialized_index

res = df.reset_index(drop=drop)
if drop:
# we droped the index, so columns and dtypes shouldn't change
assert res._query_compiler._modin_frame._dtypes._value.equals(
df._query_compiler._modin_frame._dtypes._value
)
else:
if has_materialized_index:
# we should have inserted index dtype into the descriptor,
# the resulted dtype should have information about 'index' and 'a' columns,
# and miss dtype info for 'b' column
expected_dtypes = DtypesDescriptor(
{"index": np.dtype(int), "a": np.dtype(int)},
cols_with_unknown_dtypes=["b"],
columns_order={0: "index", 1: "a", 2: "b"},
)
assert res._query_compiler._modin_frame._dtypes._value.equals(
expected_dtypes
)
else:
# we miss info about the 'index' column since it wasn't materialized at
# the time of 'reset_index()' and we're still missing dtype info for 'b' column
expected_dtypes = DtypesDescriptor(
{"a": np.dtype(int)},
cols_with_unknown_dtypes=["b"],
know_all_names=False,
)
assert res._query_compiler._modin_frame._dtypes._value.equals(
expected_dtypes
)

patch.assert_not_called()

0 comments on commit ae14ce3

Please sign in to comment.