diff --git a/python/cudf/cudf/core/reshape.py b/python/cudf/cudf/core/reshape.py index edcc4ab3e63..160c88f0684 100644 --- a/python/cudf/cudf/core/reshape.py +++ b/python/cudf/cudf/core/reshape.py @@ -438,13 +438,16 @@ def concat(objs, axis=0, join="outer", ignore_index=False, sort=None): else: df[col_label] = col - if ignore_index: - # with ignore_index the column names change to numbers - df.columns = pd.RangeIndex(len(result_columns)) - elif not only_series: - df.columns = cudf.MultiIndex.from_tuples(df._column_names) + if keys is None: + df.columns = result_columns.unique() + if ignore_index: + df.columns = pd.RangeIndex(len(result_columns.unique())) else: - pass + if ignore_index: + # with ignore_index the column names change to numbers + df.columns = pd.RangeIndex(len(result_columns)) + elif not only_series: + df.columns = cudf.MultiIndex.from_tuples(df._column_names) if empty_inner: # if join is inner and it contains an empty df diff --git a/python/cudf/cudf/tests/test_concat.py b/python/cudf/cudf/tests/test_concat.py index ede8c620d2d..e4c6ab7daa4 100644 --- a/python/cudf/cudf/tests/test_concat.py +++ b/python/cudf/cudf/tests/test_concat.py @@ -218,7 +218,8 @@ def test_concat_columns(axis): assert_eq(expect, got, check_index_type=True) -def test_concat_multiindex_dataframe(): +@pytest.mark.parametrize("axis", [0, 1]) +def test_concat_multiindex_dataframe(axis): gdf = cudf.DataFrame( { "w": np.arange(4), @@ -233,14 +234,11 @@ def test_concat_multiindex_dataframe(): pdg2 = pdg.iloc[:, 1:] gdg1 = cudf.from_pandas(pdg1) gdg2 = cudf.from_pandas(pdg2) + expected = pd.concat([pdg1, pdg2], axis=axis) + result = cudf.concat([gdg1, gdg2], axis=axis) assert_eq( - cudf.concat([gdg1, gdg2]).astype("float64"), - pd.concat([pdg1, pdg2]), - check_index_type=True, - ) - assert_eq( - cudf.concat([gdg1, gdg2], axis=1), - pd.concat([pdg1, pdg2], axis=1), + expected, + result, check_index_type=True, )