Skip to content

Commit

Permalink
FIX-#6822: Do not propagate NotImplementedError to a user on a 'set_c…
Browse files Browse the repository at this point in the history
…olumns()' with dupl labels (#6823)

Signed-off-by: Dmitry Chigarev <[email protected]>
  • Loading branch information
dchigarev authored Dec 13, 2023
1 parent 324099d commit acfcf34
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
14 changes: 11 additions & 3 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,11 +709,19 @@ def _set_columns(self, new_columns):
return
new_columns = self._validate_set_axis(new_columns, self._columns_cache)
if isinstance(self._dtypes, ModinDtypes):
new_value = self._dtypes.set_index(new_columns)
self.set_dtypes_cache(new_value)
try:
new_dtypes = self._dtypes.set_index(new_columns)
except NotImplementedError:
# can raise on duplicated labels
new_dtypes = None
elif isinstance(self._dtypes, pandas.Series):
self.dtypes.index = new_columns
new_dtypes = self.dtypes.set_axis(new_columns)
else:
new_dtypes = None
self.set_columns_cache(new_columns)
# we have to set new dtypes cache after columns,
# so the 'self.columns' and 'new_dtypes.index' indices would match
self.set_dtypes_cache(new_dtypes)
self.synchronize_labels(axis=1)

columns = property(_get_columns, _set_columns)
Expand Down
5 changes: 5 additions & 0 deletions modin/core/dataframe/pandas/metadata/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ def set_index(
Calling this method on a descriptor that returns ``None`` for ``.columns_order``
will result into information lose.
"""
if len(new_index) != len(set(new_index)):
raise NotImplementedError(
"Duplicated column names are not yet supported by DtypesDescriptor"
)

if self.columns_order is None:
# we can't map new columns to old columns and lost all dtypes :(
return DtypesDescriptor(
Expand Down
11 changes: 11 additions & 0 deletions modin/test/storage_formats/pandas/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,17 @@ def test_set_index_dataframe(self, initial_dtypes, result_dtypes):
assert df._dtypes._value.equals(result_dtypes)
assert df.dtypes.index.equals(pandas.Index(["col1", "col2", "col3"]))

def test_set_index_with_dupl_labels(self):
"""Verify that setting duplicated columns doesn't propagate any errors to a user."""
df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [3.5, 4.4, 5.5, 6.6]})
# making sure that dtypes are represented by an unmaterialized dtypes-descriptor
df._query_compiler._modin_frame.set_dtypes_cache(None)

df.columns = ["a", "a"]
assert df.dtypes.equals(
pandas.Series([np.dtype(int), np.dtype("float64")], index=["a", "a"])
)


class TestZeroComputationDtypes:
"""
Expand Down

0 comments on commit acfcf34

Please sign in to comment.