diff --git a/cpp/benchmarks/ndsh/utilities.cpp b/cpp/benchmarks/ndsh/utilities.cpp index b1b7c720c7c..c19c1361ea2 100644 --- a/cpp/benchmarks/ndsh/utilities.cpp +++ b/cpp/benchmarks/ndsh/utilities.cpp @@ -402,12 +402,12 @@ void generate_parquet_data_sources(double scale_factor, // memory. } - std::vector const requested_table_names = [&table_names]() { + std::unordered_set const requested_table_names = [&table_names]() { if (table_names.empty()) { - return std::vector{ + return std::unordered_set{ "orders", "lineitem", "part", "partsupp", "supplier", "customer", "nation", "region"}; } - return table_names; + return std::unordered_set(table_names.begin(), table_names.end()); }(); std::for_each( requested_table_names.begin(), requested_table_names.end(), [&](auto const& table_name) { @@ -418,40 +418,50 @@ void generate_parquet_data_sources(double scale_factor, if (sources.count("orders") or sources.count("lineitem") or sources.count("part")) { auto [orders, lineitem, part] = cudf::datagen::generate_orders_lineitem_part( scale_factor, cudf::get_default_stream(), cudf::get_current_device_resource_ref()); - if (sources.count("orders")) tables["orders"] = std::move(orders); - if (sources.count("lineitem")) tables["lineitem"] = std::move(lineitem); - if (sources.count("part")) tables["part"] = std::move(part); + if (sources.count("orders")) { + write_to_parquet_device_buffer(orders, SCHEMAS.at("orders"), sources.at("orders")); + orders = {}; + } + if (sources.count("part")) { + write_to_parquet_device_buffer(part, SCHEMAS.at("part"), sources.at("part")); + part = {}; + } + if (sources.count("lineitem")) { + write_to_parquet_device_buffer(lineitem, SCHEMAS.at("lineitem"), sources.at("lineitem")); + lineitem = {}; + } } if (sources.count("partsupp")) { - tables["partsupp"] = cudf::datagen::generate_partsupp( + auto partsupp = cudf::datagen::generate_partsupp( scale_factor, cudf::get_default_stream(), cudf::get_current_device_resource_ref()); + write_to_parquet_device_buffer(partsupp, SCHEMAS.at("partsupp"), sources.at("partsupp")); } if (sources.count("supplier")) { - tables["supplier"] = cudf::datagen::generate_supplier( + auto supplier = cudf::datagen::generate_supplier( scale_factor, cudf::get_default_stream(), cudf::get_current_device_resource_ref()); + write_to_parquet_device_buffer(supplier, SCHEMAS.at("supplier"), sources.at("supplier")); } if (sources.count("customer")) { - tables["customer"] = cudf::datagen::generate_customer( + auto customer = cudf::datagen::generate_customer( scale_factor, cudf::get_default_stream(), cudf::get_current_device_resource_ref()); + write_to_parquet_device_buffer(customer, SCHEMAS.at("customer"), sources.at("customer")); } if (sources.count("nation")) { - tables["nation"] = cudf::datagen::generate_nation(cudf::get_default_stream(), - cudf::get_current_device_resource_ref()); + auto nation = cudf::datagen::generate_nation(cudf::get_default_stream(), + cudf::get_current_device_resource_ref()); + write_to_parquet_device_buffer(nation, SCHEMAS.at("nation"), sources.at("nation")); } if (sources.count("region")) { - tables["region"] = cudf::datagen::generate_region(cudf::get_default_stream(), - cudf::get_current_device_resource_ref()); + auto region = cudf::datagen::generate_region(cudf::get_default_stream(), + cudf::get_current_device_resource_ref()); + write_to_parquet_device_buffer(region, SCHEMAS.at("region"), sources.at("region")); } - for (auto const& table_name : requested_table_names) { - write_to_parquet_device_buffer( - tables.at(table_name), SCHEMAS.at(table_name), sources.at(table_name)); - } // Restore the original memory resource if (!is_managed) { cudf::set_current_device_resource(old_mr); } }