diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index 5a930a41f0300..d0582f825b529 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -181,11 +181,10 @@ def get_column_metadata(column, name, arrow_type, field_name): ) ) - assert field_name is None or isinstance(field_name, str), \ - str(type(field_name)) + assert isinstance(field_name, str), str(type(field_name)) return { 'name': name, - 'field_name': 'None' if field_name is None else field_name, + 'field_name': field_name, 'pandas_type': logical_type, 'numpy_type': string_dtype, 'metadata': extra_metadata, @@ -193,7 +192,8 @@ def get_column_metadata(column, name, arrow_type, field_name): def construct_metadata(columns_to_convert, df, column_names, index_levels, - index_descriptors, preserve_index, types): + index_descriptors, preserve_index, types, + column_field_names=None): """Returns a dictionary containing enough metadata to reconstruct a pandas DataFrame as an Arrow Table, including index columns. @@ -201,6 +201,8 @@ def construct_metadata(columns_to_convert, df, column_names, index_levels, ---------- columns_to_convert : list[pd.Series] df : pandas.DataFrame + column_names : list[str | None] + column_field_names: list[str] index_levels : List[pd.Index] index_descriptors : List[Dict] preserve_index : bool @@ -210,6 +212,12 @@ def construct_metadata(columns_to_convert, df, column_names, index_levels, ------- dict """ + if column_field_names is None: + # backwards compatibility for external projects that are using + # `construct_metadata` such as cudf + # see https://github.com/apache/arrow/pull/44963#discussion_r1875771953 + column_field_names = [str(name) for name in column_names] + num_serialized_index_levels = len([descr for descr in index_descriptors if not isinstance(descr, dict)]) # Use ntypes instead of Python shorthand notation [:-len(x)] as [:-0] @@ -219,11 +227,11 @@ def construct_metadata(columns_to_convert, df, column_names, index_levels, index_types = types[ntypes - num_serialized_index_levels:] column_metadata = [] - for col, sanitized_name, arrow_type in zip(columns_to_convert, - column_names, df_types): - metadata = get_column_metadata(col, name=sanitized_name, + for col, name, field_name, arrow_type in zip(columns_to_convert, column_names, + column_field_names, df_types): + metadata = get_column_metadata(col, name=name, arrow_type=arrow_type, - field_name=sanitized_name) + field_name=field_name) column_metadata.append(metadata) index_column_metadata = [] @@ -368,6 +376,7 @@ def _get_columns_to_convert(df, schema, preserve_index, columns): return _get_columns_to_convert_given_schema(df, schema, preserve_index) column_names = [] + column_field_names = [] index_levels = ( _get_index_level_values(df.index) if preserve_index is not False @@ -388,6 +397,7 @@ def _get_columns_to_convert(df, schema, preserve_index, columns): columns_to_convert.append(col) convert_fields.append(None) column_names.append(name) + column_field_names.append(str(name)) index_descriptors = [] index_column_names = [] @@ -403,7 +413,7 @@ def _get_columns_to_convert(df, schema, preserve_index, columns): index_column_names.append(name) index_descriptors.append(descr) - all_names = column_names + index_column_names + all_names = column_field_names + index_column_names # all_names : all of the columns in the resulting table including the data # columns and serialized index columns @@ -416,8 +426,8 @@ def _get_columns_to_convert(df, schema, preserve_index, columns): # to be converted to Arrow format # columns_fields : specified column to use for coercion / casting # during serialization, if a Schema was provided - return (all_names, column_names, index_column_names, index_descriptors, - index_levels, columns_to_convert, convert_fields) + return (all_names, column_names, column_field_names, index_column_names, + index_descriptors, index_levels, columns_to_convert, convert_fields) def _get_columns_to_convert_given_schema(df, schema, preserve_index): @@ -462,8 +472,6 @@ def _get_columns_to_convert_given_schema(df, schema, preserve_index): "specified schema".format(name)) is_index = True - name = _column_name_to_strings(name) - if _pandas_api.is_sparse(col): raise TypeError( "Sparse pandas data (column {}) not supported.".format(name)) @@ -480,8 +488,8 @@ def _get_columns_to_convert_given_schema(df, schema, preserve_index): all_names = column_names + index_column_names - return (all_names, column_names, index_column_names, index_descriptors, - index_levels, columns_to_convert, convert_fields) + return (all_names, column_names, column_names, index_column_names, + index_descriptors, index_levels, columns_to_convert, convert_fields) def _get_index_level(df, name): @@ -539,6 +547,7 @@ def _resolve_columns_of_interest(df, schema, columns): def dataframe_to_types(df, preserve_index, columns=None): (all_names, column_names, + column_field_names, _, index_descriptors, index_columns, @@ -563,8 +572,8 @@ def dataframe_to_types(df, preserve_index, columns=None): types.append(type_) metadata = construct_metadata( - columns_to_convert, df, column_names, index_columns, - index_descriptors, preserve_index, types + columns_to_convert, df, column_names, index_columns, index_descriptors, + preserve_index, types, column_field_names=column_field_names ) return all_names, types, metadata @@ -574,6 +583,7 @@ def dataframe_to_arrays(df, schema, preserve_index, nthreads=1, columns=None, safe=True): (all_names, column_names, + column_field_names, index_column_names, index_descriptors, index_columns, @@ -642,13 +652,12 @@ def _can_definitely_zero_copy(arr): if schema is None: fields = [] for name, type_ in zip(all_names, types): - name = name if name is not None else 'None' fields.append(pa.field(name, type_)) schema = pa.schema(fields) pandas_metadata = construct_metadata( - columns_to_convert, df, column_names, index_columns, - index_descriptors, preserve_index, types + columns_to_convert, df, column_names, index_columns, index_descriptors, + preserve_index, types, column_field_names=column_field_names ) metadata = deepcopy(schema.metadata) if schema.metadata else dict() metadata.update(pandas_metadata)