From e68ca7f9aed876a1afcad81a417afb87c94ee951 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Fri, 17 May 2019 11:40:55 -0500 Subject: [PATCH] ARROW-3144: [C++/Python] Move "dictionary" member from DictionaryType to ArrayData to allow for variable dictionaries This patch moves the dictionary member out of DictionaryType to a new member on the internal ArrayData structure. As a result, serializing and deserializing schemas requires only a single IPC message, and schemas have no knowledge of what the dictionary values are. The objective of this change is to correct a long-standing Arrow C++ design problem with dictionary-encoded arrays where the dictionary values must be known at schema construction time. This has plagued us all over the codebase: * In reading Parquet files, reading directly to DictionaryArray is not simple because each row group may have a different dictionary * In IPC streams, delta dictionaries (not yet implemented) would invalidate the pre-existing schema, causing subsequent RecordBatch objects to be incompatible * In Arrow Flight, schema negotiation requires the dictionaries to be sent, having possibly unbounded size. * Not possible to have different dictionaries in a ChunkedArray * In CSV files, converting columns to dictionary in parallel would require an expensive type unification The summary of what can be learned from this is: do not put data in type objects, only metadata. Dictionaries are data, not metadata. There are a number of unavoidable API changes (straightforward for library users to fix) but otherwise no functional difference in the library. As you can see the change is quite complex as significant parts of IPC read/write, JSON integration testing, and Flight needed to be reworked to alter the control flow around schema resolution and handling the first record batch. Key APIs changed * `DictionaryType` constructor requires a `DataType` for the dictionary value type instead of the dictionary itself. The `dictionary` factory method is correspondingly changed. The `dictionary` accessor method on `DictionaryType` is replaced with `value_type`. * `DictionaryArray` constructor and `DictionaryArray::FromArrays` must be passed the dictionary values as an additional argument. * `DictionaryMemo` is exposed in the public API as it is now required for granular interactions with IPC messages with such functions as `ipc::ReadSchema` and `ipc::ReadRecordBatch` * A `DictionaryMemo*` argument is added to several low-level public functions in `ipc/writer.h` and `ipc/reader.h` Some other incidental changes: * Because DictionaryType objects could be reused previous in Schemas, such dictionaries would be "deduplicated" in IPC messages in passing. This is no longer possible by the same trick, so dictionary reuse will have to be handled in a different way (I opened ARROW-5340 to investigate) * As a result of this, an integration test that featured dictionary reuse has been changed to not reuse dictionaries. Technically this is a regression, but I didn't want to block the patch over it * R is added to allow_failures in Travis CI for now Author: Wes McKinney Author: Kouhei Sutou Author: Antoine Pitrou Closes #4316 from wesm/ARROW-3144 and squashes the following commits: 9f1ccfbf4 Follow DictionaryArray changes 89e274da5 Do not reuse dictionaries in integration tests for now until more follow on work around this can be done f62819f5b Support many fields referencing the same dictionary, fix integration tests 37e82b4da Fix CUDA and Duration issues 037075083 Add R to allow_failures for now bd04774e2 Code review comments b1cc52e62 Fix rest of Python unit tests, fix some incorrect code comments f1178b2a6 Fix all but 3 Python unit tests ab7fc1741 Fix up Cython compilation, haven't fixed unit tests yet though 6ce51ef79 Get everything compiling again e23c578fd Fix Parquet tests c73b2162f arrow-tests all passing again, huzzah! 04d40e8e6 Flat dictionary IPC test passing now 481f316dc Get JSON integration tests passing again 77a43dc9f Fix pretty_print-test f4ada6685 array-tests compilers again 8276dce6c libarrow compiles again 8ea0e260a Refactor IPC read path for new paradigm a1afe879a More refactoring to have correct logic in IPC paths, not yet done aed04304e More refactoring, regularize some type names 6bd72f946 Start porting changes 24f99f16b Initial boilerplate --- .travis.yml | 2 + c_glib/arrow-glib/composite-array.cpp | 30 +- c_glib/arrow-glib/composite-array.h | 5 +- c_glib/arrow-glib/composite-data-type.cpp | 22 +- c_glib/arrow-glib/composite-data-type.h | 7 +- c_glib/test/test-dictionary-array.rb | 10 +- c_glib/test/test-dictionary-data-type.rb | 8 +- cpp/src/arrow/array-dict-test.cc | 281 +++++++------ cpp/src/arrow/array.cc | 74 ++-- cpp/src/arrow/array.h | 34 +- cpp/src/arrow/array/builder_dict.cc | 29 +- cpp/src/arrow/array/builder_dict.h | 19 +- cpp/src/arrow/builder-benchmark.cc | 19 + cpp/src/arrow/builder.cc | 37 +- cpp/src/arrow/builder.h | 16 + cpp/src/arrow/compare.cc | 2 +- cpp/src/arrow/compute/kernels/cast-test.cc | 6 +- cpp/src/arrow/compute/kernels/cast.cc | 297 ++++++------- cpp/src/arrow/compute/kernels/hash-test.cc | 8 +- cpp/src/arrow/compute/kernels/hash.cc | 6 +- cpp/src/arrow/compute/kernels/take-test.cc | 7 +- cpp/src/arrow/compute/kernels/take.cc | 9 +- cpp/src/arrow/flight/flight-benchmark.cc | 3 +- cpp/src/arrow/flight/flight-test.cc | 9 +- cpp/src/arrow/flight/internal.cc | 4 +- cpp/src/arrow/flight/perf-server.cc | 6 + cpp/src/arrow/flight/server.cc | 136 ++++-- cpp/src/arrow/flight/server.h | 18 +- .../arrow/flight/test-integration-client.cc | 4 +- cpp/src/arrow/flight/types.cc | 8 +- cpp/src/arrow/flight/types.h | 17 +- cpp/src/arrow/gpu/cuda-test.cc | 4 +- cpp/src/arrow/gpu/cuda_arrow_ipc.cc | 4 +- cpp/src/arrow/ipc/dictionary.cc | 140 ++++++- cpp/src/arrow/ipc/dictionary.h | 47 ++- cpp/src/arrow/ipc/feather.cc | 15 +- cpp/src/arrow/ipc/json-integration.cc | 38 +- cpp/src/arrow/ipc/json-internal.cc | 396 ++++++++---------- cpp/src/arrow/ipc/json-internal.h | 45 +- cpp/src/arrow/ipc/json-test.cc | 13 +- cpp/src/arrow/ipc/metadata-internal.cc | 220 ++++------ cpp/src/arrow/ipc/metadata-internal.h | 13 +- cpp/src/arrow/ipc/read-write-benchmark.cc | 4 +- cpp/src/arrow/ipc/read-write-test.cc | 100 ++++- cpp/src/arrow/ipc/reader.cc | 280 +++++++------ cpp/src/arrow/ipc/reader.h | 38 +- cpp/src/arrow/ipc/test-common.cc | 48 ++- cpp/src/arrow/ipc/writer.cc | 120 +++--- cpp/src/arrow/ipc/writer.h | 29 +- cpp/src/arrow/json/converter-test.cc | 6 +- cpp/src/arrow/json/converter.cc | 79 ++-- cpp/src/arrow/json/parser.cc | 12 +- cpp/src/arrow/json/test-common.h | 2 +- cpp/src/arrow/pretty_print-test.cc | 27 +- cpp/src/arrow/pretty_print.cc | 23 +- cpp/src/arrow/python/arrow_to_pandas.cc | 6 +- cpp/src/arrow/python/flight.cc | 19 +- cpp/src/arrow/python/flight.h | 9 +- cpp/src/arrow/tensor.cc | 4 +- cpp/src/arrow/type-test.cc | 134 ++++-- cpp/src/arrow/type.cc | 34 +- cpp/src/arrow/type.h | 79 ++-- cpp/src/arrow/type_traits.h | 16 +- cpp/src/arrow/util/concatenate-test.cc | 5 +- cpp/src/arrow/util/concatenate.cc | 21 +- .../parquet/arrow/arrow-reader-writer-test.cc | 21 +- cpp/src/parquet/arrow/arrow-schema-test.cc | 38 +- cpp/src/parquet/arrow/schema.cc | 5 +- cpp/src/parquet/arrow/writer.cc | 6 +- integration/integration_test.py | 62 ++- python/pyarrow/__init__.py | 1 + python/pyarrow/_flight.pyx | 4 +- python/pyarrow/array.pxi | 7 +- python/pyarrow/includes/libarrow.pxd | 21 +- python/pyarrow/includes/libarrow_flight.pxd | 2 +- python/pyarrow/ipc.pxi | 30 +- python/pyarrow/lib.pxd | 5 + python/pyarrow/tests/test_array.py | 2 +- python/pyarrow/tests/test_compute.py | 3 +- python/pyarrow/tests/test_ipc.py | 2 +- python/pyarrow/tests/test_schema.py | 9 +- python/pyarrow/tests/test_types.py | 21 +- python/pyarrow/types.pxi | 42 +- .../lib/arrow/dictionary-data-type.rb | 35 +- .../test/test-dictionary-data-type.rb | 6 +- 85 files changed, 1985 insertions(+), 1500 deletions(-) diff --git a/.travis.yml b/.travis.yml index 92ba3b7a993da..82a88791b9652 100644 --- a/.travis.yml +++ b/.travis.yml @@ -44,6 +44,8 @@ before_install: matrix: + allow_failures: + - language: r fast_finish: true include: - name: "Lint C++, Python, R, Docker" diff --git a/c_glib/arrow-glib/composite-array.cpp b/c_glib/arrow-glib/composite-array.cpp index 4fba813a2aa0c..48dd3d3ac2a35 100644 --- a/c_glib/arrow-glib/composite-array.cpp +++ b/c_glib/arrow-glib/composite-array.cpp @@ -534,25 +534,37 @@ garrow_dictionary_array_class_init(GArrowDictionaryArrayClass *klass) /** * garrow_dictionary_array_new: - * @data_type: The data type of dictionary. + * @data_type: The data type of the dictionary array. * @indices: The indices of values in dictionary. + * @dictionary: The dictionary of the dictionary array. + * @error: (nullable): Return location for a #GError or %NULL. * - * Returns: A newly created #GArrowDictionaryArray. + * Returns: (nullable): A newly created #GArrowDictionaryArray + * or %NULL on error. * * Since: 0.8.0 */ GArrowDictionaryArray * garrow_dictionary_array_new(GArrowDataType *data_type, - GArrowArray *indices) + GArrowArray *indices, + GArrowArray *dictionary, + GError **error) { const auto arrow_data_type = garrow_data_type_get_raw(data_type); const auto arrow_indices = garrow_array_get_raw(indices); - auto arrow_dictionary_array = - std::make_shared(arrow_data_type, - arrow_indices); - auto arrow_array = - std::static_pointer_cast(arrow_dictionary_array); - return GARROW_DICTIONARY_ARRAY(garrow_array_new_raw(&arrow_array)); + const auto arrow_dictionary = garrow_array_get_raw(dictionary); + std::shared_ptr arrow_dictionary_array; + auto status = arrow::DictionaryArray::FromArrays(arrow_data_type, + arrow_indices, + arrow_dictionary, + &arrow_dictionary_array); + if (garrow_error_check(error, status, "[dictionary-array][new]")) { + auto arrow_array = + std::static_pointer_cast(arrow_dictionary_array); + return GARROW_DICTIONARY_ARRAY(garrow_array_new_raw(&arrow_array)); + } else { + return NULL; + } } /** diff --git a/c_glib/arrow-glib/composite-array.h b/c_glib/arrow-glib/composite-array.h index c54c2f8588479..40603ce3b228b 100644 --- a/c_glib/arrow-glib/composite-array.h +++ b/c_glib/arrow-glib/composite-array.h @@ -151,7 +151,10 @@ struct _GArrowDictionaryArrayClass }; GArrowDictionaryArray * -garrow_dictionary_array_new(GArrowDataType *data_type, GArrowArray *indices); +garrow_dictionary_array_new(GArrowDataType *data_type, + GArrowArray *indices, + GArrowArray *dictionary, + GError **error); GArrowArray * garrow_dictionary_array_get_indices(GArrowDictionaryArray *array); GArrowArray * diff --git a/c_glib/arrow-glib/composite-data-type.cpp b/c_glib/arrow-glib/composite-data-type.cpp index 675900a5becc2..b842f63649b7c 100644 --- a/c_glib/arrow-glib/composite-data-type.cpp +++ b/c_glib/arrow-glib/composite-data-type.cpp @@ -513,7 +513,7 @@ garrow_dictionary_data_type_class_init(GArrowDictionaryDataTypeClass *klass) /** * garrow_dictionary_data_type_new: * @index_data_type: The data type of index. - * @dictionary: The dictionary. + * @value_data_type: The data type of dictionary values. * @ordered: Whether dictionary contents are ordered or not. * * Returns: The newly created dictionary data type. @@ -522,13 +522,13 @@ garrow_dictionary_data_type_class_init(GArrowDictionaryDataTypeClass *klass) */ GArrowDictionaryDataType * garrow_dictionary_data_type_new(GArrowDataType *index_data_type, - GArrowArray *dictionary, + GArrowDataType *value_data_type, gboolean ordered) { auto arrow_index_data_type = garrow_data_type_get_raw(index_data_type); - auto arrow_dictionary = garrow_array_get_raw(dictionary); + auto arrow_value_data_type = garrow_data_type_get_raw(value_data_type); auto arrow_data_type = arrow::dictionary(arrow_index_data_type, - arrow_dictionary, + arrow_value_data_type, ordered); return GARROW_DICTIONARY_DATA_TYPE(garrow_data_type_new_raw(&arrow_data_type)); } @@ -552,21 +552,21 @@ garrow_dictionary_data_type_get_index_data_type(GArrowDictionaryDataType *dictio } /** - * garrow_dictionary_data_type_get_dictionary: + * garrow_dictionary_data_type_get_value_data_type: * @dictionary_data_type: The #GArrowDictionaryDataType. * - * Returns: (transfer full): The dictionary as #GArrowArray. + * Returns: (transfer full): The #GArrowDataType of dictionary values. * - * Since: 0.8.0 + * Since: 0.14.0 */ -GArrowArray * -garrow_dictionary_data_type_get_dictionary(GArrowDictionaryDataType *dictionary_data_type) +GArrowDataType * +garrow_dictionary_data_type_get_value_data_type(GArrowDictionaryDataType *dictionary_data_type) { auto arrow_data_type = garrow_data_type_get_raw(GARROW_DATA_TYPE(dictionary_data_type)); auto arrow_dictionary_data_type = std::static_pointer_cast(arrow_data_type); - auto arrow_dictionary = arrow_dictionary_data_type->dictionary(); - return garrow_array_new_raw(&arrow_dictionary); + auto arrow_value_data_type = arrow_dictionary_data_type->value_type(); + return garrow_data_type_new_raw(&arrow_value_data_type); } /** diff --git a/c_glib/arrow-glib/composite-data-type.h b/c_glib/arrow-glib/composite-data-type.h index f0eed87187d90..ed73b9a7ee332 100644 --- a/c_glib/arrow-glib/composite-data-type.h +++ b/c_glib/arrow-glib/composite-data-type.h @@ -145,12 +145,13 @@ struct _GArrowDictionaryDataTypeClass GArrowDictionaryDataType * garrow_dictionary_data_type_new(GArrowDataType *index_data_type, - GArrowArray *dictionary, + GArrowDataType *value_data_type, gboolean ordered); GArrowDataType * garrow_dictionary_data_type_get_index_data_type(GArrowDictionaryDataType *dictionary_data_type); -GArrowArray * -garrow_dictionary_data_type_get_dictionary(GArrowDictionaryDataType *dictionary_data_type); +GARROW_AVAILABLE_IN_0_14 +GArrowDataType * +garrow_dictionary_data_type_get_value_data_type(GArrowDictionaryDataType *dictionary_data_type); gboolean garrow_dictionary_data_type_is_ordered(GArrowDictionaryDataType *dictionary_data_type); diff --git a/c_glib/test/test-dictionary-array.rb b/c_glib/test/test-dictionary-array.rb index 7d07b45872256..0f5157869981e 100644 --- a/c_glib/test/test-dictionary-array.rb +++ b/c_glib/test/test-dictionary-array.rb @@ -23,14 +23,16 @@ def setup @dictionary = build_string_array(["C", "C++", "Ruby"]) @ordered = false @data_type = Arrow::DictionaryDataType.new(@index_data_type, - @dictionary, + @dictionary.value_data_type, @ordered) end sub_test_case(".new") do def test_new indices = build_int32_array([0, 2, 2, 1, 0]) - dictionary_array = Arrow::DictionaryArray.new(@data_type, indices) + dictionary_array = Arrow::DictionaryArray.new(@data_type, + indices, + @dictionary) assert_equal(<<-STRING.chomp, dictionary_array.to_s) -- dictionary: @@ -55,7 +57,9 @@ def test_new def setup super @indices = build_int32_array([0, 2, 2, 1, 0]) - @dictionary_array = Arrow::DictionaryArray.new(@data_type, @indices) + @dictionary_array = Arrow::DictionaryArray.new(@data_type, + @indices, + @dictionary) end def test_indices diff --git a/c_glib/test/test-dictionary-data-type.rb b/c_glib/test/test-dictionary-data-type.rb index 5530a0415cb28..2069c1f395b96 100644 --- a/c_glib/test/test-dictionary-data-type.rb +++ b/c_glib/test/test-dictionary-data-type.rb @@ -20,10 +20,10 @@ class TestDictionaryDataType < Test::Unit::TestCase def setup @index_data_type = Arrow::Int32DataType.new - @dictionary = build_string_array(["C", "C++", "Ruby"]) + @value_data_type = Arrow::StringDataType.new @ordered = true @data_type = Arrow::DictionaryDataType.new(@index_data_type, - @dictionary, + @value_data_type, @ordered) end @@ -44,8 +44,8 @@ def test_index_data_type assert_equal(@index_data_type, @data_type.index_data_type) end - def test_dictionary - assert_equal(@dictionary, @data_type.dictionary) + def test_value_data_type + assert_equal(@value_data_type, @data_type.value_data_type) end def test_ordered? diff --git a/cpp/src/arrow/array-dict-test.cc b/cpp/src/arrow/array-dict-test.cc index daa7b34376264..4d57ee3fef642 100644 --- a/cpp/src/arrow/array-dict-test.cc +++ b/cpp/src/arrow/array-dict-test.cc @@ -51,36 +51,40 @@ typedef ::testing::Types builder(default_memory_pool()); - ASSERT_OK(builder.Append(static_cast(1))); - ASSERT_OK(builder.Append(static_cast(2))); - ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.Append(static_cast(2))); + ASSERT_OK(builder.Append(static_cast(1))); ASSERT_OK(builder.AppendNull()); ASSERT_EQ(builder.length(), 4); ASSERT_EQ(builder.null_count(), 1); // Build expected data - auto dict_array = ArrayFromJSON(std::make_shared(), "[1, 2]"); - auto dict_type = dictionary(int8(), dict_array); + auto value_type = std::make_shared(); + auto dict_type = dictionary(int8(), value_type); std::shared_ptr result; ASSERT_OK(builder.Finish(&result)); - auto int_array = ArrayFromJSON(int8(), "[0, 1, 0, null]"); - DictionaryArray expected(dict_type, int_array); - + DictionaryArray expected(dict_type, ArrayFromJSON(int8(), "[0, 1, 0, null]"), + ArrayFromJSON(value_type, "[1, 2]")); ASSERT_TRUE(expected.Equals(result)); } TYPED_TEST(TestDictionaryBuilder, ArrayInit) { - auto dict_array = ArrayFromJSON(std::make_shared(), "[1, 2]"); - auto dict_type = dictionary(int8(), dict_array); + using c_type = typename TypeParam::c_type; + + auto value_type = std::make_shared(); + auto dict_array = ArrayFromJSON(value_type, "[1, 2]"); + auto dict_type = dictionary(int8(), value_type); DictionaryBuilder builder(dict_array, default_memory_pool()); - ASSERT_OK(builder.Append(static_cast(1))); - ASSERT_OK(builder.Append(static_cast(2))); - ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.Append(static_cast(2))); + ASSERT_OK(builder.Append(static_cast(1))); ASSERT_OK(builder.AppendNull()); ASSERT_EQ(builder.length(), 4); @@ -91,22 +95,25 @@ TYPED_TEST(TestDictionaryBuilder, ArrayInit) { std::shared_ptr result; ASSERT_OK(builder.Finish(&result)); - auto int_array = ArrayFromJSON(int8(), "[0, 1, 0, null]"); - DictionaryArray expected(dict_type, int_array); + auto indices = ArrayFromJSON(int8(), "[0, 1, 0, null]"); + DictionaryArray expected(dict_type, indices, dict_array); AssertArraysEqual(expected, *result); } TYPED_TEST(TestDictionaryBuilder, MakeBuilder) { - auto dict_array = ArrayFromJSON(std::make_shared(), "[1, 2]"); - auto dict_type = dictionary(int8(), dict_array); + using c_type = typename TypeParam::c_type; + + auto value_type = std::make_shared(); + auto dict_array = ArrayFromJSON(value_type, "[1, 2]"); + auto dict_type = dictionary(int8(), value_type); std::unique_ptr boxed_builder; ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder)); auto& builder = checked_cast&>(*boxed_builder); - ASSERT_OK(builder.Append(static_cast(1))); - ASSERT_OK(builder.Append(static_cast(2))); - ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.Append(static_cast(2))); + ASSERT_OK(builder.Append(static_cast(1))); ASSERT_OK(builder.AppendNull()); ASSERT_EQ(builder.length(), 4); @@ -118,7 +125,7 @@ TYPED_TEST(TestDictionaryBuilder, MakeBuilder) { ASSERT_OK(builder.Finish(&result)); auto int_array = ArrayFromJSON(int8(), "[0, 1, 0, null]"); - DictionaryArray expected(dict_type, int_array); + DictionaryArray expected(dict_type, int_array, dict_array); AssertArraysEqual(expected, *result); } @@ -134,10 +141,10 @@ TYPED_TEST(TestDictionaryBuilder, ArrayConversion) { // Build expected data auto dict_array = ArrayFromJSON(type, "[1, 2]"); - auto dict_type = dictionary(int8(), dict_array); + auto dict_type = dictionary(int8(), type); auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]"); - DictionaryArray expected(dict_type, int_array); + DictionaryArray expected(dict_type, int_array, dict_array); ASSERT_TRUE(expected.Equals(result)); } @@ -171,11 +178,12 @@ TYPED_TEST(TestDictionaryBuilder, DoubleTableSize) { // Finalize expected data std::shared_ptr dict_array; ASSERT_OK(dict_builder.Finish(&dict_array)); - auto dtype = std::make_shared(int16(), dict_array); + + auto dtype = dictionary(int16(), dict_array->type()); std::shared_ptr int_array; ASSERT_OK(int_builder.Finish(&int_array)); - DictionaryArray expected(dtype, int_array); + DictionaryArray expected(dtype, int_array, dict_array); ASSERT_TRUE(expected.Equals(result)); } } @@ -194,8 +202,9 @@ TYPED_TEST(TestDictionaryBuilder, DeltaDictionary) { FinishAndCheckPadding(&builder, &result); // Build expected data for the initial dictionary - auto dict_type1 = dictionary(int8(), ArrayFromJSON(type, "[1, 2]")); - DictionaryArray expected(dict_type1, ArrayFromJSON(int8(), "[0, 1, 0, 1]")); + auto ex_dict = ArrayFromJSON(type, "[1, 2]"); + auto dict_type1 = dictionary(int8(), type); + DictionaryArray expected(dict_type1, ArrayFromJSON(int8(), "[0, 1, 0, 1]"), ex_dict); ASSERT_TRUE(expected.Equals(result)); @@ -210,8 +219,10 @@ TYPED_TEST(TestDictionaryBuilder, DeltaDictionary) { ASSERT_OK(builder.Finish(&result_delta)); // Build expected data for the delta dictionary - auto dict_type2 = dictionary(int8(), ArrayFromJSON(type, "[3]")); - DictionaryArray expected_delta(dict_type2, ArrayFromJSON(int8(), "[1, 2, 2, 0, 2]")); + auto ex_dict2 = ArrayFromJSON(type, "[3]"); + auto dict_type2 = dictionary(int8(), type); + DictionaryArray expected_delta(dict_type2, ArrayFromJSON(int8(), "[1, 2, 2, 0, 2]"), + ex_dict2); ASSERT_TRUE(expected_delta.Equals(result_delta)); } @@ -219,6 +230,7 @@ TYPED_TEST(TestDictionaryBuilder, DeltaDictionary) { TYPED_TEST(TestDictionaryBuilder, DoubleDeltaDictionary) { using c_type = typename TypeParam::c_type; auto type = std::make_shared(); + auto dict_type = dictionary(int8(), type); DictionaryBuilder builder(default_memory_pool()); @@ -230,8 +242,8 @@ TYPED_TEST(TestDictionaryBuilder, DoubleDeltaDictionary) { FinishAndCheckPadding(&builder, &result); // Build expected data for the initial dictionary - auto dict_type1 = dictionary(int8(), ArrayFromJSON(type, "[1, 2]")); - DictionaryArray expected(dict_type1, ArrayFromJSON(int8(), "[0, 1, 0, 1]")); + auto ex_dict1 = ArrayFromJSON(type, "[1, 2]"); + DictionaryArray expected(dict_type, ArrayFromJSON(int8(), "[0, 1, 0, 1]"), ex_dict1); ASSERT_TRUE(expected.Equals(result)); @@ -246,24 +258,26 @@ TYPED_TEST(TestDictionaryBuilder, DoubleDeltaDictionary) { ASSERT_OK(builder.Finish(&result_delta1)); // Build expected data for the delta dictionary - auto dict_type2 = dictionary(int8(), ArrayFromJSON(type, "[3]")); - DictionaryArray expected_delta1(dict_type2, ArrayFromJSON(int8(), "[1, 2, 2, 0, 2]")); + auto ex_dict2 = ArrayFromJSON(type, "[3]"); + DictionaryArray expected_delta1(dict_type, ArrayFromJSON(int8(), "[1, 2, 2, 0, 2]"), + ex_dict2); ASSERT_TRUE(expected_delta1.Equals(result_delta1)); // extend the dictionary builder with new data again - ASSERT_OK(builder.Append(static_cast(1))); - ASSERT_OK(builder.Append(static_cast(2))); - ASSERT_OK(builder.Append(static_cast(3))); - ASSERT_OK(builder.Append(static_cast(4))); - ASSERT_OK(builder.Append(static_cast(5))); + ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.Append(static_cast(2))); + ASSERT_OK(builder.Append(static_cast(3))); + ASSERT_OK(builder.Append(static_cast(4))); + ASSERT_OK(builder.Append(static_cast(5))); std::shared_ptr result_delta2; ASSERT_OK(builder.Finish(&result_delta2)); // Build expected data for the delta dictionary again - auto dict_type3 = dictionary(int8(), ArrayFromJSON(type, "[4, 5]")); - DictionaryArray expected_delta2(dict_type3, ArrayFromJSON(int8(), "[0, 1, 2, 3, 4]")); + auto ex_dict3 = ArrayFromJSON(type, "[4, 5]"); + DictionaryArray expected_delta2(dict_type, ArrayFromJSON(int8(), "[0, 1, 2, 3, 4]"), + ex_dict3); ASSERT_TRUE(expected_delta2.Equals(result_delta2)); } @@ -279,9 +293,10 @@ TEST(TestStringDictionaryBuilder, Basic) { ASSERT_OK(builder.Finish(&result)); // Build expected data - auto dtype = dictionary(int8(), ArrayFromJSON(utf8(), "[\"test\", \"test2\"]")); + auto ex_dict = ArrayFromJSON(utf8(), "[\"test\", \"test2\"]"); + auto dtype = dictionary(int8(), utf8()); auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]"); - DictionaryArray expected(dtype, int_array); + DictionaryArray expected(dtype, int_array, ex_dict); ASSERT_TRUE(expected.Equals(result)); } @@ -300,14 +315,14 @@ TEST(TestStringDictionaryBuilder, ArrayInit) { ASSERT_OK(builder.Finish(&result)); // Build expected data - DictionaryArray expected(dictionary(int8(), dict_array), int_array); + DictionaryArray expected(dictionary(int8(), utf8()), int_array, dict_array); AssertArraysEqual(expected, *result); } TEST(TestStringDictionaryBuilder, MakeBuilder) { auto dict_array = ArrayFromJSON(utf8(), R"(["test", "test2"])"); - auto dict_type = dictionary(int8(), dict_array); + auto dict_type = dictionary(int8(), utf8()); auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]"); std::unique_ptr boxed_builder; ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder)); @@ -322,7 +337,7 @@ TEST(TestStringDictionaryBuilder, MakeBuilder) { ASSERT_OK(builder.Finish(&result)); // Build expected data - DictionaryArray expected(dict_type, int_array); + DictionaryArray expected(dict_type, int_array, dict_array); AssertArraysEqual(expected, *result); } @@ -337,9 +352,10 @@ TEST(TestStringDictionaryBuilder, OnlyNull) { ASSERT_OK(builder.Finish(&result)); // Build expected data - auto dtype = dictionary(int8(), ArrayFromJSON(utf8(), "[]")); + auto dict = ArrayFromJSON(utf8(), "[]"); + auto dtype = dictionary(int8(), utf8()); auto int_array = ArrayFromJSON(int8(), "[null]"); - DictionaryArray expected(dtype, int_array); + DictionaryArray expected(dtype, int_array, dict); ASSERT_TRUE(expected.Equals(result)); } @@ -372,11 +388,11 @@ TEST(TestStringDictionaryBuilder, DoubleTableSize) { // Finalize expected data std::shared_ptr str_array; ASSERT_OK(str_builder.Finish(&str_array)); - auto dtype = std::make_shared(int16(), str_array); + auto dtype = dictionary(int16(), utf8()); std::shared_ptr int_array; ASSERT_OK(int_builder.Finish(&int_array)); - DictionaryArray expected(dtype, int_array); + DictionaryArray expected(dtype, int_array, str_array); ASSERT_TRUE(expected.Equals(result)); } @@ -391,9 +407,10 @@ TEST(TestStringDictionaryBuilder, DeltaDictionary) { ASSERT_OK(builder.Finish(&result)); // Build expected data - auto dtype = dictionary(int8(), ArrayFromJSON(utf8(), "[\"test\", \"test2\"]")); + auto dict = ArrayFromJSON(utf8(), "[\"test\", \"test2\"]"); + auto dtype = dictionary(int8(), utf8()); auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]"); - DictionaryArray expected(dtype, int_array); + DictionaryArray expected(dtype, int_array, dict); ASSERT_TRUE(expected.Equals(result)); @@ -406,9 +423,8 @@ TEST(TestStringDictionaryBuilder, DeltaDictionary) { FinishAndCheckPadding(&builder, &result_delta); // Build expected data - auto dtype2 = dictionary(int8(), ArrayFromJSON(utf8(), "[\"test3\"]")); - auto int_array2 = ArrayFromJSON(int8(), "[1, 2, 1]"); - DictionaryArray expected_delta(dtype2, int_array2); + DictionaryArray expected_delta(dtype, ArrayFromJSON(int8(), "[1, 2, 1]"), + ArrayFromJSON(utf8(), "[\"test3\"]")); ASSERT_TRUE(expected_delta.Equals(result_delta)); } @@ -434,12 +450,13 @@ TEST(TestStringDictionaryBuilder, BigDeltaDictionary) { std::shared_ptr str_array1; ASSERT_OK(str_builder1.Finish(&str_array1)); - auto dtype1 = std::make_shared(int16(), str_array1); + + auto dtype1 = dictionary(int16(), utf8()); std::shared_ptr int_array1; ASSERT_OK(int_builder1.Finish(&int_array1)); - DictionaryArray expected(dtype1, int_array1); + DictionaryArray expected(dtype1, int_array1, str_array1); ASSERT_TRUE(expected.Equals(result)); // build delta 1 @@ -462,12 +479,12 @@ TEST(TestStringDictionaryBuilder, BigDeltaDictionary) { std::shared_ptr str_array2; ASSERT_OK(str_builder2.Finish(&str_array2)); - auto dtype2 = std::make_shared(int16(), str_array2); + auto dtype2 = dictionary(int16(), utf8()); std::shared_ptr int_array2; ASSERT_OK(int_builder2.Finish(&int_array2)); - DictionaryArray expected2(dtype2, int_array2); + DictionaryArray expected2(dtype2, int_array2, str_array2); ASSERT_TRUE(expected2.Equals(result2)); // build delta 2 @@ -490,12 +507,12 @@ TEST(TestStringDictionaryBuilder, BigDeltaDictionary) { std::shared_ptr str_array3; ASSERT_OK(str_builder3.Finish(&str_array3)); - auto dtype3 = std::make_shared(int16(), str_array3); + auto dtype3 = dictionary(int16(), utf8()); std::shared_ptr int_array3; ASSERT_OK(int_builder3.Finish(&int_array3)); - DictionaryArray expected3(dtype3, int_array3); + DictionaryArray expected3(dtype3, int_array3, str_array3); ASSERT_TRUE(expected3.Equals(result3)); } @@ -513,12 +530,14 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, Basic) { FinishAndCheckPadding(&builder, &result); // Build expected data - FixedSizeBinaryBuilder fsb_builder(arrow::fixed_size_binary(4)); + auto value_type = arrow::fixed_size_binary(4); + FixedSizeBinaryBuilder fsb_builder(value_type); ASSERT_OK(fsb_builder.Append(test.data())); ASSERT_OK(fsb_builder.Append(test2.data())); std::shared_ptr fsb_array; ASSERT_OK(fsb_builder.Finish(&fsb_array)); - auto dtype = std::make_shared(int8(), fsb_array); + + auto dtype = dictionary(int8(), value_type); Int8Builder int_builder; ASSERT_OK(int_builder.Append(0)); @@ -527,13 +546,14 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, Basic) { std::shared_ptr int_array; ASSERT_OK(int_builder.Finish(&int_array)); - DictionaryArray expected(dtype, int_array); + DictionaryArray expected(dtype, int_array, fsb_array); ASSERT_TRUE(expected.Equals(result)); } TEST(TestFixedSizeBinaryDictionaryBuilder, ArrayInit) { // Build the dictionary Array - auto dict_array = ArrayFromJSON(fixed_size_binary(4), R"(["abcd", "wxyz"])"); + auto value_type = fixed_size_binary(4); + auto dict_array = ArrayFromJSON(value_type, R"(["abcd", "wxyz"])"); util::string_view test = "abcd", test2 = "wxyz"; DictionaryBuilder builder(dict_array, default_memory_pool()); ASSERT_OK(builder.Append(test)); @@ -545,14 +565,16 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, ArrayInit) { // Build expected data auto indices = ArrayFromJSON(int8(), "[0, 1, 0]"); - DictionaryArray expected(dictionary(int8(), dict_array), indices); + DictionaryArray expected(dictionary(int8(), value_type), indices, dict_array); AssertArraysEqual(expected, *result); } TEST(TestFixedSizeBinaryDictionaryBuilder, MakeBuilder) { // Build the dictionary Array - auto dict_array = ArrayFromJSON(fixed_size_binary(4), R"(["abcd", "wxyz"])"); - auto dict_type = dictionary(int8(), dict_array); + auto value_type = fixed_size_binary(4); + auto dict_array = ArrayFromJSON(value_type, R"(["abcd", "wxyz"])"); + auto dict_type = dictionary(int8(), value_type); + std::unique_ptr boxed_builder; ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder)); auto& builder = checked_cast&>(*boxed_builder); @@ -566,14 +588,16 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, MakeBuilder) { // Build expected data auto indices = ArrayFromJSON(int8(), "[0, 1, 0]"); - DictionaryArray expected(dict_type, indices); + DictionaryArray expected(dict_type, indices, dict_array); AssertArraysEqual(expected, *result); } TEST(TestFixedSizeBinaryDictionaryBuilder, DeltaDictionary) { // Build the dictionary Array - DictionaryBuilder builder(arrow::fixed_size_binary(4), - default_memory_pool()); + auto value_type = arrow::fixed_size_binary(4); + auto dict_type = dictionary(int8(), value_type); + + DictionaryBuilder builder(value_type, default_memory_pool()); std::vector test{12, 12, 11, 12}; std::vector test2{12, 12, 11, 11}; std::vector test3{12, 12, 11, 10}; @@ -586,12 +610,11 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, DeltaDictionary) { FinishAndCheckPadding(&builder, &result1); // Build expected data - FixedSizeBinaryBuilder fsb_builder1(arrow::fixed_size_binary(4)); + FixedSizeBinaryBuilder fsb_builder1(value_type); ASSERT_OK(fsb_builder1.Append(test.data())); ASSERT_OK(fsb_builder1.Append(test2.data())); std::shared_ptr fsb_array1; ASSERT_OK(fsb_builder1.Finish(&fsb_array1)); - auto dtype1 = std::make_shared(int8(), fsb_array1); Int8Builder int_builder1; ASSERT_OK(int_builder1.Append(0)); @@ -600,7 +623,7 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, DeltaDictionary) { std::shared_ptr int_array1; ASSERT_OK(int_builder1.Finish(&int_array1)); - DictionaryArray expected1(dtype1, int_array1); + DictionaryArray expected1(dict_type, int_array1, fsb_array1); ASSERT_TRUE(expected1.Equals(result1)); // build delta dictionary @@ -612,11 +635,10 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, DeltaDictionary) { FinishAndCheckPadding(&builder, &result2); // Build expected data - FixedSizeBinaryBuilder fsb_builder2(arrow::fixed_size_binary(4)); + FixedSizeBinaryBuilder fsb_builder2(value_type); ASSERT_OK(fsb_builder2.Append(test3.data())); std::shared_ptr fsb_array2; ASSERT_OK(fsb_builder2.Finish(&fsb_array2)); - auto dtype2 = std::make_shared(int8(), fsb_array2); Int8Builder int_builder2; ASSERT_OK(int_builder2.Append(0)); @@ -625,16 +647,18 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, DeltaDictionary) { std::shared_ptr int_array2; ASSERT_OK(int_builder2.Finish(&int_array2)); - DictionaryArray expected2(dtype2, int_array2); + DictionaryArray expected2(dict_type, int_array2, fsb_array2); ASSERT_TRUE(expected2.Equals(result2)); } TEST(TestFixedSizeBinaryDictionaryBuilder, DoubleTableSize) { // Build the dictionary Array - DictionaryBuilder builder(arrow::fixed_size_binary(4), - default_memory_pool()); + auto value_type = arrow::fixed_size_binary(4); + auto dict_type = dictionary(int16(), value_type); + + DictionaryBuilder builder(value_type, default_memory_pool()); // Build expected data - FixedSizeBinaryBuilder fsb_builder(arrow::fixed_size_binary(4)); + FixedSizeBinaryBuilder fsb_builder(value_type); Int16Builder int_builder; // Fill with 1024 different values @@ -659,18 +683,17 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, DoubleTableSize) { // Finalize expected data std::shared_ptr fsb_array; ASSERT_OK(fsb_builder.Finish(&fsb_array)); - auto dtype = std::make_shared(int16(), fsb_array); std::shared_ptr int_array; ASSERT_OK(int_builder.Finish(&int_array)); - DictionaryArray expected(dtype, int_array); + DictionaryArray expected(dict_type, int_array, fsb_array); ASSERT_TRUE(expected.Equals(result)); } TEST(TestFixedSizeBinaryDictionaryBuilder, InvalidTypeAppend) { // Build the dictionary Array - DictionaryBuilder builder(arrow::fixed_size_binary(4), - default_memory_pool()); + auto value_type = arrow::fixed_size_binary(4); + DictionaryBuilder builder(value_type, default_memory_pool()); // Build an array with different byte width FixedSizeBinaryBuilder fsb_builder(arrow::fixed_size_binary(5)); std::vector value{100, 1, 1, 1, 1}; @@ -696,8 +719,9 @@ TEST(TestDecimalDictionaryBuilder, Basic) { ASSERT_OK(builder.Finish(&result)); // Build expected data - auto dtype = dictionary(int8(), ArrayFromJSON(decimal_type, "[\"12\", \"11\"]")); - DictionaryArray expected(dtype, ArrayFromJSON(int8(), "[0, 0, 1, 0]")); + DictionaryArray expected(dictionary(int8(), decimal_type), + ArrayFromJSON(int8(), "[0, 0, 1, 0]"), + ArrayFromJSON(decimal_type, "[\"12\", \"11\"]")); ASSERT_TRUE(expected.Equals(result)); } @@ -749,47 +773,25 @@ TEST(TestDecimalDictionaryBuilder, DoubleTableSize) { std::shared_ptr fsb_array; ASSERT_OK(fsb_builder.Finish(&fsb_array)); - auto dtype = std::make_shared(int16(), fsb_array); std::shared_ptr int_array; ASSERT_OK(int_builder.Finish(&int_array)); - DictionaryArray expected(dtype, int_array); + DictionaryArray expected(dictionary(int16(), decimal_type), int_array, fsb_array); ASSERT_TRUE(expected.Equals(result)); } // ---------------------------------------------------------------------- // DictionaryArray tests -TEST(TestDictionary, Basics) { - std::vector values = {100, 1000, 10000, 100000}; - std::shared_ptr dict; - ArrayFromVector(values, &dict); - - std::shared_ptr type1 = - std::dynamic_pointer_cast(dictionary(int16(), dict)); - - auto type2 = - std::dynamic_pointer_cast(::arrow::dictionary(int16(), dict, true)); - - ASSERT_TRUE(int16()->Equals(type1->index_type())); - ASSERT_TRUE(type1->dictionary()->Equals(dict)); - - ASSERT_TRUE(int16()->Equals(type2->index_type())); - ASSERT_TRUE(type2->dictionary()->Equals(dict)); - - ASSERT_EQ("dictionary", type1->ToString()); - ASSERT_EQ("dictionary", type2->ToString()); -} - TEST(TestDictionary, Equals) { std::vector is_valid = {true, true, false, true, true, true}; std::shared_ptr dict, dict2, indices, indices2, indices3; dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); - std::shared_ptr dict_type = dictionary(int16(), dict); + std::shared_ptr dict_type = dictionary(int16(), utf8()); dict2 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\", \"qux\"]"); - std::shared_ptr dict2_type = dictionary(int16(), dict2); + std::shared_ptr dict2_type = dictionary(int16(), utf8()); std::vector indices_values = {1, 2, -1, 0, 2, 0}; ArrayFromVector(is_valid, indices_values, &indices); @@ -800,10 +802,10 @@ TEST(TestDictionary, Equals) { std::vector indices3_values = {1, 1, 0, 0, 2, 0}; ArrayFromVector(is_valid, indices3_values, &indices3); - auto array = std::make_shared(dict_type, indices); - auto array2 = std::make_shared(dict_type, indices2); - auto array3 = std::make_shared(dict2_type, indices); - auto array4 = std::make_shared(dict_type, indices3); + auto array = std::make_shared(dict_type, indices, dict); + auto array2 = std::make_shared(dict_type, indices2, dict); + auto array3 = std::make_shared(dict2_type, indices, dict2); + auto array4 = std::make_shared(dict_type, indices3, dict); ASSERT_TRUE(array->Equals(array)); @@ -845,14 +847,23 @@ TEST(TestDictionary, Equals) { TEST(TestDictionary, Validate) { auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); - std::shared_ptr dict_type = dictionary(int16(), dict); + auto dict_type = dictionary(int16(), utf8()); auto indices = ArrayFromJSON(int16(), "[1, 2, null, 0, 2, 0]"); - std::shared_ptr arr = std::make_shared(dict_type, indices); + std::shared_ptr arr = + std::make_shared(dict_type, indices, dict); // Only checking index type for now ASSERT_OK(ValidateArray(*arr)); +#ifdef NDEBUG + std::shared_ptr null_dict_arr = + std::make_shared(dict_type, indices, nullptr); + + // Only checking index type for now + ASSERT_RAISES(Invalid, ValidateArray(*null_dict_arr)); +#endif + // TODO(wesm) In ARROW-1199, there is now a DCHECK to compare the indices // type with the dict_type. How can we test for this? @@ -870,7 +881,7 @@ TEST(TestDictionary, Validate) { TEST(TestDictionary, FromArray) { auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); - std::shared_ptr dict_type = dictionary(int16(), dict); + auto dict_type = dictionary(int16(), utf8()); auto indices1 = ArrayFromJSON(int16(), "[1, 2, 0, 0, 2, 0]"); auto indices2 = ArrayFromJSON(int16(), "[1, 2, 0, 3, 2, 0]"); @@ -885,46 +896,47 @@ TEST(TestDictionary, FromArray) { auto indices4 = ArrayFromJSON(int16(), "[1, 2, null, 3, 2, 0]"); std::shared_ptr arr1, arr2, arr3, arr4; - ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices1, &arr1)); - ASSERT_RAISES(Invalid, DictionaryArray::FromArrays(dict_type, indices2, &arr2)); - ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices3, &arr3)); - ASSERT_RAISES(Invalid, DictionaryArray::FromArrays(dict_type, indices4, &arr4)); + ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices1, dict, &arr1)); + ASSERT_RAISES(Invalid, DictionaryArray::FromArrays(dict_type, indices2, dict, &arr2)); + ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices3, dict, &arr3)); + ASSERT_RAISES(Invalid, DictionaryArray::FromArrays(dict_type, indices4, dict, &arr4)); } TEST(TestDictionary, TransposeBasic) { std::shared_ptr arr, out, expected; auto dict = ArrayFromJSON(utf8(), "[\"A\", \"B\", \"C\"]"); - auto dict_type = dictionary(int16(), dict); + auto dict_type = dictionary(int16(), utf8()); auto indices = ArrayFromJSON(int16(), "[1, 2, 0, 0]"); // ["B", "C", "A", "A"] - ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices, &arr)); + ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices, dict, &arr)); // Transpose to same index type { auto out_dict = ArrayFromJSON(utf8(), "[\"Z\", \"A\", \"C\", \"B\"]"); - auto out_dict_type = dictionary(int16(), out_dict); const std::vector transpose_map{1, 3, 2}; ASSERT_OK(internal::checked_cast(*arr).Transpose( - default_memory_pool(), out_dict_type, transpose_map, &out)); + default_memory_pool(), dict_type, out_dict, transpose_map, &out)); auto expected_indices = ArrayFromJSON(int16(), "[3, 2, 1, 1]"); - ASSERT_OK(DictionaryArray::FromArrays(out_dict_type, expected_indices, &expected)); + ASSERT_OK( + DictionaryArray::FromArrays(dict_type, expected_indices, out_dict, &expected)); AssertArraysEqual(*out, *expected); } // Transpose to other type { auto out_dict = ArrayFromJSON(utf8(), "[\"Z\", \"A\", \"C\", \"B\"]"); - auto out_dict_type = dictionary(int8(), out_dict); + auto out_dict_type = dictionary(int8(), utf8()); const std::vector transpose_map{1, 3, 2}; ASSERT_OK(internal::checked_cast(*arr).Transpose( - default_memory_pool(), out_dict_type, transpose_map, &out)); + default_memory_pool(), out_dict_type, out_dict, transpose_map, &out)); auto expected_indices = ArrayFromJSON(int8(), "[3, 2, 1, 1]"); - ASSERT_OK(DictionaryArray::FromArrays(out_dict_type, expected_indices, &expected)); + ASSERT_OK(DictionaryArray::FromArrays(out_dict_type, expected_indices, out_dict, + &expected)); AssertArraysEqual(*expected, *out); } } @@ -933,20 +945,21 @@ TEST(TestDictionary, TransposeNulls) { std::shared_ptr arr, out, expected; auto dict = ArrayFromJSON(utf8(), "[\"A\", \"B\", \"C\"]"); - auto dict_type = dictionary(int16(), dict); + auto dict_type = dictionary(int16(), utf8()); auto indices = ArrayFromJSON(int16(), "[1, 2, null, 0]"); // ["B", "C", null, "A"] - ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices, &arr)); + ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices, dict, &arr)); auto out_dict = ArrayFromJSON(utf8(), "[\"Z\", \"A\", \"C\", \"B\"]"); - auto out_dict_type = dictionary(int16(), out_dict); + auto out_dict_type = dictionary(int16(), utf8()); const std::vector transpose_map{1, 3, 2}; ASSERT_OK(internal::checked_cast(*arr).Transpose( - default_memory_pool(), out_dict_type, transpose_map, &out)); + default_memory_pool(), out_dict_type, out_dict, transpose_map, &out)); auto expected_indices = ArrayFromJSON(int16(), "[3, 2, null, 1]"); - ASSERT_OK(DictionaryArray::FromArrays(out_dict_type, expected_indices, &expected)); + ASSERT_OK( + DictionaryArray::FromArrays(out_dict_type, expected_indices, out_dict, &expected)); AssertArraysEqual(*expected, *out); } diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 4897d0653ddba..467a43ffc3df8 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -699,30 +699,47 @@ Status ValidateDictionaryIndices(const std::shared_ptr& indices, return Status::OK(); } +std::shared_ptr DictionaryArray::indices() const { return indices_; } + DictionaryArray::DictionaryArray(const std::shared_ptr& data) : dict_type_(checked_cast(data->type.get())) { DCHECK_EQ(data->type->id(), Type::DICTIONARY); + DCHECK(data->dictionary); SetData(data); } +void DictionaryArray::SetData(const std::shared_ptr& data) { + this->Array::SetData(data); + auto indices_data = data_->Copy(); + indices_data->type = dict_type_->index_type(); + indices_ = MakeArray(indices_data); +} + DictionaryArray::DictionaryArray(const std::shared_ptr& type, - const std::shared_ptr& indices) + const std::shared_ptr& indices, + const std::shared_ptr& dictionary) : dict_type_(checked_cast(type.get())) { DCHECK_EQ(type->id(), Type::DICTIONARY); + DCHECK(dict_type_->value_type()->Equals(*dictionary->type())); DCHECK_EQ(indices->type_id(), dict_type_->index_type()->id()); + DCHECK(dictionary); auto data = indices->data()->Copy(); data->type = type; + data->dictionary = dictionary; SetData(data); } +std::shared_ptr DictionaryArray::dictionary() const { return data_->dictionary; } + Status DictionaryArray::FromArrays(const std::shared_ptr& type, const std::shared_ptr& indices, + const std::shared_ptr& dictionary, std::shared_ptr* out) { DCHECK_EQ(type->id(), Type::DICTIONARY); const auto& dict = checked_cast(*type); DCHECK_EQ(indices->type_id(), dict.index_type()->id()); - int64_t upper_bound = dict.dictionary()->length(); + int64_t upper_bound = dictionary->length(); Status is_valid; switch (indices->type_id()) { @@ -747,36 +764,17 @@ Status DictionaryArray::FromArrays(const std::shared_ptr& type, return is_valid; } - *out = std::make_shared(type, indices); + *out = std::make_shared(type, indices, dictionary); return is_valid; } -void DictionaryArray::SetData(const std::shared_ptr& data) { - this->Array::SetData(data); - auto indices_data = data_->Copy(); - indices_data->type = dict_type_->index_type(); - indices_ = MakeArray(indices_data); -} - -std::shared_ptr DictionaryArray::indices() const { return indices_; } - -std::shared_ptr DictionaryArray::dictionary() const { - return dict_type_->dictionary(); -} - template static Status TransposeDictIndices(MemoryPool* pool, const ArrayData& in_data, - const std::shared_ptr& type, const std::vector& transpose_map, + const std::shared_ptr& out_data, std::shared_ptr* out) { using in_c_type = typename InType::c_type; using out_c_type = typename OutType::c_type; - - std::shared_ptr out_buffer; - RETURN_NOT_OK(AllocateBuffer(pool, in_data.length * sizeof(out_c_type), &out_buffer)); - // Null bitmap is unchanged - auto out_data = ArrayData::Make(type, in_data.length, {in_data.buffers[0], out_buffer}, - in_data.null_count); internal::TransposeInts(in_data.GetValues(1), out_data->GetMutableValues(1), in_data.length, transpose_map.data()); @@ -785,20 +783,30 @@ static Status TransposeDictIndices(MemoryPool* pool, const ArrayData& in_data, } Status DictionaryArray::Transpose(MemoryPool* pool, const std::shared_ptr& type, + const std::shared_ptr& dictionary, const std::vector& transpose_map, std::shared_ptr* out) const { DCHECK_EQ(type->id(), Type::DICTIONARY); const auto& out_dict_type = checked_cast(*type); - // XXX We'll probably want to make this operation a kernel when we - // implement dictionary-to-dictionary casting. + const auto& out_index_type = + static_cast(*out_dict_type.index_type()); + auto in_type_id = dict_type_->index_type()->id(); - auto out_type_id = out_dict_type.index_type()->id(); + auto out_type_id = out_index_type.id(); + + std::shared_ptr out_buffer; + RETURN_NOT_OK(AllocateBuffer( + pool, data_->length * out_index_type.bit_width() * CHAR_BIT, &out_buffer)); + // Null bitmap is unchanged + auto out_data = ArrayData::Make(type, data_->length, {data_->buffers[0], out_buffer}, + data_->null_count); + out_data->dictionary = dictionary; -#define TRANSPOSE_IN_OUT_CASE(IN_INDEX_TYPE, OUT_INDEX_TYPE) \ - case OUT_INDEX_TYPE::type_id: \ - return TransposeDictIndices(pool, *data(), type, \ - transpose_map, out); +#define TRANSPOSE_IN_OUT_CASE(IN_INDEX_TYPE, OUT_INDEX_TYPE) \ + case OUT_INDEX_TYPE::type_id: \ + return TransposeDictIndices( \ + pool, *data_, transpose_map, out_data, out); #define TRANSPOSE_IN_CASE(IN_INDEX_TYPE) \ case IN_INDEX_TYPE::type_id: \ @@ -819,9 +827,8 @@ Status DictionaryArray::Transpose(MemoryPool* pool, const std::shared_ptrdictionary) { + return Status::Invalid("Dictionary values must be non-null"); + } return Status::OK(); } diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index 0c6b28a420808..de8df2bb031f3 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -143,7 +143,8 @@ struct ARROW_EXPORT ArrayData { null_count(other.null_count), offset(other.offset), buffers(std::move(other.buffers)), - child_data(std::move(other.child_data)) {} + child_data(std::move(other.child_data)), + dictionary(std::move(other.dictionary)) {} // Copy constructor ArrayData(const ArrayData& other) noexcept @@ -152,7 +153,8 @@ struct ARROW_EXPORT ArrayData { null_count(other.null_count), offset(other.offset), buffers(other.buffers), - child_data(other.child_data) {} + child_data(other.child_data), + dictionary(other.dictionary) {} // Move assignment ArrayData& operator=(ArrayData&& other) = default; @@ -206,6 +208,10 @@ struct ARROW_EXPORT ArrayData { int64_t offset; std::vector> buffers; std::vector> child_data; + + // The dictionary for this Array, if any. Only used for dictionary + // type + std::shared_ptr dictionary; }; /// \brief Create a strongly-typed Array instance from generic ArrayData @@ -973,9 +979,10 @@ class ARROW_EXPORT UnionArray : public Array { }; // ---------------------------------------------------------------------- -// DictionaryArray (categorical and dictionary-encoded in memory) +// DictionaryArray -/// \brief Concrete Array class for dictionary data +/// \brief Array type for dictionary-encoded data with a +/// data-dependent dictionary /// /// A dictionary array contains an array of non-negative integers (the /// "dictionary indices") along with a data type containing a "dictionary" @@ -999,19 +1006,24 @@ class ARROW_EXPORT DictionaryArray : public Array { explicit DictionaryArray(const std::shared_ptr& data); DictionaryArray(const std::shared_ptr& type, - const std::shared_ptr& indices); + const std::shared_ptr& indices, + const std::shared_ptr& dictionary); - /// \brief Construct DictionaryArray from dictionary data type and indices array + /// \brief Construct DictionaryArray from dictionary and indices + /// array and validate /// /// This function does the validation of the indices and input type. It checks if /// all indices are non-negative and smaller than the size of the dictionary /// /// \param[in] type a dictionary type + /// \param[in] dictionary the dictionary with same value type as the + /// type object /// \param[in] indices an array of non-negative signed /// integers smaller than the size of the dictionary /// \param[out] out the resulting DictionaryArray instance static Status FromArrays(const std::shared_ptr& type, const std::shared_ptr& indices, + const std::shared_ptr& dictionary, std::shared_ptr* out); /// \brief Transpose this DictionaryArray @@ -1022,23 +1034,25 @@ class ARROW_EXPORT DictionaryArray : public Array { /// DictionaryType::Unify. /// /// \param[in] pool a pool to allocate the array data from - /// \param[in] type a dictionary type + /// \param[in] type the new type object + /// \param[in] dictionary the new dictionary /// \param[in] transpose_map a vector transposing this array's indices /// into the target array's indices /// \param[out] out the resulting DictionaryArray instance Status Transpose(MemoryPool* pool, const std::shared_ptr& type, + const std::shared_ptr& dictionary, const std::vector& transpose_map, std::shared_ptr* out) const; - // XXX Do we also want an unsafe in-place Transpose? - std::shared_ptr indices() const; + /// \brief Return the dictionary for this array, which is stored as + /// a member of the ArrayData internal structure std::shared_ptr dictionary() const; + std::shared_ptr indices() const; const DictionaryType* dict_type() const { return dict_type_; } private: void SetData(const std::shared_ptr& data); - const DictionaryType* dict_type_; std::shared_ptr indices_; }; diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc index 72bfebfb53297..e4267bfb9dd30 100644 --- a/cpp/src/arrow/array/builder_dict.cc +++ b/cpp/src/arrow/array/builder_dict.cc @@ -45,6 +45,7 @@ struct UnifyDictionaryValues { MemoryPool* pool_; std::shared_ptr value_type_; const std::vector& types_; + const std::vector& dictionaries_; std::shared_ptr* out_values_; std::vector>* out_transpose_maps_; @@ -73,8 +74,8 @@ struct UnifyDictionaryValues { out_transpose_maps_->reserve(types_.size()); } // Build up the unified dictionary values and the transpose maps - for (const auto& type : types_) { - const ArrayType& values = checked_cast(*type->dictionary()); + for (size_t i = 0; i < types_.size(); ++i) { + const ArrayType& values = checked_cast(*dictionaries_[i]); if (out_transpose_maps_ != nullptr) { std::vector transpose_map; transpose_map.reserve(values.length()); @@ -99,11 +100,16 @@ struct UnifyDictionaryValues { }; Status DictionaryType::Unify(MemoryPool* pool, const std::vector& types, + const std::vector& dictionaries, std::shared_ptr* out_type, + std::shared_ptr* out_dictionary, std::vector>* out_transpose_maps) { if (types.size() == 0) { return Status::Invalid("need at least one input type"); } + + DCHECK_EQ(types.size(), dictionaries.size()); + std::vector dict_types; dict_types.reserve(types.size()); for (const auto& type : types) { @@ -114,21 +120,21 @@ Status DictionaryType::Unify(MemoryPool* pool, const std::vectordictionary()->type(); - for (const auto& type : dict_types) { - auto values = type->dictionary(); - if (!values->type()->Equals(value_type)) { - return Status::TypeError("input types have different value types"); + auto value_type = dict_types[0]->value_type(); + for (size_t i = 0; i < types.size(); ++i) { + if (!(dictionaries[i]->type()->Equals(*value_type) && + dict_types[i]->value_type()->Equals(*value_type))) { + return Status::TypeError("dictionary value types were not all consistent"); } - if (values->null_count() != 0) { + if (dictionaries[i]->null_count() != 0) { return Status::TypeError("input types have null values"); } } std::shared_ptr values; { - UnifyDictionaryValues visitor{pool, value_type, dict_types, &values, - out_transpose_maps}; + UnifyDictionaryValues visitor{pool, value_type, dict_types, + dictionaries, &values, out_transpose_maps}; RETURN_NOT_OK(VisitTypeInline(*value_type, &visitor)); } @@ -143,7 +149,8 @@ Status DictionaryType::Unify(MemoryPool* pool, const std::vectortype()); + *out_dictionary = values; return Status::OK(); } diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 84f2e87c35971..93cad2975a227 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -83,12 +83,13 @@ class ARROW_EXPORT DictionaryMemoTable { } // namespace internal -/// \brief Array builder for created encoded DictionaryArray from dense array +/// \brief Array builder for created encoded DictionaryArray from +/// dense array /// -/// Unlike other builders, dictionary builder does not completely reset the state -/// on Finish calls. The arrays built after the initial Finish call will reuse -/// the previously created encoding and build a delta dictionary when new terms -/// occur. +/// Unlike other builders, dictionary builder does not completely +/// reset the state on Finish calls. The arrays built after the +/// initial Finish call will reuse the previously created encoding and +/// build a delta dictionary when new terms occur. /// /// data template @@ -238,15 +239,14 @@ class DictionaryBuilder : public ArrayBuilder { ARROW_RETURN_NOT_OK(values_builder_.FinishInternal(out)); // Generate dictionary array from hash table contents - std::shared_ptr dictionary; std::shared_ptr dictionary_data; ARROW_RETURN_NOT_OK( memo_table_->GetArrayData(pool_, delta_offset_, &dictionary_data)); - dictionary = MakeArray(dictionary_data); // Set type of array data to the right dictionary type - (*out)->type = std::make_shared((*out)->type, dictionary); + (*out)->type = dictionary((*out)->type, type_); + (*out)->dictionary = MakeArray(dictionary_data); // Update internals for further uses of this DictionaryBuilder delta_offset_ = memo_table_->size(); @@ -321,7 +321,8 @@ class DictionaryBuilder : public ArrayBuilder { std::shared_ptr dictionary = std::make_shared(0); ARROW_RETURN_NOT_OK(values_builder_.FinishInternal(out)); - (*out)->type = std::make_shared((*out)->type, dictionary); + (*out)->type = std::make_shared((*out)->type, type_); + (*out)->dictionary = dictionary; return Status::OK(); } diff --git a/cpp/src/arrow/builder-benchmark.cc b/cpp/src/arrow/builder-benchmark.cc index e2c94c5b4279d..6ad860af0008e 100644 --- a/cpp/src/arrow/builder-benchmark.cc +++ b/cpp/src/arrow/builder-benchmark.cc @@ -364,11 +364,30 @@ static void BM_BuildStringDictionaryArray( state.SetBytesProcessed(state.iterations() * fodder_size); } +static void BM_ArrayDataConstructDestruct( + benchmark::State& state) { // NOLINT non-const reference + std::vector> arrays; + + const int kNumArrays = 1000; + auto InitArrays = [&]() { + for (int i = 0; i < kNumArrays; ++i) { + arrays.emplace_back(new ArrayData); + } + }; + + for (auto _ : state) { + InitArrays(); + arrays.clear(); + } +} + // ---------------------------------------------------------------------- // Benchmark declarations static constexpr int32_t kRepetitions = 2; +BENCHMARK(BM_ArrayDataConstructDestruct); + BENCHMARK(BM_BuildPrimitiveArrayNoNulls) ->Repetitions(kRepetitions) ->Unit(benchmark::kMicrosecond); diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index 109ccabb6110a..2bf61781c9f80 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -35,11 +35,6 @@ class MemoryPool; // ---------------------------------------------------------------------- // Helper functions -#define BUILDER_CASE(ENUM, BuilderType) \ - case Type::ENUM: \ - out->reset(new BuilderType(type, pool)); \ - return Status::OK(); - struct DictionaryBuilderCase { template Status Visit(const ValueType&, typename ValueType::c_type* = nullptr) { @@ -65,19 +60,27 @@ struct DictionaryBuilderCase { template Status Create() { - out->reset(new BuilderType(dict_type.dictionary(), pool)); + if (dictionary != nullptr) { + out->reset(new BuilderType(dictionary, pool)); + } else { + out->reset(new BuilderType(value_type, pool)); + } return Status::OK(); } + Status Make() { return VisitTypeInline(*value_type, this); } + MemoryPool* pool; - const DictionaryType& dict_type; + const std::shared_ptr& value_type; + const std::shared_ptr& dictionary; std::unique_ptr* out; }; -// Initially looked at doing this with vtables, but shared pointers makes it -// difficult -// -// TODO(wesm): come up with a less monolithic strategy +#define BUILDER_CASE(ENUM, BuilderType) \ + case Type::ENUM: \ + out->reset(new BuilderType(type, pool)); \ + return Status::OK(); + Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, std::unique_ptr* out) { switch (type->id()) { @@ -109,8 +112,8 @@ Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, BUILDER_CASE(DECIMAL, Decimal128Builder); case Type::DICTIONARY: { const auto& dict_type = static_cast(*type); - DictionaryBuilderCase visitor = {pool, dict_type, out}; - return VisitTypeInline(*dict_type.dictionary()->type(), &visitor); + DictionaryBuilderCase visitor = {pool, dict_type.value_type(), nullptr, out}; + return visitor.Make(); } case Type::INTERVAL: { const auto& interval_type = internal::checked_cast(*type); @@ -163,4 +166,12 @@ Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, type->ToString()); } +Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr& type, + const std::shared_ptr& dictionary, + std::unique_ptr* out) { + const auto& dict_type = static_cast(*type); + DictionaryBuilderCase visitor = {pool, dict_type.value_type(), dictionary, out}; + return visitor.Make(); +} + } // namespace arrow diff --git a/cpp/src/arrow/builder.h b/cpp/src/arrow/builder.h index c0672c25a064e..56c3e2b3716a8 100644 --- a/cpp/src/arrow/builder.h +++ b/cpp/src/arrow/builder.h @@ -35,8 +35,24 @@ namespace arrow { class DataType; class MemoryPool; +/// \brief Construct an empty ArrayBuilder corresponding to the data +/// type +/// \param[in] pool the MemoryPool to use for allocations +/// \param[in] type an instance of DictionaryType +/// \param[out] out the created ArrayBuilder ARROW_EXPORT Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, std::unique_ptr* out); +/// \brief Construct an empty DictionaryBuilder initialized optionally +/// with a pre-existing dictionary +/// \param[in] pool the MemoryPool to use for allocations +/// \param[in] type an instance of DictionaryType +/// \param[in] dictionary the initial dictionary, if any. May be nullptr +/// \param[out] out the created ArrayBuilder +ARROW_EXPORT +Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr& type, + const std::shared_ptr& dictionary, + std::unique_ptr* out); + } // namespace arrow diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 8fdd3cb02407b..c82d4df5ee836 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -797,7 +797,7 @@ class TypeEqualsVisitor { Status Visit(const DictionaryType& left) { const auto& right = checked_cast(right_); result_ = left.index_type()->Equals(right.index_type()) && - left.dictionary()->Equals(right.dictionary()) && + left.value_type()->Equals(right.value_type()) && (left.ordered() == right.ordered()); return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/cast-test.cc b/cpp/src/arrow/compute/kernels/cast-test.cc index aa5815b1e4ebc..778544cda2f50 100644 --- a/cpp/src/arrow/compute/kernels/cast-test.cc +++ b/cpp/src/arrow/compute/kernels/cast-test.cc @@ -1147,9 +1147,11 @@ TEST_F(TestCast, IdentityCasts) { CheckIdentityCast(timestamp(TimeUnit::SECOND), "[1, 2, 3, 4]"); { - auto dict_type = dictionary(int8(), ArrayFromJSON(int8(), "[1, 2, 3]")); + auto dict_values = ArrayFromJSON(int8(), "[1, 2, 3]"); + auto dict_type = dictionary(int8(), dict_values->type()); auto dict_indices = ArrayFromJSON(int8(), "[0, 1, 2, 0, null, 2]"); - auto dict_array = std::make_shared(dict_type, dict_indices); + auto dict_array = + std::make_shared(dict_type, dict_indices, dict_values); CheckZeroCopy(*dict_array, dict_type); } } diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 749e200d4a404..26276bc2d2944 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -123,7 +123,7 @@ struct CastFunctor> { // Number to Boolean template struct CastFunctor::value && + typename std::enable_if::value && !std::is_same::value>::type> { void operator()(FunctionContext* ctx, const CastOptions& options, const ArrayData& input, ArrayData* output) { @@ -154,8 +154,7 @@ struct is_number_downcast { template struct is_number_downcast< O, I, - typename std::enable_if::value && - std::is_base_of::value>::type> { + typename std::enable_if::value && is_number_type::value>::type> { using O_T = typename O::c_type; using I_T = typename I::c_type; @@ -177,8 +176,8 @@ struct is_integral_signed_to_unsigned { template struct is_integral_signed_to_unsigned< O, I, - typename std::enable_if::value && - std::is_base_of::value>::type> { + typename std::enable_if::value && + is_integer_type::value>::type> { using O_T = typename O::c_type; using I_T = typename I::c_type; @@ -195,8 +194,8 @@ struct is_integral_unsigned_to_signed { template struct is_integral_unsigned_to_signed< O, I, - typename std::enable_if::value && - std::is_base_of::value>::type> { + typename std::enable_if::value && + is_integer_type::value>::type> { using O_T = typename O::c_type; using I_T = typename I::c_type; @@ -321,10 +320,9 @@ struct is_float_truncate { template struct is_float_truncate< O, I, - typename std::enable_if<(std::is_base_of::value && - std::is_base_of::value) || - (std::is_base_of::value && - std::is_base_of::value)>::type> { + typename std::enable_if<(is_integer_type::value && is_floating_type::value) || + (is_integer_type::value && + is_floating_type::value)>::type> { static constexpr bool value = true; }; @@ -383,8 +381,7 @@ struct is_safe_numeric_cast { template struct is_safe_numeric_cast< O, I, - typename std::enable_if::value && - std::is_base_of::value>::type> { + typename std::enable_if::value && is_number_type::value>::type> { using O_T = typename O::c_type; using I_T = typename I::c_type; @@ -746,213 +743,159 @@ class FromNullCastKernel : public CastKernelBase { // ---------------------------------------------------------------------- // Dictionary to other things -template -void UnpackFixedSizeBinaryDictionary(FunctionContext* ctx, const Array& indices, - const FixedSizeBinaryArray& dictionary, - ArrayData* output) { - using index_c_type = typename IndexType::c_type; +template +struct UnpackHelper {}; - const index_c_type* in = indices.data()->GetValues(1); - int32_t byte_width = - checked_cast(*output->type).byte_width(); +template +struct UnpackHelper< + T, typename std::enable_if::value>::type> { + using ArrayType = typename TypeTraits::ArrayType; + + template + Status Unpack(FunctionContext* ctx, const ArrayData& indices, + const ArrayType& dictionary, ArrayData* output) { + using index_c_type = typename IndexType::c_type; + + const index_c_type* in = indices.GetValues(1); + int32_t byte_width = + checked_cast(*output->type).byte_width(); - uint8_t* out = output->buffers[1]->mutable_data() + byte_width * output->offset; + uint8_t* out = output->buffers[1]->mutable_data() + byte_width * output->offset; - if (indices.null_count() != 0) { - internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(), indices.offset(), - indices.length()); + if (indices.GetNullCount() != 0) { + internal::BitmapReader valid_bits_reader(indices.GetValues(0), + indices.offset, indices.length); - for (int64_t i = 0; i < indices.length(); ++i) { - if (valid_bits_reader.IsSet()) { + for (int64_t i = 0; i < indices.length; ++i) { + if (valid_bits_reader.IsSet()) { + const uint8_t* value = dictionary.Value(in[i]); + memcpy(out + i * byte_width, value, byte_width); + } + valid_bits_reader.Next(); + } + } else { + for (int64_t i = 0; i < indices.length; ++i) { const uint8_t* value = dictionary.Value(in[i]); memcpy(out + i * byte_width, value, byte_width); } - valid_bits_reader.Next(); - } - } else { - for (int64_t i = 0; i < indices.length(); ++i) { - const uint8_t* value = dictionary.Value(in[i]); - memcpy(out + i * byte_width, value, byte_width); - } - } -} - -template -struct CastFunctor< - T, DictionaryType, - typename std::enable_if::value>::type> { - void operator()(FunctionContext* ctx, const CastOptions& options, - const ArrayData& input, ArrayData* output) { - DictionaryArray dict_array(input.Copy()); - - const DictionaryType& type = checked_cast(*input.type); - const DataType& values_type = *type.dictionary()->type(); - const FixedSizeBinaryArray& dictionary = - checked_cast(*type.dictionary()); - - // Check if values and output type match - DCHECK(values_type.Equals(*output->type)) - << "Dictionary type: " << values_type << " target type: " << (*output->type); - - const Array& indices = *dict_array.indices(); - switch (indices.type()->id()) { - case Type::INT8: - UnpackFixedSizeBinaryDictionary(ctx, indices, dictionary, output); - break; - case Type::INT16: - UnpackFixedSizeBinaryDictionary(ctx, indices, dictionary, output); - break; - case Type::INT32: - UnpackFixedSizeBinaryDictionary(ctx, indices, dictionary, output); - break; - case Type::INT64: - UnpackFixedSizeBinaryDictionary(ctx, indices, dictionary, output); - break; - default: - ctx->SetStatus( - Status::Invalid("Invalid index type: ", indices.type()->ToString())); - return; } + return Status::OK(); } }; -template -Status UnpackBinaryDictionary(FunctionContext* ctx, const Array& indices, - const BinaryArray& dictionary, ArrayData* output) { - using index_c_type = typename IndexType::c_type; - std::unique_ptr builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), output->type, &builder)); - BinaryBuilder* binary_builder = checked_cast(builder.get()); - - const index_c_type* in = indices.data()->GetValues(1); - if (indices.null_count() != 0) { - internal::BitmapReader valid_bits_reader(indices.null_bitmap_data(), indices.offset(), - indices.length()); - - for (int64_t i = 0; i < indices.length(); ++i) { - if (valid_bits_reader.IsSet()) { +template +struct UnpackHelper< + T, typename std::enable_if::value>::type> { + using ArrayType = typename TypeTraits::ArrayType; + + template + Status Unpack(FunctionContext* ctx, const ArrayData& indices, + const ArrayType& dictionary, ArrayData* output) { + using index_c_type = typename IndexType::c_type; + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), output->type, &builder)); + BinaryBuilder* binary_builder = checked_cast(builder.get()); + + const index_c_type* in = indices.GetValues(1); + if (indices.GetNullCount() != 0) { + internal::BitmapReader valid_bits_reader(indices.GetValues(0), + indices.offset, indices.length); + + for (int64_t i = 0; i < indices.length; ++i) { + if (valid_bits_reader.IsSet()) { + RETURN_NOT_OK(binary_builder->Append(dictionary.GetView(in[i]))); + } else { + RETURN_NOT_OK(binary_builder->AppendNull()); + } + valid_bits_reader.Next(); + } + } else { + for (int64_t i = 0; i < indices.length; ++i) { RETURN_NOT_OK(binary_builder->Append(dictionary.GetView(in[i]))); - } else { - RETURN_NOT_OK(binary_builder->AppendNull()); } - valid_bits_reader.Next(); - } - } else { - for (int64_t i = 0; i < indices.length(); ++i) { - RETURN_NOT_OK(binary_builder->Append(dictionary.GetView(in[i]))); } - } - - std::shared_ptr plain_array; - RETURN_NOT_OK(binary_builder->Finish(&plain_array)); - // Copy all buffer except the valid bitmap - for (size_t i = 1; i < plain_array->data()->buffers.size(); i++) { - output->buffers.push_back(plain_array->data()->buffers[i]); - } - - return Status::OK(); -} - -template -struct CastFunctor::value>::type> { - void operator()(FunctionContext* ctx, const CastOptions& options, - const ArrayData& input, ArrayData* output) { - DictionaryArray dict_array(input.Copy()); - - const DictionaryType& type = checked_cast(*input.type); - const DataType& values_type = *type.dictionary()->type(); - const BinaryArray& dictionary = checked_cast(*type.dictionary()); - - // Check if values and output type match - DCHECK(values_type.Equals(*output->type)) - << "Dictionary type: " << values_type << " target type: " << (*output->type); - const Array& indices = *dict_array.indices(); - switch (indices.type()->id()) { - case Type::INT8: - FUNC_RETURN_NOT_OK( - (UnpackBinaryDictionary(ctx, indices, dictionary, output))); - break; - case Type::INT16: - FUNC_RETURN_NOT_OK( - (UnpackBinaryDictionary(ctx, indices, dictionary, output))); - break; - case Type::INT32: - FUNC_RETURN_NOT_OK( - (UnpackBinaryDictionary(ctx, indices, dictionary, output))); - break; - case Type::INT64: - FUNC_RETURN_NOT_OK( - (UnpackBinaryDictionary(ctx, indices, dictionary, output))); - break; - default: - ctx->SetStatus( - Status::Invalid("Invalid index type: ", indices.type()->ToString())); - return; + std::shared_ptr plain_array; + RETURN_NOT_OK(binary_builder->Finish(&plain_array)); + // Copy all buffer except the valid bitmap + for (size_t i = 1; i < plain_array->data()->buffers.size(); i++) { + output->buffers.push_back(plain_array->data()->buffers[i]); } + + return Status::OK(); } }; -template -void UnpackPrimitiveDictionary(const Array& indices, const c_type* dictionary, - c_type* out) { - const auto& in = indices.data()->GetValues(1); - int64_t length = indices.length(); - - if (indices.null_count() == 0) { - for (int64_t i = 0; i < length; ++i) { - out[i] = dictionary[in[i]]; - } - } else { - auto null_bitmap = indices.null_bitmap_data(); - internal::BitmapReader valid_bits_reader(null_bitmap, indices.offset(), length); - for (int64_t i = 0; i < length; ++i) { - if (valid_bits_reader.IsSet()) { - out[i] = dictionary[in[i]]; +template +struct UnpackHelper::value || + is_temporal_type::value>::type> { + using ArrayType = typename TypeTraits::ArrayType; + + template + Status Unpack(FunctionContext* ctx, const ArrayData& indices, + const ArrayType& dictionary, ArrayData* output) { + using index_type = typename IndexType::c_type; + using value_type = typename T::c_type; + + const index_type* in = indices.GetValues(1); + value_type* out = output->GetMutableValues(1); + const value_type* dict_values = dictionary.data()->template GetValues(1); + + if (indices.GetNullCount() == 0) { + for (int64_t i = 0; i < indices.length; ++i) { + out[i] = dict_values[in[i]]; + } + } else { + internal::BitmapReader valid_bits_reader(indices.GetValues(0), + indices.offset, indices.length); + for (int64_t i = 0; i < indices.length; ++i) { + if (valid_bits_reader.IsSet()) { + // TODO(wesm): is it worth removing the branch here? + out[i] = dict_values[in[i]]; + } + valid_bits_reader.Next(); } - valid_bits_reader.Next(); } + return Status::OK(); } -} +}; -// Cast from dictionary to plain representation +// Dispatch dictionary casts to UnpackHelper template -struct CastFunctor::value>::type> { +struct CastFunctor { void operator()(FunctionContext* ctx, const CastOptions& options, const ArrayData& input, ArrayData* output) { - using c_type = typename T::c_type; - - DictionaryArray dict_array(input.Copy()); + using ArrayType = typename TypeTraits::ArrayType; const DictionaryType& type = checked_cast(*input.type); - const DataType& values_type = *type.dictionary()->type(); + const Array& dictionary = *input.dictionary; + const DataType& values_type = *dictionary.type(); // Check if values and output type match DCHECK(values_type.Equals(*output->type)) << "Dictionary type: " << values_type << " target type: " << (*output->type); - const c_type* dictionary = type.dictionary()->data()->GetValues(1); - - auto out = output->GetMutableValues(1); - const Array& indices = *dict_array.indices(); - switch (indices.type()->id()) { + UnpackHelper unpack_helper; + switch (type.index_type()->id()) { case Type::INT8: - UnpackPrimitiveDictionary(indices, dictionary, out); + FUNC_RETURN_NOT_OK(unpack_helper.template Unpack( + ctx, input, static_cast(dictionary), output)); break; case Type::INT16: - UnpackPrimitiveDictionary(indices, dictionary, out); + FUNC_RETURN_NOT_OK(unpack_helper.template Unpack( + ctx, input, static_cast(dictionary), output)); break; case Type::INT32: - UnpackPrimitiveDictionary(indices, dictionary, out); + FUNC_RETURN_NOT_OK(unpack_helper.template Unpack( + ctx, input, static_cast(dictionary), output)); break; case Type::INT64: - UnpackPrimitiveDictionary(indices, dictionary, out); + FUNC_RETURN_NOT_OK(unpack_helper.template Unpack( + ctx, input, static_cast(dictionary), output)); break; default: ctx->SetStatus( - Status::Invalid("Invalid index type: ", indices.type()->ToString())); + Status::TypeError("Invalid index type: ", type.index_type()->ToString())); return; } } diff --git a/cpp/src/arrow/compute/kernels/hash-test.cc b/cpp/src/arrow/compute/kernels/hash-test.cc index 6292f74fb9543..61553c133e065 100644 --- a/cpp/src/arrow/compute/kernels/hash-test.cc +++ b/cpp/src/arrow/compute/kernels/hash-test.cc @@ -116,7 +116,7 @@ void CheckDictEncode(FunctionContext* ctx, const std::shared_ptr& type std::shared_ptr ex_indices = _MakeArray(int32(), out_indices, in_is_valid); - DictionaryArray expected(dictionary(int32(), ex_dict), ex_indices); + DictionaryArray expected(dictionary(int32(), type), ex_indices, ex_dict); Datum datum_out; ASSERT_OK(DictionaryEncode(ctx, input, &datum_out)); @@ -447,13 +447,13 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) { ASSERT_ARRAYS_EQUAL(*ex_dict, *result); // Dictionary encode - auto dict_type = dictionary(int32(), ex_dict); + auto dict_type = dictionary(int32(), type); auto i1 = _MakeArray(int32(), {0, 1, 0}, {}); auto i2 = _MakeArray(int32(), {1, 2, 3, 0}, {}); - ArrayVector dict_arrays = {std::make_shared(dict_type, i1), - std::make_shared(dict_type, i2)}; + ArrayVector dict_arrays = {std::make_shared(dict_type, i1, ex_dict), + std::make_shared(dict_type, i2, ex_dict)}; auto dict_carr = std::make_shared(dict_arrays); // Unique counts diff --git a/cpp/src/arrow/compute/kernels/hash.cc b/cpp/src/arrow/compute/kernels/hash.cc index 2a3031fc7bcc6..dc9fc4b2813a1 100644 --- a/cpp/src/arrow/compute/kernels/hash.cc +++ b/cpp/src/arrow/compute/kernels/hash.cc @@ -519,13 +519,13 @@ Status DictionaryEncode(FunctionContext* ctx, const Datum& value, Datum* out) { // Create the dictionary type DCHECK_EQ(indices_outputs[0].kind(), Datum::ARRAY); std::shared_ptr dict_type = - ::arrow::dictionary(indices_outputs[0].array()->type, dictionary); + ::arrow::dictionary(indices_outputs[0].array()->type, dictionary->type()); // Create DictionaryArray for each piece yielded by the kernel invocations std::vector> dict_chunks; for (const Datum& datum : indices_outputs) { - dict_chunks.emplace_back( - std::make_shared(dict_type, MakeArray(datum.array()))); + dict_chunks.emplace_back(std::make_shared( + dict_type, MakeArray(datum.array()), dictionary)); } *out = detail::WrapArraysLike(value, dict_chunks); diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index 4813203f6d5cc..b3de04d8cc762 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -138,12 +138,13 @@ class TestTakeKernelWithString : public TestTakeKernel { const std::string& dictionary_indices, const std::string& indices, TakeOptions options, const std::string& expected_indices) { - auto type = dictionary(int8(), ArrayFromJSON(utf8(), dictionary_values)); + auto dict = ArrayFromJSON(utf8(), dictionary_values); + auto type = dictionary(int8(), utf8()); std::shared_ptr values, actual, expected; ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), dictionary_indices), - &values)); + dict, &values)); ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), - &expected)); + dict, &expected)); auto take_indices = ArrayFromJSON(int8(), indices); this->AssertTakeArrays(values, take_indices, options, expected); } diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 1d7a33719369c..f83139de93193 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -21,6 +21,7 @@ #include "arrow/builder.h" #include "arrow/compute/context.h" #include "arrow/compute/kernels/take.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" #include "arrow/visitor_inline.h" @@ -145,10 +146,11 @@ struct UnpackValues { Status Visit(const DictionaryType& t) { std::shared_ptr taken_indices; + const auto& values = internal::checked_cast(*params_.values); { // To take from a dictionary, apply the current kernel to the dictionary's // indices. (Use UnpackValues since IndexType is already unpacked) - auto indices = static_cast(params_.values.get())->indices(); + auto indices = values.indices(); TakeParameters params = params_; params.values = indices; params.out = &taken_indices; @@ -156,8 +158,9 @@ struct UnpackValues { RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack)); } // create output dictionary from taken indices - return DictionaryArray::FromArrays(dictionary(t.index_type(), t.dictionary()), - taken_indices, params_.out); + *params_.out = std::make_shared(values.type(), taken_indices, + values.dictionary()); + return Status::OK(); } Status Visit(const ExtensionType& t) { diff --git a/cpp/src/arrow/flight/flight-benchmark.cc b/cpp/src/arrow/flight/flight-benchmark.cc index 066b3683c112c..d6318eaadeff7 100644 --- a/cpp/src/arrow/flight/flight-benchmark.cc +++ b/cpp/src/arrow/flight/flight-benchmark.cc @@ -90,7 +90,8 @@ Status RunPerformanceTest(const std::string& hostname, const int port) { // Read the streams in parallel std::shared_ptr schema; - RETURN_NOT_OK(plan->GetSchema(&schema)); + ipc::DictionaryMemo dict_memo; + RETURN_NOT_OK(plan->GetSchema(&dict_memo, &schema)); PerformanceStats stats; auto ConsumeStream = [&stats, &hostname, &port](const FlightEndpoint& endpoint) { diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc index 504779db2e74b..4658b48382d83 100644 --- a/cpp/src/arrow/flight/flight-test.cc +++ b/cpp/src/arrow/flight/flight-test.cc @@ -88,8 +88,10 @@ void AssertEqual(const std::vector& expected, const std::vector& actual) { void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) { std::shared_ptr ex_schema, actual_schema; - ASSERT_OK(expected.GetSchema(&ex_schema)); - ASSERT_OK(actual.GetSchema(&actual_schema)); + ipc::DictionaryMemo expected_memo; + ipc::DictionaryMemo actual_memo; + ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema)); + ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema)); AssertSchemaEqual(*ex_schema, *actual_schema); ASSERT_EQ(expected.total_records(), actual.total_records()); @@ -181,7 +183,8 @@ class TestFlightClient : public ::testing::Test { check_endpoints(info->endpoints()); std::shared_ptr schema; - ASSERT_OK(info->GetSchema(&schema)); + ipc::DictionaryMemo dict_memo; + ASSERT_OK(info->GetSchema(&dict_memo, &schema)); AssertSchemaEqual(*expected_schema, *schema); // By convention, fetch the first endpoint diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index 4fce14c7e39bc..2e335936e51fb 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -226,7 +226,9 @@ Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) { Status SchemaToString(const Schema& schema, std::string* out) { // TODO(wesm): Do we care about better memory efficiency here? std::shared_ptr serialized_schema; - RETURN_NOT_OK(ipc::SerializeSchema(schema, default_memory_pool(), &serialized_schema)); + ipc::DictionaryMemo unused_dict_memo; + RETURN_NOT_OK(ipc::SerializeSchema(schema, &unused_dict_memo, default_memory_pool(), + &serialized_schema)); *out = std::string(reinterpret_cast(serialized_schema->data()), static_cast(serialized_schema->size())); return Status::OK(); diff --git a/cpp/src/arrow/flight/perf-server.cc b/cpp/src/arrow/flight/perf-server.cc index 5387bbd068a05..b2c268b73c0b2 100644 --- a/cpp/src/arrow/flight/perf-server.cc +++ b/cpp/src/arrow/flight/perf-server.cc @@ -71,6 +71,11 @@ class PerfDataStream : public FlightDataStream { std::shared_ptr schema() override { return schema_; } + Status GetSchemaPayload(FlightPayload* payload) override { + return ipc::internal::GetSchemaPayload(*schema_, &dictionary_memo_, + &payload->ipc_message); + } + Status Next(FlightPayload* payload) override { if (records_sent_ >= total_records_) { // Signal that iteration is over @@ -107,6 +112,7 @@ class PerfDataStream : public FlightDataStream { const int64_t total_records_; int64_t records_sent_; std::shared_ptr schema_; + ipc::DictionaryMemo dictionary_memo_; std::shared_ptr batch_; ArrayVector arrays_; }; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 8cb6921c33e7b..51384096eeaa2 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -32,12 +32,14 @@ #endif #include "arrow/buffer.h" +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/util/logging.h" +#include "arrow/util/stl.h" #include "arrow/flight/internal.h" #include "arrow/flight/serialization-internal.h" @@ -64,9 +66,11 @@ class FlightMessageReaderImpl : public FlightMessageReader { public: FlightMessageReaderImpl(const FlightDescriptor& descriptor, std::shared_ptr schema, + std::unique_ptr dict_memo, grpc::ServerReader* reader) : descriptor_(descriptor), schema_(schema), + dictionary_memo_(std::move(dict_memo)), reader_(reader), stream_finished_(false) {} @@ -87,7 +91,7 @@ class FlightMessageReaderImpl : public FlightMessageReader { // Validate IPC message RETURN_NOT_OK(data.OpenMessage(&message)); if (message->type() == ipc::Message::Type::RECORD_BATCH) { - return ipc::ReadRecordBatch(*message, schema_, out); + return ipc::ReadRecordBatch(*message, schema_, dictionary_memo_.get(), out); } else { return Status(StatusCode::Invalid, "Unrecognized message in Flight stream"); } @@ -102,6 +106,7 @@ class FlightMessageReaderImpl : public FlightMessageReader { private: FlightDescriptor descriptor_; std::shared_ptr schema_; + std::unique_ptr dictionary_memo_; grpc::ServerReader* reader_; bool stream_finished_; }; @@ -293,25 +298,15 @@ class FlightServiceImpl : public FlightService::Service { return grpc::Status(grpc::StatusCode::NOT_FOUND, "No data in this flight"); } - // Write the schema as the first message(s) in the stream - // (several messages may be required if there are dictionaries) - MemoryPool* pool = default_memory_pool(); - std::vector ipc_payloads; - GRPC_RETURN_NOT_OK( - ipc::internal::GetSchemaPayloads(*data_stream->schema(), pool, &ipc_payloads)); - - for (auto& ipc_payload : ipc_payloads) { - // For DoGet, descriptor doesn't need to be written out - FlightPayload schema_payload; - schema_payload.ipc_message = std::move(ipc_payload); - - if (!internal::WritePayload(schema_payload, writer)) { - // Connection terminated? XXX return error code? - return grpc::Status::OK; - } + // Write the schema as the first message in the stream + FlightPayload schema_payload; + GRPC_RETURN_NOT_OK(data_stream->GetSchemaPayload(&schema_payload)); + if (!internal::WritePayload(schema_payload, writer)) { + // Connection terminated? XXX return error code? + return grpc::Status::OK; } - // Write incoming data as individual messages + // Consume data stream and write out payloads while (true) { FlightPayload payload; GRPC_RETURN_NOT_OK(data_stream->Next(&payload)); @@ -343,10 +338,11 @@ class FlightServiceImpl : public FlightService::Service { Status(StatusCode::Invalid, "DoPut must start with non-null descriptor")); } else { std::shared_ptr schema; - GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, &schema)); + auto dictionary_memo = ::arrow::internal::make_unique(); + GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, dictionary_memo.get(), &schema)); - auto message_reader = std::unique_ptr( - new FlightMessageReaderImpl(*data.descriptor.get(), schema, reader)); + auto message_reader = ::arrow::internal::make_unique( + *data.descriptor.get(), schema, std::move(dictionary_memo), reader); return internal::ToGrpcStatus( server_->DoPut(flight_context, std::move(message_reader))); } @@ -545,23 +541,99 @@ Status FlightServerBase::ListActions(const ServerCallContext& context, // ---------------------------------------------------------------------- // Implement RecordBatchStream -RecordBatchStream::RecordBatchStream(const std::shared_ptr& reader) - : pool_(default_memory_pool()), reader_(reader) {} +class RecordBatchStream::RecordBatchStreamImpl { + public: + // Stages of the stream when producing paylaods + enum class Stage { + NEW, // The stream has been created, but Next has not been called yet + DICTIONARY, // Dictionaries have been collected, and are being sent + RECORD_BATCH // Initial have been sent + }; + + RecordBatchStreamImpl(const std::shared_ptr& reader, + MemoryPool* pool) + : pool_(pool), reader_(reader) {} + + std::shared_ptr schema() { return reader_->schema(); } + + Status GetSchemaPayload(FlightPayload* payload) { + return ipc::internal::GetSchemaPayload(*reader_->schema(), &dictionary_memo_, + &payload->ipc_message); + } + + Status Next(FlightPayload* payload) { + if (stage_ == Stage::NEW) { + RETURN_NOT_OK(reader_->ReadNext(¤t_batch_)); + if (!current_batch_) { + // Signal that iteration is over + payload->ipc_message.metadata = nullptr; + return Status::OK(); + } + RETURN_NOT_OK(CollectDictionaries(*current_batch_)); + stage_ = Stage::DICTIONARY; + } -std::shared_ptr RecordBatchStream::schema() { return reader_->schema(); } + if (stage_ == Stage::DICTIONARY) { + if (dictionary_index_ == static_cast(dictionaries_.size())) { + stage_ = Stage::RECORD_BATCH; + return ipc::internal::GetRecordBatchPayload(*current_batch_, pool_, + &payload->ipc_message); + } else { + return GetNextDictionary(payload); + } + } -Status RecordBatchStream::Next(FlightPayload* payload) { - std::shared_ptr batch; - RETURN_NOT_OK(reader_->ReadNext(&batch)); + RETURN_NOT_OK(reader_->ReadNext(¤t_batch_)); - if (!batch) { - // Signal that iteration is over - payload->ipc_message.metadata = nullptr; + // TODO(wesm): Delta dictionaries + if (!current_batch_) { + // Signal that iteration is over + payload->ipc_message.metadata = nullptr; + return Status::OK(); + } else { + return ipc::internal::GetRecordBatchPayload(*current_batch_, pool_, + &payload->ipc_message); + } + } + + private: + Status GetNextDictionary(FlightPayload* payload) { + const auto& it = dictionaries_[dictionary_index_++]; + return ipc::internal::GetDictionaryPayload(it.first, it.second, pool_, + &payload->ipc_message); + } + + Status CollectDictionaries(const RecordBatch& batch) { + RETURN_NOT_OK(ipc::CollectDictionaries(batch, &dictionary_memo_)); + for (auto& pair : dictionary_memo_.id_to_dictionary()) { + dictionaries_.push_back({pair.first, pair.second}); + } return Status::OK(); - } else { - return ipc::internal::GetRecordBatchPayload(*batch, pool_, &payload->ipc_message); } + + Stage stage_ = Stage::NEW; + MemoryPool* pool_; + std::shared_ptr reader_; + ipc::DictionaryMemo dictionary_memo_; + std::shared_ptr current_batch_; + std::vector>> dictionaries_; + + // Index of next dictionary to send + int dictionary_index_ = 0; +}; + +RecordBatchStream::RecordBatchStream(const std::shared_ptr& reader, + MemoryPool* pool) { + impl_.reset(new RecordBatchStreamImpl(reader, pool)); +} + +std::shared_ptr RecordBatchStream::schema() { return impl_->schema(); } + +Status RecordBatchStream::GetSchemaPayload(FlightPayload* payload) { + return impl_->GetSchemaPayload(payload); } +Status RecordBatchStream::Next(FlightPayload* payload) { return impl_->Next(payload); } + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index e9e45e03261c7..0a2b940f14bc8 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -24,10 +24,11 @@ #include #include -#include "arrow/util/visibility.h" - #include "arrow/flight/types.h" // IWYU pragma: keep +#include "arrow/ipc/dictionary.h" +#include "arrow/memory_pool.h" #include "arrow/record_batch.h" +#include "arrow/util/visibility.h" namespace arrow { @@ -45,9 +46,11 @@ class ARROW_EXPORT FlightDataStream { public: virtual ~FlightDataStream() = default; - // When the stream starts, send the schema. virtual std::shared_ptr schema() = 0; + /// \brief Compute FlightPayload containing serialized RecordBatch schema + virtual Status GetSchemaPayload(FlightPayload* payload) = 0; + // When the stream is completed, the last payload written will have null // metadata virtual Status Next(FlightPayload* payload) = 0; @@ -58,14 +61,17 @@ class ARROW_EXPORT FlightDataStream { class ARROW_EXPORT RecordBatchStream : public FlightDataStream { public: /// \param[in] reader produces a sequence of record batches - explicit RecordBatchStream(const std::shared_ptr& reader); + /// \param[in,out] pool a MemoryPool to use for allocations + explicit RecordBatchStream(const std::shared_ptr& reader, + MemoryPool* pool = default_memory_pool()); std::shared_ptr schema() override; + Status GetSchemaPayload(FlightPayload* payload) override; Status Next(FlightPayload* payload) override; private: - MemoryPool* pool_; - std::shared_ptr reader_; + class RecordBatchStreamImpl; + std::unique_ptr impl_; }; /// \brief A reader for IPC payloads uploaded by a client diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index a4a1510c436ad..66af90a25d449 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -28,6 +28,7 @@ #include #include "arrow/io/test-common.h" +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/json-integration.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" @@ -127,7 +128,8 @@ int main(int argc, char** argv) { ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); std::shared_ptr schema; - ABORT_NOT_OK(info->GetSchema(&schema)); + arrow::ipc::DictionaryMemo dict_memo; + ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema)); if (info->endpoints().size() == 0) { std::cerr << "No endpoints returned from Flight server." << std::endl; diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 3625bc5633f0e..77c8009e8bfa1 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -22,6 +22,7 @@ #include #include "arrow/io/memory.h" +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/reader.h" #include "arrow/status.h" @@ -69,15 +70,14 @@ std::string FlightDescriptor::ToString() const { return ss.str(); } -Status FlightInfo::GetSchema(std::shared_ptr* out) const { +Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo, + std::shared_ptr* out) const { if (reconstructed_schema_) { *out = schema_; return Status::OK(); } - /// XXX(wesm): arrow::ipc::ReadSchema in its current form will not suffice - /// for reading schemas with dictionaries. See ARROW-3144 io::BufferReader schema_reader(data_.schema); - RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, &schema_)); + RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo, &schema_)); reconstructed_schema_ = true; *out = schema_; return Status::OK(); diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 0c09766298a68..8e85a41b3d728 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -35,6 +35,12 @@ class Buffer; class Schema; class Status; +namespace ipc { + +class DictionaryMemo; + +} + namespace flight { /// \brief A type of action that can be performed with the DoAction RPC @@ -151,9 +157,14 @@ class FlightInfo { explicit FlightInfo(Data&& data) : data_(std::move(data)), reconstructed_schema_(false) {} - /// Deserialize the Arrow schema of the dataset, to be passed to each call to - /// DoGet - Status GetSchema(std::shared_ptr* out) const; + /// \brief Deserialize the Arrow schema of the dataset, to be passed + /// to each call to DoGet. Populate any dictionary encoded fields + /// into a DictionaryMemo for bookkeeping + /// \param[in,out] dictionary_memo for dictionary bookkeeping, will + /// be modified + /// \param[out] out the reconstructed Schema + Status GetSchema(ipc::DictionaryMemo* dictionary_memo, + std::shared_ptr* out) const; const std::string& serialized_schema() const { return data_.schema; } diff --git a/cpp/src/arrow/gpu/cuda-test.cc b/cpp/src/arrow/gpu/cuda-test.cc index 9a10b2743c0be..5a1f63376e58e 100644 --- a/cpp/src/arrow/gpu/cuda-test.cc +++ b/cpp/src/arrow/gpu/cuda-test.cc @@ -22,6 +22,7 @@ #include "gtest/gtest.h" #include "arrow/ipc/api.h" +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/test-common.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" @@ -339,7 +340,8 @@ TEST_F(TestCudaArrowIpc, BasicWriteRead) { std::shared_ptr cpu_batch; io::BufferReader cpu_reader(host_buffer); - ASSERT_OK(ipc::ReadRecordBatch(batch->schema(), &cpu_reader, &cpu_batch)); + ipc::DictionaryMemo unused_memo; + ASSERT_OK(ipc::ReadRecordBatch(batch->schema(), &unused_memo, &cpu_reader, &cpu_batch)); CompareBatch(*batch, *cpu_batch); } diff --git a/cpp/src/arrow/gpu/cuda_arrow_ipc.cc b/cpp/src/arrow/gpu/cuda_arrow_ipc.cc index b4d8744cb0bd0..34488a1513a55 100644 --- a/cpp/src/arrow/gpu/cuda_arrow_ipc.cc +++ b/cpp/src/arrow/gpu/cuda_arrow_ipc.cc @@ -24,6 +24,7 @@ #include "arrow/buffer.h" #include "arrow/ipc/Message_generated.h" +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/message.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" @@ -102,7 +103,8 @@ Status ReadRecordBatch(const std::shared_ptr& schema, } // Zero-copy read on device memory - return ipc::ReadRecordBatch(*message, schema, out); + ipc::DictionaryMemo unused_memo; + return ipc::ReadRecordBatch(*message, schema, &unused_memo, out); } } // namespace cuda diff --git a/cpp/src/arrow/ipc/dictionary.cc b/cpp/src/arrow/ipc/dictionary.cc index aa0d9085f5a8f..a639f13644644 100644 --- a/cpp/src/arrow/ipc/dictionary.cc +++ b/cpp/src/arrow/ipc/dictionary.cc @@ -18,17 +18,33 @@ #include "arrow/ipc/dictionary.h" #include +#include #include #include #include +#include "arrow/array.h" +#include "arrow/record_batch.h" #include "arrow/status.h" +#include "arrow/type.h" namespace arrow { namespace ipc { +// ---------------------------------------------------------------------- + DictionaryMemo::DictionaryMemo() {} +Status DictionaryMemo::GetDictionaryType(int64_t id, + std::shared_ptr* type) const { + auto it = id_to_type_.find(id); + if (it == id_to_type_.end()) { + return Status::KeyError("No record of dictionary type with id ", id); + } + *type = it->second; + return Status::OK(); +} + // Returns KeyError if dictionary not found Status DictionaryMemo::GetDictionary(int64_t id, std::shared_ptr* dictionary) const { @@ -40,41 +56,129 @@ Status DictionaryMemo::GetDictionary(int64_t id, return Status::OK(); } -int64_t DictionaryMemo::GetId(const std::shared_ptr& dictionary) { - intptr_t address = reinterpret_cast(dictionary.get()); - auto it = dictionary_to_id_.find(address); - if (it != dictionary_to_id_.end()) { - // Dictionary already observed, return the id - return it->second; +Status DictionaryMemo::AddFieldInternal(int64_t id, const std::shared_ptr& field) { + field_to_id_[field.get()] = id; + + if (field->type()->id() != Type::DICTIONARY) { + return Status::Invalid("Field type was not DictionaryType", + field->type()->ToString()); + } + + std::shared_ptr value_type = + static_cast(*field->type()).value_type(); + + // Add the value type for the dictionary + auto it = id_to_type_.find(id); + if (it != id_to_type_.end()) { + if (!it->second->Equals(*value_type)) { + return Status::Invalid("Field with dictionary id 0 seen but had type ", + it->second->ToString(), "and not ", value_type->ToString()); + } + } else { + // Newly-observed dictionary id + id_to_type_[id] = value_type; + } + return Status::OK(); +} + +Status DictionaryMemo::GetOrAssignId(const std::shared_ptr& field, int64_t* out) { + auto it = field_to_id_.find(field.get()); + if (it != field_to_id_.end()) { + // Field already observed, return the id + *out = it->second; + } else { + int64_t new_id = *out = static_cast(field_to_id_.size()); + RETURN_NOT_OK(AddFieldInternal(new_id, field)); + } + return Status::OK(); +} + +Status DictionaryMemo::AddField(int64_t id, const std::shared_ptr& field) { + auto it = field_to_id_.find(field.get()); + if (it != field_to_id_.end()) { + return Status::KeyError("Field is already in memo: ", field->ToString()); } else { - int64_t new_id = static_cast(dictionary_to_id_.size()); - dictionary_to_id_[address] = new_id; - id_to_dictionary_[new_id] = dictionary; - return new_id; + RETURN_NOT_OK(AddFieldInternal(id, field)); + return Status::OK(); } } -bool DictionaryMemo::HasDictionary(const std::shared_ptr& dictionary) const { - intptr_t address = reinterpret_cast(dictionary.get()); - auto it = dictionary_to_id_.find(address); - return it != dictionary_to_id_.end(); +Status DictionaryMemo::GetId(const Field& field, int64_t* id) const { + auto it = field_to_id_.find(&field); + if (it != field_to_id_.end()) { + // Field recorded, return the id + *id = it->second; + return Status::OK(); + } else { + return Status::KeyError("Field with memory address ", + reinterpret_cast(&field), " not found"); + } } -bool DictionaryMemo::HasDictionaryId(int64_t id) const { +bool DictionaryMemo::HasDictionary(const Field& field) const { + auto it = field_to_id_.find(&field); + return it != field_to_id_.end(); +} + +bool DictionaryMemo::HasDictionary(int64_t id) const { auto it = id_to_dictionary_.find(id); return it != id_to_dictionary_.end(); } Status DictionaryMemo::AddDictionary(int64_t id, const std::shared_ptr& dictionary) { - if (HasDictionaryId(id)) { + if (HasDictionary(id)) { return Status::KeyError("Dictionary with id ", id, " already exists"); } - intptr_t address = reinterpret_cast(dictionary.get()); id_to_dictionary_[id] = dictionary; - dictionary_to_id_[address] = id; return Status::OK(); } +// ---------------------------------------------------------------------- +// CollectDictionaries implementation + +struct DictionaryCollector { + DictionaryMemo* dictionary_memo_; + + Status WalkChildren(const DataType& type, const Array& array) { + for (int i = 0; i < type.num_children(); ++i) { + auto boxed_child = MakeArray(array.data()->child_data[i]); + RETURN_NOT_OK(Visit(type.child(i), *boxed_child)); + } + return Status::OK(); + } + + Status Visit(const std::shared_ptr& field, const Array& array) { + auto type = array.type(); + if (type->id() == Type::DICTIONARY) { + const auto& dict_array = static_cast(array); + auto dictionary = dict_array.dictionary(); + int64_t id = -1; + RETURN_NOT_OK(dictionary_memo_->GetOrAssignId(field, &id)); + RETURN_NOT_OK(dictionary_memo_->AddDictionary(id, dictionary)); + + // Traverse the dictionary to gather any nested dictionaries + const auto& dict_type = static_cast(*type); + RETURN_NOT_OK(WalkChildren(*dict_type.value_type(), *dictionary)); + } else { + RETURN_NOT_OK(WalkChildren(*type, array)); + } + return Status::OK(); + } + + Status Collect(const RecordBatch& batch) { + const Schema& schema = *batch.schema(); + for (int i = 0; i < schema.num_fields(); ++i) { + RETURN_NOT_OK(Visit(schema.field(i), *batch.column(i))); + } + return Status::OK(); + } +}; + +Status CollectDictionaries(const RecordBatch& batch, DictionaryMemo* memo) { + DictionaryCollector collector{memo}; + return collector.Collect(batch); +} + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/dictionary.h b/cpp/src/arrow/ipc/dictionary.h index 69ea4855f7809..787cd0ddd5a03 100644 --- a/cpp/src/arrow/ipc/dictionary.h +++ b/cpp/src/arrow/ipc/dictionary.h @@ -31,31 +31,47 @@ namespace arrow { class Array; +class DataType; class Field; +class RecordBatch; namespace ipc { using DictionaryMap = std::unordered_map>; -using DictionaryTypeMap = std::unordered_map>; -/// \brief Memoization data structure for handling shared dictionaries +/// \brief Memoization data structure for assigning id numbers to +/// dictionaries and tracking their current state through possible +/// deltas in an IPC stream class ARROW_EXPORT DictionaryMemo { public: DictionaryMemo(); DictionaryMemo(DictionaryMemo&&) = default; DictionaryMemo& operator=(DictionaryMemo&&) = default; - /// \brief Returns KeyError if dictionary not found + /// \brief Return current dictionary corresponding to a particular + /// id. Returns KeyError if id not found Status GetDictionary(int64_t id, std::shared_ptr* dictionary) const; + /// \brief Return dictionary value type corresponding to a + /// particular dictionary id. This permits multiple fields to + /// reference the same dictionary in IPC and JSON + Status GetDictionaryType(int64_t id, std::shared_ptr* type) const; + /// \brief Return id for dictionary, computing new id if necessary - int64_t GetId(const std::shared_ptr& dictionary); + Status GetOrAssignId(const std::shared_ptr& field, int64_t* out); + + /// \brief Return id for dictionary if it exists, otherwise return + /// KeyError + Status GetId(const Field& type, int64_t* id) const; - /// \brief Return true if dictionary array object is in this memo - bool HasDictionary(const std::shared_ptr& dictionary) const; + /// \brief Return true if dictionary for type is in this memo + bool HasDictionary(const Field& type) const; /// \brief Return true if we have a dictionary for the input id - bool HasDictionaryId(int64_t id) const; + bool HasDictionary(int64_t id) const; + + /// \brief Add field to the memo, return KeyError if already present + Status AddField(int64_t id, const std::shared_ptr& field); /// \brief Add a dictionary to the memo with a particular id. Returns /// KeyError if that dictionary already exists @@ -63,20 +79,27 @@ class ARROW_EXPORT DictionaryMemo { const DictionaryMap& id_to_dictionary() const { return id_to_dictionary_; } - /// \brief The number of dictionaries stored in the memo - int size() const { return static_cast(id_to_dictionary_.size()); } + /// \brief The number of fields tracked in the memo + int num_fields() const { return static_cast(field_to_id_.size()); } + int num_dictionaries() const { return static_cast(id_to_dictionary_.size()); } private: - // Dictionary memory addresses, to track whether a dictionary has been seen - // before - std::unordered_map dictionary_to_id_; + Status AddFieldInternal(int64_t id, const std::shared_ptr& field); + + // Dictionary memory addresses, to track whether a particular + // dictionary-encoded field has been seen before + std::unordered_map field_to_id_; // Map of dictionary id to dictionary array DictionaryMap id_to_dictionary_; + std::unordered_map> id_to_type_; ARROW_DISALLOW_COPY_AND_ASSIGN(DictionaryMemo); }; +ARROW_EXPORT +Status CollectDictionaries(const RecordBatch& batch, DictionaryMemo* memo); + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/feather.cc b/cpp/src/arrow/ipc/feather.cc index d28bf7512999a..5965d3616314b 100644 --- a/cpp/src/arrow/ipc/feather.cc +++ b/cpp/src/arrow/ipc/feather.cc @@ -313,7 +313,8 @@ class TableReader::TableReaderImpl { } Status GetDataType(const fbs::PrimitiveArray* values, fbs::TypeMetadata metadata_type, - const void* metadata, std::shared_ptr* out) { + const void* metadata, std::shared_ptr* out, + std::shared_ptr* out_dictionary = nullptr) { #define PRIMITIVE_CASE(CAP_TYPE, FACTORY_FUNC) \ case fbs::Type_##CAP_TYPE: \ *out = FACTORY_FUNC(); \ @@ -326,11 +327,10 @@ class TableReader::TableReaderImpl { std::shared_ptr index_type; RETURN_NOT_OK(GetDataType(values, fbs::TypeMetadata_NONE, nullptr, &index_type)); - std::shared_ptr levels; RETURN_NOT_OK( - LoadValues(meta->levels(), fbs::TypeMetadata_NONE, nullptr, &levels)); + LoadValues(meta->levels(), fbs::TypeMetadata_NONE, nullptr, out_dictionary)); - *out = std::make_shared(index_type, levels, meta->ordered()); + *out = dictionary(index_type, (*out_dictionary)->type(), meta->ordered()); break; } case fbs::TypeMetadata_TimestampMetadata: { @@ -385,7 +385,8 @@ class TableReader::TableReaderImpl { Status LoadValues(const fbs::PrimitiveArray* meta, fbs::TypeMetadata metadata_type, const void* metadata, std::shared_ptr* out) { std::shared_ptr type; - RETURN_NOT_OK(GetDataType(meta, metadata_type, metadata, &type)); + std::shared_ptr dictionary; + RETURN_NOT_OK(GetDataType(meta, metadata_type, metadata, &type, &dictionary)); std::vector> buffers; @@ -415,6 +416,7 @@ class TableReader::TableReaderImpl { auto arr_data = ArrayData::Make(type, meta->length(), std::move(buffers), meta->null_count()); + arr_data->dictionary = dictionary; *out = MakeArray(arr_data); return Status::OK(); } @@ -772,8 +774,7 @@ class TableWriter::TableWriterImpl : public ArrayVisitor { ArrayMetadata levels_meta; std::shared_ptr sanitized_dictionary; - RETURN_NOT_OK( - SanitizeUnsupportedTypes(*dict_type.dictionary(), &sanitized_dictionary)); + RETURN_NOT_OK(SanitizeUnsupportedTypes(*values.dictionary(), &sanitized_dictionary)); RETURN_NOT_OK(WriteArray(*sanitized_dictionary, &levels_meta)); current_column_->SetCategory(levels_meta, dict_type.ordered()); return Status::OK(); diff --git a/cpp/src/arrow/ipc/json-integration.cc b/cpp/src/arrow/ipc/json-integration.cc index 839890ca263b9..3eb18e064ac5a 100644 --- a/cpp/src/arrow/ipc/json-integration.cc +++ b/cpp/src/arrow/ipc/json-integration.cc @@ -21,8 +21,10 @@ #include #include +#include "arrow/array.h" #include "arrow/buffer.h" #include "arrow/io/file.h" +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/json-internal.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" @@ -42,17 +44,34 @@ namespace json { class JsonWriter::JsonWriterImpl { public: - explicit JsonWriterImpl(const std::shared_ptr& schema) : schema_(schema) { + explicit JsonWriterImpl(const std::shared_ptr& schema) + : schema_(schema), first_batch_written_(false) { writer_.reset(new RjWriter(string_buffer_)); } Status Start() { writer_->StartObject(); - RETURN_NOT_OK(json::WriteSchema(*schema_, writer_.get())); + RETURN_NOT_OK(json::WriteSchema(*schema_, &dictionary_memo_, writer_.get())); + return Status::OK(); + } + + Status FirstRecordBatch(const RecordBatch& batch) { + RETURN_NOT_OK(CollectDictionaries(batch, &dictionary_memo_)); + + // Write dictionaries, if any + if (dictionary_memo_.num_dictionaries() > 0) { + writer_->Key("dictionaries"); + writer_->StartArray(); + for (const auto& entry : dictionary_memo_.id_to_dictionary()) { + RETURN_NOT_OK(WriteDictionary(entry.first, entry.second, writer_.get())); + } + writer_->EndArray(); + } // Record batches writer_->Key("batches"); writer_->StartArray(); + first_batch_written_ = true; return Status::OK(); } @@ -66,11 +85,18 @@ class JsonWriter::JsonWriterImpl { Status WriteRecordBatch(const RecordBatch& batch) { DCHECK_EQ(batch.num_columns(), schema_->num_fields()); + + if (!first_batch_written_) { + RETURN_NOT_OK(FirstRecordBatch(batch)); + } return json::WriteRecordBatch(batch, writer_.get()); } private: std::shared_ptr schema_; + DictionaryMemo dictionary_memo_; + + bool first_batch_written_; rj::StringBuffer string_buffer_; std::unique_ptr writer_; @@ -109,7 +135,7 @@ class JsonReader::JsonReaderImpl { return Status::IOError("JSON parsing failed"); } - RETURN_NOT_OK(json::ReadSchema(doc_, pool_, &schema_)); + RETURN_NOT_OK(json::ReadSchema(doc_, pool_, &dictionary_memo_, &schema_)); auto it = doc_.FindMember("batches"); RETURN_NOT_ARRAY("batches", it, doc_); @@ -118,12 +144,13 @@ class JsonReader::JsonReaderImpl { return Status::OK(); } - Status ReadRecordBatch(int i, std::shared_ptr* batch) const { + Status ReadRecordBatch(int i, std::shared_ptr* batch) { DCHECK_GE(i, 0) << "i out of bounds"; DCHECK_LT(i, static_cast(record_batches_->GetArray().Size())) << "i out of bounds"; - return json::ReadRecordBatch(record_batches_->GetArray()[i], schema_, pool_, batch); + return json::ReadRecordBatch(record_batches_->GetArray()[i], schema_, + &dictionary_memo_, pool_, batch); } std::shared_ptr schema() const { return schema_; } @@ -139,6 +166,7 @@ class JsonReader::JsonReaderImpl { const rj::Value* record_batches_; std::shared_ptr schema_; + DictionaryMemo dictionary_memo_; }; JsonReader::JsonReader(MemoryPool* pool, const std::shared_ptr& data) { diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index a26eac016fbb2..87d5b911e8c66 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -60,9 +60,6 @@ namespace ipc { namespace internal { namespace json { -using ::arrow::ipc::DictionaryMemo; -using ::arrow::ipc::DictionaryTypeMap; - static std::string GetFloatingPrecisionName(FloatingPoint::Precision precision) { switch (precision) { case FloatingPoint::HALF: @@ -95,8 +92,9 @@ static std::string GetTimeUnitName(TimeUnit::type unit) { class SchemaWriter { public: - explicit SchemaWriter(const Schema& schema, RjWriter* writer) - : schema_(schema), writer_(writer) {} + explicit SchemaWriter(const Schema& schema, DictionaryMemo* dictionary_memo, + RjWriter* writer) + : schema_(schema), dictionary_memo_(dictionary_memo), writer_(writer) {} Status Write() { writer_->Key("schema"); @@ -104,45 +102,20 @@ class SchemaWriter { writer_->Key("fields"); writer_->StartArray(); for (const std::shared_ptr& field : schema_.fields()) { - RETURN_NOT_OK(VisitField(*field)); + RETURN_NOT_OK(VisitField(field)); } writer_->EndArray(); writer_->EndObject(); - - // Write dictionaries, if any - if (dictionary_memo_.size() > 0) { - writer_->Key("dictionaries"); - writer_->StartArray(); - for (const auto& entry : dictionary_memo_.id_to_dictionary()) { - RETURN_NOT_OK(WriteDictionary(entry.first, entry.second)); - } - writer_->EndArray(); - } - return Status::OK(); - } - - Status WriteDictionary(int64_t id, const std::shared_ptr& dictionary) { - writer_->StartObject(); - writer_->Key("id"); - writer_->Int(static_cast(id)); - writer_->Key("data"); - - // Make a dummy record batch. A bit tedious as we have to make a schema - auto schema = ::arrow::schema({arrow::field("dictionary", dictionary->type())}); - auto batch = RecordBatch::Make(schema, dictionary->length(), {dictionary}); - RETURN_NOT_OK(WriteRecordBatch(*batch, writer_)); - writer_->EndObject(); return Status::OK(); } - Status WriteDictionaryMetadata(const DictionaryType& type) { - int64_t dictionary_id = dictionary_memo_.GetId(type.dictionary()); + Status WriteDictionaryMetadata(int64_t id, const DictionaryType& type) { writer_->Key("dictionary"); // Emulate DictionaryEncoding from Schema.fbs writer_->StartObject(); writer_->Key("id"); - writer_->Int(static_cast(dictionary_id)); + writer_->Int(static_cast(id)); writer_->Key("indexType"); writer_->StartObject(); @@ -156,16 +129,16 @@ class SchemaWriter { return Status::OK(); } - Status VisitField(const Field& field) { + Status VisitField(const std::shared_ptr& field) { writer_->StartObject(); writer_->Key("name"); - writer_->String(field.name().c_str()); + writer_->String(field->name().c_str()); writer_->Key("nullable"); - writer_->Bool(field.nullable()); + writer_->Bool(field->nullable()); - const DataType& type = *field.type(); + const DataType& type = *field->type(); // Visit the type writer_->Key("type"); @@ -175,10 +148,10 @@ class SchemaWriter { if (type.id() == Type::DICTIONARY) { const auto& dict_type = checked_cast(type); - RETURN_NOT_OK(WriteDictionaryMetadata(dict_type)); - - const DataType& dictionary_type = *dict_type.dictionary()->type(); - RETURN_NOT_OK(WriteChildren(dictionary_type.children())); + int64_t dictionary_id = -1; + RETURN_NOT_OK(dictionary_memo_->GetOrAssignId(field, &dictionary_id)); + RETURN_NOT_OK(WriteDictionaryMetadata(dictionary_id, dict_type)); + RETURN_NOT_OK(WriteChildren(dict_type.value_type()->children())); } else { RETURN_NOT_OK(WriteChildren(type.children())); } @@ -316,7 +289,7 @@ class SchemaWriter { writer_->Key("children"); writer_->StartArray(); for (const std::shared_ptr& field : children) { - RETURN_NOT_OK(VisitField(*field)); + RETURN_NOT_OK(VisitField(field)); } writer_->EndArray(); return Status::OK(); @@ -367,17 +340,14 @@ class SchemaWriter { return Status::OK(); } - Status Visit(const DictionaryType& type) { - return VisitType(*type.dictionary()->type()); - } + Status Visit(const DictionaryType& type) { return VisitType(*type.value_type()); } // Default case Status Visit(const DataType& type) { return Status::NotImplemented(type.name()); } private: - DictionaryMemo dictionary_memo_; - const Schema& schema_; + DictionaryMemo* dictionary_memo_; RjWriter* writer_; }; @@ -942,11 +912,10 @@ static Status GetType(const RjObject& json_type, return Status::OK(); } -static Status GetField(const rj::Value& obj, const DictionaryMemo* dictionary_memo, +static Status GetField(const rj::Value& obj, DictionaryMemo* dictionary_memo, std::shared_ptr* field); -static Status GetFieldsFromArray(const rj::Value& obj, - const DictionaryMemo* dictionary_memo, +static Status GetFieldsFromArray(const rj::Value& obj, DictionaryMemo* dictionary_memo, std::vector>* fields) { const auto& values = obj.GetArray(); @@ -978,7 +947,7 @@ static Status ParseDictionary(const RjObject& obj, int64_t* id, bool* is_ordered return GetInteger(json_index_type, index_type); } -static Status GetField(const rj::Value& obj, const DictionaryMemo* dictionary_memo, +static Status GetField(const rj::Value& obj, DictionaryMemo* dictionary_memo, std::shared_ptr* field) { if (!obj.IsObject()) { return Status::Invalid("Field was not a JSON object"); @@ -991,10 +960,20 @@ static Status GetField(const rj::Value& obj, const DictionaryMemo* dictionary_me RETURN_NOT_OK(GetObjectBool(json_field, "nullable", &nullable)); std::shared_ptr type; + const auto& it_type = json_field.FindMember("type"); + RETURN_NOT_OBJECT("type", it_type, json_field); + + const auto& it_children = json_field.FindMember("children"); + RETURN_NOT_ARRAY("children", it_children, json_field); + + std::vector> children; + RETURN_NOT_OK(GetFieldsFromArray(it_children->value, dictionary_memo, &children)); + RETURN_NOT_OK(GetType(it_type->value.GetObject(), children, &type)); const auto& it_dictionary = json_field.FindMember("dictionary"); if (dictionary_memo != nullptr && it_dictionary != json_field.MemberEnd()) { - // Field is dictionary encoded. We must have already + // Parse dictionary id in JSON and add dictionary field to the + // memo, and parse the dictionaries later RETURN_NOT_OBJECT("dictionary", it_dictionary, json_field); int64_t dictionary_id = -1; bool is_ordered; @@ -1002,26 +981,13 @@ static Status GetField(const rj::Value& obj, const DictionaryMemo* dictionary_me RETURN_NOT_OK(ParseDictionary(it_dictionary->value.GetObject(), &dictionary_id, &is_ordered, &index_type)); - std::shared_ptr dictionary; - RETURN_NOT_OK(dictionary_memo->GetDictionary(dictionary_id, &dictionary)); - - type = std::make_shared(index_type, dictionary, is_ordered); + type = ::arrow::dictionary(index_type, type, is_ordered); + *field = ::arrow::field(name, type, nullable); + RETURN_NOT_OK(dictionary_memo->AddField(dictionary_id, *field)); } else { - // If the dictionary_memo was not passed, or if the field is not dictionary - // encoded, we are interested in the complete type including all children - - const auto& it_type = json_field.FindMember("type"); - RETURN_NOT_OBJECT("type", it_type, json_field); - - const auto& it_children = json_field.FindMember("children"); - RETURN_NOT_ARRAY("children", it_children, json_field); - - std::vector> children; - RETURN_NOT_OK(GetFieldsFromArray(it_children->value, dictionary_memo, &children)); - RETURN_NOT_OK(GetType(it_type->value.GetObject(), children, &type)); + *field = ::arrow::field(name, type, nullable); } - *field = std::make_shared(name, type, nullable); return Status::OK(); } @@ -1055,46 +1021,24 @@ UnboxValue(const rj::Value& val) { class ArrayReader { public: - explicit ArrayReader(const rj::Value& json_array, const std::shared_ptr& type, - MemoryPool* pool) - : json_array_(json_array), type_(type), pool_(pool) {} - - Status ParseTypeValues(const DataType& type); - - Status GetValidityBuffer(const std::vector& is_valid, int32_t* null_count, - std::shared_ptr* validity_buffer) { - int length = static_cast(is_valid.size()); - - std::shared_ptr out_buffer; - RETURN_NOT_OK(AllocateEmptyBitmap(pool_, length, &out_buffer)); - uint8_t* bitmap = out_buffer->mutable_data(); - - *null_count = 0; - for (int i = 0; i < length; ++i) { - if (!is_valid[i]) { - ++(*null_count); - continue; - } - BitUtil::SetBit(bitmap, i); - } - - *validity_buffer = out_buffer; - return Status::OK(); - } + ArrayReader(const RjObject& obj, MemoryPool* pool, const std::shared_ptr& field, + DictionaryMemo* dictionary_memo) + : obj_(obj), + pool_(pool), + field_(field), + type_(field->type()), + dictionary_memo_(dictionary_memo) {} template - typename std::enable_if< - std::is_base_of::value || std::is_base_of::value || - std::is_base_of::value || - std::is_base_of::value || std::is_base_of::value || - std::is_base_of::value || - std::is_base_of::value, - Status>::type + typename std::enable_if::value || + is_temporal_type::value || + std::is_base_of::value, + Status>::type Visit(const T& type) { typename TypeTraits::BuilderType builder(type_, pool_); - const auto& json_data = obj_->FindMember(kData); - RETURN_NOT_ARRAY(kData, json_data, *obj_); + const auto& json_data = obj_.FindMember(kData); + RETURN_NOT_ARRAY(kData, json_data, obj_); const auto& json_data_arr = json_data->value.GetArray(); @@ -1117,8 +1061,8 @@ class ArrayReader { const T& type) { typename TypeTraits::BuilderType builder(pool_); - const auto& json_data = obj_->FindMember(kData); - RETURN_NOT_ARRAY(kData, json_data, *obj_); + const auto& json_data = obj_.FindMember(kData); + RETURN_NOT_ARRAY(kData, json_data, obj_); const auto& json_data_arr = json_data->value.GetArray(); @@ -1158,8 +1102,8 @@ class ArrayReader { Status Visit(const DayTimeIntervalType& type) { DayTimeIntervalBuilder builder(pool_); - const auto& json_data = obj_->FindMember(kData); - RETURN_NOT_ARRAY(kData, json_data, *obj_); + const auto& json_data = obj_.FindMember(kData); + RETURN_NOT_ARRAY(kData, json_data, obj_); const auto& json_data_arr = json_data->value.GetArray(); @@ -1189,8 +1133,8 @@ class ArrayReader { Visit(const T& type) { typename TypeTraits::BuilderType builder(type_, pool_); - const auto& json_data = obj_->FindMember(kData); - RETURN_NOT_ARRAY(kData, json_data, *obj_); + const auto& json_data = obj_.FindMember(kData); + RETURN_NOT_ARRAY(kData, json_data, obj_); const auto& json_data_arr = json_data->value.GetArray(); @@ -1230,8 +1174,8 @@ class ArrayReader { const T& type) { typename TypeTraits::BuilderType builder(type_, pool_); - const auto& json_data = obj_->FindMember(kData); - RETURN_NOT_ARRAY(kData, json_data, *obj_); + const auto& json_data = obj_.FindMember(kData); + RETURN_NOT_ARRAY(kData, json_data, obj_); const auto& json_data_arr = json_data->value.GetArray(); @@ -1277,14 +1221,14 @@ class ArrayReader { std::shared_ptr validity_buffer; RETURN_NOT_OK(GetValidityBuffer(is_valid_, &null_count, &validity_buffer)); - const auto& json_offsets = obj_->FindMember("OFFSET"); - RETURN_NOT_ARRAY("OFFSET", json_offsets, *obj_); + const auto& json_offsets = obj_.FindMember("OFFSET"); + RETURN_NOT_ARRAY("OFFSET", json_offsets, obj_); std::shared_ptr offsets_buffer; RETURN_NOT_OK(GetIntArray(json_offsets->value.GetArray(), length_ + 1, &offsets_buffer)); std::vector> children; - RETURN_NOT_OK(GetChildren(*obj_, type, &children)); + RETURN_NOT_OK(GetChildren(obj_, type, &children)); DCHECK_EQ(children.size(), 1); result_ = std::make_shared(type_, length_, offsets_buffer, children[0], @@ -1299,7 +1243,7 @@ class ArrayReader { RETURN_NOT_OK(GetValidityBuffer(is_valid_, &null_count, &validity_buffer)); std::vector> children; - RETURN_NOT_OK(GetChildren(*obj_, type, &children)); + RETURN_NOT_OK(GetChildren(obj_, type, &children)); DCHECK_EQ(children.size(), 1); DCHECK_EQ(children[0]->length(), type.list_size() * length_); @@ -1315,7 +1259,7 @@ class ArrayReader { RETURN_NOT_OK(GetValidityBuffer(is_valid_, &null_count, &validity_buffer)); std::vector> fields; - RETURN_NOT_OK(GetChildren(*obj_, type, &fields)); + RETURN_NOT_OK(GetChildren(obj_, type, &fields)); result_ = std::make_shared(type_, length_, fields, validity_buffer, null_count); @@ -1332,20 +1276,20 @@ class ArrayReader { RETURN_NOT_OK(GetValidityBuffer(is_valid_, &null_count, &validity_buffer)); - const auto& json_type_ids = obj_->FindMember("TYPE_ID"); - RETURN_NOT_ARRAY("TYPE_ID", json_type_ids, *obj_); + const auto& json_type_ids = obj_.FindMember("TYPE_ID"); + RETURN_NOT_ARRAY("TYPE_ID", json_type_ids, obj_); RETURN_NOT_OK( GetIntArray(json_type_ids->value.GetArray(), length_, &type_id_buffer)); if (type.mode() == UnionMode::DENSE) { - const auto& json_offsets = obj_->FindMember("OFFSET"); - RETURN_NOT_ARRAY("OFFSET", json_offsets, *obj_); + const auto& json_offsets = obj_.FindMember("OFFSET"); + RETURN_NOT_ARRAY("OFFSET", json_offsets, obj_); RETURN_NOT_OK( GetIntArray(json_offsets->value.GetArray(), length_, &offsets_buffer)); } std::vector> children; - RETURN_NOT_OK(GetChildren(*obj_, type, &children)); + RETURN_NOT_OK(GetChildren(obj_, type, &children)); result_ = std::make_shared(type_, length_, children, type_id_buffer, offsets_buffer, validity_buffer, null_count); @@ -1359,20 +1303,47 @@ class ArrayReader { } Status Visit(const DictionaryType& type) { - // This stores the indices in result_ - // - // XXX(wesm): slight hack - auto dict_type = type_; - type_ = type.index_type(); - RETURN_NOT_OK(ParseTypeValues(*type_)); - type_ = dict_type; - result_ = std::make_shared(type_, result_); + std::shared_ptr indices; + + ArrayReader parser(obj_, pool_, ::arrow::field("indices", type.index_type()), + dictionary_memo_); + RETURN_NOT_OK(parser.Parse(&indices)); + + // Look up dictionary + int64_t dictionary_id = -1; + RETURN_NOT_OK(dictionary_memo_->GetId(*field_, &dictionary_id)); + + std::shared_ptr dictionary; + RETURN_NOT_OK(dictionary_memo_->GetDictionary(dictionary_id, &dictionary)); + + result_ = std::make_shared(field_->type(), indices, dictionary); return Status::OK(); } // Default case Status Visit(const DataType& type) { return Status::NotImplemented(type.name()); } + Status GetValidityBuffer(const std::vector& is_valid, int32_t* null_count, + std::shared_ptr* validity_buffer) { + int length = static_cast(is_valid.size()); + + std::shared_ptr out_buffer; + RETURN_NOT_OK(AllocateEmptyBitmap(pool_, length, &out_buffer)); + uint8_t* bitmap = out_buffer->mutable_data(); + + *null_count = 0; + for (int i = 0; i < length; ++i) { + if (!is_valid[i]) { + ++(*null_count); + continue; + } + BitUtil::SetBit(bitmap, i); + } + + *validity_buffer = out_buffer; + return Status::OK(); + } + Status GetChildren(const RjObject& obj, const DataType& type, std::vector>* array) { const auto& json_children = obj.FindMember("children"); @@ -1395,25 +1366,19 @@ class ArrayReader { DCHECK_EQ(it->value.GetString(), child_field->name()); std::shared_ptr child; - RETURN_NOT_OK(ReadArray(pool_, json_children_arr[i], child_field->type(), &child)); + RETURN_NOT_OK( + ReadArray(pool_, json_children_arr[i], child_field, dictionary_memo_, &child)); array->emplace_back(child); } return Status::OK(); } - Status GetArray(std::shared_ptr* out) { - if (!json_array_.IsObject()) { - return Status::Invalid("Array element was not a JSON object"); - } - - auto obj = json_array_.GetObject(); - obj_ = &obj; - - RETURN_NOT_OK(GetObjectInt(obj, "count", &length_)); + Status Parse(std::shared_ptr* out) { + RETURN_NOT_OK(GetObjectInt(obj_, "count", &length_)); - const auto& json_valid_iter = obj.FindMember("VALIDITY"); - RETURN_NOT_ARRAY("VALIDITY", json_valid_iter, obj); + const auto& json_valid_iter = obj_.FindMember("VALIDITY"); + RETURN_NOT_ARRAY("VALIDITY", json_valid_iter, obj_); const auto& json_validity = json_valid_iter->value.GetArray(); DCHECK_EQ(static_cast(json_validity.Size()), length_); @@ -1422,16 +1387,18 @@ class ArrayReader { is_valid_.push_back(val.GetInt() != 0); } - RETURN_NOT_OK(ParseTypeValues(*type_)); + RETURN_NOT_OK(VisitTypeInline(*type_, this)); + *out = result_; return Status::OK(); } private: - const rj::Value& json_array_; - const RjObject* obj_; - std::shared_ptr type_; + const RjObject& obj_; MemoryPool* pool_; + const std::shared_ptr& field_; + std::shared_ptr type_; + DictionaryMemo* dictionary_memo_; // Parsed common attributes std::vector is_valid_; @@ -1439,74 +1406,40 @@ class ArrayReader { std::shared_ptr result_; }; -Status ArrayReader::ParseTypeValues(const DataType& type) { - return VisitTypeInline(type, this); -} - -Status WriteSchema(const Schema& schema, RjWriter* json_writer) { - SchemaWriter converter(schema, json_writer); +Status WriteSchema(const Schema& schema, DictionaryMemo* dictionary_memo, + RjWriter* json_writer) { + SchemaWriter converter(schema, dictionary_memo, json_writer); return converter.Write(); } -static Status LookForDictionaries(const rj::Value& obj, DictionaryTypeMap* id_to_field) { - const auto& json_field = obj.GetObject(); - - const auto& it_dictionary = json_field.FindMember("dictionary"); - if (it_dictionary == json_field.MemberEnd()) { - // Not dictionary-encoded - return Status::OK(); - } - - // Dictionary encoded. Construct the field and set in the type map - std::shared_ptr dictionary_field; - RETURN_NOT_OK(GetField(obj, nullptr, &dictionary_field)); - - int id; - RETURN_NOT_OK(GetObjectInt(it_dictionary->value.GetObject(), "id", &id)); - (*id_to_field)[id] = dictionary_field; - return Status::OK(); -} - -static Status GetDictionaryTypes(const RjArray& fields, DictionaryTypeMap* id_to_field) { - for (rj::SizeType i = 0; i < fields.Size(); ++i) { - RETURN_NOT_OK(LookForDictionaries(fields[i], id_to_field)); - } - return Status::OK(); -} - -static Status ReadDictionary(const RjObject& obj, const DictionaryTypeMap& id_to_field, - MemoryPool* pool, int64_t* dictionary_id, - std::shared_ptr* out) { +static Status ReadDictionary(const RjObject& obj, MemoryPool* pool, + DictionaryMemo* dictionary_memo) { int id; RETURN_NOT_OK(GetObjectInt(obj, "id", &id)); const auto& it_data = obj.FindMember("data"); RETURN_NOT_OBJECT("data", it_data, obj); - auto it = id_to_field.find(id); - if (it == id_to_field.end()) { - return Status::Invalid("No dictionary with id ", id); - } - std::vector> fields = {it->second}; - - // We need a schema for the record batch - auto dummy_schema = std::make_shared(fields); + std::shared_ptr value_type; + RETURN_NOT_OK(dictionary_memo->GetDictionaryType(id, &value_type)); + auto value_field = ::arrow::field("dummy", value_type); - // The dictionary is embedded in a record batch with a single column + // We need placeholder schema and dictionary memo to read the record + // batch, because the dictionary is embedded in a record batch with + // a single column std::shared_ptr batch; - RETURN_NOT_OK(ReadRecordBatch(it_data->value, dummy_schema, pool, &batch)); + DictionaryMemo dummy_memo; + RETURN_NOT_OK(ReadRecordBatch(it_data->value, ::arrow::schema({value_field}), + &dummy_memo, pool, &batch)); if (batch->num_columns() != 1) { return Status::Invalid("Dictionary record batch must only contain one field"); } - - *dictionary_id = id; - *out = batch->column(0); - return Status::OK(); + return dictionary_memo->AddDictionary(id, batch->column(0)); } -static Status ReadDictionaries(const rj::Value& doc, const DictionaryTypeMap& id_to_field, - MemoryPool* pool, DictionaryMemo* dictionary_memo) { +static Status ReadDictionaries(const rj::Value& doc, MemoryPool* pool, + DictionaryMemo* dictionary_memo) { auto it = doc.FindMember("dictionaries"); if (it == doc.MemberEnd()) { // No dictionaries @@ -1518,18 +1451,13 @@ static Status ReadDictionaries(const rj::Value& doc, const DictionaryTypeMap& id for (const rj::Value& val : dictionary_array) { DCHECK(val.IsObject()); - int64_t dictionary_id = -1; - std::shared_ptr dictionary; - RETURN_NOT_OK( - ReadDictionary(val.GetObject(), id_to_field, pool, &dictionary_id, &dictionary)); - - RETURN_NOT_OK(dictionary_memo->AddDictionary(dictionary_id, dictionary)); + RETURN_NOT_OK(ReadDictionary(val.GetObject(), pool, dictionary_memo)); } return Status::OK(); } Status ReadSchema(const rj::Value& json_schema, MemoryPool* pool, - std::shared_ptr* schema) { + DictionaryMemo* dictionary_memo, std::shared_ptr* schema) { auto it = json_schema.FindMember("schema"); RETURN_NOT_OBJECT("schema", it, json_schema); const auto& obj_schema = it->value.GetObject(); @@ -1537,23 +1465,19 @@ Status ReadSchema(const rj::Value& json_schema, MemoryPool* pool, const auto& it_fields = obj_schema.FindMember("fields"); RETURN_NOT_ARRAY("fields", it_fields, obj_schema); - // Determine the dictionary types - DictionaryTypeMap dictionary_types; - RETURN_NOT_OK(GetDictionaryTypes(it_fields->value.GetArray(), &dictionary_types)); + std::vector> fields; + RETURN_NOT_OK(GetFieldsFromArray(it_fields->value, dictionary_memo, &fields)); // Read the dictionaries (if any) and cache in the memo - DictionaryMemo dictionary_memo; - RETURN_NOT_OK(ReadDictionaries(json_schema, dictionary_types, pool, &dictionary_memo)); + RETURN_NOT_OK(ReadDictionaries(json_schema, pool, dictionary_memo)); - std::vector> fields; - RETURN_NOT_OK(GetFieldsFromArray(it_fields->value, &dictionary_memo, &fields)); - - *schema = std::make_shared(fields); + *schema = ::arrow::schema(fields); return Status::OK(); } Status ReadRecordBatch(const rj::Value& json_obj, const std::shared_ptr& schema, - MemoryPool* pool, std::shared_ptr* batch) { + DictionaryMemo* dictionary_memo, MemoryPool* pool, + std::shared_ptr* batch) { DCHECK(json_obj.IsObject()); const auto& batch_obj = json_obj.GetObject(); @@ -1567,14 +1491,29 @@ Status ReadRecordBatch(const rj::Value& json_obj, const std::shared_ptr& std::vector> columns(json_columns.Size()); for (int i = 0; i < static_cast(columns.size()); ++i) { - const std::shared_ptr& type = schema->field(i)->type(); - RETURN_NOT_OK(ReadArray(pool, json_columns[i], type, &columns[i])); + RETURN_NOT_OK( + ReadArray(pool, json_columns[i], schema->field(i), dictionary_memo, &columns[i])); } *batch = RecordBatch::Make(schema, num_rows, columns); return Status::OK(); } +Status WriteDictionary(int64_t id, const std::shared_ptr& dictionary, + RjWriter* writer) { + writer->StartObject(); + writer->Key("id"); + writer->Int(static_cast(id)); + writer->Key("data"); + + // Make a dummy record batch. A bit tedious as we have to make a schema + auto schema = ::arrow::schema({arrow::field("dictionary", dictionary->type())}); + auto batch = RecordBatch::Make(schema, dictionary->length(), {dictionary}); + RETURN_NOT_OK(WriteRecordBatch(*batch, writer)); + writer->EndObject(); + return Status::OK(); +} + Status WriteRecordBatch(const RecordBatch& batch, RjWriter* writer) { writer->StartObject(); writer->Key("count"); @@ -1604,13 +1543,18 @@ Status WriteArray(const std::string& name, const Array& array, RjWriter* json_wr } Status ReadArray(MemoryPool* pool, const rj::Value& json_array, - const std::shared_ptr& type, std::shared_ptr* array) { - ArrayReader converter(json_array, type, pool); - return converter.GetArray(array); + const std::shared_ptr& field, DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { + if (!json_array.IsObject()) { + return Status::Invalid("Array element was not a JSON object"); + } + auto obj = json_array.GetObject(); + ArrayReader parser(obj, pool, field, dictionary_memo); + return parser.Parse(out); } Status ReadArray(MemoryPool* pool, const rj::Value& json_array, const Schema& schema, - std::shared_ptr* array) { + DictionaryMemo* dictionary_memo, std::shared_ptr* array) { if (!json_array.IsObject()) { return Status::Invalid("Element was not a JSON object"); } @@ -1621,20 +1565,12 @@ Status ReadArray(MemoryPool* pool, const rj::Value& json_array, const Schema& sc RETURN_NOT_STRING("name", it_name, json_obj); std::string name = it_name->value.GetString(); - - std::shared_ptr result = nullptr; - for (const std::shared_ptr& field : schema.fields()) { - if (field->name() == name) { - result = field; - break; - } - } - + std::shared_ptr result = schema.GetFieldByName(name); if (result == nullptr) { return Status::KeyError("Field named ", name, " not found in schema"); } - return ReadArray(pool, json_array, result->type(), array); + return ReadArray(pool, json_array, result, dictionary_memo, array); } } // namespace json diff --git a/cpp/src/arrow/ipc/json-internal.h b/cpp/src/arrow/ipc/json-internal.h index b69c8bbac6928..a68e0f6c3ccf3 100644 --- a/cpp/src/arrow/ipc/json-internal.h +++ b/cpp/src/arrow/ipc/json-internal.h @@ -75,28 +75,43 @@ using RjObject = rj::Value::ConstObject; namespace arrow { namespace ipc { + +class DictionaryMemo; + namespace internal { namespace json { -ARROW_EXPORT Status WriteSchema(const Schema& schema, RjWriter* writer); -ARROW_EXPORT Status WriteRecordBatch(const RecordBatch& batch, RjWriter* writer); -ARROW_EXPORT Status WriteArray(const std::string& name, const Array& array, - RjWriter* writer); +/// \brief Append integration test Schema format to rapidjson writer +ARROW_EXPORT +Status WriteSchema(const Schema& schema, DictionaryMemo* dict_memo, RjWriter* writer); + +ARROW_EXPORT +Status WriteDictionary(int64_t id, const std::shared_ptr& dictionary, + RjWriter* writer); + +ARROW_EXPORT +Status WriteRecordBatch(const RecordBatch& batch, RjWriter* writer); + +ARROW_EXPORT +Status WriteArray(const std::string& name, const Array& array, RjWriter* writer); -ARROW_EXPORT Status ReadSchema(const rj::Value& json_obj, MemoryPool* pool, - std::shared_ptr* schema); +ARROW_EXPORT +Status ReadSchema(const rj::Value& json_obj, MemoryPool* pool, + DictionaryMemo* dictionary_memo, std::shared_ptr* schema); -ARROW_EXPORT Status ReadRecordBatch(const rj::Value& json_obj, - const std::shared_ptr& schema, - MemoryPool* pool, - std::shared_ptr* batch); +ARROW_EXPORT +Status ReadRecordBatch(const rj::Value& json_obj, const std::shared_ptr& schema, + DictionaryMemo* dict_memo, MemoryPool* pool, + std::shared_ptr* batch); -ARROW_EXPORT Status ReadArray(MemoryPool* pool, const rj::Value& json_obj, - const std::shared_ptr& type, - std::shared_ptr* array); +ARROW_EXPORT +Status ReadArray(MemoryPool* pool, const rj::Value& json_obj, + const std::shared_ptr& type, DictionaryMemo* dict_memo, + std::shared_ptr* array); -ARROW_EXPORT Status ReadArray(MemoryPool* pool, const rj::Value& json_obj, - const Schema& schema, std::shared_ptr* array); +ARROW_EXPORT +Status ReadArray(MemoryPool* pool, const rj::Value& json_obj, const Schema& schema, + DictionaryMemo* dict_memo, std::shared_ptr* array); } // namespace json } // namespace internal diff --git a/cpp/src/arrow/ipc/json-test.cc b/cpp/src/arrow/ipc/json-test.cc index df87671619c59..36f2d16a1363f 100644 --- a/cpp/src/arrow/ipc/json-test.cc +++ b/cpp/src/arrow/ipc/json-test.cc @@ -27,6 +27,7 @@ #include "arrow/array.h" #include "arrow/buffer.h" #include "arrow/builder.h" +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/json-integration.h" #include "arrow/ipc/json-internal.h" #include "arrow/ipc/test-common.h" @@ -49,8 +50,10 @@ void TestSchemaRoundTrip(const Schema& schema) { rj::StringBuffer sb; rj::Writer writer(sb); + DictionaryMemo out_memo; + writer.StartObject(); - ASSERT_OK(WriteSchema(schema, &writer)); + ASSERT_OK(WriteSchema(schema, &out_memo, &writer)); writer.EndObject(); std::string json_schema = sb.GetString(); @@ -58,8 +61,9 @@ void TestSchemaRoundTrip(const Schema& schema) { rj::Document d; d.Parse(json_schema); + DictionaryMemo in_memo; std::shared_ptr out; - if (!ReadSchema(d, default_memory_pool(), &out).ok()) { + if (!ReadSchema(d, default_memory_pool(), &in_memo, &out).ok()) { FAIL() << "Unable to read JSON schema: " << json_schema; } @@ -85,8 +89,11 @@ void TestArrayRoundTrip(const Array& array) { FAIL() << "JSON parsing failed"; } + DictionaryMemo out_memo; + std::shared_ptr out; - ASSERT_OK(ReadArray(default_memory_pool(), d, array.type(), &out)); + ASSERT_OK(ReadArray(default_memory_pool(), d, ::arrow::field(name, array.type()), + &out_memo, &out)); // std::cout << array_as_json << std::endl; CompareArraysDetailed(0, *out, array); diff --git a/cpp/src/arrow/ipc/metadata-internal.cc b/cpp/src/arrow/ipc/metadata-internal.cc index 6195ca54c36e9..9837cbeb5eaf7 100644 --- a/cpp/src/arrow/ipc/metadata-internal.cc +++ b/cpp/src/arrow/ipc/metadata-internal.cc @@ -83,8 +83,9 @@ MetadataVersion GetMetadataVersion(flatbuf::MetadataVersion version) { } } -static Status IntFromFlatbuffer(const flatbuf::Int* int_data, - std::shared_ptr* out) { +namespace { + +Status IntFromFlatbuffer(const flatbuf::Int* int_data, std::shared_ptr* out) { if (int_data->bitWidth() > 64) { return Status::NotImplemented("Integers with more than 64 bits not implemented"); } @@ -111,8 +112,8 @@ static Status IntFromFlatbuffer(const flatbuf::Int* int_data, return Status::OK(); } -static Status FloatFromFlatbuffer(const flatbuf::FloatingPoint* float_data, - std::shared_ptr* out) { +Status FloatFromFlatbuffer(const flatbuf::FloatingPoint* float_data, + std::shared_ptr* out) { if (float_data->precision() == flatbuf::Precision_HALF) { *out = float16(); } else if (float_data->precision() == flatbuf::Precision_SINGLE) { @@ -124,23 +125,23 @@ static Status FloatFromFlatbuffer(const flatbuf::FloatingPoint* float_data, } // Forward declaration -static Status FieldToFlatbuffer(FBB& fbb, const Field& field, - DictionaryMemo* dictionary_memo, FieldOffset* offset); +Status FieldToFlatbuffer(FBB& fbb, const std::shared_ptr& field, + DictionaryMemo* dictionary_memo, FieldOffset* offset); -static Offset IntToFlatbuffer(FBB& fbb, int bitWidth, bool is_signed) { +Offset IntToFlatbuffer(FBB& fbb, int bitWidth, bool is_signed) { return flatbuf::CreateInt(fbb, bitWidth, is_signed).Union(); } -static Offset FloatToFlatbuffer(FBB& fbb, flatbuf::Precision precision) { +Offset FloatToFlatbuffer(FBB& fbb, flatbuf::Precision precision) { return flatbuf::CreateFloatingPoint(fbb, precision).Union(); } -static Status AppendChildFields(FBB& fbb, const DataType& type, - std::vector* out_children, - DictionaryMemo* dictionary_memo) { +Status AppendChildFields(FBB& fbb, const DataType& type, + std::vector* out_children, + DictionaryMemo* dictionary_memo) { FieldOffset field; for (int i = 0; i < type.num_children(); ++i) { - RETURN_NOT_OK(FieldToFlatbuffer(fbb, *type.child(i), dictionary_memo, &field)); + RETURN_NOT_OK(FieldToFlatbuffer(fbb, type.child(i), dictionary_memo, &field)); out_children->push_back(field); } return Status::OK(); @@ -149,9 +150,9 @@ static Status AppendChildFields(FBB& fbb, const DataType& type, // ---------------------------------------------------------------------- // Union implementation -static Status UnionFromFlatbuffer(const flatbuf::Union* union_data, - const std::vector>& children, - std::shared_ptr* out) { +Status UnionFromFlatbuffer(const flatbuf::Union* union_data, + const std::vector>& children, + std::shared_ptr* out) { UnionMode::type mode = (union_data->mode() == flatbuf::UnionMode_Sparse ? UnionMode::SPARSE : UnionMode::DENSE); @@ -212,9 +213,9 @@ static inline TimeUnit::type FromFlatbufferUnit(flatbuf::TimeUnit unit) { return TimeUnit::SECOND; } -static Status ConcreteTypeFromFlatbuffer( - flatbuf::Type type, const void* type_data, - const std::vector>& children, std::shared_ptr* out) { +Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data, + const std::vector>& children, + std::shared_ptr* out) { switch (type) { case flatbuf::Type_NONE: return Status::Invalid("Type metadata cannot be none"); @@ -362,8 +363,8 @@ static Status TypeFromFlatbuffer(const flatbuf::Field* field, return Status::OK(); } -static Status TensorTypeToFlatbuffer(FBB& fbb, const DataType& type, - flatbuf::Type* out_type, Offset* offset) { +Status TensorTypeToFlatbuffer(FBB& fbb, const DataType& type, flatbuf::Type* out_type, + Offset* offset) { switch (type.id()) { case Type::UINT8: INT_TO_FB_CASE(8, false); @@ -400,9 +401,12 @@ static Status TensorTypeToFlatbuffer(FBB& fbb, const DataType& type, return Status::OK(); } -static DictionaryOffset GetDictionaryEncoding(FBB& fbb, const DictionaryType& type, - DictionaryMemo* memo) { - int64_t dictionary_id = memo->GetId(type.dictionary()); +Status GetDictionaryEncoding(FBB& fbb, const std::shared_ptr& field, + DictionaryMemo* memo, DictionaryOffset* out) { + int64_t dictionary_id = -1; + RETURN_NOT_OK(memo->GetOrAssignId(field, &dictionary_id)); + + const auto& type = checked_cast(*field->type()); // We assume that the dictionary index type (as an integer) has already been // validated elsewhere, and can safely assume we are dealing with signed @@ -412,8 +416,9 @@ static DictionaryOffset GetDictionaryEncoding(FBB& fbb, const DictionaryType& ty auto index_type_offset = flatbuf::CreateInt(fbb, fw_index_type.bit_width(), true); // TODO(wesm): ordered dictionaries - return flatbuf::CreateDictionaryEncoding(fbb, dictionary_id, index_type_offset, + *out = flatbuf::CreateDictionaryEncoding(fbb, dictionary_id, index_type_offset, type.ordered()); + return Status::OK(); } KeyValueOffset AppendKeyValue(FBB& fbb, const std::string& key, @@ -429,8 +434,8 @@ void AppendKeyValueMetadata(FBB& fbb, const KeyValueMetadata& metadata, } } -static Status KeyValueMetadataFromFlatbuffer(const KVVector* fb_metadata, - std::shared_ptr* out) { +Status KeyValueMetadataFromFlatbuffer(const KVVector* fb_metadata, + std::shared_ptr* out) { auto metadata = std::make_shared(); metadata->reserve(fb_metadata->size()); @@ -634,7 +639,7 @@ class FieldToFlatbufferVisitor { // In this library, the dictionary "type" is a logical construct. Here we // pass through to the value type, as we've already captured the index // type in the DictionaryEncoding metadata in the parent field - return VisitType(*checked_cast(type).dictionary()->type()); + return VisitType(*checked_cast(type).value_type()); } Status Visit(const ExtensionType& type) { @@ -644,18 +649,17 @@ class FieldToFlatbufferVisitor { return Status::OK(); } - Status GetResult(const Field& field, FieldOffset* offset) { - auto fb_name = fbb_.CreateString(field.name()); - RETURN_NOT_OK(VisitType(*field.type())); + Status GetResult(const std::shared_ptr& field, FieldOffset* offset) { + auto fb_name = fbb_.CreateString(field->name()); + RETURN_NOT_OK(VisitType(*field->type())); auto fb_children = fbb_.CreateVector(children_); DictionaryOffset dictionary = 0; - if (field.type()->id() == Type::DICTIONARY) { - dictionary = GetDictionaryEncoding( - fbb_, checked_cast(*field.type()), dictionary_memo_); + if (field->type()->id() == Type::DICTIONARY) { + RETURN_NOT_OK(GetDictionaryEncoding(fbb_, field, dictionary_memo_, &dictionary)); } - auto metadata = field.metadata(); + auto metadata = field->metadata(); flatbuffers::Offset fb_custom_metadata; std::vector key_values; @@ -671,7 +675,7 @@ class FieldToFlatbufferVisitor { fb_custom_metadata = fbb_.CreateVector(key_values); } *offset = - flatbuf::CreateField(fbb_, fb_name, field.nullable(), fb_type_, type_offset_, + flatbuf::CreateField(fbb_, fb_name, field->nullable(), fb_type_, type_offset_, dictionary, fb_children, fb_custom_metadata); return Status::OK(); } @@ -685,14 +689,14 @@ class FieldToFlatbufferVisitor { std::unordered_map extra_type_metadata_; }; -static Status FieldToFlatbuffer(FBB& fbb, const Field& field, - DictionaryMemo* dictionary_memo, FieldOffset* offset) { +Status FieldToFlatbuffer(FBB& fbb, const std::shared_ptr& field, + DictionaryMemo* dictionary_memo, FieldOffset* offset) { FieldToFlatbufferVisitor field_visitor(fbb, dictionary_memo); return field_visitor.GetResult(field, offset); } -static Status GetFieldMetadata(const flatbuf::Field* field, - std::shared_ptr* metadata) { +Status GetFieldMetadata(const flatbuf::Field* field, + std::shared_ptr* metadata) { auto fb_metadata = field->custom_metadata(); if (fb_metadata != nullptr) { RETURN_NOT_OK(KeyValueMetadataFromFlatbuffer(fb_metadata, metadata)); @@ -700,63 +704,36 @@ static Status GetFieldMetadata(const flatbuf::Field* field, return Status::OK(); } -static Status FieldFromFlatbuffer(const flatbuf::Field* field, - const DictionaryMemo& dictionary_memo, - std::shared_ptr* out) { +Status FieldFromFlatbuffer(const flatbuf::Field* field, DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { std::shared_ptr type; - const flatbuf::DictionaryEncoding* encoding = field->dictionary(); - std::shared_ptr metadata; RETURN_NOT_OK(GetFieldMetadata(field, &metadata)); - if (encoding == nullptr) { - // The field is not dictionary encoded. We must potentially visit its - // children to fully reconstruct the data type - auto children = field->children(); - std::vector> child_fields(children->size()); - for (int i = 0; i < static_cast(children->size()); ++i) { - RETURN_NOT_OK( - FieldFromFlatbuffer(children->Get(i), dictionary_memo, &child_fields[i])); - } - RETURN_NOT_OK(TypeFromFlatbuffer(field, child_fields, metadata.get(), &type)); - } else { - // The field is dictionary encoded. The type of the dictionary values has - // been determined elsewhere, and is stored in the DictionaryMemo. Here we - // construct the logical DictionaryType object - - std::shared_ptr dictionary; - RETURN_NOT_OK(dictionary_memo.GetDictionary(encoding->id(), &dictionary)); - - std::shared_ptr index_type; - RETURN_NOT_OK(IntFromFlatbuffer(encoding->indexType(), &index_type)); - type = ::arrow::dictionary(index_type, dictionary, encoding->isOrdered()); - } - - *out = std::make_shared(field->name()->str(), type, field->nullable(), metadata); - - return Status::OK(); -} - -static Status FieldFromFlatbufferDictionary(const flatbuf::Field* field, - std::shared_ptr* out) { - // Need an empty memo to pass down for constructing children - DictionaryMemo dummy_memo; - - // Any DictionaryEncoding set is ignored here - - std::shared_ptr type; + // Reconstruct the data type auto children = field->children(); std::vector> child_fields(children->size()); for (int i = 0; i < static_cast(children->size()); ++i) { - RETURN_NOT_OK(FieldFromFlatbuffer(children->Get(i), dummy_memo, &child_fields[i])); + RETURN_NOT_OK( + FieldFromFlatbuffer(children->Get(i), dictionary_memo, &child_fields[i])); } + RETURN_NOT_OK(TypeFromFlatbuffer(field, child_fields, metadata.get(), &type)); - std::shared_ptr metadata; - RETURN_NOT_OK(GetFieldMetadata(field, &metadata)); + const flatbuf::DictionaryEncoding* encoding = field->dictionary(); - RETURN_NOT_OK(TypeFromFlatbuffer(field, child_fields, metadata.get(), &type)); - *out = std::make_shared(field->name()->str(), type, field->nullable(), metadata); + if (encoding != nullptr) { + // The field is dictionary-encoded. Construct the DictionaryType + // based on the DictionaryEncoding metadata and record in the + // dictionary_memo + std::shared_ptr index_type; + RETURN_NOT_OK(IntFromFlatbuffer(encoding->indexType(), &index_type)); + type = ::arrow::dictionary(index_type, type, encoding->isOrdered()); + *out = ::arrow::field(field->name()->str(), type, field->nullable(), metadata); + RETURN_NOT_OK(dictionary_memo->AddField(encoding->id(), *out)); + } else { + *out = ::arrow::field(field->name()->str(), type, field->nullable(), metadata); + } return Status::OK(); } @@ -771,14 +748,13 @@ flatbuf::Endianness endianness() { return bint.c[0] == 1 ? flatbuf::Endianness_Big : flatbuf::Endianness_Little; } -static Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema, - DictionaryMemo* dictionary_memo, - flatbuffers::Offset* out) { +Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema, DictionaryMemo* dictionary_memo, + flatbuffers::Offset* out) { /// Fields std::vector field_offsets; for (int i = 0; i < schema.num_fields(); ++i) { FieldOffset offset; - RETURN_NOT_OK(FieldToFlatbuffer(fbb, *schema.field(i), dictionary_memo, &offset)); + RETURN_NOT_OK(FieldToFlatbuffer(fbb, schema.field(i), dictionary_memo, &offset)); field_offsets.push_back(offset); } @@ -797,23 +773,15 @@ static Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema, return Status::OK(); } -static Status WriteFBMessage(FBB& fbb, flatbuf::MessageHeader header_type, - flatbuffers::Offset header, int64_t body_length, - std::shared_ptr* out) { +Status WriteFBMessage(FBB& fbb, flatbuf::MessageHeader header_type, + flatbuffers::Offset header, int64_t body_length, + std::shared_ptr* out) { auto message = flatbuf::CreateMessage(fbb, kCurrentMetadataVersion, header_type, header, body_length); fbb.Finish(message); return WriteFlatbufferBuilder(fbb, out); } -Status WriteSchemaMessage(const Schema& schema, DictionaryMemo* dictionary_memo, - std::shared_ptr* out) { - FBB fbb; - flatbuffers::Offset fb_schema; - RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema)); - return WriteFBMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out); -} - using FieldNodeVector = flatbuffers::Offset>; using BufferVector = flatbuffers::Offset>; @@ -861,6 +829,16 @@ static Status MakeRecordBatch(FBB& fbb, int64_t length, int64_t body_length, return Status::OK(); } +} // namespace + +Status WriteSchemaMessage(const Schema& schema, DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { + FBB fbb; + flatbuffers::Offset fb_schema; + RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema)); + return WriteFBMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out); +} + Status WriteRecordBatchMessage(int64_t length, int64_t body_length, const std::vector& nodes, const std::vector& buffers, @@ -1066,44 +1044,7 @@ Status WriteFileFooter(const Schema& schema, const std::vector& dicti // ---------------------------------------------------------------------- -static Status VisitField(const flatbuf::Field* field, DictionaryTypeMap* id_to_field) { - const flatbuf::DictionaryEncoding* dict_metadata = field->dictionary(); - if (dict_metadata == nullptr) { - // Field is not dictionary encoded. Visit children - auto children = field->children(); - if (children == nullptr) { - return Status::IOError("Children-pointer of flatbuffer-encoded Field is null."); - } - for (flatbuffers::uoffset_t i = 0; i < children->size(); ++i) { - RETURN_NOT_OK(VisitField(children->Get(i), id_to_field)); - } - } else { - // Field is dictionary encoded. Construct the data type for the - // dictionary (no descendents can be dictionary encoded) - std::shared_ptr dictionary_field; - RETURN_NOT_OK(FieldFromFlatbufferDictionary(field, &dictionary_field)); - (*id_to_field)[dict_metadata->id()] = dictionary_field; - } - return Status::OK(); -} - -Status GetDictionaryTypes(const void* opaque_schema, DictionaryTypeMap* id_to_field) { - auto schema = static_cast(opaque_schema); - if (schema->fields() == nullptr) { - return Status::IOError("Fields-pointer of flatbuffer-encoded Schema is null."); - } - int num_fields = static_cast(schema->fields()->size()); - for (int i = 0; i < num_fields; ++i) { - auto field = schema->fields()->Get(i); - if (field == nullptr) { - return Status::IOError("Field-pointer of flatbuffer-encoded Schema is null."); - } - RETURN_NOT_OK(VisitField(field, id_to_field)); - } - return Status::OK(); -} - -Status GetSchema(const void* opaque_schema, const DictionaryMemo& dictionary_memo, +Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo, std::shared_ptr* out) { auto schema = static_cast(opaque_schema); if (schema->fields() == nullptr) { @@ -1114,6 +1055,9 @@ Status GetSchema(const void* opaque_schema, const DictionaryMemo& dictionary_mem std::vector> fields(num_fields); for (int i = 0; i < num_fields; ++i) { const flatbuf::Field* field = schema->fields()->Get(i); + if (field == nullptr) { + return Status::IOError("Field-pointer of flatbuffer-encoded Schema is null."); + } RETURN_NOT_OK(FieldFromFlatbuffer(field, dictionary_memo, &fields[i])); } diff --git a/cpp/src/arrow/ipc/metadata-internal.h b/cpp/src/arrow/ipc/metadata-internal.h index c91983d91c08f..4563fb029d6f9 100644 --- a/cpp/src/arrow/ipc/metadata-internal.h +++ b/cpp/src/arrow/ipc/metadata-internal.h @@ -91,14 +91,11 @@ struct FileBlock { // individual fields metadata can be retrieved from very large schema without // -// Retrieve a list of all the dictionary ids and types required by the schema for -// reconstruction. The presumption is that these will be loaded either from -// the stream or file (or they may already be somewhere else in memory) -Status GetDictionaryTypes(const void* opaque_schema, DictionaryTypeMap* id_to_field); - -// Construct a complete Schema from the message. May be expensive for very -// large schemas if you are only interested in a few fields -Status GetSchema(const void* opaque_schema, const DictionaryMemo& dictionary_memo, +// Construct a complete Schema from the message and add +// dictinory-encoded fields to a DictionaryMemo instance. May be +// expensive for very large schemas if you are only interested in a +// few fields +Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo, std::shared_ptr* out); Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr* type, diff --git a/cpp/src/arrow/ipc/read-write-benchmark.cc b/cpp/src/arrow/ipc/read-write-benchmark.cc index 1d60ef8a3a54e..66d45fb01272e 100644 --- a/cpp/src/arrow/ipc/read-write-benchmark.cc +++ b/cpp/src/arrow/ipc/read-write-benchmark.cc @@ -86,11 +86,13 @@ static void BM_ReadRecordBatch(benchmark::State& state) { // NOLINT non-const r state.SkipWithError("Failed to write!"); } + ipc::DictionaryMemo empty_memo; while (state.KeepRunning()) { std::shared_ptr result; io::BufferReader reader(buffer); - if (!ipc::ReadRecordBatch(record_batch->schema(), &reader, &result).ok()) { + if (!ipc::ReadRecordBatch(record_batch->schema(), &empty_memo, &reader, &result) + .ok()) { state.SkipWithError("Failed to read!"); } } diff --git a/cpp/src/arrow/ipc/read-write-test.cc b/cpp/src/arrow/ipc/read-write-test.cc index 5f76545e839e9..edae88ccc0b50 100644 --- a/cpp/src/arrow/ipc/read-write-test.cc +++ b/cpp/src/arrow/ipc/read-write-test.cc @@ -139,11 +139,12 @@ class TestSchemaMetadata : public ::testing::Test { void CheckRoundtrip(const Schema& schema) { std::shared_ptr buffer; - ASSERT_OK(SerializeSchema(schema, default_memory_pool(), &buffer)); + DictionaryMemo in_memo, out_memo; + ASSERT_OK(SerializeSchema(schema, &out_memo, default_memory_pool(), &buffer)); std::shared_ptr result; io::BufferReader reader(buffer); - ASSERT_OK(ReadSchema(&reader, &result)); + ASSERT_OK(ReadSchema(&reader, &in_memo, &result)); AssertSchemaEqual(schema, *result); } }; @@ -181,8 +182,7 @@ TEST_F(TestSchemaMetadata, NestedFields) { TEST_F(TestSchemaMetadata, DictionaryFields) { { - auto dict_type = - dictionary(int8(), ArrayFromJSON(int32(), "[6, 5, 4]"), true /* ordered */); + auto dict_type = dictionary(int8(), int32(), true /* ordered */); auto f0 = field("f0", dict_type); auto f1 = field("f1", list(dict_type)); @@ -190,7 +190,7 @@ TEST_F(TestSchemaMetadata, DictionaryFields) { CheckRoundtrip(schema); } { - auto dict_type = dictionary(int8(), ArrayFromJSON(list(int32()), "[[4, 5], [6]]")); + auto dict_type = dictionary(int8(), list(int32())); auto f0 = field("f0", dict_type); Schema schema({f0}); @@ -221,21 +221,24 @@ static int g_file_number = 0; class IpcTestFixture : public io::MemoryMapFixture { public: - Status DoSchemaRoundTrip(const Schema& schema, std::shared_ptr* result) { + void DoSchemaRoundTrip(const Schema& schema, DictionaryMemo* out_memo, + std::shared_ptr* result) { std::shared_ptr serialized_schema; - RETURN_NOT_OK(SerializeSchema(schema, pool_, &serialized_schema)); + ASSERT_OK(SerializeSchema(schema, out_memo, pool_, &serialized_schema)); + DictionaryMemo in_memo; io::BufferReader buf_reader(serialized_schema); - return ReadSchema(&buf_reader, result); + ASSERT_OK(ReadSchema(&buf_reader, &in_memo, result)); + ASSERT_EQ(out_memo->num_fields(), in_memo.num_fields()); } - Status DoStandardRoundTrip(const RecordBatch& batch, + Status DoStandardRoundTrip(const RecordBatch& batch, DictionaryMemo* dictionary_memo, std::shared_ptr* batch_result) { std::shared_ptr serialized_batch; RETURN_NOT_OK(SerializeRecordBatch(batch, pool_, &serialized_batch)); io::BufferReader buf_reader(serialized_batch); - return ReadRecordBatch(batch.schema(), &buf_reader, batch_result); + return ReadRecordBatch(batch.schema(), dictionary_memo, &buf_reader, batch_result); } Status DoLargeRoundTrip(const RecordBatch& batch, bool zero_data, @@ -274,15 +277,19 @@ class IpcTestFixture : public io::MemoryMapFixture { ss << "test-write-row-batch-" << g_file_number++; ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(buffer_size, ss.str(), &mmap_)); + DictionaryMemo dictionary_memo; + std::shared_ptr schema_result; - ASSERT_OK(DoSchemaRoundTrip(*batch.schema(), &schema_result)); + DoSchemaRoundTrip(*batch.schema(), &dictionary_memo, &schema_result); ASSERT_TRUE(batch.schema()->Equals(*schema_result)); + ASSERT_OK(CollectDictionaries(batch, &dictionary_memo)); + std::shared_ptr result; - ASSERT_OK(DoStandardRoundTrip(batch, &result)); + ASSERT_OK(DoStandardRoundTrip(batch, &dictionary_memo, &result)); CheckReadResult(*result, batch); - ASSERT_OK(DoLargeRoundTrip(batch, true, &result)); + ASSERT_OK(DoLargeRoundTrip(batch, /*zero_data=*/true, &result)); CheckReadResult(*result, batch); } @@ -550,8 +557,10 @@ TEST_F(RecursionLimits, ReadLimit) { io::BufferReader reader(message->body()); + DictionaryMemo empty_memo; std::shared_ptr result; - ASSERT_RAISES(Invalid, ReadRecordBatch(*message->metadata(), schema, &reader, &result)); + ASSERT_RAISES(Invalid, ReadRecordBatch(*message->metadata(), schema, &empty_memo, + &reader, &result)); } // Test fails with a structured exception on Windows + Debug @@ -568,10 +577,12 @@ TEST_F(RecursionLimits, StressLimit) { std::unique_ptr message; ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message)); + DictionaryMemo empty_memo; + io::BufferReader reader(message->body()); std::shared_ptr result; - ASSERT_OK(ReadRecordBatch(*message->metadata(), schema, recursion_depth + 1, &reader, - &result)); + ASSERT_OK(ReadRecordBatch(*message->metadata(), schema, &empty_memo, + recursion_depth + 1, &reader, &result)); *it_works = result->Equals(*batch); }; @@ -697,7 +708,12 @@ class ReaderWriterMixin { ASSERT_OK(RoundTripHelper({batch}, &out_batches)); ASSERT_EQ(out_batches.size(), 1); - CheckBatchDictionaries(*out_batches[0]); + // TODO(wesm): This was broken in ARROW-3144. I'm not sure how to + // restore the deduplication logic yet because dictionaries are + // corresponded to the Schema using Field pointers rather than + // DataType as before + + // CheckDictionariesDeduplicated(*out_batches[0]); } void TestWriteDifferentSchema() { @@ -743,15 +759,15 @@ class ReaderWriterMixin { // Check that dictionaries that should be the same are the same auto schema = batch.schema(); - const auto& t0 = checked_cast(*schema->field(0)->type()); - const auto& t1 = checked_cast(*schema->field(1)->type()); + const auto& b0 = checked_cast(*batch.column(0)); + const auto& b1 = checked_cast(*batch.column(1)); - ASSERT_EQ(t0.dictionary().get(), t1.dictionary().get()); + ASSERT_EQ(b0.dictionary().get(), b1.dictionary().get()); // Same dictionary used for list values - const auto& t3 = checked_cast(*schema->field(3)->type()); - const auto& t3_value = checked_cast(*t3.value_type()); - ASSERT_EQ(t0.dictionary().get(), t3_value.dictionary().get()); + const auto& b3 = checked_cast(*batch.column(3)); + const auto& b3_value = checked_cast(*b3.values()); + ASSERT_EQ(b0.dictionary().get(), b3_value.dictionary().get()); } }; @@ -1014,6 +1030,44 @@ TEST(TestRecordBatchStreamReader, MalformedInput) { ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader, &batch_reader)); } +// ---------------------------------------------------------------------- +// DictionaryMemo miscellanea + +TEST(TestDictionaryMemo, ReusedDictionaries) { + DictionaryMemo memo; + + std::shared_ptr field1 = field("a", dictionary(int8(), utf8())); + std::shared_ptr field2 = field("b", dictionary(int16(), utf8())); + + // Two fields referencing the same dictionary_id + int64_t dictionary_id = 0; + auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); + + ASSERT_OK(memo.AddField(dictionary_id, field1)); + ASSERT_OK(memo.AddField(dictionary_id, field2)); + + std::shared_ptr value_type; + ASSERT_OK(memo.GetDictionaryType(dictionary_id, &value_type)); + ASSERT_TRUE(value_type->Equals(*utf8())); + + ASSERT_FALSE(memo.HasDictionary(dictionary_id)); + ASSERT_OK(memo.AddDictionary(dictionary_id, dict)); + ASSERT_TRUE(memo.HasDictionary(dictionary_id)); + + ASSERT_EQ(2, memo.num_fields()); + ASSERT_EQ(1, memo.num_dictionaries()); + + ASSERT_TRUE(memo.HasDictionary(*field1)); + ASSERT_TRUE(memo.HasDictionary(*field2)); + + int64_t returned_id = -1; + ASSERT_OK(memo.GetId(*field1, &returned_id)); + ASSERT_EQ(0, returned_id); + returned_id = -1; + ASSERT_OK(memo.GetId(*field2, &returned_id)); + ASSERT_EQ(0, returned_id); +} + } // namespace test } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index b236c518eea65..9dcf26e7297ac 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -137,28 +137,26 @@ class IpcComponentSource { /// reconstruction, for example) struct ArrayLoaderContext { IpcComponentSource* source; + const DictionaryMemo* dictionary_memo; int buffer_index; int field_index; int max_recursion_depth; }; -static Status LoadArray(const std::shared_ptr& type, - ArrayLoaderContext* context, ArrayData* out); +static Status LoadArray(const Field& field, ArrayLoaderContext* context, ArrayData* out); class ArrayLoader { public: - ArrayLoader(const std::shared_ptr& type, ArrayData* out, - ArrayLoaderContext* context) - : type_(type), context_(context), out_(out) {} + ArrayLoader(const Field& field, ArrayData* out, ArrayLoaderContext* context) + : field_(field), context_(context), out_(out) {} Status Load() { if (context_->max_recursion_depth <= 0) { return Status::Invalid("Max recursion depth reached"); } - out_->type = type_; - - RETURN_NOT_OK(VisitTypeInline(*type_, this)); + RETURN_NOT_OK(VisitTypeInline(*field_.type(), this)); + out_->type = field_.type(); return Status::OK(); } @@ -206,7 +204,7 @@ class ArrayLoader { } Status LoadChild(const Field& field, ArrayData* out) { - ArrayLoader loader(field.type(), out, context_); + ArrayLoader loader(field, out, context_); --context_->max_recursion_depth; RETURN_NOT_OK(loader.Load()); ++context_->max_recursion_depth; @@ -218,7 +216,7 @@ class ArrayLoader { for (const auto& child_field : child_fields) { auto field_array = std::make_shared(); - RETURN_NOT_OK(LoadChild(*child_field.get(), field_array.get())); + RETURN_NOT_OK(LoadChild(*child_field, field_array.get())); out_->child_data.emplace_back(field_array); } return Status::OK(); @@ -300,42 +298,48 @@ class ArrayLoader { } Status Visit(const DictionaryType& type) { - RETURN_NOT_OK(LoadArray(type.index_type(), context_, out_)); - out_->type = type_; + RETURN_NOT_OK( + LoadArray(*::arrow::field("indices", type.index_type()), context_, out_)); + + // Look up dictionary + int64_t id = -1; + RETURN_NOT_OK(context_->dictionary_memo->GetId(field_, &id)); + RETURN_NOT_OK(context_->dictionary_memo->GetDictionary(id, &out_->dictionary)); + return Status::OK(); } Status Visit(const ExtensionType& type) { - RETURN_NOT_OK(LoadArray(type.storage_type(), context_, out_)); - out_->type = type_; - return Status::OK(); + return LoadArray(*::arrow::field("storage", type.storage_type()), context_, out_); } private: - const std::shared_ptr type_; + const Field& field_; ArrayLoaderContext* context_; // Used in visitor pattern ArrayData* out_; }; -static Status LoadArray(const std::shared_ptr& type, - ArrayLoaderContext* context, ArrayData* out) { - ArrayLoader loader(type, out, context); +static Status LoadArray(const Field& field, ArrayLoaderContext* context, ArrayData* out) { + ArrayLoader loader(field, out, context); return loader.Load(); } Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, - io::RandomAccessFile* file, std::shared_ptr* out) { - return ReadRecordBatch(metadata, schema, kMaxNestingDepth, file, out); + const DictionaryMemo* dictionary_memo, io::RandomAccessFile* file, + std::shared_ptr* out) { + return ReadRecordBatch(metadata, schema, dictionary_memo, kMaxNestingDepth, file, out); } Status ReadRecordBatch(const Message& message, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, std::shared_ptr* out) { CHECK_MESSAGE_TYPE(message.type(), Message::RECORD_BATCH); CHECK_HAS_BODY(message); io::BufferReader reader(message.body()); - return ReadRecordBatch(*message.metadata(), schema, kMaxNestingDepth, &reader, out); + return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, kMaxNestingDepth, + &reader, out); } // ---------------------------------------------------------------------- @@ -344,17 +348,15 @@ Status ReadRecordBatch(const Message& message, const std::shared_ptr& sc static Status LoadRecordBatchFromSource(const std::shared_ptr& schema, int64_t num_rows, int max_recursion_depth, IpcComponentSource* source, + const DictionaryMemo* dictionary_memo, std::shared_ptr* out) { - ArrayLoaderContext context; - context.source = source; - context.field_index = 0; - context.buffer_index = 0; - context.max_recursion_depth = max_recursion_depth; + ArrayLoaderContext context{source, dictionary_memo, /*field_index=*/0, + /*buffer_index=*/0, max_recursion_depth}; std::vector> arrays(schema->num_fields()); for (int i = 0; i < schema->num_fields(); ++i) { auto arr = std::make_shared(); - RETURN_NOT_OK(LoadArray(schema->field(i)->type(), &context, arr.get())); + RETURN_NOT_OK(LoadArray(*schema->field(i), &context, arr.get())); DCHECK_EQ(num_rows, arr->length) << "Array length did not match record batch length"; arrays[i] = std::move(arr); } @@ -365,16 +367,17 @@ static Status LoadRecordBatchFromSource(const std::shared_ptr& schema, static inline Status ReadRecordBatch(const flatbuf::RecordBatch* metadata, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, int max_recursion_depth, io::RandomAccessFile* file, std::shared_ptr* out) { IpcComponentSource source(metadata, file); return LoadRecordBatchFromSource(schema, metadata->length(), max_recursion_depth, - &source, out); + &source, dictionary_memo, out); } Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, - int max_recursion_depth, io::RandomAccessFile* file, - std::shared_ptr* out) { + const DictionaryMemo* dictionary_memo, int max_recursion_depth, + io::RandomAccessFile* file, std::shared_ptr* out) { auto message = flatbuf::GetMessage(metadata.data()); if (message->header_type() != flatbuf::MessageHeader_RecordBatch) { DCHECK_EQ(message->header_type(), flatbuf::MessageHeader_RecordBatch); @@ -383,56 +386,49 @@ Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& sc return Status::IOError("Header-pointer of flatbuffer-encoded Message is null."); } auto batch = reinterpret_cast(message->header()); - return ReadRecordBatch(batch, schema, max_recursion_depth, file, out); + return ReadRecordBatch(batch, schema, dictionary_memo, max_recursion_depth, file, out); } -Status ReadDictionary(const Buffer& metadata, const DictionaryTypeMap& dictionary_types, - io::RandomAccessFile* file, int64_t* dictionary_id, - std::shared_ptr* out) { +Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, + io::RandomAccessFile* file) { auto message = flatbuf::GetMessage(metadata.data()); auto dictionary_batch = reinterpret_cast(message->header()); - int64_t id = *dictionary_id = dictionary_batch->id(); - auto it = dictionary_types.find(id); - if (it == dictionary_types.end()) { - return Status::KeyError("Do not have type metadata for dictionary with id: ", id); - } + int64_t id = dictionary_batch->id(); - std::vector> fields = {it->second}; + // Look up the field, which must have been added to the + // DictionaryMemo already prior to invoking this function + std::shared_ptr value_type; + RETURN_NOT_OK(dictionary_memo->GetDictionaryType(id, &value_type)); - // We need a schema for the record batch - auto dummy_schema = std::make_shared(fields); + auto value_field = ::arrow::field("dummy", value_type); // The dictionary is embedded in a record batch with a single column std::shared_ptr batch; auto batch_meta = reinterpret_cast(dictionary_batch->data()); - RETURN_NOT_OK( - ReadRecordBatch(batch_meta, dummy_schema, kMaxNestingDepth, file, &batch)); + RETURN_NOT_OK(ReadRecordBatch(batch_meta, ::arrow::schema({value_field}), + dictionary_memo, kMaxNestingDepth, file, &batch)); if (batch->num_columns() != 1) { return Status::Invalid("Dictionary record batch must only contain one field"); } - - *out = batch->column(0); - return Status::OK(); + auto dictionary = batch->column(0); + return dictionary_memo->AddDictionary(id, dictionary); } -static Status ReadMessageAndValidate(MessageReader* reader, Message::Type expected_type, - bool allow_null, std::unique_ptr* message) { +static Status ReadMessageAndValidate(MessageReader* reader, bool allow_null, + std::unique_ptr* message) { RETURN_NOT_OK(reader->ReadNextMessage(message)); if (!(*message) && !allow_null) { - return Status::Invalid("Expected ", FormatMessageType(expected_type), - " message in stream, was null or length 0"); + return Status::Invalid("Expected message in stream, was null or length 0"); } if ((*message) == nullptr) { // End of stream? return Status::OK(); } - - CHECK_MESSAGE_TYPE((*message)->type(), expected_type); return Status::OK(); } @@ -453,64 +449,72 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { return ReadSchema(); } - Status ReadNextDictionary() { - std::unique_ptr message; - RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::DICTIONARY_BATCH, - false, &message)); - if (message == nullptr) { - // End of stream - return Status::IOError( - "End of IPC stream when attempting to read dictionary batch"); - } - - CHECK_HAS_BODY(*message); - io::BufferReader reader(message->body()); - - std::shared_ptr dictionary; - int64_t id; - RETURN_NOT_OK(ReadDictionary(*message->metadata(), dictionary_types_, &reader, &id, - &dictionary)); - return dictionary_memo_.AddDictionary(id, dictionary); - } - Status ReadSchema() { std::unique_ptr message; RETURN_NOT_OK( - ReadMessageAndValidate(message_reader_.get(), Message::SCHEMA, false, &message)); - if (message == nullptr) { - // End of stream - return Status::IOError("End of IPC stream when attempting to read schema"); - } + ReadMessageAndValidate(message_reader_.get(), /*allow_null=*/false, &message)); + CHECK_MESSAGE_TYPE(message->type(), Message::SCHEMA); CHECK_HAS_NO_BODY(*message); if (message->header() == nullptr) { return Status::IOError("Header-pointer of flatbuffer-encoded Message is null."); } - RETURN_NOT_OK(internal::GetDictionaryTypes(message->header(), &dictionary_types_)); + return internal::GetSchema(message->header(), &dictionary_memo_, &schema_); + } + + Status ParseDictionary(const Message& message) { + // Only invoke this method if we already know we have a dictionary message + DCHECK_EQ(message.type(), Message::DICTIONARY_BATCH); + CHECK_HAS_BODY(message); + io::BufferReader reader(message.body()); + return ReadDictionary(*message.metadata(), &dictionary_memo_, &reader); + } + + Status ReadInitialDictionaries() { + // We must receive all dictionaries before reconstructing the + // first record batch. Subsequent dictionary deltas modify the memo + std::unique_ptr message; // TODO(wesm): In future, we may want to reconcile the ids in the stream with // those found in the schema - int num_dictionaries = static_cast(dictionary_types_.size()); - for (int i = 0; i < num_dictionaries; ++i) { - RETURN_NOT_OK(ReadNextDictionary()); + for (int i = 0; i < dictionary_memo_.num_fields(); ++i) { + RETURN_NOT_OK( + ReadMessageAndValidate(message_reader_.get(), /*allow_null=*/false, &message)); + if (message->type() != Message::DICTIONARY_BATCH) { + return Status::Invalid( + "IPC stream did not find the expected number of " + "dictionaries at the start of the stream"); + } + RETURN_NOT_OK(ParseDictionary(*message)); } - return internal::GetSchema(message->header(), dictionary_memo_, &schema_); + read_initial_dictionaries_ = true; + return Status::OK(); } Status ReadNext(std::shared_ptr* batch) { + if (!read_initial_dictionaries_) { + RETURN_NOT_OK(ReadInitialDictionaries()); + } + std::unique_ptr message; - RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::RECORD_BATCH, - true, &message)); + RETURN_NOT_OK( + ReadMessageAndValidate(message_reader_.get(), /*allow_null=*/true, &message)); if (message == nullptr) { // End of stream *batch = nullptr; return Status::OK(); } - CHECK_HAS_BODY(*message); - io::BufferReader reader(message->body()); - return ReadRecordBatch(*message->metadata(), schema_, &reader, batch); + if (message->type() == Message::DICTIONARY_BATCH) { + // TODO(wesm): implement delta dictionaries + return Status::NotImplemented("Delta dictionaries not yet implemented"); + } else { + CHECK_HAS_BODY(*message); + io::BufferReader reader(message->body()); + return ReadRecordBatch(*message->metadata(), schema_, &dictionary_memo_, &reader, + batch); + } } std::shared_ptr schema() const { return schema_; } @@ -518,8 +522,8 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { private: std::unique_ptr message_reader_; - // dictionary_id -> type - DictionaryTypeMap dictionary_types_; + bool read_initial_dictionaries_ = false; + DictionaryMemo dictionary_memo_; std::shared_ptr schema_; }; @@ -571,9 +575,7 @@ Status RecordBatchStreamReader::ReadNext(std::shared_ptr* batch) { class RecordBatchFileReader::RecordBatchFileReaderImpl { public: - RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) { - dictionary_memo_ = std::make_shared(); - } + RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) {} Status ReadFooter() { int magic_size = static_cast(strlen(kArrowMagicBytes)); @@ -619,61 +621,58 @@ class RecordBatchFileReader::RecordBatchFileReaderImpl { return internal::GetMetadataVersion(footer_->version()); } - FileBlock record_batch(int i) const { + FileBlock GetRecordBatchBlock(int i) const { return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i)); } - FileBlock dictionary(int i) const { + FileBlock GetDictionaryBlock(int i) const { return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i)); } - Status ReadRecordBatch(int i, std::shared_ptr* batch) { - DCHECK_GE(i, 0); - DCHECK_LT(i, num_record_batches()); - FileBlock block = record_batch(i); - + Status ReadMessageFromBlock(const FileBlock& block, std::unique_ptr* out) { DCHECK(BitUtil::IsMultipleOf8(block.offset)); DCHECK(BitUtil::IsMultipleOf8(block.metadata_length)); DCHECK(BitUtil::IsMultipleOf8(block.body_length)); - std::unique_ptr message; - RETURN_NOT_OK(ReadMessage(block.offset, block.metadata_length, file_, &message)); + RETURN_NOT_OK(ReadMessage(block.offset, block.metadata_length, file_, out)); // TODO(wesm): this breaks integration tests, see ARROW-3256 - // DCHECK_EQ(message->body_length(), block.body_length); - - io::BufferReader reader(message->body()); - return ::arrow::ipc::ReadRecordBatch(*message->metadata(), schema_, &reader, batch); + // DCHECK_EQ((*out)->body_length(), block.body_length); + return Status::OK(); } - Status ReadSchema() { - RETURN_NOT_OK(internal::GetDictionaryTypes(footer_->schema(), &dictionary_fields_)); - + Status ReadDictionaries() { // Read all the dictionaries for (int i = 0; i < num_dictionaries(); ++i) { - FileBlock block = dictionary(i); - - DCHECK(BitUtil::IsMultipleOf8(block.offset)); - DCHECK(BitUtil::IsMultipleOf8(block.metadata_length)); - DCHECK(BitUtil::IsMultipleOf8(block.body_length)); - std::unique_ptr message; - RETURN_NOT_OK(ReadMessage(block.offset, block.metadata_length, file_, &message)); - - // TODO(wesm): this breaks integration tests, see ARROW-3256 - // DCHECK_EQ(message->body_length(), block.body_length); + RETURN_NOT_OK(ReadMessageFromBlock(GetDictionaryBlock(i), &message)); io::BufferReader reader(message->body()); + RETURN_NOT_OK(ReadDictionary(*message->metadata(), &dictionary_memo_, &reader)); + } + return Status::OK(); + } - std::shared_ptr dictionary; - int64_t dictionary_id; - RETURN_NOT_OK(ReadDictionary(*message->metadata(), dictionary_fields_, &reader, - &dictionary_id, &dictionary)); - RETURN_NOT_OK(dictionary_memo_->AddDictionary(dictionary_id, dictionary)); + Status ReadRecordBatch(int i, std::shared_ptr* batch) { + DCHECK_GE(i, 0); + DCHECK_LT(i, num_record_batches()); + + if (!read_dictionaries_) { + RETURN_NOT_OK(ReadDictionaries()); + read_dictionaries_ = true; } - // Get the schema - return internal::GetSchema(footer_->schema(), *dictionary_memo_, &schema_); + std::unique_ptr message; + RETURN_NOT_OK(ReadMessageFromBlock(GetRecordBatchBlock(i), &message)); + + io::BufferReader reader(message->body()); + return ::arrow::ipc::ReadRecordBatch(*message->metadata(), schema_, &dictionary_memo_, + &reader, batch); + } + + Status ReadSchema() { + // Get the schema and record any observed dictionaries + return internal::GetSchema(footer_->schema(), &dictionary_memo_, &schema_); } Status Open(const std::shared_ptr& file, int64_t footer_offset) { @@ -703,8 +702,8 @@ class RecordBatchFileReader::RecordBatchFileReaderImpl { std::shared_ptr footer_buffer_; const flatbuf::Footer* footer_; - DictionaryTypeMap dictionary_fields_; - std::shared_ptr dictionary_memo_; + bool read_dictionaries_ = false; + DictionaryMemo dictionary_memo_; // Reconstructed schema, including any read dictionaries std::shared_ptr schema_; @@ -765,26 +764,29 @@ static Status ReadContiguousPayload(io::InputStream* file, return Status::OK(); } -Status ReadSchema(io::InputStream* stream, std::shared_ptr* out) { - std::shared_ptr reader; - RETURN_NOT_OK(RecordBatchStreamReader::Open(stream, &reader)); - *out = reader->schema(); - return Status::OK(); +Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { + std::unique_ptr reader = MessageReader::Open(stream); + std::unique_ptr message; + RETURN_NOT_OK(ReadMessageAndValidate(reader.get(), /*allow_null=*/false, &message)); + CHECK_MESSAGE_TYPE(message->type(), Message::SCHEMA); + return ReadSchema(*message, dictionary_memo, out); } -Status ReadSchema(const Message& message, std::shared_ptr* out) { +Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { std::shared_ptr reader; - DictionaryMemo dictionary_memo; return internal::GetSchema(message.header(), dictionary_memo, &*out); } -Status ReadRecordBatch(const std::shared_ptr& schema, io::InputStream* file, +Status ReadRecordBatch(const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, io::InputStream* file, std::shared_ptr* out) { std::unique_ptr message; RETURN_NOT_OK(ReadContiguousPayload(file, &message)); io::BufferReader buffer_reader(message->body()); - return ReadRecordBatch(*message->metadata(), schema, kMaxNestingDepth, &buffer_reader, - out); + return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, kMaxNestingDepth, + &buffer_reader, out); } Status ReadTensor(io::InputStream* file, std::shared_ptr* out) { diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index 8fe310f5b7745..34a0eefbbb59d 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -23,6 +23,7 @@ #include #include +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/message.h" #include "arrow/record_batch.h" #include "arrow/util/visibility.h" @@ -166,70 +167,91 @@ class ARROW_EXPORT RecordBatchFileReader { // Generic read functions; does not copy data if the input supports zero copy reads -/// \brief Read Schema from stream serialized as a sequence of one or more IPC -/// messages +/// \brief Read Schema from stream serialized as a single IPC message +/// and populate any dictionary-encoded fields into a DictionaryMemo /// /// \param[in] stream an InputStream +/// \param[in] dictionary_memo for recording dictionary-encoded fields /// \param[out] out the output Schema /// \return Status /// /// If record batches follow the schema, it is better to use /// RecordBatchStreamReader ARROW_EXPORT -Status ReadSchema(io::InputStream* stream, std::shared_ptr* out); +Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo, + std::shared_ptr* out); /// \brief Read Schema from encapsulated Message /// /// \param[in] message a message instance containing metadata +/// \param[in] dictionary_memo DictionaryMemo for recording dictionary-encoded +/// fields. Can be nullptr if you are sure there are no +/// dictionary-encoded fields /// \param[out] out the resulting Schema /// \return Status ARROW_EXPORT -Status ReadSchema(const Message& message, std::shared_ptr* out); +Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo, + std::shared_ptr* out); /// Read record batch as encapsulated IPC message with metadata size prefix and /// header /// /// \param[in] schema the record batch schema +/// \param[in] dictionary_memo DictionaryMemo which has any +/// dictionaries. Can be nullptr if you are sure there are no +/// dictionary-encoded fields /// \param[in] stream the file where the batch is located /// \param[out] out the read record batch /// \return Status ARROW_EXPORT -Status ReadRecordBatch(const std::shared_ptr& schema, io::InputStream* stream, +Status ReadRecordBatch(const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, io::InputStream* stream, std::shared_ptr* out); /// \brief Read record batch from file given metadata and schema /// /// \param[in] metadata a Message containing the record batch metadata /// \param[in] schema the record batch schema +/// \param[in] dictionary_memo DictionaryMemo which has any +/// dictionaries. Can be nullptr if you are sure there are no +/// dictionary-encoded fields /// \param[in] file a random access file /// \param[out] out the read record batch /// \return Status ARROW_EXPORT Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, - io::RandomAccessFile* file, std::shared_ptr* out); + const DictionaryMemo* dictionary_memo, io::RandomAccessFile* file, + std::shared_ptr* out); /// \brief Read record batch from encapsulated Message /// /// \param[in] message a message instance containing metadata and body /// \param[in] schema the record batch schema +/// \param[in] dictionary_memo DictionaryMemo which has any +/// dictionaries. Can be nullptr if you are sure there are no +/// dictionary-encoded fields /// \param[out] out the resulting RecordBatch /// \return Status ARROW_EXPORT Status ReadRecordBatch(const Message& message, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, std::shared_ptr* out); /// Read record batch from file given metadata and schema /// /// \param[in] metadata a Message containing the record batch metadata /// \param[in] schema the record batch schema +/// \param[in] dictionary_memo DictionaryMemo which has any +/// dictionaries. Can be nullptr if you are sure there are no +/// dictionary-encoded fields /// \param[in] file a random access file /// \param[in] max_recursion_depth the maximum permitted nesting depth /// \param[out] out the read record batch /// \return Status ARROW_EXPORT Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, - int max_recursion_depth, io::RandomAccessFile* file, - std::shared_ptr* out); + const DictionaryMemo* dictionary_memo, int max_recursion_depth, + io::RandomAccessFile* file, std::shared_ptr* out); /// \brief Read arrow::Tensor as encapsulated IPC message in file /// diff --git a/cpp/src/arrow/ipc/test-common.cc b/cpp/src/arrow/ipc/test-common.cc index ff7ce051cef1a..abf27a113b4ab 100644 --- a/cpp/src/arrow/ipc/test-common.cc +++ b/cpp/src/arrow/ipc/test-common.cc @@ -456,12 +456,14 @@ Status MakeDictionary(std::shared_ptr* out) { std::vector is_valid = {true, true, false, true, true, true}; - auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); - auto dict2 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\", \"qux\"]"); + auto dict_ty = utf8(); - auto f0_type = arrow::dictionary(arrow::int32(), dict1); - auto f1_type = arrow::dictionary(arrow::int8(), dict1, true); - auto f2_type = arrow::dictionary(arrow::int32(), dict2); + auto dict1 = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\", \"baz\"]"); + auto dict2 = ArrayFromJSON(dict_ty, "[\"fo\", \"bap\", \"bop\", \"qup\"]"); + + auto f0_type = arrow::dictionary(arrow::int32(), dict_ty); + auto f1_type = arrow::dictionary(arrow::int8(), dict_ty, true); + auto f2_type = arrow::dictionary(arrow::int32(), dict_ty); std::shared_ptr indices0, indices1, indices2; std::vector indices0_values = {1, 2, -1, 0, 2, 0}; @@ -472,9 +474,9 @@ Status MakeDictionary(std::shared_ptr* out) { ArrayFromVector(is_valid, indices1_values, &indices1); ArrayFromVector(is_valid, indices2_values, &indices2); - auto a0 = std::make_shared(f0_type, indices0); - auto a1 = std::make_shared(f1_type, indices1); - auto a2 = std::make_shared(f2_type, indices2); + auto a0 = std::make_shared(f0_type, indices0, dict1); + auto a1 = std::make_shared(f1_type, indices1, dict1); + auto a2 = std::make_shared(f2_type, indices2, dict2); // Lists of dictionary-encoded strings auto f3_type = list(f1_type); @@ -487,23 +489,22 @@ Status MakeDictionary(std::shared_ptr* out) { std::shared_ptr a3 = std::make_shared( f3_type, length, std::static_pointer_cast(offsets3)->values(), - std::make_shared(f1_type, indices3), null_bitmap, 1); + std::make_shared(f1_type, indices3, dict1), null_bitmap, 1); // Dictionary-encoded lists of integers - auto dict4 = ArrayFromJSON(list(int8()), "[[44, 55], [], [66]]"); - auto f4_type = dictionary(int8(), dict4); + auto dict4_ty = list(int8()); + auto f4_type = dictionary(int8(), dict4_ty); auto indices4 = ArrayFromJSON(int8(), "[0, 1, 2, 0, 2, 2]"); - auto a4 = std::make_shared(f4_type, indices4); + auto dict4 = ArrayFromJSON(dict4_ty, "[[44, 55], [], [66]]"); + auto a4 = std::make_shared(f4_type, indices4, dict4); // construct batch auto schema = ::arrow::schema( {field("dict1", f0_type), field("dict2", f1_type), field("dict3", f2_type), field("list", f3_type), field("encoded list", f4_type)}); - std::vector> arrays = {a0, a1, a2, a3, a4}; - - *out = RecordBatch::Make(schema, length, arrays); + *out = RecordBatch::Make(schema, length, {a0, a1, a2, a3, a4}); return Status::OK(); } @@ -512,12 +513,13 @@ Status MakeDictionaryFlat(std::shared_ptr* out) { std::vector is_valid = {true, true, false, true, true, true}; - auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); - auto dict2 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\", \"qux\"]"); + auto dict_ty = utf8(); + auto dict1 = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\", \"baz\"]"); + auto dict2 = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\", \"baz\", \"qux\"]"); - auto f0_type = arrow::dictionary(arrow::int32(), dict1); - auto f1_type = arrow::dictionary(arrow::int8(), dict1); - auto f2_type = arrow::dictionary(arrow::int32(), dict2); + auto f0_type = arrow::dictionary(arrow::int32(), dict_ty); + auto f1_type = arrow::dictionary(arrow::int8(), dict_ty); + auto f2_type = arrow::dictionary(arrow::int32(), dict_ty); std::shared_ptr indices0, indices1, indices2; std::vector indices0_values = {1, 2, -1, 0, 2, 0}; @@ -528,9 +530,9 @@ Status MakeDictionaryFlat(std::shared_ptr* out) { ArrayFromVector(is_valid, indices1_values, &indices1); ArrayFromVector(is_valid, indices2_values, &indices2); - auto a0 = std::make_shared(f0_type, indices0); - auto a1 = std::make_shared(f1_type, indices1); - auto a2 = std::make_shared(f2_type, indices2); + auto a0 = std::make_shared(f0_type, indices0, dict1); + auto a1 = std::make_shared(f1_type, indices1, dict1); + auto a2 = std::make_shared(f2_type, indices2, dict2); // construct batch auto schema = ::arrow::schema( diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 4089cdd2a9e03..94edb4d518511 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -535,41 +535,19 @@ Status WriteIpcPayload(const IpcPayload& payload, io::OutputStream* dst, return Status::OK(); } -Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, DictionaryMemo* out_memo, - std::vector* out_payloads) { - DictionaryMemo dictionary_memo; - IpcPayload payload; - - out_payloads->clear(); - payload.type = Message::SCHEMA; - RETURN_NOT_OK(WriteSchemaMessage(schema, &dictionary_memo, &payload.metadata)); - out_payloads->push_back(std::move(payload)); - out_payloads->reserve(dictionary_memo.size() + 1); - - // Append dictionaries - for (auto& pair : dictionary_memo.id_to_dictionary()) { - int64_t dictionary_id = pair.first; - const auto& dictionary = pair.second; - - // Frame of reference is 0, see ARROW-384 - const int64_t buffer_start_offset = 0; - payload.type = Message::DICTIONARY_BATCH; - DictionaryWriter writer(dictionary_id, pool, buffer_start_offset, kMaxNestingDepth, - true /* allow_64bit */, &payload); - RETURN_NOT_OK(writer.Assemble(dictionary)); - out_payloads->push_back(std::move(payload)); - } - - if (out_memo != nullptr) { - *out_memo = std::move(dictionary_memo); - } - - return Status::OK(); +Status GetSchemaPayload(const Schema& schema, DictionaryMemo* dictionary_memo, + IpcPayload* out) { + out->type = Message::SCHEMA; + return WriteSchemaMessage(schema, dictionary_memo, &out->metadata); } -Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, - std::vector* out_payloads) { - return GetSchemaPayloads(schema, pool, nullptr, out_payloads); +Status GetDictionaryPayload(int64_t id, const std::shared_ptr& dictionary, + MemoryPool* pool, IpcPayload* out) { + out->type = Message::DICTIONARY_BATCH; + // Frame of reference is 0, see ARROW-384 + DictionaryWriter writer(id, pool, /*buffer_start_offset=*/0, ipc::kMaxNestingDepth, + true /* allow_64bit */, out); + return writer.Assemble(dictionary); } Status GetRecordBatchPayload(const RecordBatch& batch, MemoryPool* pool, @@ -825,9 +803,7 @@ Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr& dict int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length, MemoryPool* pool) { internal::IpcPayload payload; - internal::DictionaryWriter writer(dictionary_id, pool, buffer_start_offset, - kMaxNestingDepth, true, &payload); - RETURN_NOT_OK(writer.Assemble(dictionary)); + RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, dictionary, pool, &payload)); // The body size is computed in the payload *body_length = payload.body_length; @@ -899,20 +875,23 @@ class RecordBatchPayloadWriter : public RecordBatchWriter { ~RecordBatchPayloadWriter() override = default; RecordBatchPayloadWriter(std::unique_ptr payload_writer, - const Schema& schema) + const Schema& schema, DictionaryMemo* out_memo = nullptr) : payload_writer_(std::move(payload_writer)), schema_(schema), pool_(default_memory_pool()), - started_(false) {} + dictionary_memo_(out_memo) { + if (out_memo == nullptr) { + dictionary_memo_ = &internal_dict_memo_; + } + } // A Schema-owning constructor variant RecordBatchPayloadWriter(std::unique_ptr payload_writer, - const std::shared_ptr& schema) - : payload_writer_(std::move(payload_writer)), - shared_schema_(schema), - schema_(*schema), - pool_(default_memory_pool()), - started_(false) {} + const std::shared_ptr& schema, + DictionaryMemo* out_memo = nullptr) + : RecordBatchPayloadWriter(std::move(payload_writer), *schema, out_memo) { + shared_schema_ = schema; + } Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { if (!batch.schema()->Equals(schema_, false /* check_metadata */)) { @@ -920,6 +899,15 @@ class RecordBatchPayloadWriter : public RecordBatchWriter { } RETURN_NOT_OK(CheckStarted()); + + if (!wrote_dictionaries_) { + RETURN_NOT_OK(WriteDictionaries(batch)); + wrote_dictionaries_ = true; + } + + // TODO(wesm): Check for delta dictionaries. Can we scan for + // deltas while computing the RecordBatch payload to save time? + internal::IpcPayload payload; RETURN_NOT_OK(GetRecordBatchPayload(batch, pool_, &payload)); return payload_writer_->WritePayload(payload); @@ -936,15 +924,9 @@ class RecordBatchPayloadWriter : public RecordBatchWriter { started_ = true; RETURN_NOT_OK(payload_writer_->Start()); - // Write out schema payloads - std::vector payloads; - // XXX should we have a GetSchemaPayloads() variant that generates them - // one by one, to minimize memory usage? - RETURN_NOT_OK(GetSchemaPayloads(schema_, pool_, &payloads)); - for (const auto& payload : payloads) { - RETURN_NOT_OK(payload_writer_->WritePayload(payload)); - } - return Status::OK(); + internal::IpcPayload payload; + RETURN_NOT_OK(GetSchemaPayload(schema_, dictionary_memo_, &payload)); + return payload_writer_->WritePayload(payload); } protected: @@ -955,12 +937,29 @@ class RecordBatchPayloadWriter : public RecordBatchWriter { return Status::OK(); } + Status WriteDictionaries(const RecordBatch& batch) { + RETURN_NOT_OK(CollectDictionaries(batch, dictionary_memo_)); + + for (auto& pair : dictionary_memo_->id_to_dictionary()) { + internal::IpcPayload payload; + int64_t dictionary_id = pair.first; + const auto& dictionary = pair.second; + + RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, dictionary, pool_, &payload)); + RETURN_NOT_OK(payload_writer_->WritePayload(payload)); + } + return Status::OK(); + } + protected: std::unique_ptr payload_writer_; std::shared_ptr shared_schema_; const Schema& schema_; MemoryPool* pool_; - bool started_; + DictionaryMemo* dictionary_memo_; + DictionaryMemo internal_dict_memo_; + bool started_ = false; + bool wrote_dictionaries_ = false; }; // ---------------------------------------------------------------------- @@ -1207,20 +1206,15 @@ Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, kMaxNestingDepth, true); } -// TODO: this function also serializes dictionaries. This is suboptimal for -// the purpose of transmitting working set metadata without actually sending -// the data (e.g. ListFlights() in Flight RPC). - -Status SerializeSchema(const Schema& schema, MemoryPool* pool, - std::shared_ptr* out) { +Status SerializeSchema(const Schema& schema, DictionaryMemo* dictionary_memo, + MemoryPool* pool, std::shared_ptr* out) { std::shared_ptr stream; RETURN_NOT_OK(io::BufferOutputStream::Create(1024, pool, &stream)); auto payload_writer = make_unique(stream.get()); - RecordBatchPayloadWriter writer(std::move(payload_writer), schema); - // Write out schema and dictionaries + RecordBatchPayloadWriter writer(std::move(payload_writer), schema, dictionary_memo); + // Write schema and populate fields (but not dictionaries) in dictionary_memo RETURN_NOT_OK(writer.Start()); - return stream->Finish(out); } diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index 75034ea9ae9df..6bb55dbc1a53b 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -29,6 +29,7 @@ namespace arrow { +class Array; class Buffer; class MemoryPool; class RecordBatch; @@ -215,16 +216,16 @@ ARROW_EXPORT Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, io::OutputStream* out); -/// \brief Serialize schema using stream writer as a sequence of one or more -/// IPC messages +/// \brief Serialize schema as encapsulated IPC message /// /// \param[in] schema the schema to write +/// \param[in] dictionary_memo a DictionaryMemo for recording dictionary ids /// \param[in] pool a MemoryPool to allocate memory from /// \param[out] out the serialized schema /// \return Status ARROW_EXPORT -Status SerializeSchema(const Schema& schema, MemoryPool* pool, - std::shared_ptr* out); +Status SerializeSchema(const Schema& schema, DictionaryMemo* dictionary_memo, + MemoryPool* pool, std::shared_ptr* out); /// \brief Write multiple record batches to OutputStream, including schema /// \param[in] batches a vector of batches. Must all have same schema @@ -331,19 +332,23 @@ Status OpenRecordBatchWriter(std::unique_ptr sink, const std::shared_ptr& schema, std::unique_ptr* out); -/// \brief Compute IpcPayloads for the given schema +/// \brief Compute IpcPayload for the given schema /// \param[in] schema the Schema that is being serialized -/// \param[in,out] pool for any required temporary memory allocations -/// \param[in,out] dictionary_memo class for tracking dictionaries and assigning -/// dictionary ids +/// \param[in,out] dictionary_memo class to populate with assigned dictionary ids /// \param[out] out the returned vector of IpcPayloads /// \return Status ARROW_EXPORT -Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, - DictionaryMemo* dictionary_memo, std::vector* out); +Status GetSchemaPayload(const Schema& schema, DictionaryMemo* dictionary_memo, + IpcPayload* out); + +/// \brief Compute IpcPayload for a dictionary +/// \param[in] id the dictionary id +/// \param[in] dictionary the dictionary values +/// \param[out] payload the output IpcPayload +/// \return Status ARROW_EXPORT -Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, - std::vector* out); +Status GetDictionaryPayload(int64_t id, const std::shared_ptr& dictionary, + MemoryPool* pool, IpcPayload* payload); /// \brief Compute IpcPayload for the given record batch /// \param[in] batch the RecordBatch that is being serialized diff --git a/cpp/src/arrow/json/converter-test.cc b/cpp/src/arrow/json/converter-test.cc index fcf8b8797b516..86e8e8dc84a54 100644 --- a/cpp/src/arrow/json/converter-test.cc +++ b/cpp/src/arrow/json/converter-test.cc @@ -45,8 +45,10 @@ void AssertConvert(const std::shared_ptr& expected_type, } std::shared_ptr indices, unconverted, converted; ASSERT_OK(indices_builder.Finish(&indices)); - ASSERT_OK(DictionaryArray::FromArrays(dictionary(int32(), scalar_values), indices, - &unconverted)); + + auto unconverted_type = dictionary(int32(), scalar_values->type()); + unconverted = + std::make_shared(unconverted_type, indices, scalar_values); // convert the array std::shared_ptr converter; diff --git a/cpp/src/arrow/json/converter.cc b/cpp/src/arrow/json/converter.cc index d3698aa89d8c7..078e314186932 100644 --- a/cpp/src/arrow/json/converter.cc +++ b/cpp/src/arrow/json/converter.cc @@ -39,25 +39,33 @@ Status GenericConversionError(const DataType& type, Args&&... args) { std::forward(args)...); } -const DictionaryArray* GetDictionaryArray(const std::shared_ptr& in) { +namespace { + +const DictionaryArray& GetDictionaryArray(const std::shared_ptr& in) { DCHECK_EQ(in->type_id(), Type::DICTIONARY); auto dict_type = static_cast(in->type().get()); DCHECK_EQ(dict_type->index_type()->id(), Type::INT32); - DCHECK_EQ(dict_type->dictionary()->type_id(), Type::STRING); - return static_cast(in.get()); + DCHECK_EQ(dict_type->value_type()->id(), Type::STRING); + return static_cast(*in); } -template -Status VisitDictionaryEntries(const DictionaryArray* dict_array, Vis&& vis) { - const StringArray& dict = static_cast(*dict_array->dictionary()); - const Int32Array& indices = static_cast(*dict_array->indices()); +template +Status VisitDictionaryEntries(const DictionaryArray& dict_array, + ValidVisitor&& visit_valid, NullVisitor&& visit_null) { + const StringArray& dict = static_cast(*dict_array.dictionary()); + const Int32Array& indices = static_cast(*dict_array.indices()); for (int64_t i = 0; i < indices.length(); ++i) { - bool is_valid = indices.IsValid(i); - RETURN_NOT_OK(vis(is_valid, is_valid ? dict.GetView(indices.GetView(i)) : "")); + if (indices.IsValid(i)) { + RETURN_NOT_OK(visit_valid(dict.GetView(indices.GetView(i)))); + } else { + RETURN_NOT_OK(visit_null()); + } } return Status::OK(); } +} // namespace + // base class for types which accept and output non-nested types class PrimitiveConverter : public Converter { public: @@ -127,18 +135,13 @@ class NumericConverter : public PrimitiveConverter { if (in->type_id() == Type::NA) { return PrimitiveFromNull(pool_, out_type_, *in, out); } - auto dict_array = GetDictionaryArray(in); + const auto& dict_array = GetDictionaryArray(in); using Builder = typename TypeTraits::BuilderType; Builder builder(out_type_, pool_); - RETURN_NOT_OK(builder.Resize(dict_array->indices()->length())); - - auto visit = [&](bool is_valid, string_view repr) { - if (!is_valid) { - builder.UnsafeAppendNull(); - return Status::OK(); - } + RETURN_NOT_OK(builder.Resize(dict_array.indices()->length())); + auto visit_valid = [&](string_view repr) { value_type value; if (!convert_one_(repr.data(), repr.size(), &value)) { return GenericConversionError(*out_type_, ", couldn't parse:", repr); @@ -148,7 +151,12 @@ class NumericConverter : public PrimitiveConverter { return Status::OK(); }; - RETURN_NOT_OK(VisitDictionaryEntries(dict_array, visit)); + auto visit_null = [&]() { + builder.UnsafeAppendNull(); + return Status::OK(); + }; + + RETURN_NOT_OK(VisitDictionaryEntries(dict_array, visit_valid, visit_null)); return builder.Finish(out); } @@ -193,32 +201,39 @@ class BinaryConverter : public PrimitiveConverter { if (in->type_id() == Type::NA) { return BinaryFromNull(pool_, out_type_, *in, out); } - auto dict_array = GetDictionaryArray(in); + const auto& dict_array = GetDictionaryArray(in); using Builder = typename TypeTraits::BuilderType; Builder builder(out_type_, pool_); - RETURN_NOT_OK(builder.Resize(dict_array->indices()->length())); + RETURN_NOT_OK(builder.Resize(dict_array.indices()->length())); // TODO(bkietz) this can be computed during parsing at low cost int64_t data_length = 0; - auto visit_lengths = [&](bool is_valid, string_view value) { - if (is_valid) { - data_length += value.size(); - } + auto visit_lengths_valid = [&](string_view value) { + data_length += value.size(); return Status::OK(); }; - RETURN_NOT_OK(VisitDictionaryEntries(dict_array, visit_lengths)); + + auto visit_lengths_null = [&]() { + // no-op + return Status::OK(); + }; + + RETURN_NOT_OK( + VisitDictionaryEntries(dict_array, visit_lengths_valid, visit_lengths_null)); RETURN_NOT_OK(builder.ReserveData(data_length)); - auto visit = [&](bool is_valid, string_view value) { - if (is_valid) { - builder.UnsafeAppend(value); - } else { - builder.UnsafeAppendNull(); - } + auto visit_valid = [&](string_view value) { + builder.UnsafeAppend(value); return Status::OK(); }; - RETURN_NOT_OK(VisitDictionaryEntries(dict_array, visit)); + + auto visit_null = [&]() { + builder.UnsafeAppendNull(); + return Status::OK(); + }; + + RETURN_NOT_OK(VisitDictionaryEntries(dict_array, visit_valid, visit_null)); return builder.Finish(out); } }; diff --git a/cpp/src/arrow/json/parser.cc b/cpp/src/arrow/json/parser.cc index ac48a16dd1878..83f125b49e8d8 100644 --- a/cpp/src/arrow/json/parser.cc +++ b/cpp/src/arrow/json/parser.cc @@ -106,7 +106,7 @@ Status Kind::ForType(const DataType& type, Kind::type* kind) { Status Visit(const BinaryType&) { return SetKind(Kind::kString); } Status Visit(const FixedSizeBinaryType&) { return SetKind(Kind::kString); } Status Visit(const DictionaryType& dict_type) { - return Kind::ForType(*dict_type.dictionary()->type(), kind_); + return Kind::ForType(*dict_type.value_type(), kind_); } Status Visit(const ListType&) { return SetKind(Kind::kArray); } Status Visit(const StructType&) { return SetKind(Kind::kObject); } @@ -208,9 +208,9 @@ class RawArrayBuilder { /// an index refers). This means building is faster since we don't do /// allocation for string/number characters but accessing is strided. /// -/// On completion the indices and the character storage are combined into -/// a DictionaryArray, which is a convenient container for indices referring -/// into another array. +/// On completion the indices and the character storage are combined +/// into a dictionary-encoded array, which is a convenient container +/// for indices referring into another array. class ScalarBuilder { public: explicit ScalarBuilder(MemoryPool* pool) @@ -541,7 +541,9 @@ class RawBuilderSet { std::shared_ptr indices; // TODO(bkietz) embed builder->values_length() in this output somehow RETURN_NOT_OK(builder->Finish(&indices)); - return DictionaryArray::FromArrays(dictionary(int32(), scalar_values), indices, out); + auto ty = dictionary(int32(), scalar_values->type()); + *out = std::make_shared(ty, indices, scalar_values); + return Status::OK(); } template diff --git a/cpp/src/arrow/json/test-common.h b/cpp/src/arrow/json/test-common.h index 5d57e04fa664b..b4d20694ce525 100644 --- a/cpp/src/arrow/json/test-common.h +++ b/cpp/src/arrow/json/test-common.h @@ -94,7 +94,7 @@ struct GenerateImpl { } template Status Visit( - T const& t, typename std::enable_if::value>::type* = nullptr, + T const& t, typename std::enable_if::value>::type* = nullptr, typename std::enable_if::value>::type* = nullptr) { return Status::Invalid("can't generate a value of type " + t.name()); } diff --git a/cpp/src/arrow/pretty_print-test.cc b/cpp/src/arrow/pretty_print-test.cc index 39982825e85e8..fd8e0938b5419 100644 --- a/cpp/src/arrow/pretty_print-test.cc +++ b/cpp/src/arrow/pretty_print-test.cc @@ -427,12 +427,12 @@ TEST_F(TestPrettyPrint, DictionaryType) { std::shared_ptr dict; std::vector dict_values = {"foo", "bar", "baz"}; ArrayFromVector(dict_values, &dict); - std::shared_ptr dict_type = dictionary(int16(), dict); + std::shared_ptr dict_type = dictionary(int16(), utf8()); std::shared_ptr indices; std::vector indices_values = {1, 2, -1, 0, 2, 0}; ArrayFromVector(is_valid, indices_values, &indices); - auto arr = std::make_shared(dict_type, indices); + auto arr = std::make_shared(dict_type, indices, dict); static const char* expected = R"expected( -- dictionary: @@ -563,38 +563,19 @@ TEST_F(TestPrettyPrint, SchemaWithDictionary) { ArrayFromVector(dict_values, &dict); auto simple = field("one", int32()); - auto simple_dict = field("two", dictionary(int16(), dict)); + auto simple_dict = field("two", dictionary(int16(), utf8())); auto list_of_dict = field("three", list(simple_dict)); - auto struct_with_dict = field("four", struct_({simple, simple_dict})); auto sch = schema({simple, simple_dict, list_of_dict, struct_with_dict}); static const char* expected = R"expected(one: int32 two: dictionary - dictionary: - [ - "foo", - "bar", - "baz" - ] three: list> child 0, two: dictionary - dictionary: - [ - "foo", - "bar", - "baz" - ] four: struct> child 0, one: int32 - child 1, two: dictionary - dictionary: - [ - "foo", - "bar", - "baz" - ])expected"; + child 1, two: dictionary)expected"; PrettyPrintOptions options{0}; diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc index db3795a4d52cb..5c6f870c2f303 100644 --- a/cpp/src/arrow/pretty_print.cc +++ b/cpp/src/arrow/pretty_print.cc @@ -553,25 +553,16 @@ class SchemaPrinter : public PrettyPrinter { Status SchemaPrinter::PrintType(const DataType& type) { Write(type.ToString()); - if (type.id() == Type::DICTIONARY) { - indent_ += indent_size_; + for (int i = 0; i < type.num_children(); ++i) { Newline(); - Write("dictionary:\n"); - const auto& dict_type = checked_cast(type); - RETURN_NOT_OK(PrettyPrint(*dict_type.dictionary(), indent_ + indent_size_, sink_)); - indent_ -= indent_size_; - } else { - for (int i = 0; i < type.num_children(); ++i) { - Newline(); - std::stringstream ss; - ss << "child " << i << ", "; + std::stringstream ss; + ss << "child " << i << ", "; - indent_ += indent_size_; - WriteIndented(ss.str()); - RETURN_NOT_OK(PrintField(*type.child(i))); - indent_ -= indent_size_; - } + indent_ += indent_size_; + WriteIndented(ss.str()); + RETURN_NOT_OK(PrintField(*type.child(i))); + indent_ -= indent_size_; } return Status::OK(); } diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index 01fb29d7ff400..f0e4b921a575b 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -1178,9 +1178,13 @@ class CategoricalBlock : public PandasBlock { } } + // TODO(wesm): variable dictionaries + auto arr = converted_col->data()->chunk(0); + const auto& dict_arr = checked_cast(*arr); + placement_data_[rel_placement] = abs_placement; PyObject* dict; - RETURN_NOT_OK(ConvertArrayToPandas(options_, dict_type.dictionary(), nullptr, &dict)); + RETURN_NOT_OK(ConvertArrayToPandas(options_, dict_arr.dictionary(), nullptr, &dict)); dictionary_.reset(dict); ordered_ = dict_type.ordered(); diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index 861aa8d3dc8dd..b033b341ff214 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -22,6 +22,8 @@ #include "arrow/python/flight.h" #include "arrow/util/logging.h" +using arrow::flight::FlightPayload; + namespace arrow { namespace py { namespace flight { @@ -182,12 +184,14 @@ PyFlightDataStream::PyFlightDataStream( data_source_.reset(data_source); } -std::shared_ptr PyFlightDataStream::schema() { return stream_->schema(); } +std::shared_ptr PyFlightDataStream::schema() { return stream_->schema(); } -Status PyFlightDataStream::Next(arrow::flight::FlightPayload* payload) { - return stream_->Next(payload); +Status PyFlightDataStream::GetSchemaPayload(FlightPayload* payload) { + return stream_->GetSchemaPayload(payload); } +Status PyFlightDataStream::Next(FlightPayload* payload) { return stream_->Next(payload); } + PyGeneratorFlightDataStream::PyGeneratorFlightDataStream( PyObject* generator, std::shared_ptr schema, PyGeneratorFlightDataStreamCallback callback) @@ -196,9 +200,14 @@ PyGeneratorFlightDataStream::PyGeneratorFlightDataStream( generator_.reset(generator); } -std::shared_ptr PyGeneratorFlightDataStream::schema() { return schema_; } +std::shared_ptr PyGeneratorFlightDataStream::schema() { return schema_; } + +Status PyGeneratorFlightDataStream::GetSchemaPayload(FlightPayload* payload) { + return ipc::internal::GetSchemaPayload(*schema_, &dictionary_memo_, + &payload->ipc_message); +} -Status PyGeneratorFlightDataStream::Next(arrow::flight::FlightPayload* payload) { +Status PyGeneratorFlightDataStream::Next(FlightPayload* payload) { return SafeCallIntoPython([=] { callback_(generator_.obj(), payload); return CheckPyError(); diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h index 7f02c2494196d..19fbb02c592d3 100644 --- a/cpp/src/arrow/python/flight.h +++ b/cpp/src/arrow/python/flight.h @@ -23,6 +23,7 @@ #include #include "arrow/flight/api.h" +#include "arrow/ipc/dictionary.h" #include "arrow/python/common.h" #include "arrow/python/config.h" @@ -158,7 +159,9 @@ class ARROW_PYTHON_EXPORT PyFlightDataStream : public arrow::flight::FlightDataS /// Must only be called while holding the GIL. explicit PyFlightDataStream(PyObject* data_source, std::unique_ptr stream); - std::shared_ptr schema() override; + + std::shared_ptr schema() override; + Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override; Status Next(arrow::flight::FlightPayload* payload) override; private: @@ -179,12 +182,14 @@ class ARROW_PYTHON_EXPORT PyGeneratorFlightDataStream explicit PyGeneratorFlightDataStream(PyObject* generator, std::shared_ptr schema, PyGeneratorFlightDataStreamCallback callback); - std::shared_ptr schema() override; + std::shared_ptr schema() override; + Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override; Status Next(arrow::flight::FlightPayload* payload) override; private: OwnedRefNoGIL generator_; std::shared_ptr schema_; + ipc::DictionaryMemo dictionary_memo_; PyGeneratorFlightDataStreamCallback callback_; }; diff --git a/cpp/src/arrow/tensor.cc b/cpp/src/arrow/tensor.cc index 7cd4a3263c98e..8c1c58a5d1cf3 100644 --- a/cpp/src/arrow/tensor.cc +++ b/cpp/src/arrow/tensor.cc @@ -169,14 +169,14 @@ struct NonZeroCounter { : tensor_(tensor), result_(result) {} template - typename std::enable_if::value, Status>::type Visit( + typename std::enable_if::value, Status>::type Visit( const TYPE& type) { DCHECK(!is_tensor_supported(type.id())); return Status::NotImplemented("Tensor of ", type.ToString(), " is not implemented"); } template - typename std::enable_if::value, Status>::type Visit( + typename std::enable_if::value, Status>::type Visit( const TYPE& type) { *result_ = TensorCountNonZero(tensor_); return Status::OK(); diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc index aeffb8f42ff30..3c4adc9ab9165 100644 --- a/cpp/src/arrow/type-test.cc +++ b/cpp/src/arrow/type-test.cc @@ -632,38 +632,70 @@ TEST(TestStructType, GetFieldDuplicates) { ASSERT_EQ(results.size(), 0); } +TEST(TestDictionaryType, Basics) { + auto value_type = int32(); + + std::shared_ptr type1 = + std::dynamic_pointer_cast(dictionary(int16(), value_type)); + + auto type2 = std::dynamic_pointer_cast( + ::arrow::dictionary(int16(), type1, true)); + + ASSERT_TRUE(int16()->Equals(type1->index_type())); + ASSERT_TRUE(type1->value_type()->Equals(value_type)); + + ASSERT_TRUE(int16()->Equals(type2->index_type())); + ASSERT_TRUE(type2->value_type()->Equals(type1)); + + ASSERT_EQ("dictionary", type1->ToString()); + ASSERT_EQ( + "dictionary, " + "indices=int16, ordered=1>", + type2->ToString()); +} + TEST(TestDictionaryType, Equals) { - auto t1 = dictionary(int8(), ArrayFromJSON(int32(), "[3, 4, 5, 6]")); - auto t2 = dictionary(int8(), ArrayFromJSON(int32(), "[3, 4, 5, 6]")); - auto t3 = dictionary(int16(), ArrayFromJSON(int32(), "[3, 4, 5, 6]")); - auto t4 = dictionary(int8(), ArrayFromJSON(int16(), "[3, 4, 5, 6]")); - auto t5 = dictionary(int8(), ArrayFromJSON(int32(), "[3, 4, 7, 6]")); + auto t1 = dictionary(int8(), int32()); + auto t2 = dictionary(int8(), int32()); + auto t3 = dictionary(int16(), int32()); + auto t4 = dictionary(int8(), int16()); ASSERT_TRUE(t1->Equals(t2)); // Different index type ASSERT_FALSE(t1->Equals(t3)); // Different value type ASSERT_FALSE(t1->Equals(t4)); - // Different values - ASSERT_FALSE(t1->Equals(t5)); } TEST(TestDictionaryType, UnifyNumeric) { - auto t1 = dictionary(int8(), ArrayFromJSON(int64(), "[3, 4, 7]")); - auto t2 = dictionary(int8(), ArrayFromJSON(int64(), "[1, 7, 4, 8]")); - auto t3 = dictionary(int8(), ArrayFromJSON(int64(), "[1, -200]")); + auto dict_ty = int64(); + + auto t1 = dictionary(int8(), dict_ty); + auto d1 = ArrayFromJSON(dict_ty, "[3, 4, 7]"); - auto expected = dictionary(int8(), ArrayFromJSON(int64(), "[3, 4, 7, 1, 8, -200]")); + auto t2 = dictionary(int8(), dict_ty); + auto d2 = ArrayFromJSON(dict_ty, "[1, 7, 4, 8]"); - std::shared_ptr dict_type; + auto t3 = dictionary(int8(), dict_ty); + auto d3 = ArrayFromJSON(dict_ty, "[1, -200]"); + + auto expected = dictionary(int8(), dict_ty); + auto expected_dict = ArrayFromJSON(dict_ty, "[3, 4, 7, 1, 8, -200]"); + + std::shared_ptr out_type; + std::shared_ptr out_dict; ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get(), t3.get()}, - &dict_type)); - ASSERT_TRUE(dict_type->Equals(expected)); + {d1.get(), d2.get(), d3.get()}, &out_type, &out_dict)); + ASSERT_TRUE(out_type->Equals(*expected)); + ASSERT_TRUE(out_dict->Equals(*expected_dict)); std::vector> transpose_maps; ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get(), t3.get()}, - &dict_type, &transpose_maps)); - ASSERT_TRUE(dict_type->Equals(expected)); + {d1.get(), d2.get(), d3.get()}, &out_type, &out_dict, + &transpose_maps)); + ASSERT_TRUE(out_type->Equals(*expected)); + ASSERT_TRUE(out_dict->Equals(*expected_dict)); ASSERT_EQ(transpose_maps.size(), 3); ASSERT_EQ(transpose_maps[0], std::vector({0, 1, 2})); ASSERT_EQ(transpose_maps[1], std::vector({3, 2, 1, 4})); @@ -671,21 +703,30 @@ TEST(TestDictionaryType, UnifyNumeric) { } TEST(TestDictionaryType, UnifyString) { - auto t1 = dictionary(int16(), ArrayFromJSON(utf8(), "[\"foo\", \"bar\"]")); - auto t2 = dictionary(int32(), ArrayFromJSON(utf8(), "[\"quux\", \"foo\"]")); + auto dict_ty = utf8(); + + auto t1 = dictionary(int16(), dict_ty); + auto d1 = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\"]"); - auto expected = - dictionary(int8(), ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]")); + auto t2 = dictionary(int32(), dict_ty); + auto d2 = ArrayFromJSON(dict_ty, "[\"quux\", \"foo\"]"); - std::shared_ptr dict_type; - ASSERT_OK( - DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, &dict_type)); - ASSERT_TRUE(dict_type->Equals(expected)); + auto expected = dictionary(int8(), dict_ty); + auto expected_dict = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\", \"quux\"]"); + + std::shared_ptr out_type; + std::shared_ptr out_dict; + ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, + {d1.get(), d2.get()}, &out_type, &out_dict)); + ASSERT_TRUE(out_type->Equals(*expected)); + ASSERT_TRUE(out_dict->Equals(*expected_dict)); std::vector> transpose_maps; - ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, &dict_type, + ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, + {d1.get(), d2.get()}, &out_type, &out_dict, &transpose_maps)); - ASSERT_TRUE(dict_type->Equals(expected)); + ASSERT_TRUE(out_type->Equals(*expected)); + ASSERT_TRUE(out_dict->Equals(*expected_dict)); ASSERT_EQ(transpose_maps.size(), 2); ASSERT_EQ(transpose_maps[0], std::vector({0, 1})); @@ -699,24 +740,28 @@ TEST(TestDictionaryType, UnifyFixedSizeBinary) { auto buf = std::make_shared(data); // ["foo", "bar"] auto dict1 = std::make_shared(type, 2, SliceBuffer(buf, 0, 6)); - auto t1 = dictionary(int16(), dict1); + auto t1 = dictionary(int16(), type); // ["bar", "baz", "qux"] auto dict2 = std::make_shared(type, 3, SliceBuffer(buf, 3, 9)); - auto t2 = dictionary(int16(), dict2); + auto t2 = dictionary(int16(), type); // ["foo", "bar", "baz", "qux"] auto expected_dict = std::make_shared(type, 4, buf); - auto expected = dictionary(int8(), expected_dict); + auto expected = dictionary(int8(), type); - std::shared_ptr dict_type; - ASSERT_OK( - DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, &dict_type)); - ASSERT_TRUE(dict_type->Equals(expected)); + std::shared_ptr out_type; + std::shared_ptr out_dict; + ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, + {dict1.get(), dict2.get()}, &out_type, &out_dict)); + ASSERT_TRUE(out_type->Equals(*expected)); + ASSERT_TRUE(out_dict->Equals(*expected_dict)); std::vector> transpose_maps; - ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, &dict_type, + ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, + {dict1.get(), dict2.get()}, &out_type, &out_dict, &transpose_maps)); - ASSERT_TRUE(dict_type->Equals(expected)); + ASSERT_TRUE(out_type->Equals(*expected)); + ASSERT_TRUE(out_dict->Equals(*expected_dict)); ASSERT_EQ(transpose_maps.size(), 2); ASSERT_EQ(transpose_maps[0], std::vector({0, 1})); ASSERT_EQ(transpose_maps[1], std::vector({1, 2, 3})); @@ -733,7 +778,7 @@ TEST(TestDictionaryType, UnifyLarge) { } ASSERT_OK(builder.Finish(&dict1)); ASSERT_EQ(dict1->length(), 120); - auto t1 = dictionary(int8(), dict1); + auto t1 = dictionary(int8(), int32()); ASSERT_OK(builder.Reserve(30)); for (int32_t i = 110; i < 140; ++i) { @@ -741,7 +786,7 @@ TEST(TestDictionaryType, UnifyLarge) { } ASSERT_OK(builder.Finish(&dict2)); ASSERT_EQ(dict2->length(), 30); - auto t2 = dictionary(int8(), dict2); + auto t2 = dictionary(int8(), int32()); ASSERT_OK(builder.Reserve(140)); for (int32_t i = 0; i < 140; ++i) { @@ -749,13 +794,16 @@ TEST(TestDictionaryType, UnifyLarge) { } ASSERT_OK(builder.Finish(&expected_dict)); ASSERT_EQ(expected_dict->length(), 140); - // int8 would be too narrow to hold all possible index values - auto expected = dictionary(int16(), expected_dict); - std::shared_ptr dict_type; - ASSERT_OK( - DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, &dict_type)); - ASSERT_TRUE(dict_type->Equals(expected)); + // int8 would be too narrow to hold all possible index values + auto expected = dictionary(int16(), int32()); + + std::shared_ptr out_type; + std::shared_ptr out_dict; + ASSERT_OK(DictionaryType::Unify(default_memory_pool(), {t1.get(), t2.get()}, + {dict1.get(), dict2.get()}, &out_type, &out_dict)); + ASSERT_TRUE(out_type->Equals(*expected)); + ASSERT_TRUE(out_dict->Equals(*expected_dict)); } TEST(TypesTest, TestDecimal128Small) { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 58b8cb3d15084..67f07ea45c78d 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -58,17 +58,21 @@ std::vector> Field::Flatten() const { std::vector> flattened; if (type_->id() == Type::STRUCT) { for (const auto& child : type_->children()) { - auto flattened_child = std::make_shared(*child); + auto flattened_child = child->Copy(); flattened.push_back(flattened_child); flattened_child->name_.insert(0, name() + "."); flattened_child->nullable_ |= nullable_; } } else { - flattened.push_back(std::make_shared(*this)); + flattened.push_back(this->Copy()); } return flattened; } +std::shared_ptr Field::Copy() const { + return ::arrow::field(name_, type_, nullable_, metadata_); +} + bool Field::Equals(const Field& other, bool check_metadata) const { if (this == &other) { return true; @@ -149,7 +153,7 @@ std::string FixedSizeBinaryType::ToString() const { // ---------------------------------------------------------------------- // Date types -DateType::DateType(Type::type type_id) : FixedWidthType(type_id) {} +DateType::DateType(Type::type type_id) : TemporalType(type_id) {} Date32Type::Date32Type() : DateType(Type::DATE32) {} @@ -163,7 +167,7 @@ std::string Date32Type::ToString() const { return std::string("date32[day]"); } // Time types TimeType::TimeType(Type::type type_id, TimeUnit::type unit) - : FixedWidthType(type_id), unit_(unit) {} + : TemporalType(type_id), unit_(unit) {} Time32Type::Time32Type(TimeUnit::type unit) : TimeType(Type::TIME32, unit) { DCHECK(unit == TimeUnit::SECOND || unit == TimeUnit::MILLI) @@ -334,13 +338,17 @@ Decimal128Type::Decimal128Type(int32_t precision, int32_t scale) } // ---------------------------------------------------------------------- -// DictionaryType +// Dictionary-encoded type + +int DictionaryType::bit_width() const { + return checked_cast(*index_type_).bit_width(); +} DictionaryType::DictionaryType(const std::shared_ptr& index_type, - const std::shared_ptr& dictionary, bool ordered) + const std::shared_ptr& value_type, bool ordered) : FixedWidthType(Type::DICTIONARY), index_type_(index_type), - dictionary_(dictionary), + value_type_(value_type), ordered_(ordered) { #ifndef NDEBUG const auto& int_type = checked_cast(*index_type); @@ -348,15 +356,9 @@ DictionaryType::DictionaryType(const std::shared_ptr& index_type, #endif } -int DictionaryType::bit_width() const { - return checked_cast(*index_type_).bit_width(); -} - -std::shared_ptr DictionaryType::dictionary() const { return dictionary_; } - std::string DictionaryType::ToString() const { std::stringstream ss; - ss << "dictionarytype()->ToString() + ss << this->name() << "ToString() << ", indices=" << index_type_->ToString() << ", ordered=" << ordered_ << ">"; return ss.str(); } @@ -632,9 +634,9 @@ std::shared_ptr union_(const std::vector>& chil } std::shared_ptr dictionary(const std::shared_ptr& index_type, - const std::shared_ptr& dict_values, + const std::shared_ptr& dict_type, bool ordered) { - return std::make_shared(index_type, dict_values, ordered); + return std::make_shared(index_type, dict_type, ordered); } std::shared_ptr field(const std::string& name, diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index dff2bbe84694d..75ee674ec1bc3 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -129,7 +129,10 @@ struct Type { /// Unions of logical types UNION, - /// Dictionary aka Category type + /// Dictionary-encoded type, also called "categorical" or "factor" + /// in other programming languages. Holds the dictionary value + /// type but not the dictionary itself, which is part of the + /// ArrayData struct DICTIONARY, /// Map, a repeated struct logical type @@ -292,6 +295,8 @@ class ARROW_EXPORT Field { /// \brief Return whether the field is nullable bool nullable() const { return nullable_; } + std::shared_ptr Copy() const; + private: // Field name std::string name_; @@ -304,6 +309,8 @@ class ARROW_EXPORT Field { // The field's metadata, if any std::shared_ptr metadata_; + + ARROW_DISALLOW_COPY_AND_ASSIGN(Field); }; namespace detail { @@ -628,8 +635,14 @@ class ARROW_EXPORT UnionType : public NestedType { enum class DateUnit : char { DAY = 0, MILLI = 1 }; +/// \brief Base type for all date and time types +class ARROW_EXPORT TemporalType : public FixedWidthType { + public: + using FixedWidthType::FixedWidthType; +}; + /// \brief Base type class for date data -class ARROW_EXPORT DateType : public FixedWidthType { +class ARROW_EXPORT DateType : public TemporalType { public: virtual DateUnit unit() const = 0; @@ -697,7 +710,7 @@ static inline std::ostream& operator<<(std::ostream& os, TimeUnit::type unit) { } /// Base type class for time data -class ARROW_EXPORT TimeType : public FixedWidthType, public ParametricType { +class ARROW_EXPORT TimeType : public TemporalType, public ParametricType { public: TimeUnit::type unit() const { return unit_; } @@ -734,7 +747,7 @@ class ARROW_EXPORT Time64Type : public TimeType { std::string name() const override { return "time64"; } }; -class ARROW_EXPORT TimestampType : public FixedWidthType, public ParametricType { +class ARROW_EXPORT TimestampType : public TemporalType, public ParametricType { public: using Unit = TimeUnit; @@ -744,10 +757,10 @@ class ARROW_EXPORT TimestampType : public FixedWidthType, public ParametricType int bit_width() const override { return static_cast(sizeof(int64_t) * CHAR_BIT); } explicit TimestampType(TimeUnit::type unit = TimeUnit::MILLI) - : FixedWidthType(Type::TIMESTAMP), unit_(unit) {} + : TemporalType(Type::TIMESTAMP), unit_(unit) {} explicit TimestampType(TimeUnit::type unit, const std::string& timezone) - : FixedWidthType(Type::TIMESTAMP), unit_(unit), timezone_(timezone) {} + : TemporalType(Type::TIMESTAMP), unit_(unit), timezone_(timezone) {} std::string ToString() const override; std::string name() const override { return "timestamp"; } @@ -760,11 +773,11 @@ class ARROW_EXPORT TimestampType : public FixedWidthType, public ParametricType std::string timezone_; }; -// Holds different types of intervals. -class ARROW_EXPORT IntervalType : public FixedWidthType, public ParametricType { +// Base class for the different kinds of intervals. +class ARROW_EXPORT IntervalType : public TemporalType, public ParametricType { public: enum type { MONTHS, DAY_TIME }; - IntervalType() : FixedWidthType(Type::INTERVAL) {} + IntervalType() : TemporalType(Type::INTERVAL) {} virtual type interval_type() const = 0; virtual ~IntervalType() = default; @@ -783,7 +796,7 @@ class ARROW_EXPORT MonthIntervalType : public IntervalType { int bit_width() const override { return static_cast(sizeof(c_type) * CHAR_BIT); } - MonthIntervalType() {} + MonthIntervalType() : IntervalType() {} std::string ToString() const override { return name(); } std::string name() const override { return "month_interval"; } @@ -806,7 +819,7 @@ class ARROW_EXPORT DayTimeIntervalType : public IntervalType { static constexpr Type::type type_id = Type::INTERVAL; IntervalType::type interval_type() const override { return IntervalType::DAY_TIME; } - DayTimeIntervalType() {} + DayTimeIntervalType() : IntervalType() {} int bit_width() const override { return static_cast(sizeof(c_type) * CHAR_BIT); } @@ -816,7 +829,7 @@ class ARROW_EXPORT DayTimeIntervalType : public IntervalType { // \brief Represents an amount of elapsed time without any relation to a calendar // artifact. -class ARROW_EXPORT DurationType : public FixedWidthType, public ParametricType { +class ARROW_EXPORT DurationType : public TemporalType, public ParametricType { public: using Unit = TimeUnit; @@ -826,7 +839,7 @@ class ARROW_EXPORT DurationType : public FixedWidthType, public ParametricType { int bit_width() const override { return static_cast(sizeof(int64_t) * CHAR_BIT); } explicit DurationType(TimeUnit::type unit = TimeUnit::MILLI) - : FixedWidthType(Type::DURATION), unit_(unit) {} + : TemporalType(Type::DURATION), unit_(unit) {} std::string ToString() const override; std::string name() const override { return "duration"; } @@ -838,48 +851,54 @@ class ARROW_EXPORT DurationType : public FixedWidthType, public ParametricType { }; // ---------------------------------------------------------------------- -// DictionaryType (for categorical or dictionary-encoded data) +// Dictionary type (for representing categorical or dictionary-encoded +// in memory) -/// Concrete type class for dictionary data +/// \brief Dictionary-encoded value type with data-dependent +/// dictionary class ARROW_EXPORT DictionaryType : public FixedWidthType { public: static constexpr Type::type type_id = Type::DICTIONARY; DictionaryType(const std::shared_ptr& index_type, - const std::shared_ptr& dictionary, bool ordered = false); + const std::shared_ptr& value_type, bool ordered = false); + + std::string ToString() const override; + std::string name() const override { return "dictionary"; } int bit_width() const override; std::shared_ptr index_type() const { return index_type_; } - - std::shared_ptr dictionary() const; - - std::string ToString() const override; - std::string name() const override { return "dictionary"; } + std::shared_ptr value_type() const { return value_type_; } bool ordered() const { return ordered_; } - /// \brief Unify several dictionary types + /// \brief Unify dictionaries types /// /// Compute a resulting dictionary that will allow the union of values /// of all input dictionary types. The input types must all have the /// same value type. /// \param[in] pool Memory pool to allocate dictionary values from /// \param[in] types A sequence of input dictionary types + /// \param[in] dictionaries A sequence of input dictionaries + /// corresponding to each type /// \param[out] out_type The unified dictionary type + /// \param[out] out_dictionary The unified dictionary /// \param[out] out_transpose_maps (optionally) A sequence of integer vectors, /// one per input type. Each integer vector represents the transposition /// of input type indices into unified type indices. // XXX Should we return something special (an empty transpose map?) when // the transposition is the identity function? static Status Unify(MemoryPool* pool, const std::vector& types, + const std::vector& dictionaries, std::shared_ptr* out_type, + std::shared_ptr* out_dictionary, std::vector>* out_transpose_maps = NULLPTR); - private: + protected: // Must be an integer type (not currently checked) std::shared_ptr index_type_; - std::shared_ptr dictionary_; + std::shared_ptr value_type_; bool ordered_; }; @@ -1050,9 +1069,15 @@ union_(const std::vector>& children, } /// \brief Create a DictionaryType instance -std::shared_ptr ARROW_EXPORT -dictionary(const std::shared_ptr& index_type, - const std::shared_ptr& values, bool ordered = false); +/// \param[in] index_type the type of the dictionary indices (must be +/// a signed integer) +/// \param[in] dict_type the type of the values in the variable dictionary +/// \param[in] ordered true if the order of the dictionary values has +/// semantic meaning and should be preserved where possible +ARROW_EXPORT +std::shared_ptr dictionary(const std::shared_ptr& index_type, + const std::shared_ptr& dict_type, + bool ordered = false); /// @} diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 13612ba2ea5a4..a0f461f32626f 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -312,8 +312,7 @@ struct TypeTraits { template <> struct TypeTraits { using ArrayType = DictionaryArray; - // TODO(wesm): Not sure what to do about this - // using ScalarType = DictionaryScalar; + using ScalarType = DictionaryScalar; constexpr static bool is_parameter_free = false; }; @@ -328,7 +327,16 @@ struct TypeTraits { // template -using is_number = std::is_base_of; +using is_number_type = std::is_base_of; + +template +using is_integer_type = std::is_base_of; + +template +using is_floating_type = std::is_base_of; + +template +using is_temporal_type = std::is_base_of; template struct has_c_type { @@ -426,7 +434,7 @@ using enable_if_fixed_size_list = typename std::enable_if::value, R>::type; template -using enable_if_number = typename std::enable_if::value, R>::type; +using enable_if_number = typename std::enable_if::value, R>::type; namespace detail { diff --git a/cpp/src/arrow/util/concatenate-test.cc b/cpp/src/arrow/util/concatenate-test.cc index 8d9e9d6d62d61..0a4d851ceac39 100644 --- a/cpp/src/arrow/util/concatenate-test.cc +++ b/cpp/src/arrow/util/concatenate-test.cc @@ -167,8 +167,9 @@ TEST_F(ConcatenateTest, StructType) { TEST_F(ConcatenateTest, DictionaryType) { Check([this](int32_t size, double null_probability, std::shared_ptr* out) { auto indices = this->GeneratePrimitive(size, null_probability); - auto type = dictionary(int32(), this->GeneratePrimitive(128, 0)); - *out = std::make_shared(type, indices); + auto dict = this->GeneratePrimitive(128, 0); + auto type = dictionary(int32(), dict->type()); + *out = std::make_shared(type, indices, dict); }); } diff --git a/cpp/src/arrow/util/concatenate.cc b/cpp/src/arrow/util/concatenate.cc index 73a6c4920dc0f..9a77501ac6c85 100644 --- a/cpp/src/arrow/util/concatenate.cc +++ b/cpp/src/arrow/util/concatenate.cc @@ -210,7 +210,24 @@ class ConcatenateImpl { Status Visit(const DictionaryType& d) { auto fixed = internal::checked_cast(d.index_type().get()); - return ConcatenateBuffers(Buffers(1, *fixed), pool_, &out_.buffers[1]); + + // Two cases: all the dictionaries are the same, or unification is + // required + bool dictionaries_same = true; + const Array& dictionary0 = *in_[0].dictionary; + for (size_t i = 1; i < in_.size(); ++i) { + if (!in_[i].dictionary->Equals(dictionary0)) { + dictionaries_same = false; + break; + } + } + + if (dictionaries_same) { + out_.dictionary = in_[0].dictionary; + return ConcatenateBuffers(Buffers(1, *fixed), pool_, &out_.buffers[1]); + } else { + return Status::NotImplemented("Concat with dictionary unification NYI"); + } } Status Visit(const UnionType& u) { @@ -313,7 +330,7 @@ Status Concatenate(const ArrayVector& arrays, MemoryPool* pool, *arrays[0]->type(), " and ", *arrays[i]->type(), " were encountered."); } - data[i] = ArrayData(*arrays[i]->data()); + data[i] = *arrays[i]->data(); } ArrayData out_data; diff --git a/cpp/src/parquet/arrow/arrow-reader-writer-test.cc b/cpp/src/parquet/arrow/arrow-reader-writer-test.cc index c8c644252572f..d9fd2d3df9dd9 100644 --- a/cpp/src/parquet/arrow/arrow-reader-writer-test.cc +++ b/cpp/src/parquet/arrow/arrow-reader-writer-test.cc @@ -131,7 +131,7 @@ LogicalType::type get_logical_type(const ::DataType& type) { case ArrowId::DICTIONARY: { const ::arrow::DictionaryType& dict_type = static_cast(type); - return get_logical_type(*dict_type.dictionary()->type()); + return get_logical_type(*dict_type.value_type()); } case ArrowId::DECIMAL: return LogicalType::DECIMAL; @@ -180,7 +180,7 @@ ParquetType::type get_physical_type(const ::DataType& type) { case ArrowId::DICTIONARY: { const ::arrow::DictionaryType& dict_type = static_cast(type); - return get_physical_type(*dict_type.dictionary()->type()); + return get_physical_type(*dict_type.value_type()); } default: break; @@ -406,7 +406,7 @@ static std::shared_ptr MakeSimpleSchema(const ::DataType& type, switch (type.id()) { case ::arrow::Type::DICTIONARY: { const auto& dict_type = static_cast(type); - const ::DataType& values_type = *dict_type.dictionary()->type(); + const ::DataType& values_type = *dict_type.value_type(); switch (values_type.id()) { case ::arrow::Type::FIXED_SIZE_BINARY: byte_width = @@ -1077,13 +1077,14 @@ TEST_F(TestNullParquetIO, NullListColumn) { } TEST_F(TestNullParquetIO, NullDictionaryColumn) { - std::shared_ptr values = std::make_shared<::arrow::NullArray>(0); std::shared_ptr indices = std::make_shared<::arrow::Int8Array>(SMALL_SIZE, nullptr, nullptr, SMALL_SIZE); std::shared_ptr<::arrow::DictionaryType> dict_type = - std::make_shared<::arrow::DictionaryType>(::arrow::int8(), values); + std::make_shared<::arrow::DictionaryType>(::arrow::int8(), ::arrow::null()); + + std::shared_ptr dict = std::make_shared<::arrow::NullArray>(0); std::shared_ptr dict_values = - std::make_shared<::arrow::DictionaryArray>(dict_type, indices); + std::make_shared<::arrow::DictionaryArray>(dict_type, indices, dict); std::shared_ptr table = MakeSimpleTable(dict_values, true); this->sink_ = std::make_shared(); ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_, @@ -1897,7 +1898,9 @@ TEST(TestArrowReadWrite, DictionaryColumnChunkedWrite) { std::shared_ptr dict_values; ArrayFromVector<::arrow::StringType, std::string>(values, &dict_values); - auto dict_type = ::arrow::dictionary(::arrow::int32(), dict_values); + auto value_type = ::arrow::utf8(); + auto dict_type = ::arrow::dictionary(::arrow::int32(), value_type); + auto f0 = field("dictionary", dict_type); std::vector> fields; fields.emplace_back(f0); @@ -1907,8 +1910,8 @@ TEST(TestArrowReadWrite, DictionaryColumnChunkedWrite) { ArrayFromVector<::arrow::Int32Type, int32_t>({0, 1, 0, 2, 1}, &f0_values); ArrayFromVector<::arrow::Int32Type, int32_t>({2, 0, 1, 0, 2}, &f1_values); ::arrow::ArrayVector dict_arrays = { - std::make_shared<::arrow::DictionaryArray>(dict_type, f0_values), - std::make_shared<::arrow::DictionaryArray>(dict_type, f1_values)}; + std::make_shared<::arrow::DictionaryArray>(dict_type, f0_values, dict_values), + std::make_shared<::arrow::DictionaryArray>(dict_type, f1_values, dict_values)}; std::vector> columns; auto column = MakeColumn("dictionary", dict_arrays, true); diff --git a/cpp/src/parquet/arrow/arrow-schema-test.cc b/cpp/src/parquet/arrow/arrow-schema-test.cc index b399eb46d6839..b806782a09dca 100644 --- a/cpp/src/parquet/arrow/arrow-schema-test.cc +++ b/cpp/src/parquet/arrow/arrow-schema-test.cc @@ -721,53 +721,43 @@ TEST_F(TestConvertArrowSchema, ParquetFlatPrimitivesAsDictionaries) { parquet_fields.push_back( PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32)); - ArrayFromVector<::arrow::Int32Type, int32_t>(std::vector(), &dict); - arrow_fields.push_back( - ::arrow::field("int32", ::arrow::dictionary(::arrow::int8(), dict), false)); + arrow_fields.push_back(::arrow::field( + "int32", ::arrow::dictionary(::arrow::int8(), ::arrow::int32()), false)); parquet_fields.push_back( PrimitiveNode::Make("int64", Repetition::REQUIRED, ParquetType::INT64)); - ArrayFromVector<::arrow::Int64Type, int64_t>(std::vector(), &dict); arrow_fields.push_back(std::make_shared( - "int64", ::arrow::dictionary(::arrow::int8(), dict), false)); + "int64", ::arrow::dictionary(::arrow::int8(), ::arrow::int64()), false)); parquet_fields.push_back(PrimitiveNode::Make("date", Repetition::REQUIRED, ParquetType::INT32, LogicalType::DATE)); - ArrayFromVector<::arrow::Date32Type, int32_t>(std::vector(), &dict); - arrow_fields.push_back( - std::make_shared("date", ::arrow::dictionary(::arrow::int8(), dict), false)); + arrow_fields.push_back(std::make_shared( + "date", ::arrow::dictionary(::arrow::int8(), ::arrow::date32()), false)); parquet_fields.push_back(PrimitiveNode::Make("date64", Repetition::REQUIRED, ParquetType::INT32, LogicalType::DATE)); - ArrayFromVector<::arrow::Date64Type, int64_t>(std::vector(), &dict); arrow_fields.push_back(std::make_shared( - "date64", ::arrow::dictionary(::arrow::int8(), dict), false)); + "date64", ::arrow::dictionary(::arrow::int8(), ::arrow::date64()), false)); parquet_fields.push_back( PrimitiveNode::Make("float", Repetition::OPTIONAL, ParquetType::FLOAT)); - ArrayFromVector<::arrow::FloatType, float>(std::vector(), &dict); - arrow_fields.push_back( - std::make_shared("float", ::arrow::dictionary(::arrow::int8(), dict))); + arrow_fields.push_back(std::make_shared( + "float", ::arrow::dictionary(::arrow::int8(), ::arrow::float32()))); parquet_fields.push_back( PrimitiveNode::Make("double", Repetition::OPTIONAL, ParquetType::DOUBLE)); - ArrayFromVector<::arrow::DoubleType, double>(std::vector(), &dict); - arrow_fields.push_back( - std::make_shared("double", ::arrow::dictionary(::arrow::int8(), dict))); + arrow_fields.push_back(std::make_shared( + "double", ::arrow::dictionary(::arrow::int8(), ::arrow::float64()))); parquet_fields.push_back(PrimitiveNode::Make( "string", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, LogicalType::UTF8)); - ::arrow::StringBuilder string_builder(::arrow::default_memory_pool()); - ASSERT_OK(string_builder.Finish(&dict)); - arrow_fields.push_back( - std::make_shared("string", ::arrow::dictionary(::arrow::int8(), dict))); + arrow_fields.push_back(std::make_shared( + "string", ::arrow::dictionary(::arrow::int8(), ::arrow::utf8()))); parquet_fields.push_back(PrimitiveNode::Make( "binary", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, LogicalType::NONE)); - ::arrow::BinaryBuilder binary_builder(::arrow::default_memory_pool()); - ASSERT_OK(binary_builder.Finish(&dict)); - arrow_fields.push_back( - std::make_shared("binary", ::arrow::dictionary(::arrow::int8(), dict))); + arrow_fields.push_back(std::make_shared( + "binary", ::arrow::dictionary(::arrow::int8(), ::arrow::binary()))); ASSERT_OK(ConvertSchema(arrow_fields)); diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc index 1b03398f987e3..a6ee8f485a6ac 100644 --- a/cpp/src/parquet/arrow/schema.cc +++ b/cpp/src/parquet/arrow/schema.cc @@ -614,9 +614,8 @@ Status FieldToNode(const std::shared_ptr& field, // the encoding, not the schema level. const ::arrow::DictionaryType& dict_type = static_cast(*field->type()); - std::shared_ptr<::arrow::Field> unpacked_field = - ::arrow::field(field->name(), dict_type.dictionary()->type(), field->nullable(), - field->metadata()); + std::shared_ptr<::arrow::Field> unpacked_field = ::arrow::field( + field->name(), dict_type.value_type(), field->nullable(), field->metadata()); return FieldToNode(unpacked_field, properties, arrow_properties, out); } default: { diff --git a/cpp/src/parquet/arrow/writer.cc b/cpp/src/parquet/arrow/writer.cc index aefbdaea52c84..29e00fefdd6cd 100644 --- a/cpp/src/parquet/arrow/writer.cc +++ b/cpp/src/parquet/arrow/writer.cc @@ -1023,7 +1023,7 @@ class FileWriter::Impl { // TODO(ARROW-1648): Remove this special handling once we require an Arrow // version that has this fixed. - if (dict_type.dictionary()->type()->id() == ::arrow::Type::NA) { + if (dict_type.value_type()->id() == ::arrow::Type::NA) { auto null_array = std::make_shared<::arrow::NullArray>(data->length()); return WriteColumnChunk(*null_array); } @@ -1031,8 +1031,8 @@ class FileWriter::Impl { FunctionContext ctx(this->memory_pool()); ::arrow::compute::Datum cast_input(data); ::arrow::compute::Datum cast_output; - RETURN_NOT_OK(Cast(&ctx, cast_input, dict_type.dictionary()->type(), CastOptions(), - &cast_output)); + RETURN_NOT_OK( + Cast(&ctx, cast_input, dict_type.value_type(), CastOptions(), &cast_output)); return WriteColumnChunk(cast_output.chunked_array(), offset, size); } diff --git a/integration/integration_test.py b/integration/integration_test.py index 4c3a354d31daa..9aafb6cd73755 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -973,22 +973,25 @@ def generate_nested_case(): def generate_dictionary_case(): + dict_type0 = StringType('dictionary1') dict_type1 = StringType('dictionary1') dict_type2 = get_field('dictionary2', 'int64') - dict1 = Dictionary(0, dict_type1, - dict_type1.generate_column(10, name='DICT0')) - dict2 = Dictionary(1, dict_type2, - dict_type2.generate_column(50, name='DICT1')) + dict0 = Dictionary(0, dict_type0, + dict_type0.generate_column(10, name='DICT0')) + dict1 = Dictionary(1, dict_type1, + dict_type1.generate_column(5, name='DICT1')) + dict2 = Dictionary(2, dict_type2, + dict_type2.generate_column(50, name='DICT2')) fields = [ - DictionaryType('dict1_0', get_field('', 'int8'), dict1), - DictionaryType('dict1_1', get_field('', 'int32'), dict1), - DictionaryType('dict2_0', get_field('', 'int16'), dict2) + DictionaryType('dict0', get_field('', 'int8'), dict0), + DictionaryType('dict1', get_field('', 'int32'), dict1), + DictionaryType('dict2', get_field('', 'int16'), dict2) ] batch_sizes = [7, 10] return _generate_file("dictionary", fields, batch_sizes, - dictionaries=[dict1, dict2]) + dictionaries=[dict0, dict1, dict2]) def get_generated_json_files(tempdir=None, flight=False): @@ -1144,8 +1147,9 @@ class Tester(object): FLIGHT_CLIENT = False FLIGHT_PORT = 31337 - def __init__(self, debug=False): - self.debug = debug + def __init__(self, args): + self.args = args + self.debug = args.debug def json_to_file(self, json_path, arrow_path): raise NotImplementedError @@ -1416,20 +1420,28 @@ def get_static_json_files(): for p in glob.glob(glob_pattern)] -def run_all_tests(run_flight=False, debug=False, tempdir=None): - testers = [CPPTester(debug=debug), - JavaTester(debug=debug), - JSTester(debug=debug)] +def run_all_tests(args): + testers = [] + + if args.enable_cpp: + testers.append(CPPTester(args)) + + if args.enable_java: + testers.append(JavaTester(args)) + + if args.enable_js: + testers.append(JSTester(args)) + static_json_files = get_static_json_files() - generated_json_files = get_generated_json_files(tempdir=tempdir, - flight=run_flight) + generated_json_files = get_generated_json_files(tempdir=args.tempdir, + flight=args.run_flight) json_files = static_json_files + generated_json_files runner = IntegrationRunner(json_files, testers, - tempdir=tempdir, debug=debug) + tempdir=args.tempdir, debug=args.debug) failures = [] failures.extend(runner.run()) - if run_flight: + if args.run_flight: failures.extend(runner.run_flight()) fail_count = 0 @@ -1463,6 +1475,17 @@ def write_js_test_json(directory): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Arrow integration test CLI') + + parser.add_argument('--enable-c++', dest='enable_cpp', + action='store', type=int, default=1, + help='Include C++ in integration tests') + parser.add_argument('--enable-java', dest='enable_java', + action='store', type=int, default=1, + help='Include Java in integration tests') + parser.add_argument('--enable-js', dest='enable_js', + action='store', type=int, default=1, + help='Include JavaScript in integration tests') + parser.add_argument('--write_generated_json', dest='generated_json_path', action='store', default=False, help='Generate test JSON') @@ -1485,5 +1508,4 @@ def write_js_test_json(directory): raise write_js_test_json(args.generated_json_path) else: - run_all_tests(run_flight=args.run_flight, - debug=args.debug, tempdir=args.tempdir) + run_all_tests(args) diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index e74a9a168fb01..17916df8a0bb1 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -56,6 +56,7 @@ def parse_git(root, **kwargs): DataType, DictionaryType, ListType, StructType, UnionType, TimestampType, Time32Type, Time64Type, FixedSizeBinaryType, Decimal128Type, + DictionaryMemo, Field, Schema, schema, diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 271d13566e0c3..474e0076a1498 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -265,7 +265,9 @@ cdef class FlightInfo: """The schema of the data in this flight.""" cdef: shared_ptr[CSchema] schema - check_status(self.info.get().GetSchema(&schema)) + CDictionaryMemo dummy_memo + + check_status(self.info.get().GetSchema(&dummy_memo, &schema)) return pyarrow_wrap_schema(schema) @property diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index c8be6dd0be877..957c87549f944 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1303,15 +1303,18 @@ cdef class DictionaryArray(Array): cdef c_bool c_ordered = ordered c_type.reset(new CDictionaryType(_indices.type.sp_type, - _dictionary.sp_array, c_ordered)) + _dictionary.sp_array.get().type(), + c_ordered)) if safe: with nogil: check_status( CDictionaryArray.FromArrays(c_type, _indices.sp_array, + _dictionary.sp_array, &c_result)) else: - c_result.reset(new CDictionaryArray(c_type, _indices.sp_array)) + c_result.reset(new CDictionaryArray(c_type, _indices.sp_array, + _dictionary.sp_array)) return pyarrow_wrap_array(c_result) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 46905e0799a88..11141a15acf60 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -149,11 +149,13 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CDictionaryArray" arrow::DictionaryArray"(CArray): CDictionaryArray(const shared_ptr[CDataType]& type, - const shared_ptr[CArray]& indices) + const shared_ptr[CArray]& indices, + const shared_ptr[CArray]& dictionary) @staticmethod CStatus FromArrays(const shared_ptr[CDataType]& type, const shared_ptr[CArray]& indices, + const shared_ptr[CArray]& dictionary, shared_ptr[CArray]* out) shared_ptr[CArray] indices() @@ -180,11 +182,11 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CDictionaryType" arrow::DictionaryType"(CFixedWidthType): CDictionaryType(const shared_ptr[CDataType]& index_type, - const shared_ptr[CArray]& dictionary, + const shared_ptr[CDataType]& value_type, c_bool ordered) shared_ptr[CDataType] index_type() - shared_ptr[CArray] dictionary() + shared_ptr[CDataType] value_type() c_bool ordered() shared_ptr[CDataType] ctimestamp" arrow::timestamp"(TimeUnit unit) @@ -860,6 +862,9 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: MessageType_V3" arrow::ipc::MetadataVersion::V3" MessageType_V4" arrow::ipc::MetadataVersion::V4" + cdef cppclass CDictionaryMemo" arrow::ipc::DictionaryMemo": + pass + cdef cppclass CMessage" arrow::ipc::Message": CStatus Open(const shared_ptr[CBuffer]& metadata, const shared_ptr[CBuffer]& body, @@ -942,18 +947,22 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: CStatus ReadRecordBatch(const CMessage& message, const shared_ptr[CSchema]& schema, + CDictionaryMemo* dictionary_memo, shared_ptr[CRecordBatch]* out) - CStatus SerializeSchema(const CSchema& schema, CMemoryPool* pool, - shared_ptr[CBuffer]* out) + CStatus SerializeSchema(const CSchema& schema, + CDictionaryMemo* dictionary_memo, + CMemoryPool* pool, shared_ptr[CBuffer]* out) CStatus SerializeRecordBatch(const CRecordBatch& schema, CMemoryPool* pool, shared_ptr[CBuffer]* out) - CStatus ReadSchema(InputStream* stream, shared_ptr[CSchema]* out) + CStatus ReadSchema(InputStream* stream, CDictionaryMemo* dictionary_memo, + shared_ptr[CSchema]* out) CStatus ReadRecordBatch(const shared_ptr[CSchema]& schema, + CDictionaryMemo* dictionary_memo, InputStream* stream, shared_ptr[CRecordBatch]* out) diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 2d083e3de3e7d..bac7de54c89cd 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -94,7 +94,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CFlightInfo(CFlightInfo info) uint64_t total_records() uint64_t total_bytes() - CStatus GetSchema(shared_ptr[CSchema]* out) + CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out) CFlightDescriptor& descriptor() const vector[CFlightEndpoint]& endpoints() diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index e8573022bb4ca..cfd1cd71ac216 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -509,13 +509,16 @@ def read_message(source): return result -def read_schema(obj): +def read_schema(obj, DictionaryMemo dictionary_memo=None): """ Read Schema from message or buffer Parameters ---------- obj : buffer or Message + dictionary_memo : DictionaryMemo, optional + Needed to be able to reconstruct dictionary-encoded fields + with read_record_batch Returns ------- @@ -524,19 +527,27 @@ def read_schema(obj): cdef: shared_ptr[CSchema] result shared_ptr[RandomAccessFile] cpp_file + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo if isinstance(obj, Message): raise NotImplementedError(type(obj)) get_reader(obj, True, &cpp_file) + if dictionary_memo is not None: + arg_dict_memo = &dictionary_memo.memo + else: + arg_dict_memo = &temp_memo + with nogil: - check_status(ReadSchema(cpp_file.get(), &result)) + check_status(ReadSchema(cpp_file.get(), arg_dict_memo, &result)) return pyarrow_wrap_schema(result) -def read_record_batch(obj, Schema schema): +def read_record_batch(obj, Schema schema, + DictionaryMemo dictionary_memo=None): """ Read RecordBatch from message, given a known schema @@ -544,6 +555,9 @@ def read_record_batch(obj, Schema schema): ---------- obj : Message or Buffer-like schema : Schema + dictionary_memo : DictionaryMemo, optional + If message contains dictionaries, must pass a populated + DictionaryMemo Returns ------- @@ -552,14 +566,22 @@ def read_record_batch(obj, Schema schema): cdef: shared_ptr[CRecordBatch] result Message message + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo if isinstance(obj, Message): message = obj else: message = read_message(obj) + if dictionary_memo is not None: + arg_dict_memo = &dictionary_memo.memo + else: + arg_dict_memo = &temp_memo + with nogil: check_status(ReadRecordBatch(deref(message.message.get()), - schema.sp_schema, &result)) + schema.sp_schema, + arg_dict_memo, &result)) return pyarrow_wrap_batch(result) diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 49f7b3f60f88f..cb5c73290f366 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -69,6 +69,11 @@ cdef class StructType(DataType): cdef Field field_by_name(self, name) +cdef class DictionaryMemo: + cdef: + CDictionaryMemo memo + + cdef class DictionaryType(DataType): cdef: const CDictionaryType* dict_type diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index b70dbcaf44651..f59301c1ca1a0 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -697,7 +697,7 @@ def test_cast_from_null(): _check_cast_case((in_data, in_type, in_data, out_type)) out_types = [ - pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])), + pa.dictionary(pa.int32(), pa.string()), pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), pa.union([pa.field('a', pa.binary(10)), diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 499af7b01d5b2..655dd38d1f02e 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -92,8 +92,7 @@ def test_take_indices_types(): arr.take(indices) -@pytest.mark.parametrize('ordered', [ - False, pytest.param(True, marks=pytest.mark.xfail(strict=True))]) +@pytest.mark.parametrize('ordered', [False, True]) def test_take_dictionary(ordered): arr = pa.DictionaryArray.from_arrays([0, 1, 2, 0, 1, 2], ['a', 'b', 'c'], ordered=ordered) diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 467e26d3cd0db..3eb2cdc543b92 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -219,7 +219,7 @@ def test_stream_categorical_roundtrip(stream_fixture): }) batch = pa.RecordBatch.from_pandas(df) writer = stream_fixture._get_writer(stream_fixture.sink, batch.schema) - writer.write_batch(pa.RecordBatch.from_pandas(df)) + writer.write_batch(batch) writer.close() table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source())) diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py index d922443bedd0f..fadb901977c21 100644 --- a/python/pyarrow/tests/test_schema.py +++ b/python/pyarrow/tests/test_schema.py @@ -415,9 +415,8 @@ def test_schema_negative_indexing(): def test_schema_repr_with_dictionaries(): - dct = pa.array(['foo', 'bar', 'baz'], type=pa.string()) fields = [ - pa.field('one', pa.dictionary(pa.int16(), dct)), + pa.field('one', pa.dictionary(pa.int16(), pa.string())), pa.field('two', pa.int32()) ] sch = pa.schema(fields) @@ -425,12 +424,6 @@ def test_schema_repr_with_dictionaries(): expected = ( """\ one: dictionary - dictionary: - [ - "foo", - "bar", - "baz" - ] two: int32""") assert repr(sch) == expected diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index 251f4f6a87806..7d9abf7956b21 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -62,7 +62,7 @@ def get_many_types(): pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), pa.union([pa.field('a', pa.binary(10), nullable=False), pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), - pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])) + pa.dictionary(pa.int32(), pa.string()) ) @@ -113,9 +113,7 @@ def test_is_list(): def test_is_dictionary(): - assert types.is_dictionary( - pa.dictionary(pa.int32(), - pa.array(['a', 'b', 'c']))) + assert types.is_dictionary(pa.dictionary(pa.int32(), pa.string())) assert not types.is_dictionary(pa.int32()) @@ -308,23 +306,20 @@ def check_fields(ty, fields): def test_dictionary_type(): - ty0 = pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])) + ty0 = pa.dictionary(pa.int32(), pa.string()) assert ty0.index_type == pa.int32() - assert isinstance(ty0.dictionary, pa.Array) - assert ty0.dictionary.to_pylist() == ['a', 'b', 'c'] + assert ty0.value_type == pa.string() assert ty0.ordered is False - ty1 = pa.dictionary(pa.int8(), pa.array([1.0, 2.0]), ordered=True) + ty1 = pa.dictionary(pa.int8(), pa.float64(), ordered=True) assert ty1.index_type == pa.int8() - assert isinstance(ty0.dictionary, pa.Array) - assert ty1.dictionary.to_pylist() == [1.0, 2.0] + assert ty1.value_type == pa.float64() assert ty1.ordered is True # construct from non-arrow objects - ty2 = pa.dictionary('int8', ['a', 'b', 'c', 'd']) + ty2 = pa.dictionary('int8', 'string') assert ty2.index_type == pa.int8() - assert isinstance(ty2.dictionary, pa.Array) - assert ty2.dictionary.to_pylist() == ['a', 'b', 'c', 'd'] + assert ty2.value_type == pa.string() assert ty2.ordered is False diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 148f5823d7151..24feec758e835 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -176,6 +176,13 @@ cdef class DataType: raise NotImplementedError(str(self)) +cdef class DictionaryMemo: + """ + Tracking container for dictionary-encoded fields + """ + pass + + cdef class DictionaryType(DataType): """ Concrete class for dictionary data types. @@ -186,7 +193,7 @@ cdef class DictionaryType(DataType): self.dict_type = type.get() def __reduce__(self): - return dictionary, (self.index_type, self.dictionary, self.ordered) + return dictionary, (self.index_type, self.value_type, self.ordered) @property def ordered(self): @@ -204,11 +211,12 @@ cdef class DictionaryType(DataType): return pyarrow_wrap_data_type(self.dict_type.index_type()) @property - def dictionary(self): + def value_type(self): """ - The dictionary array, mapping dictionary indices to values. + The dictionary value type. The dictionary values are found in an + instance of DictionaryArray """ - return pyarrow_wrap_array(self.dict_type.dictionary()) + return pyarrow_wrap_data_type(self.dict_type.value_type()) cdef class ListType(DataType): @@ -893,7 +901,7 @@ cdef class Schema: return pyarrow_wrap_schema(c_schema) - def serialize(self, memory_pool=None): + def serialize(self, DictionaryMemo dictionary_memo=None, memory_pool=None): """ Write Schema to Buffer as encapsulated IPC message @@ -901,6 +909,10 @@ cdef class Schema: ---------- memory_pool : MemoryPool, default None Uses default memory pool if not specified + dictionary_memo : DictionaryMemo, optional + If schema contains dictionaries, must pass a + DictionaryMemo to be able to deserialize RecordBatch + objects Returns ------- @@ -909,9 +921,16 @@ cdef class Schema: cdef: shared_ptr[CBuffer] buffer CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo + + if dictionary_memo is not None: + arg_dict_memo = &dictionary_memo.memo + else: + arg_dict_memo = &temp_memo with nogil: - check_status(SerializeSchema(deref(self.schema), + check_status(SerializeSchema(deref(self.schema), arg_dict_memo, pool, &buffer)) return pyarrow_wrap_buffer(buffer) @@ -1430,14 +1449,14 @@ cpdef ListType list_(value_type): return out -cpdef DictionaryType dictionary(index_type, dict_values, bint ordered=False): +cpdef DictionaryType dictionary(index_type, value_type, bint ordered=False): """ Dictionary (categorical, or simply encoded) type Parameters ---------- index_type : DataType - dictionary : Array + value_type : DataType ordered : boolean Returns @@ -1446,15 +1465,12 @@ cpdef DictionaryType dictionary(index_type, dict_values, bint ordered=False): """ cdef: DataType _index_type = ensure_type(index_type, allow_none=False) + DataType _value_type = ensure_type(value_type, allow_none=False) DictionaryType out = DictionaryType.__new__(DictionaryType) shared_ptr[CDataType] dict_type - if not isinstance(dict_values, Array): - dict_values = array(dict_values) - dict_type.reset(new CDictionaryType(_index_type.sp_type, - ( dict_values).sp_array, - ordered == 1)) + _value_type.sp_type, ordered == 1)) out.init(dict_type) return out diff --git a/ruby/red-arrow/lib/arrow/dictionary-data-type.rb b/ruby/red-arrow/lib/arrow/dictionary-data-type.rb index e799fdfac799e..c90f4586730d9 100644 --- a/ruby/red-arrow/lib/arrow/dictionary-data-type.rb +++ b/ruby/red-arrow/lib/arrow/dictionary-data-type.rb @@ -22,7 +22,7 @@ class DictionaryDataType # Creates a new {Arrow::DictionaryDataType}. # - # @overload initialize(index_data_type, dictionary, ordered) + # @overload initialize(index_data_type, value_data_type, ordered) # # @param index_data_type [Arrow::DataType, Hash, String, Symbol] # The index data type of the dictionary data type. It must be @@ -39,18 +39,23 @@ class DictionaryDataType # See {Arrow::DataType.resolve} how to specify data type # description. # - # @param dictionary [Arrow::Array] The real values of the - # dictionary data type. + # @param value_data_type [Arrow::DataType, Hash, String, Symbol] + # The value data type of the dictionary data type. + # + # You can specify data type as a description by `Hash`. + # + # See {Arrow::DataType.resolve} how to specify data type + # description. # # @param ordered [Boolean] Whether dictionary contents are # ordered or not. # # @example Create a dictionary data type for {0: "Hello", 1: "World"} # index_data_type = :int8 - # dictionary = Arrow::StringArray.new(["Hello", "World"]) + # value_data_type = :string # ordered = true # Arrow::DictionaryDataType.new(index_data_type, - # dictionary, + # value_data_type, # ordered) # # @overload initialize(description) @@ -74,16 +79,21 @@ class DictionaryDataType # See {Arrow::DataType.resolve} how to specify data type # description. # - # @option description [Arrow::Array] :dictionary The real values - # of the dictionary data type. + # @option description [Arrow::DataType, Hash, String, Symbol] + # :value_data_type + # The value data type of the dictionary data type. + # + # You can specify data type as a description by `Hash`. + # + # See {Arrow::DataType.resolve} how to specify data type + # description. # # @option description [Boolean] :ordered Whether dictionary # contents are ordered or not. # # @example Create a dictionary data type for {0: "Hello", 1: "World"} - # dictionary = Arrow::StringArray.new(["Hello", "World"]) # Arrow::DictionaryDataType.new(index_data_type: :int8, - # dictionary: dictionary, + # value_data_type: :string, # ordered: true) def initialize(*args) n_args = args.size @@ -91,16 +101,17 @@ def initialize(*args) when 1 description = args[0] index_data_type = description[:index_data_type] - dictionary = description[:dictionary] + value_data_type = description[:value_data_type] ordered = description[:ordered] when 3 - index_data_type, dictionary, ordered = args + index_data_type, value_data_type, ordered = args else message = "wrong number of arguments (given, #{n_args}, expected 1 or 3)" raise ArgumentError, message end index_data_type = DataType.resolve(index_data_type) - initialize_raw(index_data_type, dictionary, ordered) + value_data_type = DataType.resolve(value_data_type) + initialize_raw(index_data_type, value_data_type, ordered) end end end diff --git a/ruby/red-arrow/test/test-dictionary-data-type.rb b/ruby/red-arrow/test/test-dictionary-data-type.rb index be9cd6f301035..c5b6dd1bfb5d8 100644 --- a/ruby/red-arrow/test/test-dictionary-data-type.rb +++ b/ruby/red-arrow/test/test-dictionary-data-type.rb @@ -19,21 +19,21 @@ class DictionaryDataTypeTest < Test::Unit::TestCase sub_test_case(".new") do def setup @index_data_type = :int8 - @dictionary = Arrow::StringArray.new(["Hello", "World"]) + @value_data_type = :string @ordered = true end test("ordered arguments") do assert_equal("dictionary", Arrow::DictionaryDataType.new(@index_data_type, - @dictionary, + @value_data_type, @ordered).to_s) end test("description") do assert_equal("dictionary", Arrow::DictionaryDataType.new(index_data_type: @index_data_type, - dictionary: @dictionary, + value_data_type: @value_data_type, ordered: @ordered).to_s) end end