Skip to content

Commit

Permalink
ARROW-16000: [C++][Python] Dataset: Alternative implementation for ad…
Browse files Browse the repository at this point in the history
…ding transcoding function option to CSV scanner (#13820)

This is an alternative version of #13709, to compare what the best approach is.

Instead of extending the C++ ReadOptions struct with an `encoding` field, this implementations adds a python version of the ReadOptions object to both `CsvFileFormat` and `CsvFragmentScanOptions`. The reason it is needed in both places, is to prevent these kinds of inconsistencies:
```
>>> import pyarrow.dataset as ds
>>> import pyarrow.csv as csv
>>> ro =csv.ReadOptions(encoding='iso8859')
>>> fo = ds.CsvFileFormat(read_options=ro)
>>> fo.default_fragment_scan_options.read_options.encoding
'utf8'
```

Authored-by: Joost Hoozemans <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
joosthooz authored Sep 6, 2022
1 parent a5ecb0f commit cbf0ec0
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 5 deletions.
17 changes: 15 additions & 2 deletions cpp/src/arrow/dataset/file_csv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,15 @@ static inline Future<std::shared_ptr<csv::StreamingReader>> OpenReaderAsync(
auto tracer = arrow::internal::tracing::GetTracer();
auto span = tracer->StartSpan("arrow::dataset::CsvFileFormat::OpenReaderAsync");
#endif
ARROW_ASSIGN_OR_RAISE(
auto fragment_scan_options,
GetFragmentScanOptions<CsvFragmentScanOptions>(
kCsvTypeName, scan_options.get(), format.default_fragment_scan_options));
ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options));

ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed());
if (fragment_scan_options->stream_transform_func) {
ARROW_ASSIGN_OR_RAISE(input, fragment_scan_options->stream_transform_func(input));
}
const auto& path = source.path();
ARROW_ASSIGN_OR_RAISE(
input, io::BufferedInputStream::Create(reader_options.block_size,
Expand Down Expand Up @@ -289,8 +295,15 @@ Future<util::optional<int64_t>> CsvFileFormat::CountRows(
return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
}
auto self = checked_pointer_cast<CsvFileFormat>(shared_from_this());
ARROW_ASSIGN_OR_RAISE(auto input, file->source().OpenCompressed());
ARROW_ASSIGN_OR_RAISE(
auto fragment_scan_options,
GetFragmentScanOptions<CsvFragmentScanOptions>(
kCsvTypeName, options.get(), self->default_fragment_scan_options));
ARROW_ASSIGN_OR_RAISE(auto read_options, GetReadOptions(*self, options));
ARROW_ASSIGN_OR_RAISE(auto input, file->source().OpenCompressed());
if (fragment_scan_options->stream_transform_func) {
ARROW_ASSIGN_OR_RAISE(input, fragment_scan_options->stream_transform_func(input));
}
return csv::CountRowsAsync(options->io_context, std::move(input),
::arrow::internal::GetCpuThreadPool(), read_options,
self->parse_options)
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/dataset/file_csv.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,23 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat {
struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions {
std::string type_name() const override { return kCsvTypeName; }

using StreamWrapFunc = std::function<Result<std::shared_ptr<io::InputStream>>(
std::shared_ptr<io::InputStream>)>;

/// CSV conversion options
csv::ConvertOptions convert_options = csv::ConvertOptions::Defaults();

/// CSV reading options
///
/// Note that use_threads is always ignored.
csv::ReadOptions read_options = csv::ReadOptions::Defaults();

/// Optional stream wrapping function
///
/// If defined, all open dataset file fragments will be passed
/// through this function. One possible use case is to transparently
/// transcode all input files from a given character set to utf8.
StreamWrapFunc stream_transform_func{};
};

class ARROW_DS_EXPORT CsvFileWriteOptions : public FileWriteOptions {
Expand Down
30 changes: 28 additions & 2 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from cython.operator cimport dereference as deref

import codecs
import collections
import os
import warnings
Expand Down Expand Up @@ -831,8 +832,14 @@ cdef class FileFormat(_Weakrefable):

@property
def default_fragment_scan_options(self):
return FragmentScanOptions.wrap(
dfso = FragmentScanOptions.wrap(
self.wrapped.get().default_fragment_scan_options)
# CsvFileFormat stores a Python-specific encoding field that needs
# to be restored because it does not exist in the C++ struct
if isinstance(self, CsvFileFormat):
if self._read_options_py is not None:
dfso.read_options = self._read_options_py
return dfso

@default_fragment_scan_options.setter
def default_fragment_scan_options(self, FragmentScanOptions options):
Expand Down Expand Up @@ -1178,6 +1185,10 @@ cdef class CsvFileFormat(FileFormat):
"""
cdef:
CCsvFileFormat* csv_format
# The encoding field in ReadOptions does not exist in the C++ struct.
# We need to store it here and override it when reading
# default_fragment_scan_options.read_options
public ReadOptions _read_options_py

# Avoid mistakingly creating attributes
__slots__ = ()
Expand Down Expand Up @@ -1205,6 +1216,8 @@ cdef class CsvFileFormat(FileFormat):
raise TypeError('`default_fragment_scan_options` must be either '
'a dictionary or an instance of '
'CsvFragmentScanOptions')
if read_options is not None:
self._read_options_py = read_options

cdef void init(self, const shared_ptr[CFileFormat]& sp):
FileFormat.init(self, sp)
Expand All @@ -1227,6 +1240,8 @@ cdef class CsvFileFormat(FileFormat):
cdef _set_default_fragment_scan_options(self, FragmentScanOptions options):
if options.type_name == 'csv':
self.csv_format.default_fragment_scan_options = options.wrapped
self.default_fragment_scan_options.read_options = options.read_options
self._read_options_py = options.read_options
else:
super()._set_default_fragment_scan_options(options)

Expand Down Expand Up @@ -1258,6 +1273,9 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions):

cdef:
CCsvFragmentScanOptions* csv_options
# The encoding field in ReadOptions does not exist in the C++ struct.
# We need to store it here and override it when reading read_options
ReadOptions _read_options_py

# Avoid mistakingly creating attributes
__slots__ = ()
Expand All @@ -1270,6 +1288,7 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions):
self.convert_options = convert_options
if read_options is not None:
self.read_options = read_options
self._read_options_py = read_options

cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp):
FragmentScanOptions.init(self, sp)
Expand All @@ -1285,11 +1304,18 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions):

@property
def read_options(self):
return ReadOptions.wrap(self.csv_options.read_options)
read_options = ReadOptions.wrap(self.csv_options.read_options)
if self._read_options_py is not None:
read_options.encoding = self._read_options_py.encoding
return read_options

@read_options.setter
def read_options(self, ReadOptions read_options not None):
self.csv_options.read_options = deref(read_options.options)
self._read_options_py = read_options
if codecs.lookup(read_options.encoding).name != 'utf-8':
self.csv_options.stream_transform_func = deref(
make_streamwrap_func(read_options.encoding, 'utf-8'))

def equals(self, CsvFragmentScanOptions other):
return (
Expand Down
8 changes: 8 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,9 @@ cdef extern from "arrow/builder.h" namespace "arrow" nogil:
ctypedef void CallbackTransform(object, const shared_ptr[CBuffer]& src,
shared_ptr[CBuffer]* dest)

ctypedef CResult[shared_ptr[CInputStream]] StreamWrapFunc(
shared_ptr[CInputStream])


cdef extern from "arrow/util/cancel.h" namespace "arrow" nogil:
cdef cppclass CStopToken "arrow::StopToken":
Expand Down Expand Up @@ -1396,6 +1399,11 @@ cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil:
shared_ptr[CInputStream] wrapped, CTransformInputStreamVTable vtable,
object method_arg)

shared_ptr[function[StreamWrapFunc]] MakeStreamTransformFunc \
"arrow::py::MakeStreamTransformFunc"(
CTransformInputStreamVTable vtable,
object method_arg)

# ----------------------------------------------------------------------
# HDFS

Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
"arrow::dataset::CsvFragmentScanOptions"(CFragmentScanOptions):
CCSVConvertOptions convert_options
CCSVReadOptions read_options
function[StreamWrapFunc] stream_transform_func

cdef cppclass CPartitioning "arrow::dataset::Partitioning":
c_string type_name() const
Expand Down
29 changes: 28 additions & 1 deletion python/pyarrow/io.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,33 @@ class Transcoder:
return self._encoder.encode(self._decoder.decode(buf, final), final)


cdef shared_ptr[function[StreamWrapFunc]] make_streamwrap_func(
src_encoding, dest_encoding) except *:
"""
Create a function that will add a transcoding transformation to a stream.
Data from that stream will be decoded according to ``src_encoding`` and
then re-encoded according to ``dest_encoding``.
The created function can be used to wrap streams.
Parameters
----------
src_encoding : str
The codec to use when reading data.
dest_encoding : str
The codec to use for emitted data.
"""
cdef:
shared_ptr[function[StreamWrapFunc]] empty_func
CTransformInputStreamVTable vtable

vtable.transform = _cb_transform
src_codec = codecs.lookup(src_encoding)
dest_codec = codecs.lookup(dest_encoding)
return MakeStreamTransformFunc(move(vtable),
Transcoder(src_codec.incrementaldecoder(),
dest_codec.incrementalencoder()))


def transcoding_input_stream(stream, src_encoding, dest_encoding):
"""
Add a transcoding transformation to the stream.
Expand All @@ -1618,7 +1645,7 @@ def transcoding_input_stream(stream, src_encoding, dest_encoding):
stream : NativeFile
The stream to which the transformation should be applied.
src_encoding : str
The codec to use when reading data data.
The codec to use when reading data.
dest_encoding : str
The codec to use for emitted data.
"""
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ cdef shared_ptr[CInputStream] native_transcoding_input_stream(
shared_ptr[CInputStream] stream, src_encoding,
dest_encoding) except *

cdef shared_ptr[function[StreamWrapFunc]] make_streamwrap_func(
src_encoding, dest_encoding) except *

# Default is allow_none=False
cpdef DataType ensure_type(object type, bint allow_none=*)

Expand Down
10 changes: 10 additions & 0 deletions python/pyarrow/src/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,5 +370,15 @@ std::shared_ptr<::arrow::io::InputStream> MakeTransformInputStream(
return std::make_shared<TransformInputStream>(std::move(wrapped), std::move(transform));
}

std::shared_ptr<StreamWrapFunc> MakeStreamTransformFunc(TransformInputStreamVTable vtable,
PyObject* handler) {
TransformInputStream::TransformFunc transform(
TransformFunctionWrapper{std::move(vtable.transform), handler});
StreamWrapFunc func = [transform](std::shared_ptr<::arrow::io::InputStream> wrapped) {
return std::make_shared<TransformInputStream>(wrapped, transform);
};
return std::make_shared<StreamWrapFunc>(func);
}

} // namespace py
} // namespace arrow
5 changes: 5 additions & 0 deletions python/pyarrow/src/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,10 @@ std::shared_ptr<::arrow::io::InputStream> MakeTransformInputStream(
std::shared_ptr<::arrow::io::InputStream> wrapped, TransformInputStreamVTable vtable,
PyObject* arg);

using StreamWrapFunc = std::function<Result<std::shared_ptr<io::InputStream>>(
std::shared_ptr<io::InputStream>)>;
ARROW_PYTHON_EXPORT
std::shared_ptr<StreamWrapFunc> MakeStreamTransformFunc(TransformInputStreamVTable vtable,
PyObject* handler);
} // namespace py
} // namespace arrow
49 changes: 49 additions & 0 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3137,6 +3137,55 @@ def test_csv_fragment_options(tempdir, dataset_reader):
pa.table({'col0': pa.array(['foo', 'spam', 'MYNULL'])}))


def test_encoding(tempdir, dataset_reader):
path = str(tempdir / 'test.csv')

for encoding, input_rows in [
('latin-1', b"a,b\nun,\xe9l\xe9phant"),
('utf16', b'\xff\xfea\x00,\x00b\x00\n\x00u\x00n\x00,'
b'\x00\xe9\x00l\x00\xe9\x00p\x00h\x00a\x00n\x00t\x00'),
]:

with open(path, 'wb') as sink:
sink.write(input_rows)

# Interpret as utf8:
expected_schema = pa.schema([("a", pa.string()), ("b", pa.string())])
expected_table = pa.table({'a': ["un"],
'b': ["éléphant"]}, schema=expected_schema)

read_options = pa.csv.ReadOptions(encoding=encoding)
file_format = ds.CsvFileFormat(read_options=read_options)
dataset_transcoded = ds.dataset(path, format=file_format)
assert dataset_transcoded.schema.equals(expected_schema)
assert dataset_transcoded.to_table().equals(expected_table)


# Test if a dataset with non-utf8 chars in the column names is properly handled
def test_column_names_encoding(tempdir, dataset_reader):
path = str(tempdir / 'test.csv')

with open(path, 'wb') as sink:
sink.write(b"\xe9,b\nun,\xe9l\xe9phant")

# Interpret as utf8:
expected_schema = pa.schema([("é", pa.string()), ("b", pa.string())])
expected_table = pa.table({'é': ["un"],
'b': ["éléphant"]}, schema=expected_schema)

# Reading as string without specifying encoding should produce an error
dataset = ds.dataset(path, format='csv', schema=expected_schema)
with pytest.raises(pyarrow.lib.ArrowInvalid, match="invalid UTF8"):
dataset_reader.to_table(dataset)

# Setting the encoding in the read_options should transcode the data
read_options = pa.csv.ReadOptions(encoding='latin-1')
file_format = ds.CsvFileFormat(read_options=read_options)
dataset_transcoded = ds.dataset(path, format=file_format)
assert dataset_transcoded.schema.equals(expected_schema)
assert dataset_transcoded.to_table().equals(expected_table)


def test_feather_format(tempdir, dataset_reader):
from pyarrow.feather import write_feather

Expand Down

0 comments on commit cbf0ec0

Please sign in to comment.