Skip to content

Commit

Permalink
Merge branch 'branch-24.08' into pylibcudf-strings-slice
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller committed Jun 24, 2024
2 parents 238a583 + ac3c8dd commit 9d70bb2
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 21 deletions.
13 changes: 13 additions & 0 deletions cpp/include/cudf/io/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,19 @@ struct column_name_info {
}

column_name_info() = default;

/**
* @brief Compares two column name info structs for equality
*
* @param rhs column name info struct to compare against
* @return boolean indicating if this and rhs are equal
*/
bool operator==(column_name_info const& rhs) const
{
return ((name == rhs.name) && (is_nullable == rhs.is_nullable) &&
(is_binary == rhs.is_binary) && (type_length == rhs.type_length) &&
(children == rhs.children));
};
};

/**
Expand Down
121 changes: 105 additions & 16 deletions cpp/src/io/json/read_json.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include "io/json/nested_json.hpp"
#include "read_json.hpp"

#include <cudf/concatenate.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/integer_utils.hpp>
#include <cudf/detail/utilities/stream_pool.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/io/detail/json.hpp>
Expand Down Expand Up @@ -76,15 +78,15 @@ device_span<char> ingest_raw_input(device_span<char> buffer,
auto constexpr num_delimiter_chars = 1;

if (compression == compression_type::NONE) {
std::vector<size_type> delimiter_map{};
std::vector<size_t> delimiter_map{};
std::vector<size_t> prefsum_source_sizes(sources.size());
std::vector<std::unique_ptr<datasource::buffer>> h_buffers;
delimiter_map.reserve(sources.size());
size_t bytes_read = 0;
std::transform_inclusive_scan(sources.begin(),
sources.end(),
prefsum_source_sizes.begin(),
std::plus<int>{},
std::plus<size_t>{},
[](std::unique_ptr<datasource> const& s) { return s->size(); });
auto upper =
std::upper_bound(prefsum_source_sizes.begin(), prefsum_source_sizes.end(), range_offset);
Expand Down Expand Up @@ -259,6 +261,33 @@ datasource::owning_buffer<rmm::device_uvector<char>> get_record_range_raw_input(
readbufspan.size() - first_delim_pos - shift_for_nonzero_offset);
}

table_with_metadata read_batch(host_span<std::unique_ptr<datasource>> sources,
json_reader_options const& reader_opts,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
datasource::owning_buffer<rmm::device_uvector<char>> bufview =
get_record_range_raw_input(sources, reader_opts, stream);

// If input JSON buffer has single quotes and option to normalize single quotes is enabled,
// invoke pre-processing FST
if (reader_opts.is_enabled_normalize_single_quotes()) {
normalize_single_quotes(bufview, stream, rmm::mr::get_current_device_resource());
}

// If input JSON buffer has unquoted spaces and tabs and option to normalize whitespaces is
// enabled, invoke pre-processing FST
if (reader_opts.is_enabled_normalize_whitespace()) {
normalize_whitespace(bufview, stream, rmm::mr::get_current_device_resource());
}

auto buffer =
cudf::device_span<char const>(reinterpret_cast<char const*>(bufview.data()), bufview.size());
stream.synchronize();
return device_parse_nested_json(buffer, reader_opts, stream, mr);
}

table_with_metadata read_json(host_span<std::unique_ptr<datasource>> sources,
json_reader_options const& reader_opts,
rmm::cuda_stream_view stream,
Expand All @@ -278,25 +307,85 @@ table_with_metadata read_json(host_span<std::unique_ptr<datasource>> sources,
"Multiple inputs are supported only for JSON Lines format");
}

datasource::owning_buffer<rmm::device_uvector<char>> bufview =
get_record_range_raw_input(sources, reader_opts, stream);
std::for_each(sources.begin(), sources.end(), [](auto const& source) {
CUDF_EXPECTS(source->size() < std::numeric_limits<int>::max(),
"The size of each source file must be less than INT_MAX bytes");
});

// If input JSON buffer has single quotes and option to normalize single quotes is enabled,
// invoke pre-processing FST
if (reader_opts.is_enabled_normalize_single_quotes()) {
normalize_single_quotes(bufview, stream, rmm::mr::get_current_device_resource());
constexpr size_t batch_size_ub = std::numeric_limits<int>::max();
size_t const chunk_offset = reader_opts.get_byte_range_offset();
size_t chunk_size = reader_opts.get_byte_range_size();
chunk_size = !chunk_size ? sources_size(sources, 0, 0) : chunk_size;

// Identify the position of starting source file from which to begin batching based on
// byte range offset. If the offset is larger than the sum of all source
// sizes, then start_source is total number of source files i.e. no file is read
size_t const start_source = [&]() {
size_t sum = 0;
for (size_t src_idx = 0; src_idx < sources.size(); ++src_idx) {
if (sum + sources[src_idx]->size() > chunk_offset) return src_idx;
sum += sources[src_idx]->size();
}
return sources.size();
}();

// Construct batches of source files, with starting position of batches indicated by
// batch_positions. The size of each batch i.e. the sum of sizes of the source files in the batch
// is capped at INT_MAX bytes.
size_t cur_size = 0;
std::vector<size_t> batch_positions;
std::vector<size_t> batch_sizes;
batch_positions.push_back(0);
for (size_t i = start_source; i < sources.size(); i++) {
cur_size += sources[i]->size();
if (cur_size >= batch_size_ub) {
batch_positions.push_back(i);
batch_sizes.push_back(cur_size - sources[i]->size());
cur_size = sources[i]->size();
}
}
batch_positions.push_back(sources.size());
batch_sizes.push_back(cur_size);

// If input JSON buffer has unquoted spaces and tabs and option to normalize whitespaces is
// enabled, invoke pre-processing FST
if (reader_opts.is_enabled_normalize_whitespace()) {
normalize_whitespace(bufview, stream, rmm::mr::get_current_device_resource());
// If there is a single batch, then we can directly return the table without the
// unnecessary concatenate
if (batch_sizes.size() == 1) return read_batch(sources, reader_opts, stream, mr);

std::vector<cudf::io::table_with_metadata> partial_tables;
json_reader_options batched_reader_opts{reader_opts};

// Dispatch individual batches to read_batch and push the resulting table into
// partial_tables array. Note that the reader options need to be updated for each
// batch to adjust byte range offset and byte range size.
for (size_t i = 0; i < batch_sizes.size(); i++) {
batched_reader_opts.set_byte_range_size(std::min(batch_sizes[i], chunk_size));
partial_tables.emplace_back(read_batch(
host_span<std::unique_ptr<datasource>>(sources.begin() + batch_positions[i],
batch_positions[i + 1] - batch_positions[i]),
batched_reader_opts,
stream,
rmm::mr::get_current_device_resource()));
if (chunk_size <= batch_sizes[i]) break;
chunk_size -= batch_sizes[i];
batched_reader_opts.set_byte_range_offset(0);
}

auto buffer =
cudf::device_span<char const>(reinterpret_cast<char const*>(bufview.data()), bufview.size());
stream.synchronize();
return device_parse_nested_json(buffer, reader_opts, stream, mr);
auto expects_schema_equality =
std::all_of(partial_tables.begin() + 1,
partial_tables.end(),
[&gt = partial_tables[0].metadata.schema_info](auto& ptbl) {
return ptbl.metadata.schema_info == gt;
});
CUDF_EXPECTS(expects_schema_equality,
"Mismatch in JSON schema across batches in multi-source multi-batch reading");

auto partial_table_views = std::vector<cudf::table_view>(partial_tables.size());
std::transform(partial_tables.begin(),
partial_tables.end(),
partial_table_views.begin(),
[](auto const& table) { return table.tbl->view(); });
return table_with_metadata{cudf::concatenate(partial_table_views, stream, mr),
{partial_tables[0].metadata.schema_info}};
}

} // namespace cudf::io::json::detail
4 changes: 2 additions & 2 deletions cpp/src/io/text/byte_range_info.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,7 +31,7 @@ std::vector<byte_range_info> create_byte_range_infos_consecutive(int64_t total_b
auto range_size = util::div_rounding_up_safe(total_bytes, range_count);
auto ranges = std::vector<byte_range_info>();

ranges.reserve(range_size);
ranges.reserve(range_count);

for (int64_t i = 0; i < range_count; i++) {
auto offset = i * range_size;
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ ConfigureTest(
LARGE_STRINGS_TEST
large_strings/concatenate_tests.cpp
large_strings/case_tests.cpp
large_strings/json_tests.cpp
large_strings/large_strings_fixture.cpp
large_strings/merge_tests.cpp
large_strings/parquet_tests.cpp
Expand Down
58 changes: 58 additions & 0 deletions cpp/tests/large_strings/json_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "large_strings_fixture.hpp"

#include <cudf/io/json.hpp>
#include <cudf/utilities/span.hpp>

struct JsonLargeReaderTest : public cudf::test::StringsLargeTest {};

TEST_F(JsonLargeReaderTest, MultiBatch)
{
std::string json_string = R"(
{ "a": { "y" : 6}, "b" : [1, 2, 3], "c": 11 }
{ "a": { "y" : 6}, "b" : [4, 5 ], "c": 12 }
{ "a": { "y" : 6}, "b" : [6 ], "c": 13 }
{ "a": { "y" : 6}, "b" : [7 ], "c": 14 })";
constexpr size_t expected_file_size = std::numeric_limits<int>::max() / 2;
std::size_t const log_repetitions =
static_cast<std::size_t>(std::ceil(std::log2(expected_file_size / json_string.size())));

json_string.reserve(json_string.size() * (1UL << log_repetitions));
std::size_t numrows = 4;
for (std::size_t i = 0; i < log_repetitions; i++) {
json_string += json_string;
numrows <<= 1;
}

constexpr int num_sources = 2;
std::vector<cudf::host_span<char>> hostbufs(
num_sources, cudf::host_span<char>(json_string.data(), json_string.size()));

// Initialize parsing options (reading json lines)
cudf::io::json_reader_options json_lines_options =
cudf::io::json_reader_options::builder(
cudf::io::source_info{
cudf::host_span<cudf::host_span<char>>(hostbufs.data(), hostbufs.size())})
.lines(true)
.compression(cudf::io::compression_type::NONE)
.recovery_mode(cudf::io::json_recovery_mode_t::FAIL);

// Read full test data via existing, nested JSON lines reader
cudf::io::table_with_metadata current_reader_table = cudf::io::read_json(json_lines_options);
ASSERT_EQ(current_reader_table.tbl->num_rows(), numrows * num_sources);
}
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def __post_init__(self) -> None:
if self.file_options.n_rows is not None:
raise NotImplementedError("row limit in scan")
if self.typ not in ("csv", "parquet"):
raise NotImplementedError(f"Unhandled scan type: {self.typ}")
raise NotImplementedError(
f"Unhandled scan type: {self.typ}"
) # pragma: no cover; polars raises on the rust side for now

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
Expand Down
34 changes: 33 additions & 1 deletion python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from polars.testing.asserts import assert_frame_equal

from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir

if TYPE_CHECKING:
from collections.abc import Mapping
Expand All @@ -19,7 +20,7 @@

from cudf_polars.typing import OptimizationArgs

__all__: list[str] = ["assert_gpu_result_equal"]
__all__: list[str] = ["assert_gpu_result_equal", "assert_ir_translation_raises"]


def assert_gpu_result_equal(
Expand Down Expand Up @@ -84,3 +85,34 @@ def assert_gpu_result_equal(
atol=atol,
categorical_as_str=categorical_as_str,
)


def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception]) -> None:
"""
Assert that translation of a query raises an exception.
Parameters
----------
q
Query to translate.
exceptions
Exceptions that one expects might be raised.
Returns
-------
None
If translation successfully raised the specified exceptions.
Raises
------
AssertionError
If the specified exceptions were not raised.
"""
try:
_ = translate_ir(q._ldf.visit())
except exceptions:
return
except Exception as e:
raise AssertionError(f"Translation DID NOT RAISE {exceptions}") from e
else:
raise AssertionError(f"Translation DID NOT RAISE {exceptions}")
18 changes: 18 additions & 0 deletions python/cudf_polars/docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,24 @@ def test_whatever():
assert_gpu_result_equal(query)
```

## Test coverage and asserting failure modes

Where translation of a query should fail due to the feature being
unsupported we should test this. To assert that _translation_ raises
an exception (usually `NotImplementedError`), use the utility function
`assert_ir_translation_raises`:

```python
from cudf_polars.testing.asserts import assert_ir_translation_raises


def test_whatever():
unsupported_query = ...
assert_ir_translation_raises(unsupported_query, NotImplementedError)
```

This test will fail if translation does not raise.

# Debugging

If the callback execution fails during the polars `collect` call, we
Expand Down
43 changes: 43 additions & 0 deletions python/cudf_polars/tests/test_dataframescan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize(
"subset",
[
None,
["a", "c"],
["b", "c", "d"],
["b", "d"],
["b", "c"],
["c", "e"],
["d", "e"],
pl.selectors.string(),
pl.selectors.integer(),
],
)
@pytest.mark.parametrize("predicate_pushdown", [False, True])
def test_scan_drop_nulls(subset, predicate_pushdown):
df = pl.LazyFrame(
{
"a": [1, 2, 3, 4],
"b": [None, 4, 5, None],
"c": [6, 7, None, None],
"d": [8, None, 9, 10],
"e": [None, None, "A", None],
}
)
# Drop nulls are pushed into filters
q = df.drop_nulls(subset)

assert_gpu_result_equal(
q, collect_kwargs={"predicate_pushdown": predicate_pushdown}
)
Loading

0 comments on commit 9d70bb2

Please sign in to comment.