Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass column names to write_csv instead of table_metadata pointer #11972

Merged
merged 8 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions cpp/include/cudf/io/csv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1338,8 +1338,8 @@ class csv_writer_options {
std::string _true_value = std::string{"true"};
// string to use for values == 0 in INT8 types (default 'false')
std::string _false_value = std::string{"false"};
// Optional associated metadata
table_metadata const* _metadata = nullptr;
// Names of all columns; if empty, writer will generate column names
std::vector<std::string> _names;
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Constructor from sink and table.
Expand Down Expand Up @@ -1387,11 +1387,11 @@ class csv_writer_options {
[[nodiscard]] table_view const& get_table() const { return _table; }

/**
* @brief Returns optional associated metadata.
* @brief Returns names of the columns.
*
* @return Optional associated metadata
* @return Names of the columns in the output file
*/
[[nodiscard]] table_metadata const* get_metadata() const { return _metadata; }
[[nodiscard]] std::vector<std::string> const& get_names() const { return _names; }
vuule marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Returns string to used for null entries.
Expand Down Expand Up @@ -1444,11 +1444,11 @@ class csv_writer_options {

// Setter
/**
* @brief Sets optional associated metadata.
* @brief Sets optional associated column names.
*
@param metadata Associated metadata
@param names Associated column names
*/
void set_metadata(table_metadata* metadata) { _metadata = metadata; }
void set_names(std::vector<std::string> names) { _names = std::move(names); }

/**
* @brief Sets string to used for null entries.
Expand Down Expand Up @@ -1526,14 +1526,14 @@ class csv_writer_options_builder {
}

/**
* @brief Sets optional associated metadata.
* @brief Sets optional column names.
*
* @param metadata Associated metadata
* @param names Column names
* @return this for chaining
*/
csv_writer_options_builder& metadata(table_metadata* metadata)
csv_writer_options_builder& names(std::vector<std::string> names)
{
options._metadata = metadata;
options._names = names;
return *this;
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/io/detail/csv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ table_with_metadata read_csv(std::unique_ptr<cudf::io::datasource>&& source,
*
* @param sink Output sink
* @param table The set of columns
* @param metadata The metadata associated with the table
* @param column_names Column names for the output CSV
* @param options Settings for controlling behavior
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource to use for device memory allocation
*/
void write_csv(data_sink* sink,
table_view const& table,
const table_metadata* metadata,
host_span<std::string const> column_names,
csv_writer_options const& options,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
Expand Down
15 changes: 7 additions & 8 deletions cpp/src/io/csv/writer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,21 +279,21 @@ struct column_to_strings_fn {
//
void write_chunked_begin(data_sink* out_sink,
table_view const& table,
table_metadata const* metadata,
host_span<std::string const> user_column_names,
csv_writer_options const& options,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
if (options.is_enabled_include_header()) {
// need to generate column names if metadata is not provided
// need to generate column names if names are not provided
std::vector<std::string> generated_col_names;
if (metadata == nullptr) {
if (user_column_names.empty()) {
generated_col_names.resize(table.num_columns());
thrust::tabulate(generated_col_names.begin(), generated_col_names.end(), [](auto idx) {
return std::to_string(idx);
});
}
auto const& column_names = (metadata == nullptr) ? generated_col_names : metadata->column_names;
auto const& column_names = user_column_names.empty() ? generated_col_names : user_column_names;
CUDF_EXPECTS(column_names.size() == static_cast<size_t>(table.num_columns()),
"Mismatch between number of column headers and table columns.");

Expand Down Expand Up @@ -346,7 +346,6 @@ void write_chunked_begin(data_sink* out_sink,

void write_chunked(data_sink* out_sink,
strings_column_view const& str_column_view,
table_metadata const* metadata,
csv_writer_options const& options,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
Expand Down Expand Up @@ -399,15 +398,15 @@ void write_chunked(data_sink* out_sink,

void write_csv(data_sink* out_sink,
table_view const& table,
table_metadata const* metadata,
host_span<std::string const> user_column_names,
csv_writer_options const& options,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// write header: column names separated by delimiter:
// (even for tables with no rows)
//
write_chunked_begin(out_sink, table, metadata, options, stream, mr);
write_chunked_begin(out_sink, table, user_column_names, options, stream, mr);

if (table.num_rows() > 0) {
// no need to check same-size columns constraint; auto-enforced by table_view
Expand Down Expand Up @@ -476,7 +475,7 @@ void write_csv(data_sink* out_sink,
return cudf::strings::detail::replace_nulls(str_table_view.column(0), narep, stream);
}();

write_chunked(out_sink, str_concat_col->view(), metadata, options, stream, mr);
write_chunked(out_sink, str_concat_col->view(), options, stream, mr);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/io/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void write_csv(csv_writer_options const& options, rmm::mr::device_memory_resourc
return csv::write_csv( //
sinks[0].get(),
options.get_table(),
options.get_metadata(),
options.get_names(),
options,
cudf::get_default_stream(),
mr);
Expand Down
64 changes: 24 additions & 40 deletions cpp/tests/io/csv_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,30 +277,12 @@ void expect_column_data_equal(std::vector<T> const& lhs, cudf::column_view const

void write_csv_helper(std::string const& filename,
cudf::table_view const& table,
bool include_header,
std::vector<std::string> const& names = {})
{
// csv_writer_options only keeps a pointer to metadata (non-owning)
cudf::io::table_metadata metadata{};

if (not names.empty()) {
metadata.column_names = names;
} else {
// generate some dummy column names
int i = 0;
auto const num_columns = table.num_columns();
metadata.column_names.reserve(num_columns);
std::generate_n(std::back_inserter(metadata.column_names), num_columns, [&i]() {
return std::string("col") + std::to_string(i++);
});
}

cudf::io::csv_writer_options writer_options =
cudf::io::csv_writer_options::builder(cudf::io::sink_info(filename), table)
.include_header(include_header)
.rows_per_chunk(
1) // Note: this gets adjusted to multiple of 8 (per legacy code logic and requirements)
.metadata(&metadata);
.include_header(not names.empty())
.names(names);

cudf::io::write_csv(writer_options);
}
Expand Down Expand Up @@ -1509,7 +1491,7 @@ TYPED_TEST(CsvReaderNumericTypeTest, SingleColumnWithWriter)

auto filepath = temp_env->get_temp_filepath("SingleColumnWithWriter.csv");

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath}).header(-1);
Expand Down Expand Up @@ -1577,7 +1559,7 @@ TEST_F(CsvReaderTest, MultiColumnWithWriter)

auto filepath = temp_env->get_temp_dir() + "MultiColumnWithWriter.csv";

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand Down Expand Up @@ -1625,7 +1607,7 @@ TEST_F(CsvReaderTest, DatesWithWriter)
cudf::table_view input_table(std::vector<cudf::column_view>{input_column});

// TODO need to add a dayfirst flag?
write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1650,7 +1632,7 @@ TEST_F(CsvReaderTest, DatesStringWithWriter)

cudf::table_view input_table(std::vector<cudf::column_view>{input_column});

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1673,7 +1655,7 @@ TEST_F(CsvReaderTest, DatesStringWithWriter)

cudf::table_view input_table(std::vector<cudf::column_view>{input_column});

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1696,7 +1678,7 @@ TEST_F(CsvReaderTest, DatesStringWithWriter)

cudf::table_view input_table(std::vector<cudf::column_view>{input_column});

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1720,7 +1702,7 @@ TEST_F(CsvReaderTest, DatesStringWithWriter)

cudf::table_view input_table(std::vector<cudf::column_view>{input_column});

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1743,7 +1725,7 @@ TEST_F(CsvReaderTest, DatesStringWithWriter)

cudf::table_view input_table(std::vector<cudf::column_view>{input_column});

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1766,7 +1748,7 @@ TEST_F(CsvReaderTest, FloatingPointWithWriter)
cudf::table_view input_table(std::vector<cudf::column_view>{input_column});

// TODO add lineterminator=";"
write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1792,18 +1774,18 @@ TEST_F(CsvReaderTest, StringsWithWriter)
cudf::table_view input_table(std::vector<cudf::column_view>{int_column, string_column});

// TODO add quoting style flag?
write_csv_helper(filepath, input_table, true, names);
write_csv_helper(filepath, input_table, names);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
.names(names)
.dtypes(std::vector<data_type>{dtype<int32_t>(), dtype<cudf::string_view>()})
.quoting(cudf::io::quote_style::NONE);
auto result = cudf::io::read_csv(in_opts);

const auto result_table = result.tbl->view();
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(input_table.column(0), result_table.column(0));
check_string_column(input_table.column(1), result_table.column(1));
ASSERT_EQ(names, result.metadata.column_names);
}

TEST_F(CsvReaderTest, StringsWithWriterSimple)
Expand All @@ -1817,18 +1799,18 @@ TEST_F(CsvReaderTest, StringsWithWriterSimple)
cudf::table_view input_table(std::vector<cudf::column_view>{int_column, string_column});

// TODO add quoting style flag?
write_csv_helper(filepath, input_table, true, names);
write_csv_helper(filepath, input_table, names);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
.names(names)
.dtypes(std::vector<data_type>{dtype<int32_t>(), dtype<cudf::string_view>()})
.quoting(cudf::io::quote_style::NONE);
auto result = cudf::io::read_csv(in_opts);

const auto result_table = result.tbl->view();
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(input_table.column(0), result_table.column(0));
check_string_column(input_table.column(1), result_table.column(1));
ASSERT_EQ(names, result.metadata.column_names);
}

TEST_F(CsvReaderTest, StringsEmbeddedDelimiter)
Expand All @@ -1841,15 +1823,15 @@ TEST_F(CsvReaderTest, StringsEmbeddedDelimiter)
auto string_column = column_wrapper<cudf::string_view>{"abc def ghi", "jkl,mno,pq", "stu vwx y"};
cudf::table_view input_table(std::vector<cudf::column_view>{int_column, string_column});

write_csv_helper(filepath, input_table, true, names);
write_csv_helper(filepath, input_table, names);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
.names(names)
.dtypes(std::vector<data_type>{dtype<int32_t>(), dtype<cudf::string_view>()});
auto result = cudf::io::read_csv(in_opts);

CUDF_TEST_EXPECT_TABLES_EQUIVALENT(input_table, result.tbl->view());
ASSERT_EQ(names, result.metadata.column_names);
}

TEST_F(CsvReaderTest, HeaderEmbeddedDelimiter)
Expand All @@ -1864,7 +1846,7 @@ TEST_F(CsvReaderTest, HeaderEmbeddedDelimiter)
cudf::table_view input_table(
std::vector<cudf::column_view>{int_column, string_column, int_column, int_column, int_column});

write_csv_helper(filepath, input_table, true, names);
write_csv_helper(filepath, input_table, names);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1877,14 +1859,15 @@ TEST_F(CsvReaderTest, HeaderEmbeddedDelimiter)
auto result = cudf::io::read_csv(in_opts);

CUDF_TEST_EXPECT_TABLES_EQUIVALENT(input_table, result.tbl->view());
ASSERT_EQ(names, result.metadata.column_names);
}

TEST_F(CsvReaderTest, EmptyFileWithWriter)
{
auto filepath = temp_env->get_temp_dir() + "EmptyFileWithWriter.csv";

cudf::table_view empty_table;
write_csv_helper(filepath, empty_table, false);
write_csv_helper(filepath, empty_table);
cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath});
auto result = cudf::io::read_csv(in_opts);
Expand Down Expand Up @@ -1968,7 +1951,7 @@ TEST_F(CsvReaderTest, DurationsWithWriter)
durations_D, durations_s, durations_ms, durations_us, durations_ns});
std::vector<std::string> names{"D", "s", "ms", "us", "ns"};

write_csv_helper(filepath, input_table, true, names);
write_csv_helper(filepath, input_table, names);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath})
Expand All @@ -1982,6 +1965,7 @@ TEST_F(CsvReaderTest, DurationsWithWriter)

const auto result_table = result.tbl->view();
CUDF_TEST_EXPECT_TABLES_EQUIVALENT(input_table, result_table);
ASSERT_EQ(names, result.metadata.column_names);
}

TEST_F(CsvReaderTest, ParseInRangeIntegers)
Expand Down Expand Up @@ -2044,7 +2028,7 @@ TEST_F(CsvReaderTest, ParseInRangeIntegers)

auto filepath = temp_env->get_temp_filepath("ParseInRangeIntegers.csv");

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath}).header(-1);
Expand Down Expand Up @@ -2123,7 +2107,7 @@ TEST_F(CsvReaderTest, ParseOutOfRangeIntegers)

auto filepath = temp_env->get_temp_filepath("ParseOutOfRangeIntegers.csv");

write_csv_helper(filepath, input_table, false);
write_csv_helper(filepath, input_table);

cudf::io::csv_reader_options in_opts =
cudf::io::csv_reader_options::builder(cudf::io::source_info{filepath}).header(-1);
Expand Down
Loading