diff --git a/.asf.yaml b/.asf.yaml index d2522ecae0b43..ae2709a8b0acd 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -19,13 +19,13 @@ github: description: "Apache Arrow is a multi-language toolbox for accelerated data interchange and in-memory processing" homepage: https://arrow.apache.org/ collaborators: + - amoeba - anjakefala - benibus - danepitkin - davisusanibar - - felipecrv - js8544 - - amoeba + - vibhatha notifications: commits: commits@arrow.apache.org diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 3d4fb10b10c39..bd14f1b895bf6 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -101,7 +101,7 @@ jobs: fetch-depth: 0 submodules: recursive - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: ${{ matrix.image }}-${{ hashFiles('cpp/**') }} @@ -214,7 +214,7 @@ jobs: run: | echo "cache-dir=$(ccache --get-config cache_dir)" >> $GITHUB_OUTPUT - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ${{ steps.ccache-info.outputs.cache-dir }} key: cpp-ccache-macos-${{ hashFiles('cpp/**') }} @@ -310,7 +310,7 @@ jobs: run: | echo "cache-dir=$(ccache --get-config cache_dir)" >> $GITHUB_OUTPUT - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ${{ steps.ccache-info.outputs.cache-dir }} key: cpp-ccache-windows-${{ env.CACHE_VERSION }}-${{ hashFiles('cpp/**') }} @@ -402,7 +402,7 @@ jobs: shell: msys2 {0} run: ci/scripts/msys2_setup.sh cpp - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ccache key: cpp-ccache-${{ matrix.msystem_lower}}-${{ hashFiles('cpp/**') }} diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 098b5ff29df5a..e394347e95261 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -45,7 +45,7 @@ jobs: run: | ci/scripts/util_free_space.sh - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: ubuntu-docs-${{ hashFiles('cpp/**') }} diff --git a/.github/workflows/docs_light.yml b/.github/workflows/docs_light.yml index 8d10060c9d8a0..5303531f34350 100644 --- a/.github/workflows/docs_light.yml +++ b/.github/workflows/docs_light.yml @@ -51,7 +51,7 @@ jobs: with: fetch-depth: 0 - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: conda-docs-${{ hashFiles('cpp/**') }} diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 4960b4dbd61e8..adb6fb2b57c75 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -75,7 +75,7 @@ jobs: run: | ci/scripts/util_free_space.sh - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: conda-${{ hashFiles('cpp/**') }} diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index ee4c1b21c37d4..1f1fc1b47a3c8 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -69,7 +69,7 @@ jobs: fetch-depth: 0 submodules: recursive - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: maven-${{ hashFiles('java/**') }} diff --git a/.github/workflows/java_jni.yml b/.github/workflows/java_jni.yml index 9f05a357a11d3..45de57f360a42 100644 --- a/.github/workflows/java_jni.yml +++ b/.github/workflows/java_jni.yml @@ -63,7 +63,7 @@ jobs: run: | ci/scripts/util_free_space.sh - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: java-jni-manylinux-2014-${{ hashFiles('cpp/**', 'java/**') }} @@ -103,7 +103,7 @@ jobs: fetch-depth: 0 submodules: recursive - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: maven-${{ hashFiles('java/**') }} diff --git a/.github/workflows/js.yml b/.github/workflows/js.yml index e2c76a3d1cb24..0d09e30d6eab5 100644 --- a/.github/workflows/js.yml +++ b/.github/workflows/js.yml @@ -91,7 +91,7 @@ jobs: with: fetch-depth: 0 - name: Jest Cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: js/.jest-cache key: js-jest-cache-${{ runner.os }}-${{ hashFiles('js/src/**/*.ts', 'js/test/**/*.ts', 'js/yarn.lock') }} @@ -121,7 +121,7 @@ jobs: with: fetch-depth: 0 - name: Jest Cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: js/.jest-cache key: js-jest-cache-${{ runner.os }}-${{ hashFiles('js/src/**/*.ts', 'js/test/**/*.ts', 'js/yarn.lock') }} diff --git a/.github/workflows/matlab.yml b/.github/workflows/matlab.yml index 6921e12213b5b..512ff2bb929b3 100644 --- a/.github/workflows/matlab.yml +++ b/.github/workflows/matlab.yml @@ -65,7 +65,7 @@ jobs: shell: bash run: echo "cache-dir=$(ccache --get-config cache_dir)" >> $GITHUB_OUTPUT - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ${{ steps.ccache-info.outputs.cache-dir }} key: matlab-ccache-ubuntu-${{ hashFiles('cpp/**', 'matlab/**') }} @@ -113,7 +113,7 @@ jobs: shell: bash run: echo "cache-dir=$(ccache --get-config cache_dir)" >> $GITHUB_OUTPUT - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ${{ steps.ccache-info.outputs.cache-dir }} key: matlab-ccache-macos-${{ hashFiles('cpp/**', 'matlab/**') }} @@ -155,7 +155,7 @@ jobs: shell: bash run: echo "cache-dir=$(ccache --get-config cache_dir)" >> $GITHUB_OUTPUT - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | ${{ steps.ccache-info.outputs.cache-dir }} diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index d9979da0ee12a..6e3797b29c21e 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -94,7 +94,7 @@ jobs: fetch-depth: 0 submodules: recursive - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: ${{ matrix.cache }}-${{ hashFiles('cpp/**') }} diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index 4fc308a28d4d6..2a801b6040ec8 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -73,7 +73,7 @@ jobs: fetch-depth: 0 submodules: recursive - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker # As this key is identical on both matrix builds only one will be able to successfully cache, @@ -206,7 +206,7 @@ jobs: ci/scripts/ccache_setup.sh echo "CCACHE_DIR=$(cygpath --absolute --windows ccache)" >> $GITHUB_ENV - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ccache key: r-${{ matrix.config.rtools }}-ccache-mingw-${{ matrix.config.arch }}-${{ hashFiles('cpp/src/**/*.cc','cpp/src/**/*.h)') }}-${{ github.run_id }} diff --git a/.github/workflows/r_nightly.yml b/.github/workflows/r_nightly.yml index 27a32d22f90c0..a57a8cddea3c0 100644 --- a/.github/workflows/r_nightly.yml +++ b/.github/workflows/r_nightly.yml @@ -86,7 +86,7 @@ jobs: exit 1 fi - name: Cache Repo - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: repo key: r-nightly-${{ github.run_id }} diff --git a/.github/workflows/ruby.yml b/.github/workflows/ruby.yml index be30865ac7ac6..74d56895f4c34 100644 --- a/.github/workflows/ruby.yml +++ b/.github/workflows/ruby.yml @@ -76,7 +76,7 @@ jobs: fetch-depth: 0 submodules: recursive - name: Cache Docker Volumes - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 # v3.3.1 + uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: path: .docker key: ubuntu-${{ matrix.ubuntu }}-ruby-${{ hashFiles('cpp/**') }} @@ -167,7 +167,7 @@ jobs: run: | echo "cache-dir=$(ccache --get-config cache_dir)" >> $GITHUB_OUTPUT - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ${{ steps.ccache-info.outputs.cache-dir }} key: ruby-ccache-macos-${{ hashFiles('cpp/**') }} @@ -252,7 +252,7 @@ jobs: run: | ridk exec bash ci\scripts\msys2_setup.sh ruby - name: Cache ccache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ccache key: ruby-ccache-ucrt${{ matrix.mingw-n-bits }}-${{ hashFiles('cpp/**') }} @@ -277,7 +277,7 @@ jobs: Write-Output "gem-dir=$(ridk exec gem env gemdir)" | ` Out-File -FilePath $env:GITHUB_OUTPUT -Encoding utf8 -Append - name: Cache RubyGems - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ${{ steps.rubygems-info.outputs.gem-dir }} key: ruby-rubygems-ucrt${{ matrix.mingw-n-bits }}-${{ hashFiles('**/Gemfile', 'ruby/*/*.gemspec') }} diff --git a/c_glib/README.md b/c_glib/README.md index d571053c3dce8..2a4d6b8a6628c 100644 --- a/c_glib/README.md +++ b/c_glib/README.md @@ -142,6 +142,17 @@ $ meson compile -C c_glib.build $ sudo meson install -C c_glib.build ``` +> [!WARNING] +> +> When building Arrow GLib, it typically uses the Arrow C++ installed via Homebrew. However, this can lead to build failures +> if there are mismatches between the changes in Arrow's GLib and C++ libraries. To resolve this, you may need to +> reference the Arrow C++ library built locally. In such cases, use the `--cmake-prefix-path` option with the `meson setup` +> command to explicitly specify the library path. +> +> ```console +> $ meson setup c_glib.build c_glib --cmake-prefix-path=${arrow_cpp_install_prefix} -Dgtk_doc=true +> ``` + Others: ```console @@ -231,9 +242,18 @@ Now, you can run unit tests by the followings: ```console $ cd c_glib.build -$ bundle exec ../c_glib/test/run-test.sh +$ BUNDLE_GEMFILE=../c_glib/Gemfile bundle exec ../c_glib/test/run-test.sh ``` + +> [!NOTE] +> +> If debugging is necessary, you can proceed using the `DEBUGGER` option as follows: +> +> ```console +> $ DEBUGGER=lldb BUNDLE_GEMFILE=../c_glib/Gemfile bundle exec ../c_glib/test/run-test.sh +> ``` + ## Common build problems ### build failed - /usr/bin/ld: cannot find -larrow diff --git a/c_glib/arrow-glib/basic-data-type.cpp b/c_glib/arrow-glib/basic-data-type.cpp index 0697646e5806d..98b2c92104507 100644 --- a/c_glib/arrow-glib/basic-data-type.cpp +++ b/c_glib/arrow-glib/basic-data-type.cpp @@ -125,9 +125,9 @@ G_BEGIN_DECLS * data types. */ -typedef struct GArrowDataTypePrivate_ { +struct GArrowDataTypePrivate { std::shared_ptr data_type; -} GArrowDataTypePrivate; +}; enum { PROP_DATA_TYPE = 1 @@ -1113,9 +1113,71 @@ garrow_date64_data_type_new(void) } -G_DEFINE_TYPE(GArrowTimestampDataType, - garrow_timestamp_data_type, - GARROW_TYPE_TEMPORAL_DATA_TYPE) +struct GArrowTimestampDataTypePrivate { + GTimeZone *time_zone; +}; + +enum { + PROP_TIME_ZONE = 1 +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GArrowTimestampDataType, + garrow_timestamp_data_type, + GARROW_TYPE_TEMPORAL_DATA_TYPE) + +#define GARROW_TIMESTAMP_DATA_TYPE_GET_PRIVATE(object) \ + static_cast( \ + garrow_timestamp_data_type_get_instance_private( \ + GARROW_TIMESTAMP_DATA_TYPE(object))) + +static void +garrow_timestamp_data_type_dispose(GObject *object) +{ + auto priv = GARROW_TIMESTAMP_DATA_TYPE_GET_PRIVATE(object); + + if (priv->time_zone) { + g_time_zone_unref(priv->time_zone); + priv->time_zone = nullptr; + } + + G_OBJECT_CLASS(garrow_timestamp_data_type_parent_class)->dispose(object); +} + +static void +garrow_timestamp_data_type_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GARROW_TIMESTAMP_DATA_TYPE_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_TIME_ZONE: + priv->time_zone = static_cast(g_value_dup_boxed(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +garrow_timestamp_data_type_get_property(GObject *object, + guint prop_id, + GValue *value, + GParamSpec *pspec) +{ + auto priv = GARROW_TIMESTAMP_DATA_TYPE_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_TIME_ZONE: + g_value_set_boxed(value, priv->time_zone); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} static void garrow_timestamp_data_type_init(GArrowTimestampDataType *object) @@ -1125,11 +1187,33 @@ garrow_timestamp_data_type_init(GArrowTimestampDataType *object) static void garrow_timestamp_data_type_class_init(GArrowTimestampDataTypeClass *klass) { + auto gobject_class = G_OBJECT_CLASS(klass); + gobject_class->dispose = garrow_timestamp_data_type_dispose; + gobject_class->set_property = garrow_timestamp_data_type_set_property; + gobject_class->get_property = garrow_timestamp_data_type_get_property; + + GParamSpec *spec; + /** + * GArrowTimestampDataType:time-zone: + * + * The time zone of this data type. + * + * Since: 16.0.0 + */ + spec = g_param_spec_boxed("time-zone", + "Time zone", + "The time zone of this data type", + G_TYPE_TIME_ZONE, + static_cast(G_PARAM_READWRITE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_TIME_ZONE, spec); } /** * garrow_timestamp_data_type_new: * @unit: The unit of the timestamp data. + * @time_zone: (nullable): The time zone of the timestamp data. If based GLib + * is less than 2.58, this is ignored. * * Returns: A newly created the number of * seconds/milliseconds/microseconds/nanoseconds since UNIX epoch in @@ -1138,30 +1222,38 @@ garrow_timestamp_data_type_class_init(GArrowTimestampDataTypeClass *klass) * Since: 0.7.0 */ GArrowTimestampDataType * -garrow_timestamp_data_type_new(GArrowTimeUnit unit) +garrow_timestamp_data_type_new(GArrowTimeUnit unit, + GTimeZone *time_zone) { auto arrow_unit = garrow_time_unit_to_raw(unit); - auto arrow_data_type = arrow::timestamp(arrow_unit); + std::string arrow_timezone; +#if GLIB_CHECK_VERSION(2, 58, 0) + if (time_zone) { + arrow_timezone = g_time_zone_get_identifier(time_zone); + } +#endif + auto arrow_data_type = arrow::timestamp(arrow_unit, arrow_timezone); auto data_type = GARROW_TIMESTAMP_DATA_TYPE(g_object_new(GARROW_TYPE_TIMESTAMP_DATA_TYPE, "data-type", &arrow_data_type, + "time-zone", time_zone, NULL)); return data_type; } /** * garrow_timestamp_data_type_get_unit: - * @timestamp_data_type: The #GArrowTimestampDataType. + * @data_type: The #GArrowTimestampDataType. * * Returns: The unit of the timestamp data type. * * Since: 0.8.0 */ GArrowTimeUnit -garrow_timestamp_data_type_get_unit(GArrowTimestampDataType *timestamp_data_type) +garrow_timestamp_data_type_get_unit(GArrowTimestampDataType *data_type) { const auto arrow_data_type = - garrow_data_type_get_raw(GARROW_DATA_TYPE(timestamp_data_type)); + garrow_data_type_get_raw(GARROW_DATA_TYPE(data_type)); const auto arrow_timestamp_data_type = std::static_pointer_cast(arrow_data_type); return garrow_time_unit_from_raw(arrow_timestamp_data_type->unit()); diff --git a/c_glib/arrow-glib/basic-data-type.h b/c_glib/arrow-glib/basic-data-type.h index affbfcf13c283..f1c5af409c9da 100644 --- a/c_glib/arrow-glib/basic-data-type.h +++ b/c_glib/arrow-glib/basic-data-type.h @@ -425,9 +425,11 @@ struct _GArrowTimestampDataTypeClass GArrowTemporalDataTypeClass parent_class; }; -GArrowTimestampDataType *garrow_timestamp_data_type_new (GArrowTimeUnit unit); +GArrowTimestampDataType * +garrow_timestamp_data_type_new(GArrowTimeUnit unit, + GTimeZone *time_zone); GArrowTimeUnit -garrow_timestamp_data_type_get_unit (GArrowTimestampDataType *timestamp_data_type); +garrow_timestamp_data_type_get_unit(GArrowTimestampDataType *data_type); #define GARROW_TYPE_TIME_DATA_TYPE (garrow_time_data_type_get_type()) diff --git a/c_glib/arrow-glib/version.h.in b/c_glib/arrow-glib/version.h.in index abb8ba08708de..01760fbfed1ff 100644 --- a/c_glib/arrow-glib/version.h.in +++ b/c_glib/arrow-glib/version.h.in @@ -110,6 +110,15 @@ # define GARROW_UNAVAILABLE(major, minor) G_UNAVAILABLE(major, minor) #endif +/** + * GARROW_VERSION_16_0: + * + * You can use this macro value for compile time API version check. + * + * Since: 16.0.0 + */ +#define GARROW_VERSION_16_0 G_ENCODE_VERSION(16, 0) + /** * GARROW_VERSION_15_0: * @@ -355,6 +364,20 @@ #define GARROW_AVAILABLE_IN_ALL +#if GARROW_VERSION_MIN_REQUIRED >= GARROW_VERSION_16_0 +# define GARROW_DEPRECATED_IN_16_0 GARROW_DEPRECATED +# define GARROW_DEPRECATED_IN_16_0_FOR(function) GARROW_DEPRECATED_FOR(function) +#else +# define GARROW_DEPRECATED_IN_16_0 +# define GARROW_DEPRECATED_IN_16_0_FOR(function) +#endif + +#if GARROW_VERSION_MAX_ALLOWED < GARROW_VERSION_16_0 +# define GARROW_AVAILABLE_IN_16_0 GARROW_UNAVAILABLE(16, 0) +#else +# define GARROW_AVAILABLE_IN_16_0 +#endif + #if GARROW_VERSION_MIN_REQUIRED >= GARROW_VERSION_15_0 # define GARROW_DEPRECATED_IN_15_0 GARROW_DEPRECATED # define GARROW_DEPRECATED_IN_15_0_FOR(function) GARROW_DEPRECATED_FOR(function) diff --git a/c_glib/doc/arrow-glib/arrow-glib-docs.xml b/c_glib/doc/arrow-glib/arrow-glib-docs.xml index 57b4b98701686..e92eb955675ed 100644 --- a/c_glib/doc/arrow-glib/arrow-glib-docs.xml +++ b/c_glib/doc/arrow-glib/arrow-glib-docs.xml @@ -193,6 +193,10 @@ Index of deprecated API + + Index of new symbols in 16.0.0 + + Index of new symbols in 13.0.0 diff --git a/c_glib/meson.build b/c_glib/meson.build index 7c495d2567d72..ffd41d4d574a7 100644 --- a/c_glib/meson.build +++ b/c_glib/meson.build @@ -24,7 +24,7 @@ project('arrow-glib', 'c', 'cpp', 'cpp_std=c++17', ]) -version = '15.0.0-SNAPSHOT' +version = '16.0.0-SNAPSHOT' if version.endswith('-SNAPSHOT') version_numbers = version.split('-')[0].split('.') version_tag = version.split('-')[1] diff --git a/c_glib/test/run-test.sh b/c_glib/test/run-test.sh index 33e9fbf85d026..c7bc6edca5f0d 100755 --- a/c_glib/test/run-test.sh +++ b/c_glib/test/run-test.sh @@ -34,14 +34,14 @@ for module in "${modules[@]}"; do module_build_dir="${build_dir}/${module}" if [ -d "${module_build_dir}" ]; then LD_LIBRARY_PATH="${module_build_dir}:${LD_LIBRARY_PATH}" + DYLD_LIBRARY_PATH="${module_build_dir}:${DYLD_LIBRARY_PATH}" fi done export LD_LIBRARY_PATH +export DYLD_LIBRARY_PATH if [ "${BUILD}" != "no" ]; then - if [ -f "Makefile" ]; then - make -j8 > /dev/null || exit $? - elif [ -f "build.ninja" ]; then + if [ -f "build.ninja" ]; then ninja || exit $? fi fi @@ -59,4 +59,19 @@ for module in "${modules[@]}"; do done export GI_TYPELIB_PATH -${GDB} ruby ${test_dir}/run-test.rb "$@" +if type rbenv > /dev/null 2>&1; then + RUBY="$(rbenv which ruby)" +else + RUBY=ruby +fi +DEBUGGER_ARGS=() +case "${DEBUGGER}" in + "gdb") + DEBUGGER_ARGS+=(--args) + ;; + "lldb") + DEBUGGER_ARGS+=(--one-line "env DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}") + DEBUGGER_ARGS+=(--) + ;; +esac +${DEBUGGER} "${DEBUGGER_ARGS[@]}" "${RUBY}" ${test_dir}/run-test.rb "$@" diff --git a/c_glib/test/test-timestamp-data-type.rb b/c_glib/test/test-timestamp-data-type.rb index dac3a9bc631d0..69437609feebf 100644 --- a/c_glib/test/test-timestamp-data-type.rb +++ b/c_glib/test/test-timestamp-data-type.rb @@ -26,6 +26,23 @@ def test_name assert_equal("timestamp", data_type.name) end + sub_test_case("time_zone") do + def test_nil + data_type = Arrow::TimestampDataType.new(:micro) + assert_nil(data_type.time_zone) + end + + def test_time_zone + data_type = Arrow::TimestampDataType.new(:micro, GLib::TimeZone.new("UTC")) + time_zone = data_type.time_zone + assert_not_nil(time_zone) + # glib2 gem 4.2.1 or later is required + if time_zone.respond_to?(:identifier) + assert_equal("UTC", time_zone.identifier) + end + end + end + sub_test_case("second") do def setup @data_type = Arrow::TimestampDataType.new(:second) diff --git a/ci/conda_env_python.txt b/ci/conda_env_python.txt index 97203442129c4..5fdd21d2bd1f9 100644 --- a/ci/conda_env_python.txt +++ b/ci/conda_env_python.txt @@ -23,7 +23,7 @@ cloudpickle fsspec hypothesis numpy>=1.16.6 -pytest +pytest<8 # pytest-lazy-fixture broken on pytest 8.0.0 pytest-faulthandler pytest-lazy-fixture s3fs>=2023.10.0 diff --git a/ci/conda_env_sphinx.txt b/ci/conda_env_sphinx.txt index 0e50875fc1ef8..d0f494d2e085d 100644 --- a/ci/conda_env_sphinx.txt +++ b/ci/conda_env_sphinx.txt @@ -20,7 +20,7 @@ breathe doxygen ipython numpydoc -pydata-sphinx-theme=0.14 +pydata-sphinx-theme=0.14.1 sphinx-autobuild sphinx-design sphinx-copybutton diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 674acc99f54a9..50d4fc28c58f3 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -18,7 +18,7 @@ _realname=arrow pkgbase=mingw-w64-${_realname} pkgname="${MINGW_PACKAGE_PREFIX}-${_realname}" -pkgver=14.0.2.9000 +pkgver=15.0.0.9000 pkgrel=8000 pkgdesc="Apache Arrow is a cross-language development platform for in-memory data (mingw-w64)" arch=("any") diff --git a/ci/scripts/integration_arrow_build.sh b/ci/scripts/integration_arrow_build.sh index 02f593bf77b23..e5c31527aedff 100755 --- a/ci/scripts/integration_arrow_build.sh +++ b/ci/scripts/integration_arrow_build.sh @@ -46,7 +46,7 @@ if [ "${ARROW_INTEGRATION_JAVA}" == "ON" ]; then export ARROW_JAVA_CDATA="ON" export JAVA_JNI_CMAKE_ARGS="-DARROW_JAVA_JNI_ENABLE_DEFAULT=OFF -DARROW_JAVA_JNI_ENABLE_C=ON" - ${arrow_dir}/ci/scripts/java_jni_build.sh ${arrow_dir} ${ARROW_HOME} ${build_dir} /tmp/dist/java/$(arch) + ${arrow_dir}/ci/scripts/java_jni_build.sh ${arrow_dir} ${ARROW_HOME} ${build_dir} /tmp/dist/java ${arrow_dir}/ci/scripts/java_build.sh ${arrow_dir} ${build_dir} /tmp/dist/java fi diff --git a/ci/scripts/java_jni_build.sh b/ci/scripts/java_jni_build.sh index 320c98c04df1e..d989351ab7e4d 100755 --- a/ci/scripts/java_jni_build.sh +++ b/ci/scripts/java_jni_build.sh @@ -24,7 +24,6 @@ arrow_install_dir=${2} build_dir=${3}/java_jni # The directory where the final binaries will be stored when scripts finish dist_dir=${4} - prefix_dir="${build_dir}/java-jni" echo "=== Clear output directories and leftovers ===" @@ -56,7 +55,6 @@ cmake \ -DBUILD_TESTING=${ARROW_JAVA_BUILD_TESTS} \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DCMAKE_PREFIX_PATH=${arrow_install_dir} \ - -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_INSTALL_PREFIX=${prefix_dir} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD:-OFF} \ -DProtobuf_USE_STATIC_LIBS=ON \ diff --git a/ci/scripts/java_jni_macos_build.sh b/ci/scripts/java_jni_macos_build.sh index d66c39a37c5bd..4ecc029bdd3c2 100755 --- a/ci/scripts/java_jni_macos_build.sh +++ b/ci/scripts/java_jni_macos_build.sh @@ -31,7 +31,7 @@ case ${normalized_arch} in ;; esac # The directory where the final binaries will be stored when scripts finish -dist_dir=${3}/${normalized_arch} +dist_dir=${3} echo "=== Clear output directories and leftovers ===" # Clear output directories and leftovers @@ -82,7 +82,6 @@ cmake \ -DARROW_S3=${ARROW_S3} \ -DARROW_USE_CCACHE=${ARROW_USE_CCACHE} \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ - -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_INSTALL_PREFIX=${install_dir} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ -DGTest_SOURCE=BUNDLED \ @@ -138,8 +137,8 @@ archery linking check-dependencies \ --allow libncurses \ --allow libobjc \ --allow libz \ - libarrow_cdata_jni.dylib \ - libarrow_dataset_jni.dylib \ - libarrow_orc_jni.dylib \ - libgandiva_jni.dylib + arrow_cdata_jni/${normalized_arch}/libarrow_cdata_jni.dylib \ + arrow_dataset_jni/${normalized_arch}/libarrow_dataset_jni.dylib \ + arrow_orc_jni/${normalized_arch}/libarrow_orc_jni.dylib \ + gandiva_jni/${normalized_arch}/libgandiva_jni.dylib popd diff --git a/ci/scripts/java_jni_manylinux_build.sh b/ci/scripts/java_jni_manylinux_build.sh index 03939715e390f..da4987d307ce4 100755 --- a/ci/scripts/java_jni_manylinux_build.sh +++ b/ci/scripts/java_jni_manylinux_build.sh @@ -28,7 +28,7 @@ case ${normalized_arch} in ;; esac # The directory where the final binaries will be stored when scripts finish -dist_dir=${3}/${normalized_arch} +dist_dir=${3} echo "=== Clear output directories and leftovers ===" # Clear output directories and leftovers @@ -91,7 +91,6 @@ cmake \ -DARROW_S3=${ARROW_S3} \ -DARROW_USE_CCACHE=${ARROW_USE_CCACHE} \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ - -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_INSTALL_PREFIX=${ARROW_HOME} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ -DGTest_SOURCE=BUNDLED \ @@ -164,8 +163,8 @@ archery linking check-dependencies \ --allow libstdc++ \ --allow libz \ --allow linux-vdso \ - libarrow_cdata_jni.so \ - libarrow_dataset_jni.so \ - libarrow_orc_jni.so \ - libgandiva_jni.so + arrow_cdata_jni/${normalized_arch}/libarrow_cdata_jni.so \ + arrow_dataset_jni/${normalized_arch}/libarrow_dataset_jni.so \ + arrow_orc_jni/${normalized_arch}/libarrow_orc_jni.so \ + gandiva_jni/${normalized_arch}/libgandiva_jni.so popd diff --git a/ci/scripts/java_jni_windows_build.sh b/ci/scripts/java_jni_windows_build.sh index 778ee9696790e..39288f4a9d0ce 100755 --- a/ci/scripts/java_jni_windows_build.sh +++ b/ci/scripts/java_jni_windows_build.sh @@ -22,7 +22,7 @@ set -ex arrow_dir=${1} build_dir=${2} # The directory where the final binaries will be stored when scripts finish -dist_dir=${3}/x86_64 +dist_dir=${3} echo "=== Clear output directories and leftovers ===" # Clear output directories and leftovers @@ -72,7 +72,6 @@ cmake \ -DARROW_WITH_SNAPPY=ON \ -DARROW_WITH_ZSTD=ON \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ - -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_INSTALL_PREFIX=${install_dir} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ -GNinja \ diff --git a/ci/scripts/r_docker_configure.sh b/ci/scripts/r_docker_configure.sh index 1cbd5f0b5ea96..52db2e6df6611 100755 --- a/ci/scripts/r_docker_configure.sh +++ b/ci/scripts/r_docker_configure.sh @@ -91,8 +91,9 @@ if [ -f "${ARROW_SOURCE_HOME}/ci/scripts/r_install_system_dependencies.sh" ]; th "${ARROW_SOURCE_HOME}/ci/scripts/r_install_system_dependencies.sh" fi -# Install rsync for bundling cpp source and curl to make sure it is installed on all images -$PACKAGE_MANAGER install -y rsync curl +# Install rsync for bundling cpp source and curl to make sure it is installed on all images, +# cmake is now a listed sys req. +$PACKAGE_MANAGER install -y rsync cmake curl # Workaround for html help install failure; see https://github.com/r-lib/devtools/issues/2084#issuecomment-530912786 Rscript -e 'x <- file.path(R.home("doc"), "html"); if (!file.exists(x)) {dir.create(x, recursive=TRUE); file.copy(system.file("html/R.css", package="stats"), x)}' diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d26e06a146b56..016cd8a1b9ec8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -71,7 +71,7 @@ if(POLICY CMP0135) cmake_policy(SET CMP0135 NEW) endif() -set(ARROW_VERSION "15.0.0-SNAPSHOT") +set(ARROW_VERSION "16.0.0-SNAPSHOT") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ARROW_BASE_VERSION "${ARROW_VERSION}") diff --git a/cpp/cmake_modules/FindClangTools.cmake b/cpp/cmake_modules/FindClangTools.cmake index 90df60bf541d4..1364ccbed8162 100644 --- a/cpp/cmake_modules/FindClangTools.cmake +++ b/cpp/cmake_modules/FindClangTools.cmake @@ -40,7 +40,8 @@ set(CLANG_TOOLS_SEARCH_PATHS /usr/local/bin /usr/bin "C:/Program Files/LLVM/bin" # Windows, non-conda - "$ENV{CONDA_PREFIX}/Library/bin") # Windows, conda + "$ENV{CONDA_PREFIX}/Library/bin" # Windows, conda + "$ENV{CONDA_PREFIX}/bin") # Unix, conda if(APPLE) find_program(BREW brew) if(BREW) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index a2627c190f738..6bb9c0f6af2ca 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1005,14 +1005,8 @@ if("${MAKE}" STREQUAL "") endif() endif() -# Using make -j in sub-make is fragile -# see discussion https://github.com/apache/arrow/pull/2779 -if(${CMAKE_GENERATOR} MATCHES "Makefiles") - set(MAKE_BUILD_ARGS "") -else() - # limit the maximum number of jobs for ninja - set(MAKE_BUILD_ARGS "-j${NPROC}") -endif() +# Args for external projects using make. +set(MAKE_BUILD_ARGS "-j${NPROC}") include(FetchContent) set(FC_DECLARE_COMMON_OPTIONS) @@ -2042,10 +2036,6 @@ macro(build_jemalloc) endif() set(JEMALLOC_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS}) - # Paralleism for Make fails with CMake > 3.28 see #39517 - if(${CMAKE_GENERATOR} MATCHES "Makefiles") - list(APPEND JEMALLOC_BUILD_COMMAND "-j1") - endif() if(CMAKE_OSX_SYSROOT) list(APPEND JEMALLOC_BUILD_COMMAND "SDKROOT=${CMAKE_OSX_SYSROOT}") diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc b/cpp/src/arrow/acero/hash_aggregate_test.cc index a4874f3581040..2626fd50379dd 100644 --- a/cpp/src/arrow/acero/hash_aggregate_test.cc +++ b/cpp/src/arrow/acero/hash_aggregate_test.cc @@ -1694,6 +1694,42 @@ TEST_P(GroupBy, SumMeanProductScalar) { } } +TEST_P(GroupBy, MeanOverflow) { + BatchesWithSchema input; + // would overflow if intermediate sum is integer + input.batches = { + ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, + + "[[9223372036854775805, 1], [9223372036854775805, 1], " + "[9223372036854775805, 2], [9223372036854775805, 3]]"), + ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, + "[[null, 1], [null, 1], [null, 2], [null, 3]]"), + ExecBatchFromJSON({int64(), int64()}, + "[[9223372036854775805, 1], [9223372036854775805, 2], " + "[9223372036854775805, 3]]"), + }; + input.schema = schema({field("argument", int64()), field("key", int64())}); + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + ASSERT_OK_AND_ASSIGN(Datum actual, + RunGroupBy(input, {"key"}, + { + {"hash_mean", nullptr, "argument", "hash_mean"}, + }, + use_threads)); + Datum expected = ArrayFromJSON(struct_({ + field("key", int64()), + field("hash_mean", float64()), + }), + R"([ + [1, 9223372036854775805], + [2, 9223372036854775805], + [3, 9223372036854775805] + ])"); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } +} + TEST_P(GroupBy, VarianceAndStddev) { auto batch = RecordBatchFromJSON( schema({field("argument", int32()), field("key", int64())}), R"([ diff --git a/cpp/src/arrow/chunk_resolver.h b/cpp/src/arrow/chunk_resolver.h index 818070ffe350a..d3ae315568d08 100644 --- a/cpp/src/arrow/chunk_resolver.h +++ b/cpp/src/arrow/chunk_resolver.h @@ -18,87 +18,151 @@ #pragma once #include +#include #include #include #include "arrow/type_fwd.h" #include "arrow/util/macros.h" -namespace arrow { -namespace internal { +namespace arrow::internal { struct ChunkLocation { - int64_t chunk_index, index_in_chunk; + /// \brief Index of the chunk in the array of chunks + /// + /// The value is always in the range `[0, chunks.size()]`. `chunks.size()` is used + /// to represent out-of-bounds locations. + int64_t chunk_index; + + /// \brief Index of the value in the chunk + /// + /// The value is undefined if chunk_index >= chunks.size() + int64_t index_in_chunk; }; -// An object that resolves an array chunk depending on a logical index +/// \brief An utility that incrementally resolves logical indices into +/// physical indices in a chunked array. struct ARROW_EXPORT ChunkResolver { - explicit ChunkResolver(const ArrayVector& chunks); + private: + /// \brief Array containing `chunks.size() + 1` offsets. + /// + /// `offsets_[i]` is the starting logical index of chunk `i`. `offsets_[0]` is always 0 + /// and `offsets_[chunks.size()]` is the logical length of the chunked array. + std::vector offsets_; - explicit ChunkResolver(const std::vector& chunks); + /// \brief Cache of the index of the last resolved chunk. + /// + /// \invariant `cached_chunk_ in [0, chunks.size()]` + mutable std::atomic cached_chunk_; + public: + explicit ChunkResolver(const ArrayVector& chunks); + explicit ChunkResolver(const std::vector& chunks); explicit ChunkResolver(const RecordBatchVector& batches); ChunkResolver(ChunkResolver&& other) noexcept - : offsets_(std::move(other.offsets_)), cached_chunk_(other.cached_chunk_.load()) {} + : offsets_(std::move(other.offsets_)), + cached_chunk_(other.cached_chunk_.load(std::memory_order_relaxed)) {} ChunkResolver& operator=(ChunkResolver&& other) { offsets_ = std::move(other.offsets_); - cached_chunk_.store(other.cached_chunk_.load()); + cached_chunk_.store(other.cached_chunk_.load(std::memory_order_relaxed)); return *this; } - /// \brief Return a ChunkLocation containing the chunk index and in-chunk value index of - /// the chunked array at logical index - inline ChunkLocation Resolve(const int64_t index) const { - // It is common for the algorithms below to make consecutive accesses at - // a relatively small distance from each other, hence often falling in - // the same chunk. - // This is trivial when merging (assuming each side of the merge uses - // its own resolver), but also in the inner recursive invocations of + /// \brief Resolve a logical index to a ChunkLocation. + /// + /// The returned ChunkLocation contains the chunk index and the within-chunk index + /// equivalent to the logical index. + /// + /// \pre index >= 0 + /// \post location.chunk_index in [0, chunks.size()] + /// \param index The logical index to resolve + /// \return ChunkLocation with a valid chunk_index if index is within + /// bounds, or with chunk_index == chunks.size() if logical index is + /// `>= chunked_array.length()`. + inline ChunkLocation Resolve(int64_t index) const { + const auto cached_chunk = cached_chunk_.load(std::memory_order_relaxed); + const auto chunk_index = + ResolveChunkIndex(index, cached_chunk); + return {chunk_index, index - offsets_[chunk_index]}; + } + + /// \brief Resolve a logical index to a ChunkLocation. + /// + /// The returned ChunkLocation contains the chunk index and the within-chunk index + /// equivalent to the logical index. + /// + /// \pre index >= 0 + /// \post location.chunk_index in [0, chunks.size()] + /// \param index The logical index to resolve + /// \param cached_chunk_index 0 or the chunk_index of the last ChunkLocation + /// returned by this ChunkResolver. + /// \return ChunkLocation with a valid chunk_index if index is within + /// bounds, or with chunk_index == chunks.size() if logical index is + /// `>= chunked_array.length()`. + inline ChunkLocation ResolveWithChunkIndexHint(int64_t index, + int64_t cached_chunk_index) const { + assert(cached_chunk_index < static_cast(offsets_.size())); + const auto chunk_index = + ResolveChunkIndex(index, cached_chunk_index); + return {chunk_index, index - offsets_[chunk_index]}; + } + + private: + template + inline int64_t ResolveChunkIndex(int64_t index, int64_t cached_chunk) const { + // It is common for algorithms sequentially processing arrays to make consecutive + // accesses at a relatively small distance from each other, hence often falling in the + // same chunk. + // + // This is guaranteed when merging (assuming each side of the merge uses its + // own resolver), and is the most common case in recursive invocations of // partitioning. - if (offsets_.size() <= 1) { - return {0, index}; + const auto num_offsets = static_cast(offsets_.size()); + const int64_t* offsets = offsets_.data(); + if (ARROW_PREDICT_TRUE(index >= offsets[cached_chunk]) && + (cached_chunk + 1 == num_offsets || index < offsets[cached_chunk + 1])) { + return cached_chunk; } - const auto cached_chunk = cached_chunk_.load(); - const bool cache_hit = - (index >= offsets_[cached_chunk] && index < offsets_[cached_chunk + 1]); - if (ARROW_PREDICT_TRUE(cache_hit)) { - return {cached_chunk, index - offsets_[cached_chunk]}; + // lo < hi is guaranteed by `num_offsets = chunks.size() + 1` + const auto chunk_index = Bisect(index, offsets, /*lo=*/0, /*hi=*/num_offsets); + if constexpr (StoreCachedChunk) { + assert(chunk_index < static_cast(offsets_.size())); + cached_chunk_.store(chunk_index, std::memory_order_relaxed); } - auto chunk_index = Bisect(index); - cached_chunk_.store(chunk_index); - return {chunk_index, index - offsets_[chunk_index]}; + return chunk_index; } - protected: - // Find the chunk index corresponding to a value index using binary search - inline int64_t Bisect(const int64_t index) const { - // Like std::upper_bound(), but hand-written as it can help the compiler. - // Search [lo, lo + n) - int64_t lo = 0; - auto n = static_cast(offsets_.size()); - while (n > 1) { + /// \brief Find the index of the chunk that contains the logical index. + /// + /// Any non-negative index is accepted. When `hi=num_offsets`, the largest + /// possible return value is `num_offsets-1` which is equal to + /// `chunks.size()`. The is returned when the logical index is out-of-bounds. + /// + /// \pre index >= 0 + /// \pre lo < hi + /// \pre lo >= 0 && hi <= offsets_.size() + static inline int64_t Bisect(int64_t index, const int64_t* offsets, int64_t lo, + int64_t hi) { + // Similar to std::upper_bound(), but slightly different as our offsets + // array always starts with 0. + auto n = hi - lo; + // First iteration does not need to check for n > 1 + // (lo < hi is guaranteed by the precondition). + assert(n > 1 && "lo < hi is a precondition of Bisect"); + do { const int64_t m = n >> 1; const int64_t mid = lo + m; - if (static_cast(index) >= offsets_[mid]) { + if (index >= offsets[mid]) { lo = mid; n -= m; } else { n = m; } - } + } while (n > 1); return lo; } - - private: - // Collection of starting offsets used for binary search - std::vector offsets_; - - // Tracks the most recently used chunk index to allow fast - // access for consecutive indices corresponding to the same chunk - mutable std::atomic cached_chunk_; }; -} // namespace internal -} // namespace arrow +} // namespace arrow::internal diff --git a/cpp/src/arrow/compute/CMakeLists.txt b/cpp/src/arrow/compute/CMakeLists.txt index 1134e0a98ae45..e14d78ff6e5ca 100644 --- a/cpp/src/arrow/compute/CMakeLists.txt +++ b/cpp/src/arrow/compute/CMakeLists.txt @@ -89,7 +89,8 @@ add_arrow_test(internals_test kernel_test.cc light_array_test.cc registry_test.cc - key_hash_test.cc) + key_hash_test.cc + row/compare_test.cc) add_arrow_compute_test(expression_test SOURCES expression_test.cc) diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index c37e45513d040..5052d8dd66694 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -38,6 +38,7 @@ #include "arrow/compute/row/grouper.h" #include "arrow/record_batch.h" #include "arrow/stl_allocator.h" +#include "arrow/type_traits.h" #include "arrow/util/bit_run_reader.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_writer.h" @@ -441,9 +442,10 @@ struct GroupedCountImpl : public GroupedAggregator { // ---------------------------------------------------------------------- // Sum/Mean/Product implementation -template +template ::Type> struct GroupedReducingAggregator : public GroupedAggregator { - using AccType = typename FindAccumulatorType::Type; + using AccType = AccumulateType; using CType = typename TypeTraits::CType; using InputCType = typename TypeTraits::CType; @@ -483,7 +485,8 @@ struct GroupedReducingAggregator : public GroupedAggregator { Status Merge(GroupedAggregator&& raw_other, const ArrayData& group_id_mapping) override { - auto other = checked_cast*>(&raw_other); + auto other = + checked_cast*>(&raw_other); CType* reduced = reduced_.mutable_data(); int64_t* counts = counts_.mutable_data(); @@ -733,9 +736,18 @@ using GroupedProductFactory = // ---------------------------------------------------------------------- // Mean implementation +template +struct GroupedMeanAccType { + using Type = typename std::conditional::value, DoubleType, + typename FindAccumulatorType::Type>::type; +}; + template -struct GroupedMeanImpl : public GroupedReducingAggregator> { - using Base = GroupedReducingAggregator>; +struct GroupedMeanImpl + : public GroupedReducingAggregator, + typename GroupedMeanAccType::Type> { + using Base = GroupedReducingAggregator, + typename GroupedMeanAccType::Type>; using CType = typename Base::CType; using InputCType = typename Base::InputCType; using MeanType = @@ -746,7 +758,7 @@ struct GroupedMeanImpl : public GroupedReducingAggregator static enable_if_number Reduce(const DataType&, const CType u, const InputCType v) { - return static_cast(to_unsigned(u) + to_unsigned(static_cast(v))); + return static_cast(u) + static_cast(v); } static CType Reduce(const DataType&, const CType u, const CType v) { diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index ad33d7f8951f4..44f5fea79078a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -1286,12 +1286,27 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { auto absolute_value = MakeUnaryArithmeticFunction("abs", absolute_value_doc); AddDecimalUnaryKernels(absolute_value.get()); + + // abs(duration) + for (auto unit : TimeUnit::values()) { + auto exec = ArithmeticExecFromOp(duration(unit)); + DCHECK_OK( + absolute_value->AddKernel({duration(unit)}, OutputType(duration(unit)), exec)); + } + DCHECK_OK(registry->AddFunction(std::move(absolute_value))); // ---------------------------------------------------------------------- auto absolute_value_checked = MakeUnaryArithmeticFunctionNotNull( "abs_checked", absolute_value_checked_doc); AddDecimalUnaryKernels(absolute_value_checked.get()); + // abs_checked(duraton) + for (auto unit : TimeUnit::values()) { + auto exec = + ArithmeticExecFromOp(duration(unit)); + DCHECK_OK(absolute_value_checked->AddKernel({duration(unit)}, + OutputType(duration(unit)), exec)); + } DCHECK_OK(registry->AddFunction(std::move(absolute_value_checked))); // ---------------------------------------------------------------------- @@ -1545,12 +1560,27 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto negate = MakeUnaryArithmeticFunction("negate", negate_doc); AddDecimalUnaryKernels(negate.get()); + + // Add neg(duration) -> duration + for (auto unit : TimeUnit::values()) { + auto exec = ArithmeticExecFromOp(duration(unit)); + DCHECK_OK(negate->AddKernel({duration(unit)}, OutputType(duration(unit)), exec)); + } + DCHECK_OK(registry->AddFunction(std::move(negate))); // ---------------------------------------------------------------------- auto negate_checked = MakeUnarySignedArithmeticFunctionNotNull( "negate_checked", negate_checked_doc); AddDecimalUnaryKernels(negate_checked.get()); + + // Add neg_checked(duration) -> duration + for (auto unit : TimeUnit::values()) { + auto exec = ArithmeticExecFromOp(duration(unit)); + DCHECK_OK( + negate_checked->AddKernel({duration(unit)}, OutputType(duration(unit)), exec)); + } + DCHECK_OK(registry->AddFunction(std::move(negate_checked))); // ---------------------------------------------------------------------- @@ -1581,6 +1611,11 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto sign = MakeUnaryArithmeticFunctionWithFixedIntOutType("sign", sign_doc); + // sign(duration) + for (auto unit : TimeUnit::values()) { + auto exec = ScalarUnary::Exec; + DCHECK_OK(sign->AddKernel({duration(unit)}, int8(), std::move(exec))); + } DCHECK_OK(registry->AddFunction(std::move(sign))); // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index aad648ca275c3..daf8ed76d628d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -22,6 +22,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common_internal.h" +#include "arrow/type.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" @@ -806,6 +807,14 @@ std::shared_ptr MakeScalarMinMax(std::string name, FunctionDoc d kernel.mem_allocation = MemAllocation::type::PREALLOCATE; DCHECK_OK(func->AddKernel(std::move(kernel))); } + for (const auto& ty : DurationTypes()) { + auto exec = GeneratePhysicalNumeric(ty); + ScalarKernel kernel{KernelSignature::Make({ty}, ty, /*is_varargs=*/true), exec, + MinMaxState::Init}; + kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::type::PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(kernel))); + } for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateTypeAgnosticVarBinaryBase(ty); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 48fa780b03104..8f5952b40500a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -1281,7 +1281,7 @@ using CompareNumericBasedTypes = ::testing::Types; using CompareParametricTemporalTypes = - ::testing::Types; + ::testing::Types; using CompareFixedSizeBinaryTypes = ::testing::Types; TYPED_TEST_SUITE(TestVarArgsCompareNumeric, CompareNumericBasedTypes); @@ -2121,6 +2121,11 @@ TEST(TestMaxElementWiseMinElementWise, CommonTemporal) { ScalarFromJSON(date64(), "172800000"), }), ResultWith(ScalarFromJSON(date64(), "86400000"))); + EXPECT_THAT(MinElementWise({ + ScalarFromJSON(duration(TimeUnit::SECOND), "1"), + ScalarFromJSON(duration(TimeUnit::MILLI), "12000"), + }), + ResultWith(ScalarFromJSON(duration(TimeUnit::MILLI), "1000"))); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index 6764845dfca81..8fdc6172aa6d3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -95,7 +95,7 @@ struct FixedSizeBinaryTransformExecBase { ctx->Allocate(output_width * input_nstrings)); uint8_t* output_str = values_buffer->mutable_data(); - const uint8_t* input_data = input.GetValues(1); + const uint8_t* input_data = input.GetValues(1, input.offset * input_width); for (int64_t i = 0; i < input_nstrings; i++) { if (!input.IsNull(i)) { const uint8_t* input_string = input_data + i * input_width; @@ -132,7 +132,8 @@ struct FixedSizeBinaryTransformExecWithState DCHECK_EQ(1, types.size()); const auto& options = State::Get(ctx); const int32_t input_width = types[0].type->byte_width(); - const int32_t output_width = StringTransform::FixedOutputSize(options, input_width); + ARROW_ASSIGN_OR_RAISE(const int32_t output_width, + StringTransform::FixedOutputSize(options, input_width)); return fixed_size_binary(output_width); } }; @@ -2377,7 +2378,8 @@ struct BinaryReplaceSliceTransform : ReplaceStringSliceTransformBase { return output - output_start; } - static int32_t FixedOutputSize(const ReplaceSliceOptions& opts, int32_t input_width) { + static Result FixedOutputSize(const ReplaceSliceOptions& opts, + int32_t input_width) { int32_t before_slice = 0; int32_t after_slice = 0; const int32_t start = static_cast(opts.start); @@ -2436,6 +2438,7 @@ void AddAsciiStringReplaceSlice(FunctionRegistry* registry) { namespace { struct SliceBytesTransform : StringSliceTransformBase { + using StringSliceTransformBase::StringSliceTransformBase; int64_t MaxCodeunits(int64_t ninputs, int64_t input_bytes) override { const SliceOptions& opt = *this->options; if ((opt.start >= 0) != (opt.stop >= 0)) { @@ -2454,22 +2457,15 @@ struct SliceBytesTransform : StringSliceTransformBase { return SliceBackward(input, input_string_bytes, output); } - int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes, - uint8_t* output) { - // Slice in forward order (step > 0) - const SliceOptions& opt = *this->options; - const uint8_t* begin = input; - const uint8_t* end = input + input_string_bytes; - const uint8_t* begin_sliced; - const uint8_t* end_sliced; - - if (!input_string_bytes) { - return 0; - } - // First, compute begin_sliced and end_sliced + static std::pair SliceForwardRange(const SliceOptions& opt, + int64_t input_string_bytes) { + int64_t begin = 0; + int64_t end = input_string_bytes; + int64_t begin_sliced = 0; + int64_t end_sliced = 0; if (opt.start >= 0) { // start counting from the left - begin_sliced = std::min(begin + opt.start, end); + begin_sliced = std::min(opt.start, end); if (opt.stop > opt.start) { // continue counting from begin_sliced const int64_t length = opt.stop - opt.start; @@ -2479,7 +2475,7 @@ struct SliceBytesTransform : StringSliceTransformBase { end_sliced = std::max(end + opt.stop, begin_sliced); } else { // zero length slice - return 0; + return {0, 0}; } } else { // start counting from the right @@ -2491,7 +2487,7 @@ struct SliceBytesTransform : StringSliceTransformBase { // and therefore we also need this if (end_sliced <= begin_sliced) { // zero length slice - return 0; + return {0, 0}; } } else if ((opt.stop < 0) && (opt.stop > opt.start)) { // stop is negative, but larger than start, so we count again from the right @@ -2501,12 +2497,30 @@ struct SliceBytesTransform : StringSliceTransformBase { end_sliced = std::max(end + opt.stop, begin_sliced); } else { // zero length slice - return 0; + return {0, 0}; } } + return {begin_sliced, end_sliced}; + } + + int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes, + uint8_t* output) { + // Slice in forward order (step > 0) + if (!input_string_bytes) { + return 0; + } + + const SliceOptions& opt = *this->options; + auto [begin_index, end_index] = SliceForwardRange(opt, input_string_bytes); + const uint8_t* begin_sliced = input + begin_index; + const uint8_t* end_sliced = input + end_index; + + if (begin_sliced == end_sliced) { + return 0; + } // Second, copy computed slice to output - DCHECK(begin_sliced <= end_sliced); + DCHECK(begin_sliced < end_sliced); if (opt.step == 1) { // fast case, where we simply can finish with a memcpy std::copy(begin_sliced, end_sliced, output); @@ -2525,18 +2539,13 @@ struct SliceBytesTransform : StringSliceTransformBase { return dest - output; } - int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes, - uint8_t* output) { + static std::pair SliceBackwardRange(const SliceOptions& opt, + int64_t input_string_bytes) { // Slice in reverse order (step < 0) - const SliceOptions& opt = *this->options; - const uint8_t* begin = input; - const uint8_t* end = input + input_string_bytes; - const uint8_t* begin_sliced = begin; - const uint8_t* end_sliced = end; - - if (!input_string_bytes) { - return 0; - } + int64_t begin = 0; + int64_t end = input_string_bytes; + int64_t begin_sliced = begin; + int64_t end_sliced = end; if (opt.start >= 0) { // +1 because begin_sliced acts as as the end of a reverse iterator @@ -2555,6 +2564,28 @@ struct SliceBytesTransform : StringSliceTransformBase { } end_sliced--; + if (begin_sliced <= end_sliced) { + // zero length slice + return {0, 0}; + } + + return {begin_sliced, end_sliced}; + } + + int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes, + uint8_t* output) { + if (!input_string_bytes) { + return 0; + } + + const SliceOptions& opt = *this->options; + auto [begin_index, end_index] = SliceBackwardRange(opt, input_string_bytes); + const uint8_t* begin_sliced = input + begin_index; + const uint8_t* end_sliced = input + end_index; + + if (begin_sliced == end_sliced) { + return 0; + } // Copy computed slice to output uint8_t* dest = output; const uint8_t* i = begin_sliced; @@ -2568,6 +2599,22 @@ struct SliceBytesTransform : StringSliceTransformBase { return dest - output; } + + static Result FixedOutputSize(SliceOptions options, int32_t input_width_32) { + auto step = options.step; + if (step == 0) { + return Status::Invalid("Slice step cannot be zero"); + } + if (step > 0) { + // forward slice + auto [begin_index, end_index] = SliceForwardRange(options, input_width_32); + return static_cast((end_index - begin_index + step - 1) / step); + } else { + // backward slice + auto [begin_index, end_index] = SliceBackwardRange(options, input_width_32); + return static_cast((end_index - begin_index + step + 1) / step); + } + } }; template @@ -2594,6 +2641,12 @@ void AddAsciiStringSlice(FunctionRegistry* registry) { DCHECK_OK( func->AddKernel({ty}, ty, std::move(exec), SliceBytesTransform::State::Init)); } + using TransformExec = FixedSizeBinaryTransformExecWithState; + ScalarKernel fsb_kernel({InputType(Type::FIXED_SIZE_BINARY)}, + OutputType(TransformExec::OutputType), TransformExec::Exec, + StringSliceTransformBase::State::Init); + fsb_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(fsb_kernel))); DCHECK_OK(registry->AddFunction(std::move(func))); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_internal.h b/cpp/src/arrow/compute/kernels/scalar_string_internal.h index 7a5d5a7c86e85..6723d11c8deb8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_string_internal.h @@ -250,6 +250,8 @@ struct StringSliceTransformBase : public StringTransformBase { using State = OptionsWrapper; const SliceOptions* options; + StringSliceTransformBase() = default; + explicit StringSliceTransformBase(const SliceOptions& options) : options{&options} {} Status PreExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) override { options = &State::Get(ctx); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 5dec16d89e29c..d7e35d07334ea 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -33,10 +33,10 @@ #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/type.h" +#include "arrow/type_fwd.h" #include "arrow/util/value_parsing.h" -namespace arrow { -namespace compute { +namespace arrow::compute { // interesting utf8 characters for testing (lower case / upper case): // * ῦ / Υ͂ (3 to 4 code units) (Note, we don't support this yet, utf8proc does not use @@ -712,11 +712,140 @@ TEST_F(TestFixedSizeBinaryKernels, BinaryLength) { "[6, null, 6]"); } +TEST_F(TestFixedSizeBinaryKernels, BinarySliceEmpty) { + SliceOptions options{2, 4}; + CheckScalarUnary("binary_slice", ArrayFromJSON(fixed_size_binary(0), R"([""])"), + ArrayFromJSON(fixed_size_binary(0), R"([""])"), &options); + + CheckScalarUnary("binary_slice", + ArrayFromJSON(fixed_size_binary(0), R"(["", null, ""])"), + ArrayFromJSON(fixed_size_binary(0), R"(["", null, ""])"), &options); + + CheckUnary("binary_slice", R"([null, null])", fixed_size_binary(2), R"([null, null])", + &options); +} + +TEST_F(TestFixedSizeBinaryKernels, BinarySliceBasic) { + SliceOptions options{2, 4}; + CheckUnary("binary_slice", R"(["abcdef", null, "foobaz"])", fixed_size_binary(2), + R"(["cd", null, "ob"])", &options); + + SliceOptions options_edgecase_1{-3, 1}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(0), + R"(["", ""])", &options_edgecase_1); + + SliceOptions options_edgecase_2{-10, -3}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz", null])", fixed_size_binary(3), + R"(["abc", "foo", null])", &options_edgecase_2); + + auto input = ArrayFromJSON(this->type(), R"(["foobaz"])"); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr("Function 'binary_slice' cannot be called without options"), + CallFunction("binary_slice", {input})); + + SliceOptions options_invalid{2, 4, 0}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Slice step cannot be zero"), + CallFunction("binary_slice", {input}, &options_invalid)); +} + +TEST_F(TestFixedSizeBinaryKernels, BinarySlicePosPos) { + SliceOptions options_step{1, 5, 2}; + CheckUnary("binary_slice", R"([null, "abcdef", "foobaz"])", fixed_size_binary(2), + R"([null, "bd", "ob"])", &options_step); + + SliceOptions options_step_neg{5, 0, -2}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3), + R"(["fdb", "zbo"])", &options_step_neg); +} + +TEST_F(TestFixedSizeBinaryKernels, BinarySlicePosNeg) { + SliceOptions options{2, -1}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3), + R"(["cde", "oba"])", &options); + + SliceOptions options_step{1, -1, 2}; + CheckUnary("binary_slice", R"(["abcdef", null, "foobaz"])", fixed_size_binary(2), + R"(["bd", null, "ob"])", &options_step); + + SliceOptions options_step_neg{5, -4, -2}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2), + R"(["fd", "zb"])", &options_step_neg); + + options_step_neg.stop = -6; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3), + R"(["fdb", "zbo"])", &options_step_neg); +} + +TEST_F(TestFixedSizeBinaryKernels, BinarySliceNegNeg) { + SliceOptions options{-2, -1}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(1), + R"(["e", "a"])", &options); + + SliceOptions options_step{-4, -1, 2}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz", null, null])", fixed_size_binary(2), + R"(["ce", "oa", null, null])", &options_step); + + SliceOptions options_step_neg{-1, -3, -2}; + CheckUnary("binary_slice", R"([null, "abcdef", null, "foobaz"])", fixed_size_binary(1), + R"([null, "f", null, "z"])", &options_step_neg); + + options_step_neg.stop = -4; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2), + R"(["fd", "zb"])", &options_step_neg); +} + +TEST_F(TestFixedSizeBinaryKernels, BinarySliceNegPos) { + SliceOptions options{-2, 4}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(0), + R"(["", ""])", &options); + + SliceOptions options_step{-4, 5, 2}; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2), + R"(["ce", "oa"])", &options_step); + + SliceOptions options_step_neg{-1, 1, -2}; + CheckUnary("binary_slice", R"([null, "abcdef", "foobaz", null])", fixed_size_binary(2), + R"([null, "fd", "zb", null])", &options_step_neg); + + options_step_neg.stop = 0; + CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3), + R"(["fdb", "zbo"])", &options_step_neg); +} + +TEST_F(TestFixedSizeBinaryKernels, BinarySliceConsistentyWithVarLenBinary) { + std::string source_str = "abcdef"; + for (size_t str_len = 0; str_len < source_str.size(); ++str_len) { + auto input_str = source_str.substr(0, str_len); + auto fixed_input = ArrayFromJSON(fixed_size_binary(static_cast(str_len)), + R"([")" + input_str + R"("])"); + auto varlen_input = ArrayFromJSON(binary(), R"([")" + input_str + R"("])"); + for (auto start = -6; start <= 6; ++start) { + for (auto stop = -6; stop <= 6; ++stop) { + for (auto step = -3; step <= 4; ++step) { + if (step == 0) { + continue; + } + SliceOptions options{start, stop, step}; + auto expected = + CallFunction("binary_slice", {varlen_input}, &options).ValueOrDie(); + auto actual = + CallFunction("binary_slice", {fixed_input}, &options).ValueOrDie(); + actual = Cast(actual, binary()).ValueOrDie(); + ASSERT_OK(actual.make_array()->ValidateFull()); + AssertDatumsEqual(expected, actual); + } + } + } + } +} + TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) { ReplaceSliceOptions options{0, 1, "XX"}; CheckUnary("binary_replace_slice", "[]", fixed_size_binary(7), "[]", &options); - CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(7), - R"([null, "XXbcdef"])", &options); + CheckUnary("binary_replace_slice", R"(["foobaz", null, "abcdef"])", + fixed_size_binary(7), R"(["XXoobaz", null, "XXbcdef"])", &options); ReplaceSliceOptions options_shrink{0, 2, ""}; CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(4), @@ -731,8 +860,8 @@ TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) { R"([null, "abXXef"])", &options_middle); ReplaceSliceOptions options_neg_start{-3, -2, "XX"}; - CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(7), - R"([null, "abcXXef"])", &options_neg_start); + CheckUnary("binary_replace_slice", R"(["foobaz", null, "abcdef"])", + fixed_size_binary(7), R"(["fooXXaz", null, "abcXXef"])", &options_neg_start); ReplaceSliceOptions options_neg_end{2, -2, "XX"}; CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(6), @@ -807,7 +936,7 @@ TEST_F(TestFixedSizeBinaryKernels, CountSubstringIgnoreCase) { offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 1]", &options); MatchSubstringOptions options_empty{"", /*ignore_case=*/true}; - CheckUnary("count_substring", R"([" ", null, "abcABc"])", offset_type(), + CheckUnary("count_substring", R"([" ", null, "abcdef"])", offset_type(), "[7, null, 7]", &options_empty); } @@ -2382,5 +2511,4 @@ TEST(TestStringKernels, UnicodeLibraryAssumptions) { } #endif -} // namespace compute -} // namespace arrow +} // namespace arrow::compute diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc index d4482334285bc..8dac6525fe2e6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc @@ -3665,5 +3665,17 @@ TEST_F(ScalarTemporalTest, TestCeilFloorRoundTemporalDate) { CheckScalarUnary("ceil_temporal", arr_ns, arr_ns, &round_to_2_hours); } +TEST_F(ScalarTemporalTest, DurationUnaryArithmetics) { + auto arr = ArrayFromJSON(duration(TimeUnit::SECOND), "[2, -1, null, 3, 0]"); + CheckScalarUnary("negate", arr, + ArrayFromJSON(duration(TimeUnit::SECOND), "[-2, 1, null, -3, 0]")); + CheckScalarUnary("negate_checked", arr, + ArrayFromJSON(duration(TimeUnit::SECOND), "[-2, 1, null, -3, 0]")); + CheckScalarUnary("abs", arr, + ArrayFromJSON(duration(TimeUnit::SECOND), "[2, 1, null, 3, 0]")); + CheckScalarUnary("abs_checked", arr, + ArrayFromJSON(duration(TimeUnit::SECOND), "[2, 1, null, 3, 0]")); + CheckScalarUnary("sign", arr, ArrayFromJSON(int8(), "[1, -1, null, 1, 0]")); +} } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc index a88ce389360f5..f49e201492c9b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc @@ -1510,7 +1510,7 @@ struct ISOCalendar { for (int i = 0; i < 3; i++) { field_builders.push_back( checked_cast(struct_builder->field_builder(i))); - RETURN_NOT_OK(field_builders[i]->Reserve(1)); + RETURN_NOT_OK(field_builders[i]->Reserve(in.length)); } auto visit_null = [&]() { return struct_builder->AppendNull(); }; std::function visit_value; diff --git a/cpp/src/arrow/compute/kernels/scalar_validity.cc b/cpp/src/arrow/compute/kernels/scalar_validity.cc index 6b1cec0f5ccc6..8505fc4c6e0af 100644 --- a/cpp/src/arrow/compute/kernels/scalar_validity.cc +++ b/cpp/src/arrow/compute/kernels/scalar_validity.cc @@ -169,6 +169,7 @@ std::shared_ptr MakeIsFiniteFunction(std::string name, FunctionD func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec)); DCHECK_OK( func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec)); + DCHECK_OK(func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec)); return func; } @@ -187,7 +188,8 @@ std::shared_ptr MakeIsInfFunction(std::string name, FunctionDoc func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec)); DCHECK_OK( func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec)); - + DCHECK_OK( + func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec)); return func; } @@ -205,6 +207,8 @@ std::shared_ptr MakeIsNanFunction(std::string name, FunctionDoc func->AddKernel({InputType(Type::DECIMAL128)}, boolean(), ConstBoolExec)); DCHECK_OK( func->AddKernel({InputType(Type::DECIMAL256)}, boolean(), ConstBoolExec)); + DCHECK_OK( + func->AddKernel({InputType(Type::DURATION)}, boolean(), ConstBoolExec)); return func; } diff --git a/cpp/src/arrow/compute/kernels/scalar_validity_test.cc b/cpp/src/arrow/compute/kernels/scalar_validity_test.cc index 94d951c838209..d1462838f3be6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_validity_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_validity_test.cc @@ -103,6 +103,9 @@ TEST(TestValidityKernels, IsFinite) { } CheckScalar("is_finite", {std::make_shared(4)}, ArrayFromJSON(boolean(), "[null, null, null, null]")); + CheckScalar("is_finite", + {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")}, + ArrayFromJSON(boolean(), "[true, true, true, null]")); } TEST(TestValidityKernels, IsInf) { @@ -116,6 +119,8 @@ TEST(TestValidityKernels, IsInf) { } CheckScalar("is_inf", {std::make_shared(4)}, ArrayFromJSON(boolean(), "[null, null, null, null]")); + CheckScalar("is_inf", {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")}, + ArrayFromJSON(boolean(), "[false, false, false, null]")); } TEST(TestValidityKernels, IsNan) { @@ -129,6 +134,8 @@ TEST(TestValidityKernels, IsNan) { } CheckScalar("is_nan", {std::make_shared(4)}, ArrayFromJSON(boolean(), "[null, null, null, null]")); + CheckScalar("is_nan", {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")}, + ArrayFromJSON(boolean(), "[false, false, false, null]")); } TEST(TestValidityKernels, IsValidIsNullNullType) { diff --git a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc index 25e30e65a3526..e65d5dbcab1c9 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc @@ -128,6 +128,13 @@ struct TakeBenchmark { Bench(values); } + void FixedSizeBinary() { + const int32_t byte_width = static_cast(state.range(2)); + auto values = rand.FixedSizeBinary(args.size, byte_width, args.null_proportion); + Bench(values); + state.counters["byte_width"] = byte_width; + } + void String() { int32_t string_min_length = 0, string_max_length = 32; auto values = std::static_pointer_cast(rand.String( @@ -149,6 +156,7 @@ struct TakeBenchmark { for (auto _ : state) { ABORT_NOT_OK(Take(values, indices).status()); } + state.SetItemsProcessed(state.iterations() * values->length()); } }; @@ -166,8 +174,7 @@ struct FilterBenchmark { void Int64() { const int64_t array_size = args.size / sizeof(int64_t); - auto values = std::static_pointer_cast>( - rand.Int64(array_size, -100, 100, args.values_null_proportion)); + auto values = rand.Int64(array_size, -100, 100, args.values_null_proportion); Bench(values); } @@ -181,6 +188,15 @@ struct FilterBenchmark { Bench(values); } + void FixedSizeBinary() { + const int32_t byte_width = static_cast(state.range(2)); + const int64_t array_size = args.size / byte_width; + auto values = + rand.FixedSizeBinary(array_size, byte_width, args.values_null_proportion); + Bench(values); + state.counters["byte_width"] = byte_width; + } + void String() { int32_t string_min_length = 0, string_max_length = 32; int32_t string_mean_length = (string_max_length + string_min_length) / 2; @@ -202,6 +218,7 @@ struct FilterBenchmark { for (auto _ : state) { ABORT_NOT_OK(Filter(values, filter).status()); } + state.SetItemsProcessed(state.iterations() * values->length()); } void BenchRecordBatch() { @@ -236,6 +253,7 @@ struct FilterBenchmark { for (auto _ : state) { ABORT_NOT_OK(Filter(batch, filter).status()); } + state.SetItemsProcessed(state.iterations() * num_rows); } }; @@ -255,6 +273,14 @@ static void FilterFSLInt64FilterWithNulls(benchmark::State& state) { FilterBenchmark(state, true).FSLInt64(); } +static void FilterFixedSizeBinaryFilterNoNulls(benchmark::State& state) { + FilterBenchmark(state, false).FixedSizeBinary(); +} + +static void FilterFixedSizeBinaryFilterWithNulls(benchmark::State& state) { + FilterBenchmark(state, true).FixedSizeBinary(); +} + static void FilterStringFilterNoNulls(benchmark::State& state) { FilterBenchmark(state, false).String(); } @@ -283,6 +309,19 @@ static void TakeInt64MonotonicIndices(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true).Int64(); } +static void TakeFixedSizeBinaryRandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, false).FixedSizeBinary(); +} + +static void TakeFixedSizeBinaryRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, true).FixedSizeBinary(); +} + +static void TakeFixedSizeBinaryMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) + .FixedSizeBinary(); +} + static void TakeFSLInt64RandomIndicesNoNulls(benchmark::State& state) { TakeBenchmark(state, false).FSLInt64(); } @@ -315,8 +354,22 @@ void FilterSetArgs(benchmark::internal::Benchmark* bench) { } } +void FilterFSBSetArgs(benchmark::internal::Benchmark* bench) { + for (int64_t size : g_data_sizes) { + for (int i = 0; i < static_cast(g_filter_params.size()); ++i) { + // FixedSizeBinary of primitive sizes (powers of two up to 32) + // have a faster path. + for (int32_t byte_width : {8, 9}) { + bench->Args({static_cast(size), i, byte_width}); + } + } + } +} + BENCHMARK(FilterInt64FilterNoNulls)->Apply(FilterSetArgs); BENCHMARK(FilterInt64FilterWithNulls)->Apply(FilterSetArgs); +BENCHMARK(FilterFixedSizeBinaryFilterNoNulls)->Apply(FilterFSBSetArgs); +BENCHMARK(FilterFixedSizeBinaryFilterWithNulls)->Apply(FilterFSBSetArgs); BENCHMARK(FilterFSLInt64FilterNoNulls)->Apply(FilterSetArgs); BENCHMARK(FilterFSLInt64FilterWithNulls)->Apply(FilterSetArgs); BENCHMARK(FilterStringFilterNoNulls)->Apply(FilterSetArgs); @@ -340,9 +393,24 @@ void TakeSetArgs(benchmark::internal::Benchmark* bench) { } } +void TakeFSBSetArgs(benchmark::internal::Benchmark* bench) { + for (int64_t size : g_data_sizes) { + for (auto nulls : std::vector({1000, 10, 2, 1, 0})) { + // FixedSizeBinary of primitive sizes (powers of two up to 32) + // have a faster path. + for (int32_t byte_width : {8, 9}) { + bench->Args({static_cast(size), nulls, byte_width}); + } + } + } +} + BENCHMARK(TakeInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); BENCHMARK(TakeInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); BENCHMARK(TakeInt64MonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeFixedSizeBinaryRandomIndicesNoNulls)->Apply(TakeFSBSetArgs); +BENCHMARK(TakeFixedSizeBinaryRandomIndicesWithNulls)->Apply(TakeFSBSetArgs); +BENCHMARK(TakeFixedSizeBinaryMonotonicIndices)->Apply(TakeFSBSetArgs); BENCHMARK(TakeFSLInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); BENCHMARK(TakeFSLInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); BENCHMARK(TakeFSLInt64MonotonicIndices)->Apply(TakeSetArgs); diff --git a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc index a25b04ae4fa65..8825d697fdf77 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc @@ -146,36 +146,40 @@ class DropNullCounter { /// \brief The Filter implementation for primitive (fixed-width) types does not /// use the logical Arrow type but rather the physical C type. This way we only -/// generate one take function for each byte width. We use the same -/// implementation here for boolean and fixed-byte-size inputs with some -/// template specialization. -template +/// generate one take function for each byte width. +/// +/// We use compile-time specialization for two variations: +/// - operating on boolean data (using kIsBoolean = true) +/// - operating on fixed-width data of arbitrary width (using kByteWidth = -1), +/// with the actual width only known at runtime +template class PrimitiveFilterImpl { public: - using T = typename std::conditional::value, - uint8_t, typename ArrowType::c_type>::type; - PrimitiveFilterImpl(const ArraySpan& values, const ArraySpan& filter, FilterOptions::NullSelectionBehavior null_selection, ArrayData* out_arr) - : values_is_valid_(values.buffers[0].data), - values_data_(reinterpret_cast(values.buffers[1].data)), + : byte_width_(values.type->byte_width()), + values_is_valid_(values.buffers[0].data), + values_data_(values.buffers[1].data), values_null_count_(values.null_count), values_offset_(values.offset), values_length_(values.length), filter_(filter), null_selection_(null_selection) { - if (values.type->id() != Type::BOOL) { + if constexpr (kByteWidth >= 0 && !kIsBoolean) { + DCHECK_EQ(kByteWidth, byte_width_); + } + if constexpr (!kIsBoolean) { // No offset applied for boolean because it's a bitmap - values_data_ += values.offset; + values_data_ += values.offset * byte_width(); } if (out_arr->buffers[0] != nullptr) { // May be unallocated if neither filter nor values contain nulls out_is_valid_ = out_arr->buffers[0]->mutable_data(); } - out_data_ = reinterpret_cast(out_arr->buffers[1]->mutable_data()); - out_offset_ = out_arr->offset; + out_data_ = out_arr->buffers[1]->mutable_data(); + DCHECK_EQ(out_arr->offset, 0); out_length_ = out_arr->length; out_position_ = 0; } @@ -201,14 +205,11 @@ class PrimitiveFilterImpl { [&](int64_t position, int64_t segment_length, bool filter_valid) { if (filter_valid) { CopyBitmap(values_is_valid_, values_offset_ + position, segment_length, - out_is_valid_, out_offset_ + out_position_); + out_is_valid_, out_position_); WriteValueSegment(position, segment_length); } else { - bit_util::SetBitsTo(out_is_valid_, out_offset_ + out_position_, - segment_length, false); - memset(out_data_ + out_offset_ + out_position_, 0, - segment_length * sizeof(T)); - out_position_ += segment_length; + bit_util::SetBitsTo(out_is_valid_, out_position_, segment_length, false); + WriteNullSegment(segment_length); } return true; }); @@ -218,7 +219,7 @@ class PrimitiveFilterImpl { if (out_is_valid_) { // Set all to valid, so only if nulls are produced by EMIT_NULL, we need // to set out_is_valid[i] to false. - bit_util::SetBitsTo(out_is_valid_, out_offset_, out_length_, true); + bit_util::SetBitsTo(out_is_valid_, 0, out_length_, true); } return VisitPlainxREEFilterOutputSegments( filter_, /*filter_may_have_nulls=*/true, null_selection_, @@ -226,11 +227,8 @@ class PrimitiveFilterImpl { if (filter_valid) { WriteValueSegment(position, segment_length); } else { - bit_util::SetBitsTo(out_is_valid_, out_offset_ + out_position_, - segment_length, false); - memset(out_data_ + out_offset_ + out_position_, 0, - segment_length * sizeof(T)); - out_position_ += segment_length; + bit_util::SetBitsTo(out_is_valid_, out_position_, segment_length, false); + WriteNullSegment(segment_length); } return true; }); @@ -260,13 +258,13 @@ class PrimitiveFilterImpl { values_length_); auto WriteNotNull = [&](int64_t index) { - bit_util::SetBit(out_is_valid_, out_offset_ + out_position_); + bit_util::SetBit(out_is_valid_, out_position_); // Increments out_position_ WriteValue(index); }; auto WriteMaybeNull = [&](int64_t index) { - bit_util::SetBitTo(out_is_valid_, out_offset_ + out_position_, + bit_util::SetBitTo(out_is_valid_, out_position_, bit_util::GetBit(values_is_valid_, values_offset_ + index)); // Increments out_position_ WriteValue(index); @@ -279,15 +277,14 @@ class PrimitiveFilterImpl { BitBlockCount data_block = data_counter.NextWord(); if (filter_block.AllSet() && data_block.AllSet()) { // Fastest path: all values in block are included and not null - bit_util::SetBitsTo(out_is_valid_, out_offset_ + out_position_, - filter_block.length, true); + bit_util::SetBitsTo(out_is_valid_, out_position_, filter_block.length, true); WriteValueSegment(in_position, filter_block.length); in_position += filter_block.length; } else if (filter_block.AllSet()) { // Faster: all values are selected, but some values are null // Batch copy bits from values validity bitmap to output validity bitmap CopyBitmap(values_is_valid_, values_offset_ + in_position, filter_block.length, - out_is_valid_, out_offset_ + out_position_); + out_is_valid_, out_position_); WriteValueSegment(in_position, filter_block.length); in_position += filter_block.length; } else if (filter_block.NoneSet() && null_selection_ == FilterOptions::DROP) { @@ -326,7 +323,7 @@ class PrimitiveFilterImpl { WriteNotNull(in_position); } else if (!is_valid) { // Filter slot is null, so we have a null in the output - bit_util::ClearBit(out_is_valid_, out_offset_ + out_position_); + bit_util::ClearBit(out_is_valid_, out_position_); WriteNull(); } ++in_position; @@ -362,7 +359,7 @@ class PrimitiveFilterImpl { WriteMaybeNull(in_position); } else if (!is_valid) { // Filter slot is null, so we have a null in the output - bit_util::ClearBit(out_is_valid_, out_offset_ + out_position_); + bit_util::ClearBit(out_is_valid_, out_position_); WriteNull(); } ++in_position; @@ -376,54 +373,72 @@ class PrimitiveFilterImpl { // Write the next out_position given the selected in_position for the input // data and advance out_position void WriteValue(int64_t in_position) { - out_data_[out_offset_ + out_position_++] = values_data_[in_position]; + if constexpr (kIsBoolean) { + bit_util::SetBitTo(out_data_, out_position_, + bit_util::GetBit(values_data_, values_offset_ + in_position)); + } else { + memcpy(out_data_ + out_position_ * byte_width(), + values_data_ + in_position * byte_width(), byte_width()); + } + ++out_position_; } void WriteValueSegment(int64_t in_start, int64_t length) { - std::memcpy(out_data_ + out_position_, values_data_ + in_start, length * sizeof(T)); + if constexpr (kIsBoolean) { + CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_, + out_position_); + } else { + memcpy(out_data_ + out_position_ * byte_width(), + values_data_ + in_start * byte_width(), length * byte_width()); + } out_position_ += length; } void WriteNull() { - // Zero the memory - out_data_[out_offset_ + out_position_++] = T{}; + if constexpr (kIsBoolean) { + // Zero the bit + bit_util::ClearBit(out_data_, out_position_); + } else { + // Zero the memory + memset(out_data_ + out_position_ * byte_width(), 0, byte_width()); + } + ++out_position_; + } + + void WriteNullSegment(int64_t length) { + if constexpr (kIsBoolean) { + // Zero the bits + bit_util::SetBitsTo(out_data_, out_position_, length, false); + } else { + // Zero the memory + memset(out_data_ + out_position_ * byte_width(), 0, length * byte_width()); + } + out_position_ += length; + } + + constexpr int32_t byte_width() const { + if constexpr (kByteWidth >= 0) { + return kByteWidth; + } else { + return byte_width_; + } } private: + int32_t byte_width_; const uint8_t* values_is_valid_; - const T* values_data_; + const uint8_t* values_data_; int64_t values_null_count_; int64_t values_offset_; int64_t values_length_; const ArraySpan& filter_; FilterOptions::NullSelectionBehavior null_selection_; uint8_t* out_is_valid_ = NULLPTR; - T* out_data_; - int64_t out_offset_; + uint8_t* out_data_; int64_t out_length_; int64_t out_position_; }; -template <> -inline void PrimitiveFilterImpl::WriteValue(int64_t in_position) { - bit_util::SetBitTo(out_data_, out_offset_ + out_position_++, - bit_util::GetBit(values_data_, values_offset_ + in_position)); -} - -template <> -inline void PrimitiveFilterImpl::WriteValueSegment(int64_t in_start, - int64_t length) { - CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_, - out_offset_ + out_position_); - out_position_ += length; -} - -template <> -inline void PrimitiveFilterImpl::WriteNull() { - // Zero the bit - bit_util::ClearBit(out_data_, out_offset_ + out_position_++); -} - Status PrimitiveFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { const ArraySpan& values = batch[0].array; const ArraySpan& filter = batch[1].array; @@ -459,22 +474,32 @@ Status PrimitiveFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult switch (bit_width) { case 1: - PrimitiveFilterImpl(values, filter, null_selection, out_arr).Exec(); + PrimitiveFilterImpl<1, /*kIsBoolean=*/true>(values, filter, null_selection, out_arr) + .Exec(); break; case 8: - PrimitiveFilterImpl(values, filter, null_selection, out_arr).Exec(); + PrimitiveFilterImpl<1>(values, filter, null_selection, out_arr).Exec(); break; case 16: - PrimitiveFilterImpl(values, filter, null_selection, out_arr).Exec(); + PrimitiveFilterImpl<2>(values, filter, null_selection, out_arr).Exec(); break; case 32: - PrimitiveFilterImpl(values, filter, null_selection, out_arr).Exec(); + PrimitiveFilterImpl<4>(values, filter, null_selection, out_arr).Exec(); break; case 64: - PrimitiveFilterImpl(values, filter, null_selection, out_arr).Exec(); + PrimitiveFilterImpl<8>(values, filter, null_selection, out_arr).Exec(); + break; + case 128: + // For INTERVAL_MONTH_DAY_NANO, DECIMAL128 + PrimitiveFilterImpl<16>(values, filter, null_selection, out_arr).Exec(); + break; + case 256: + // For DECIMAL256 + PrimitiveFilterImpl<32>(values, filter, null_selection, out_arr).Exec(); break; default: - DCHECK(false) << "Invalid values bit width"; + // Non-specializing on byte width + PrimitiveFilterImpl<-1>(values, filter, null_selection, out_arr).Exec(); break; } return Status::OK(); @@ -1050,10 +1075,10 @@ void PopulateFilterKernels(std::vector* out) { {InputType(match::Primitive()), plain_filter, PrimitiveFilterExec}, {InputType(match::BinaryLike()), plain_filter, BinaryFilterExec}, {InputType(match::LargeBinaryLike()), plain_filter, BinaryFilterExec}, - {InputType(Type::FIXED_SIZE_BINARY), plain_filter, FSBFilterExec}, {InputType(null()), plain_filter, NullFilterExec}, - {InputType(Type::DECIMAL128), plain_filter, FSBFilterExec}, - {InputType(Type::DECIMAL256), plain_filter, FSBFilterExec}, + {InputType(Type::FIXED_SIZE_BINARY), plain_filter, PrimitiveFilterExec}, + {InputType(Type::DECIMAL128), plain_filter, PrimitiveFilterExec}, + {InputType(Type::DECIMAL256), plain_filter, PrimitiveFilterExec}, {InputType(Type::DICTIONARY), plain_filter, DictionaryFilterExec}, {InputType(Type::EXTENSION), plain_filter, ExtensionFilterExec}, {InputType(Type::LIST), plain_filter, ListFilterExec}, @@ -1068,10 +1093,10 @@ void PopulateFilterKernels(std::vector* out) { {InputType(match::Primitive()), ree_filter, PrimitiveFilterExec}, {InputType(match::BinaryLike()), ree_filter, BinaryFilterExec}, {InputType(match::LargeBinaryLike()), ree_filter, BinaryFilterExec}, - {InputType(Type::FIXED_SIZE_BINARY), ree_filter, FSBFilterExec}, {InputType(null()), ree_filter, NullFilterExec}, - {InputType(Type::DECIMAL128), ree_filter, FSBFilterExec}, - {InputType(Type::DECIMAL256), ree_filter, FSBFilterExec}, + {InputType(Type::FIXED_SIZE_BINARY), ree_filter, PrimitiveFilterExec}, + {InputType(Type::DECIMAL128), ree_filter, PrimitiveFilterExec}, + {InputType(Type::DECIMAL256), ree_filter, PrimitiveFilterExec}, {InputType(Type::DICTIONARY), ree_filter, DictionaryFilterExec}, {InputType(Type::EXTENSION), ree_filter, ExtensionFilterExec}, {InputType(Type::LIST), ree_filter, ListFilterExec}, diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc index 98eb37e9c5fd2..a0fe2808e3e4e 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc @@ -77,7 +77,8 @@ Status PreallocatePrimitiveArrayData(KernelContext* ctx, int64_t length, int bit if (bit_width == 1) { ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->AllocateBitmap(length)); } else { - ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->Allocate(length * bit_width / 8)); + ARROW_ASSIGN_OR_RAISE(out->buffers[1], + ctx->Allocate(bit_util::BytesForBits(length * bit_width))); } return Status::OK(); } @@ -899,10 +900,6 @@ Status FilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { } // namespace -Status FSBFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return FilterExec(ctx, batch, out); -} - Status ListFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { return FilterExec>(ctx, batch, out); } @@ -946,7 +943,20 @@ Status LargeVarBinaryTakeExec(KernelContext* ctx, const ExecSpan& batch, } Status FSBTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec(ctx, batch, out); + const ArraySpan& values = batch[0].array; + const auto byte_width = values.type->byte_width(); + // Use primitive Take implementation (presumably faster) for some byte widths + switch (byte_width) { + case 1: + case 2: + case 4: + case 8: + case 16: + case 32: + return PrimitiveTakeExec(ctx, batch, out); + default: + return TakeExec(ctx, batch, out); + } } Status ListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h b/cpp/src/arrow/compute/kernels/vector_selection_internal.h index b9eba6ea6631f..95f3e51cd67e3 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.h @@ -70,7 +70,6 @@ void VisitPlainxREEFilterOutputSegments( FilterOptions::NullSelectionBehavior null_selection, const EmitREEFilterSegment& emit_segment); -Status FSBFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status ListFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status LargeListFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status FSLFilterExec(KernelContext*, const ExecSpan&, ExecResult*); @@ -79,6 +78,7 @@ Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status LargeVarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*); +Status PrimitiveTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status FSBTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc index 612de8505d3ab..89b3f7d0d3c58 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc @@ -334,11 +334,15 @@ using TakeState = OptionsWrapper; /// only generate one take function for each byte width. /// /// This function assumes that the indices have been boundschecked. -template +template struct PrimitiveTakeImpl { + static constexpr int kValueWidth = ValueWidthConstant::value; + static void Exec(const ArraySpan& values, const ArraySpan& indices, ArrayData* out_arr) { - const auto* values_data = values.GetValues(1); + DCHECK_EQ(values.type->byte_width(), kValueWidth); + const auto* values_data = + values.GetValues(1, 0) + kValueWidth * values.offset; const uint8_t* values_is_valid = values.buffers[0].data; auto values_offset = values.offset; @@ -346,9 +350,10 @@ struct PrimitiveTakeImpl { const uint8_t* indices_is_valid = indices.buffers[0].data; auto indices_offset = indices.offset; - auto out = out_arr->GetMutableValues(1); + auto out = out_arr->GetMutableValues(1, 0) + kValueWidth * out_arr->offset; auto out_is_valid = out_arr->buffers[0]->mutable_data(); auto out_offset = out_arr->offset; + DCHECK_EQ(out_offset, 0); // If either the values or indices have nulls, we preemptively zero out the // out validity bitmap so that we don't have to use ClearBit in each @@ -357,6 +362,19 @@ struct PrimitiveTakeImpl { bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false); } + auto WriteValue = [&](int64_t position) { + memcpy(out + position * kValueWidth, + values_data + indices_data[position] * kValueWidth, kValueWidth); + }; + + auto WriteZero = [&](int64_t position) { + memset(out + position * kValueWidth, 0, kValueWidth); + }; + + auto WriteZeroSegment = [&](int64_t position, int64_t length) { + memset(out + position * kValueWidth, 0, kValueWidth * length); + }; + OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, indices.length); int64_t position = 0; @@ -370,7 +388,7 @@ struct PrimitiveTakeImpl { // Fastest path: neither values nor index nulls bit_util::SetBitsTo(out_is_valid, out_offset + position, block.length, true); for (int64_t i = 0; i < block.length; ++i) { - out[position] = values_data[indices_data[position]]; + WriteValue(position); ++position; } } else if (block.popcount > 0) { @@ -379,14 +397,14 @@ struct PrimitiveTakeImpl { if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { // index is not null bit_util::SetBit(out_is_valid, out_offset + position); - out[position] = values_data[indices_data[position]]; + WriteValue(position); } else { - out[position] = ValueCType{}; + WriteZero(position); } ++position; } } else { - memset(out + position, 0, sizeof(ValueCType) * block.length); + WriteZeroSegment(position, block.length); position += block.length; } } else { @@ -397,11 +415,11 @@ struct PrimitiveTakeImpl { if (bit_util::GetBit(values_is_valid, values_offset + indices_data[position])) { // value is not null - out[position] = values_data[indices_data[position]]; + WriteValue(position); bit_util::SetBit(out_is_valid, out_offset + position); ++valid_count; } else { - out[position] = ValueCType{}; + WriteZero(position); } ++position; } @@ -414,16 +432,16 @@ struct PrimitiveTakeImpl { bit_util::GetBit(values_is_valid, values_offset + indices_data[position])) { // index is not null && value is not null - out[position] = values_data[indices_data[position]]; + WriteValue(position); bit_util::SetBit(out_is_valid, out_offset + position); ++valid_count; } else { - out[position] = ValueCType{}; + WriteZero(position); } ++position; } } else { - memset(out + position, 0, sizeof(ValueCType) * block.length); + WriteZeroSegment(position, block.length); position += block.length; } } @@ -554,6 +572,8 @@ void TakeIndexDispatch(const ArraySpan& values, const ArraySpan& indices, } } +} // namespace + Status PrimitiveTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { const ArraySpan& values = batch[0].array; const ArraySpan& indices = batch[1].array; @@ -577,24 +597,40 @@ Status PrimitiveTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* TakeIndexDispatch(values, indices, out_arr); break; case 8: - TakeIndexDispatch(values, indices, out_arr); + TakeIndexDispatch>( + values, indices, out_arr); break; case 16: - TakeIndexDispatch(values, indices, out_arr); + TakeIndexDispatch>( + values, indices, out_arr); break; case 32: - TakeIndexDispatch(values, indices, out_arr); + TakeIndexDispatch>( + values, indices, out_arr); break; case 64: - TakeIndexDispatch(values, indices, out_arr); + TakeIndexDispatch>( + values, indices, out_arr); break; - default: - DCHECK(false) << "Invalid values byte width"; + case 128: + // For INTERVAL_MONTH_DAY_NANO, DECIMAL128 + TakeIndexDispatch>( + values, indices, out_arr); + break; + case 256: + // For DECIMAL256 + TakeIndexDispatch>( + values, indices, out_arr); break; + default: + return Status::NotImplemented("Unsupported primitive type for take: ", + *values.type); } return Status::OK(); } +namespace { + // ---------------------------------------------------------------------- // Null take @@ -836,8 +872,8 @@ void PopulateTakeKernels(std::vector* out) { {InputType(match::LargeBinaryLike()), take_indices, LargeVarBinaryTakeExec}, {InputType(Type::FIXED_SIZE_BINARY), take_indices, FSBTakeExec}, {InputType(null()), take_indices, NullTakeExec}, - {InputType(Type::DECIMAL128), take_indices, FSBTakeExec}, - {InputType(Type::DECIMAL256), take_indices, FSBTakeExec}, + {InputType(Type::DECIMAL128), take_indices, PrimitiveTakeExec}, + {InputType(Type::DECIMAL256), take_indices, PrimitiveTakeExec}, {InputType(Type::DICTIONARY), take_indices, DictionaryTake}, {InputType(Type::EXTENSION), take_indices, ExtensionTake}, {InputType(Type::LIST), take_indices, ListTakeExec}, diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index bdf9f5454fdef..ec94b328ea361 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -309,6 +309,33 @@ class TestFilterKernel : public ::testing::Test { AssertFilter(values_array, ree_filter, expected_array); } + void TestNumericBasics(const std::shared_ptr& type) { + ARROW_SCOPED_TRACE("type = ", *type); + AssertFilter(type, "[]", "[]", "[]"); + + AssertFilter(type, "[9]", "[0]", "[]"); + AssertFilter(type, "[9]", "[1]", "[9]"); + AssertFilter(type, "[9]", "[null]", "[null]"); + AssertFilter(type, "[null]", "[0]", "[]"); + AssertFilter(type, "[null]", "[1]", "[null]"); + AssertFilter(type, "[null]", "[null]", "[null]"); + + AssertFilter(type, "[7, 8, 9]", "[0, 1, 0]", "[8]"); + AssertFilter(type, "[7, 8, 9]", "[1, 0, 1]", "[7, 9]"); + AssertFilter(type, "[null, 8, 9]", "[0, 1, 0]", "[8]"); + AssertFilter(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8]"); + AssertFilter(type, "[7, 8, 9]", "[1, null, 1]", "[7, null, 9]"); + + AssertFilter(ArrayFromJSON(type, "[7, 8, 9]"), + ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3, 3), + ArrayFromJSON(type, "[7, 9]")); + + ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"), + ArrayFromJSON(boolean(), "[]"), emit_null_)); + ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"), + ArrayFromJSON(boolean(), "[]"), drop_)); + } + const FilterOptions emit_null_, drop_; }; @@ -342,6 +369,33 @@ void ValidateFilter(const std::shared_ptr& values, /*verbose=*/true); } +TEST_F(TestFilterKernel, Temporal) { + this->TestNumericBasics(time32(TimeUnit::MILLI)); + this->TestNumericBasics(time64(TimeUnit::MICRO)); + this->TestNumericBasics(timestamp(TimeUnit::NANO, "Europe/Paris")); + this->TestNumericBasics(duration(TimeUnit::SECOND)); + this->TestNumericBasics(date32()); + this->AssertFilter(date64(), "[0, 86400000, null]", "[null, 1, 0]", "[null, 86400000]"); +} + +TEST_F(TestFilterKernel, Duration) { + for (auto type : DurationTypes()) { + this->TestNumericBasics(type); + } +} + +TEST_F(TestFilterKernel, Interval) { + this->TestNumericBasics(month_interval()); + + auto type = day_time_interval(); + this->AssertFilter(type, "[[1, -600], [2, 3000], null]", "[null, 1, 0]", + "[null, [2, 3000]]"); + type = month_day_nano_interval(); + this->AssertFilter(type, + "[[1, -2, 34567890123456789], [2, 3, -34567890123456789], null]", + "[null, 1, 0]", "[null, [2, 3, -34567890123456789]]"); +} + class TestFilterKernelWithNull : public TestFilterKernel { protected: void AssertFilter(const std::string& values, const std::string& filter, @@ -401,30 +455,7 @@ class TestFilterKernelWithNumeric : public TestFilterKernel { TYPED_TEST_SUITE(TestFilterKernelWithNumeric, NumericArrowTypes); TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) { - auto type = this->type_singleton(); - this->AssertFilter(type, "[]", "[]", "[]"); - - this->AssertFilter(type, "[9]", "[0]", "[]"); - this->AssertFilter(type, "[9]", "[1]", "[9]"); - this->AssertFilter(type, "[9]", "[null]", "[null]"); - this->AssertFilter(type, "[null]", "[0]", "[]"); - this->AssertFilter(type, "[null]", "[1]", "[null]"); - this->AssertFilter(type, "[null]", "[null]", "[null]"); - - this->AssertFilter(type, "[7, 8, 9]", "[0, 1, 0]", "[8]"); - this->AssertFilter(type, "[7, 8, 9]", "[1, 0, 1]", "[7, 9]"); - this->AssertFilter(type, "[null, 8, 9]", "[0, 1, 0]", "[8]"); - this->AssertFilter(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8]"); - this->AssertFilter(type, "[7, 8, 9]", "[1, null, 1]", "[7, null, 9]"); - - this->AssertFilter(ArrayFromJSON(type, "[7, 8, 9]"), - ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3, 3), - ArrayFromJSON(type, "[7, 9]")); - - ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"), - ArrayFromJSON(boolean(), "[]"), this->emit_null_)); - ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"), - ArrayFromJSON(boolean(), "[]"), this->drop_)); + this->TestNumericBasics(this->type_singleton()); } template @@ -588,7 +619,7 @@ TYPED_TEST(TestFilterKernelWithDecimal, FilterNumeric) { ArrayFromJSON(boolean(), "[]"), this->drop_)); } -TEST(TestFilterKernel, NoValidityBitmapButUnknownNullCount) { +TEST_F(TestFilterKernel, NoValidityBitmapButUnknownNullCount) { auto values = ArrayFromJSON(int32(), "[1, 2, 3, 4]"); auto filter = ArrayFromJSON(boolean(), "[true, true, false, true]"); @@ -1136,6 +1167,20 @@ class TestTakeKernel : public ::testing::Test { TestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values), ArrayFromJSON(int16(), indices)); } + + void TestNumericBasics(const std::shared_ptr& type) { + ARROW_SCOPED_TRACE("type = ", *type); + CheckTake(type, "[7, 8, 9]", "[]", "[]"); + CheckTake(type, "[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]"); + CheckTake(type, "[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]"); + CheckTake(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]"); + CheckTake(type, "[null, 8, 9]", "[]", "[]"); + CheckTake(type, "[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]"); + + std::shared_ptr arr; + ASSERT_RAISES(IndexError, TakeJSON(type, "[7, 8, 9]", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, TakeJSON(type, "[7, 8, 9]", int8(), "[0, -1, 0]", &arr)); + } }; template @@ -1201,6 +1246,34 @@ TEST_F(TestTakeKernel, TakeBoolean) { TakeJSON(boolean(), "[true, false, true]", int8(), "[0, -1, 0]", &arr)); } +TEST_F(TestTakeKernel, Temporal) { + this->TestNumericBasics(time32(TimeUnit::MILLI)); + this->TestNumericBasics(time64(TimeUnit::MICRO)); + this->TestNumericBasics(timestamp(TimeUnit::NANO, "Europe/Paris")); + this->TestNumericBasics(duration(TimeUnit::SECOND)); + this->TestNumericBasics(date32()); + CheckTake(date64(), "[0, 86400000, null]", "[null, 1, 1, 0]", + "[null, 86400000, 86400000, 0]"); +} + +TEST_F(TestTakeKernel, Duration) { + for (auto type : DurationTypes()) { + this->TestNumericBasics(type); + } +} + +TEST_F(TestTakeKernel, Interval) { + this->TestNumericBasics(month_interval()); + + auto type = day_time_interval(); + CheckTake(type, "[[1, -600], [2, 3000], null]", "[0, null, 2, 1]", + "[[1, -600], null, null, [2, 3000]]"); + type = month_day_nano_interval(); + CheckTake(type, "[[1, -2, 34567890123456789], [2, 3, -34567890123456789], null]", + "[0, null, 2, 1]", + "[[1, -2, 34567890123456789], null, null, [2, 3, -34567890123456789]]"); +} + template class TestTakeKernelWithNumeric : public TestTakeKernelTyped { protected: @@ -1216,18 +1289,7 @@ class TestTakeKernelWithNumeric : public TestTakeKernelTyped { TYPED_TEST_SUITE(TestTakeKernelWithNumeric, NumericArrowTypes); TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { - this->AssertTake("[7, 8, 9]", "[]", "[]"); - this->AssertTake("[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]"); - this->AssertTake("[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]"); - this->AssertTake("[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]"); - this->AssertTake("[null, 8, 9]", "[]", "[]"); - this->AssertTake("[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]"); - - std::shared_ptr arr; - ASSERT_RAISES(IndexError, - TakeJSON(this->type_singleton(), "[7, 8, 9]", int8(), "[0, 9, 0]", &arr)); - ASSERT_RAISES(IndexError, TakeJSON(this->type_singleton(), "[7, 8, 9]", int8(), - "[0, -1, 0]", &arr)); + this->TestNumericBasics(this->type_singleton()); } template @@ -1816,6 +1878,7 @@ TEST(TestTakeMetaFunction, ArityChecking) { template struct FilterRandomTest { static void Test(const std::shared_ptr& type) { + ARROW_SCOPED_TRACE("type = ", *type); auto rand = random::RandomArrayGenerator(kRandomSeed); const int64_t length = static_cast(1ULL << 10); for (auto null_probability : {0.0, 0.01, 0.1, 0.999, 1.0}) { @@ -1856,6 +1919,7 @@ void CheckTakeRandom(const std::shared_ptr& values, int64_t indices_lengt template struct TakeRandomTest { static void Test(const std::shared_ptr& type) { + ARROW_SCOPED_TRACE("type = ", *type); auto rand = random::RandomArrayGenerator(kRandomSeed); const int64_t values_length = 64 * 16 + 1; const int64_t indices_length = 64 * 4 + 1; @@ -1897,8 +1961,10 @@ TEST(TestFilter, RandomString) { } TEST(TestFilter, RandomFixedSizeBinary) { - FilterRandomTest<>::Test(fixed_size_binary(0)); - FilterRandomTest<>::Test(fixed_size_binary(16)); + // FixedSizeBinary filter is special-cased for some widths + for (int32_t width : {0, 1, 16, 32, 35}) { + FilterRandomTest<>::Test(fixed_size_binary(width)); + } } TEST(TestTake, PrimitiveRandom) { TestRandomPrimitiveCTypes(); } @@ -1911,8 +1977,10 @@ TEST(TestTake, RandomString) { } TEST(TestTake, RandomFixedSizeBinary) { - TakeRandomTest::Test(fixed_size_binary(0)); - TakeRandomTest::Test(fixed_size_binary(16)); + // FixedSizeBinary take is special-cased for some widths + for (int32_t width : {0, 1, 16, 32, 35}) { + TakeRandomTest::Test(fixed_size_binary(width)); + } } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index e08a2bc10372f..d3914173b65aa 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -24,6 +24,7 @@ namespace arrow { using internal::checked_cast; +using internal::ChunkLocation; namespace compute { namespace internal { @@ -748,11 +749,15 @@ class TableSorter { auto& comparator = comparator_; const auto& first_sort_key = sort_keys_[0]; + ChunkLocation left_loc{0, 0}; + ChunkLocation right_loc{0, 0}; std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices, [&](uint64_t left, uint64_t right) { // First column is either null or nan - const auto left_loc = left_resolver_.Resolve(left); - const auto right_loc = right_resolver_.Resolve(right); + left_loc = + left_resolver_.ResolveWithChunkIndexHint(left, left_loc.chunk_index); + right_loc = right_resolver_.ResolveWithChunkIndexHint( + right, right_loc.chunk_index); auto chunk_left = first_sort_key.GetChunk(left_loc); auto chunk_right = first_sort_key.GetChunk(right_loc); const auto left_is_null = chunk_left.IsNull(); @@ -783,11 +788,15 @@ class TableSorter { // Untyped implementation auto& comparator = comparator_; + ChunkLocation left_loc{0, 0}; + ChunkLocation right_loc{0, 0}; std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices, [&](uint64_t left, uint64_t right) { // First column is always null - const auto left_loc = left_resolver_.Resolve(left); - const auto right_loc = right_resolver_.Resolve(right); + left_loc = + left_resolver_.ResolveWithChunkIndexHint(left, left_loc.chunk_index); + right_loc = right_resolver_.ResolveWithChunkIndexHint( + right, right_loc.chunk_index); return comparator.Compare(left_loc, right_loc, 1); }); // Copy back temp area into main buffer @@ -807,11 +816,15 @@ class TableSorter { auto& comparator = comparator_; const auto& first_sort_key = sort_keys_[0]; + ChunkLocation left_loc{0, 0}; + ChunkLocation right_loc{0, 0}; std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, [&](uint64_t left, uint64_t right) { // Both values are never null nor NaN. - const auto left_loc = left_resolver_.Resolve(left); - const auto right_loc = right_resolver_.Resolve(right); + left_loc = + left_resolver_.ResolveWithChunkIndexHint(left, left_loc.chunk_index); + right_loc = right_resolver_.ResolveWithChunkIndexHint( + right, right_loc.chunk_index); auto chunk_left = first_sort_key.GetChunk(left_loc); auto chunk_right = first_sort_key.GetChunk(right_loc); DCHECK(!chunk_left.IsNull()); diff --git a/cpp/src/arrow/compute/key_hash.cc b/cpp/src/arrow/compute/key_hash.cc index f5867b405ec71..1902b9ce9a88e 100644 --- a/cpp/src/arrow/compute/key_hash.cc +++ b/cpp/src/arrow/compute/key_hash.cc @@ -105,23 +105,23 @@ inline void Hashing32::StripeMask(int i, uint32_t* mask1, uint32_t* mask2, } template -void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys, - uint32_t* hashes) { +void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t key_length, + const uint8_t* keys, uint32_t* hashes) { // Calculate the number of rows that skip the last 16 bytes // uint32_t num_rows_safe = num_rows; - while (num_rows_safe > 0 && (num_rows - num_rows_safe) * length < kStripeSize) { + while (num_rows_safe > 0 && (num_rows - num_rows_safe) * key_length < kStripeSize) { --num_rows_safe; } // Compute masks for the last 16 byte stripe // - uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize); + uint64_t num_stripes = bit_util::CeilDiv(key_length, kStripeSize); uint32_t mask1, mask2, mask3, mask4; - StripeMask(((length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4); + StripeMask(((key_length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4); for (uint32_t i = 0; i < num_rows_safe; ++i) { - const uint8_t* key = keys + static_cast(i) * length; + const uint8_t* key = keys + static_cast(i) * key_length; uint32_t acc1, acc2, acc3, acc4; ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4); ProcessLastStripe(mask1, mask2, mask3, mask4, key + (num_stripes - 1) * kStripeSize, @@ -138,11 +138,11 @@ void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_ uint32_t last_stripe_copy[4]; for (uint32_t i = num_rows_safe; i < num_rows; ++i) { - const uint8_t* key = keys + static_cast(i) * length; + const uint8_t* key = keys + static_cast(i) * key_length; uint32_t acc1, acc2, acc3, acc4; ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4); memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize, - length - (num_stripes - 1) * kStripeSize); + key_length - (num_stripes - 1) * kStripeSize); ProcessLastStripe(mask1, mask2, mask3, mask4, reinterpret_cast(last_stripe_copy), &acc1, &acc2, &acc3, &acc4); @@ -168,15 +168,16 @@ void Hashing32::HashVarLenImp(uint32_t num_rows, const T* offsets, } for (uint32_t i = 0; i < num_rows_safe; ++i) { - uint64_t length = offsets[i + 1] - offsets[i]; + uint64_t key_length = offsets[i + 1] - offsets[i]; // Compute masks for the last 16 byte stripe. // For an empty string set number of stripes to 1 but mask to all zeroes. // - int is_non_empty = length == 0 ? 0 : 1; - uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty); + int is_non_empty = key_length == 0 ? 0 : 1; + uint64_t num_stripes = + bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty); uint32_t mask1, mask2, mask3, mask4; - StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1, + StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1, &mask2, &mask3, &mask4); const uint8_t* key = concatenated_keys + offsets[i]; @@ -198,23 +199,24 @@ void Hashing32::HashVarLenImp(uint32_t num_rows, const T* offsets, uint32_t last_stripe_copy[4]; for (uint32_t i = num_rows_safe; i < num_rows; ++i) { - uint64_t length = offsets[i + 1] - offsets[i]; + uint64_t key_length = offsets[i + 1] - offsets[i]; // Compute masks for the last 16 byte stripe. // For an empty string set number of stripes to 1 but mask to all zeroes. // - int is_non_empty = length == 0 ? 0 : 1; - uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty); + int is_non_empty = key_length == 0 ? 0 : 1; + uint64_t num_stripes = + bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty); uint32_t mask1, mask2, mask3, mask4; - StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1, + StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1, &mask2, &mask3, &mask4); const uint8_t* key = concatenated_keys + offsets[i]; uint32_t acc1, acc2, acc3, acc4; ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4); - if (length > 0) { + if (key_length > 0) { memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize, - length - (num_stripes - 1) * kStripeSize); + key_length - (num_stripes - 1) * kStripeSize); } if (num_stripes > 0) { ProcessLastStripe(mask1, mask2, mask3, mask4, @@ -309,9 +311,9 @@ void Hashing32::HashIntImp(uint32_t num_keys, const T* keys, uint32_t* hashes) { } } -void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key, +void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length, const uint8_t* keys, uint32_t* hashes) { - switch (length_key) { + switch (key_length) { case sizeof(uint8_t): if (combine_hashes) { HashIntImp(num_keys, keys, hashes); @@ -352,27 +354,27 @@ void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_ } } -void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_rows, - uint64_t length, const uint8_t* keys, uint32_t* hashes, - uint32_t* hashes_temp_for_combine) { - if (ARROW_POPCOUNT64(length) == 1 && length <= sizeof(uint64_t)) { - HashInt(combine_hashes, num_rows, length, keys, hashes); +void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_keys, + uint64_t key_length, const uint8_t* keys, uint32_t* hashes, + uint32_t* temp_hashes_for_combine) { + if (ARROW_POPCOUNT64(key_length) == 1 && key_length <= sizeof(uint64_t)) { + HashInt(combine_hashes, num_keys, key_length, keys, hashes); return; } uint32_t num_processed = 0; #if defined(ARROW_HAVE_RUNTIME_AVX2) if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - num_processed = HashFixedLen_avx2(combine_hashes, num_rows, length, keys, hashes, - hashes_temp_for_combine); + num_processed = HashFixedLen_avx2(combine_hashes, num_keys, key_length, keys, hashes, + temp_hashes_for_combine); } #endif if (combine_hashes) { - HashFixedLenImp(num_rows - num_processed, length, keys + length * num_processed, - hashes + num_processed); + HashFixedLenImp(num_keys - num_processed, key_length, + keys + key_length * num_processed, hashes + num_processed); } else { - HashFixedLenImp(num_rows - num_processed, length, - keys + length * num_processed, hashes + num_processed); + HashFixedLenImp(num_keys - num_processed, key_length, + keys + key_length * num_processed, hashes + num_processed); } } @@ -423,13 +425,13 @@ void Hashing32::HashMultiColumn(const std::vector& cols, } if (cols[icol].metadata().is_fixed_length) { - uint32_t col_width = cols[icol].metadata().fixed_length; - if (col_width == 0) { + uint32_t key_length = cols[icol].metadata().fixed_length; + if (key_length == 0) { HashBit(icol > 0, cols[icol].bit_offset(1), batch_size_next, cols[icol].data(1) + first_row / 8, hashes + first_row); } else { - HashFixed(ctx->hardware_flags, icol > 0, batch_size_next, col_width, - cols[icol].data(1) + first_row * col_width, hashes + first_row, + HashFixed(ctx->hardware_flags, icol > 0, batch_size_next, key_length, + cols[icol].data(1) + first_row * key_length, hashes + first_row, hash_temp); } } else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) { @@ -463,8 +465,9 @@ void Hashing32::HashMultiColumn(const std::vector& cols, Status Hashing32::HashBatch(const ExecBatch& key_batch, uint32_t* hashes, std::vector& column_arrays, int64_t hardware_flags, util::TempVectorStack* temp_stack, - int64_t offset, int64_t length) { - RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays)); + int64_t start_rows, int64_t num_rows) { + RETURN_NOT_OK( + ColumnArraysFromExecBatch(key_batch, start_rows, num_rows, &column_arrays)); LightContext ctx; ctx.hardware_flags = hardware_flags; @@ -574,23 +577,23 @@ inline void Hashing64::StripeMask(int i, uint64_t* mask1, uint64_t* mask2, } template -void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys, - uint64_t* hashes) { +void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t key_length, + const uint8_t* keys, uint64_t* hashes) { // Calculate the number of rows that skip the last 32 bytes // uint32_t num_rows_safe = num_rows; - while (num_rows_safe > 0 && (num_rows - num_rows_safe) * length < kStripeSize) { + while (num_rows_safe > 0 && (num_rows - num_rows_safe) * key_length < kStripeSize) { --num_rows_safe; } // Compute masks for the last 32 byte stripe // - uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize); + uint64_t num_stripes = bit_util::CeilDiv(key_length, kStripeSize); uint64_t mask1, mask2, mask3, mask4; - StripeMask(((length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4); + StripeMask(((key_length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4); for (uint32_t i = 0; i < num_rows_safe; ++i) { - const uint8_t* key = keys + static_cast(i) * length; + const uint8_t* key = keys + static_cast(i) * key_length; uint64_t acc1, acc2, acc3, acc4; ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4); ProcessLastStripe(mask1, mask2, mask3, mask4, key + (num_stripes - 1) * kStripeSize, @@ -607,11 +610,11 @@ void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_ uint64_t last_stripe_copy[4]; for (uint32_t i = num_rows_safe; i < num_rows; ++i) { - const uint8_t* key = keys + static_cast(i) * length; + const uint8_t* key = keys + static_cast(i) * key_length; uint64_t acc1, acc2, acc3, acc4; ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4); memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize, - length - (num_stripes - 1) * kStripeSize); + key_length - (num_stripes - 1) * kStripeSize); ProcessLastStripe(mask1, mask2, mask3, mask4, reinterpret_cast(last_stripe_copy), &acc1, &acc2, &acc3, &acc4); @@ -637,15 +640,16 @@ void Hashing64::HashVarLenImp(uint32_t num_rows, const T* offsets, } for (uint32_t i = 0; i < num_rows_safe; ++i) { - uint64_t length = offsets[i + 1] - offsets[i]; + uint64_t key_length = offsets[i + 1] - offsets[i]; // Compute masks for the last 32 byte stripe. // For an empty string set number of stripes to 1 but mask to all zeroes. // - int is_non_empty = length == 0 ? 0 : 1; - uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty); + int is_non_empty = key_length == 0 ? 0 : 1; + uint64_t num_stripes = + bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty); uint64_t mask1, mask2, mask3, mask4; - StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1, + StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1, &mask2, &mask3, &mask4); const uint8_t* key = concatenated_keys + offsets[i]; @@ -667,22 +671,23 @@ void Hashing64::HashVarLenImp(uint32_t num_rows, const T* offsets, uint64_t last_stripe_copy[4]; for (uint32_t i = num_rows_safe; i < num_rows; ++i) { - uint64_t length = offsets[i + 1] - offsets[i]; + uint64_t key_length = offsets[i + 1] - offsets[i]; // Compute masks for the last 32 byte stripe // - int is_non_empty = length == 0 ? 0 : 1; - uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty); + int is_non_empty = key_length == 0 ? 0 : 1; + uint64_t num_stripes = + bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty); uint64_t mask1, mask2, mask3, mask4; - StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1, + StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1, &mask2, &mask3, &mask4); const uint8_t* key = concatenated_keys + offsets[i]; uint64_t acc1, acc2, acc3, acc4; ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4); - if (length > 0) { + if (key_length > 0) { memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize, - length - (num_stripes - 1) * kStripeSize); + key_length - (num_stripes - 1) * kStripeSize); } if (num_stripes > 0) { ProcessLastStripe(mask1, mask2, mask3, mask4, @@ -759,9 +764,9 @@ void Hashing64::HashIntImp(uint32_t num_keys, const T* keys, uint64_t* hashes) { } } -void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key, +void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length, const uint8_t* keys, uint64_t* hashes) { - switch (length_key) { + switch (key_length) { case sizeof(uint8_t): if (combine_hashes) { HashIntImp(num_keys, keys, hashes); @@ -802,17 +807,17 @@ void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_ } } -void Hashing64::HashFixed(bool combine_hashes, uint32_t num_rows, uint64_t length, +void Hashing64::HashFixed(bool combine_hashes, uint32_t num_keys, uint64_t key_length, const uint8_t* keys, uint64_t* hashes) { - if (ARROW_POPCOUNT64(length) == 1 && length <= sizeof(uint64_t)) { - HashInt(combine_hashes, num_rows, length, keys, hashes); + if (ARROW_POPCOUNT64(key_length) == 1 && key_length <= sizeof(uint64_t)) { + HashInt(combine_hashes, num_keys, key_length, keys, hashes); return; } if (combine_hashes) { - HashFixedLenImp(num_rows, length, keys, hashes); + HashFixedLenImp(num_keys, key_length, keys, hashes); } else { - HashFixedLenImp(num_rows, length, keys, hashes); + HashFixedLenImp(num_keys, key_length, keys, hashes); } } @@ -860,13 +865,13 @@ void Hashing64::HashMultiColumn(const std::vector& cols, } if (cols[icol].metadata().is_fixed_length) { - uint64_t col_width = cols[icol].metadata().fixed_length; - if (col_width == 0) { + uint64_t key_length = cols[icol].metadata().fixed_length; + if (key_length == 0) { HashBit(icol > 0, cols[icol].bit_offset(1), batch_size_next, cols[icol].data(1) + first_row / 8, hashes + first_row); } else { - HashFixed(icol > 0, batch_size_next, col_width, - cols[icol].data(1) + first_row * col_width, hashes + first_row); + HashFixed(icol > 0, batch_size_next, key_length, + cols[icol].data(1) + first_row * key_length, hashes + first_row); } } else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) { HashVarLen(icol > 0, batch_size_next, cols[icol].offsets() + first_row, @@ -897,8 +902,9 @@ void Hashing64::HashMultiColumn(const std::vector& cols, Status Hashing64::HashBatch(const ExecBatch& key_batch, uint64_t* hashes, std::vector& column_arrays, int64_t hardware_flags, util::TempVectorStack* temp_stack, - int64_t offset, int64_t length) { - RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays)); + int64_t start_row, int64_t num_rows) { + RETURN_NOT_OK( + ColumnArraysFromExecBatch(key_batch, start_row, num_rows, &column_arrays)); LightContext ctx; ctx.hardware_flags = hardware_flags; diff --git a/cpp/src/arrow/compute/key_hash.h b/cpp/src/arrow/compute/key_hash.h index b193716c9bdfd..1173df5ed103e 100644 --- a/cpp/src/arrow/compute/key_hash.h +++ b/cpp/src/arrow/compute/key_hash.h @@ -51,10 +51,10 @@ class ARROW_EXPORT Hashing32 { static Status HashBatch(const ExecBatch& key_batch, uint32_t* hashes, std::vector& column_arrays, int64_t hardware_flags, util::TempVectorStack* temp_stack, - int64_t offset, int64_t length); + int64_t start_row, int64_t num_rows); static void HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_keys, - uint64_t length_key, const uint8_t* keys, uint32_t* hashes, + uint64_t key_length, const uint8_t* keys, uint32_t* hashes, uint32_t* temp_hashes_for_combine); private: @@ -100,7 +100,7 @@ class ARROW_EXPORT Hashing32 { static inline void StripeMask(int i, uint32_t* mask1, uint32_t* mask2, uint32_t* mask3, uint32_t* mask4); template - static void HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys, + static void HashFixedLenImp(uint32_t num_rows, uint64_t key_length, const uint8_t* keys, uint32_t* hashes); template static void HashVarLenImp(uint32_t num_rows, const T* offsets, @@ -112,7 +112,7 @@ class ARROW_EXPORT Hashing32 { const uint8_t* keys, uint32_t* hashes); template static void HashIntImp(uint32_t num_keys, const T* keys, uint32_t* hashes); - static void HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key, + static void HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length, const uint8_t* keys, uint32_t* hashes); #if defined(ARROW_HAVE_RUNTIME_AVX2) @@ -129,11 +129,11 @@ class ARROW_EXPORT Hashing32 { __m256i mask_last_stripe, const uint8_t* keys, int64_t offset_A, int64_t offset_B); template - static uint32_t HashFixedLenImp_avx2(uint32_t num_rows, uint64_t length, + static uint32_t HashFixedLenImp_avx2(uint32_t num_rows, uint64_t key_length, const uint8_t* keys, uint32_t* hashes, uint32_t* hashes_temp_for_combine); static uint32_t HashFixedLen_avx2(bool combine_hashes, uint32_t num_rows, - uint64_t length, const uint8_t* keys, + uint64_t key_length, const uint8_t* keys, uint32_t* hashes, uint32_t* hashes_temp_for_combine); template static uint32_t HashVarLenImp_avx2(uint32_t num_rows, const T* offsets, @@ -164,9 +164,9 @@ class ARROW_EXPORT Hashing64 { static Status HashBatch(const ExecBatch& key_batch, uint64_t* hashes, std::vector& column_arrays, int64_t hardware_flags, util::TempVectorStack* temp_stack, - int64_t offset, int64_t length); + int64_t start_row, int64_t num_rows); - static void HashFixed(bool combine_hashes, uint32_t num_keys, uint64_t length_key, + static void HashFixed(bool combine_hashes, uint32_t num_keys, uint64_t key_length, const uint8_t* keys, uint64_t* hashes); private: @@ -203,7 +203,7 @@ class ARROW_EXPORT Hashing64 { static inline void StripeMask(int i, uint64_t* mask1, uint64_t* mask2, uint64_t* mask3, uint64_t* mask4); template - static void HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys, + static void HashFixedLenImp(uint32_t num_rows, uint64_t key_length, const uint8_t* keys, uint64_t* hashes); template static void HashVarLenImp(uint32_t num_rows, const T* offsets, @@ -211,11 +211,11 @@ class ARROW_EXPORT Hashing64 { template static void HashBitImp(int64_t bit_offset, uint32_t num_keys, const uint8_t* keys, uint64_t* hashes); - static void HashBit(bool T_COMBINE_HASHES, int64_t bit_offset, uint32_t num_keys, + static void HashBit(bool combine_hashes, int64_t bit_offset, uint32_t num_keys, const uint8_t* keys, uint64_t* hashes); template static void HashIntImp(uint32_t num_keys, const T* keys, uint64_t* hashes); - static void HashInt(bool T_COMBINE_HASHES, uint32_t num_keys, uint64_t length_key, + static void HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length, const uint8_t* keys, uint64_t* hashes); }; diff --git a/cpp/src/arrow/compute/key_hash_avx2.cc b/cpp/src/arrow/compute/key_hash_avx2.cc index 1b444b576784f..aec2800c647d7 100644 --- a/cpp/src/arrow/compute/key_hash_avx2.cc +++ b/cpp/src/arrow/compute/key_hash_avx2.cc @@ -190,7 +190,7 @@ uint32_t Hashing32::HashFixedLenImp_avx2(uint32_t num_rows, uint64_t length, // Do not process rows that could read past the end of the buffer using 16 // byte loads. Round down number of rows to process to multiple of 2. // - uint64_t num_rows_to_skip = bit_util::CeilDiv(length, kStripeSize); + uint64_t num_rows_to_skip = bit_util::CeilDiv(kStripeSize, length); uint32_t num_rows_to_process = (num_rows_to_skip > num_rows) ? 0 diff --git a/cpp/src/arrow/compute/key_hash_test.cc b/cpp/src/arrow/compute/key_hash_test.cc index 3e6d41525cf44..c998df7169c4a 100644 --- a/cpp/src/arrow/compute/key_hash_test.cc +++ b/cpp/src/arrow/compute/key_hash_test.cc @@ -252,5 +252,64 @@ TEST(VectorHash, BasicString) { RunTestVectorHash(); } TEST(VectorHash, BasicLargeString) { RunTestVectorHash(); } +void HashFixedLengthFrom(int key_length, int num_rows, int start_row) { + int num_rows_to_hash = num_rows - start_row; + auto num_bytes_aligned = arrow::bit_util::RoundUpToMultipleOf64(key_length * num_rows); + + const auto hardware_flags_for_testing = HardwareFlagsForTesting(); + ASSERT_GT(hardware_flags_for_testing.size(), 0); + + std::vector> hashes32(hardware_flags_for_testing.size()); + std::vector> hashes64(hardware_flags_for_testing.size()); + for (auto& h : hashes32) { + h.resize(num_rows_to_hash); + } + for (auto& h : hashes64) { + h.resize(num_rows_to_hash); + } + + FixedSizeBinaryBuilder keys_builder(fixed_size_binary(key_length)); + for (int j = 0; j < num_rows; ++j) { + ASSERT_OK(keys_builder.Append(std::string(key_length, 42))); + } + ASSERT_OK_AND_ASSIGN(auto keys, keys_builder.Finish()); + // Make sure the buffer is aligned as expected. + ASSERT_EQ(keys->data()->buffers[1]->capacity(), num_bytes_aligned); + + constexpr int mini_batch_size = 1024; + std::vector temp_buffer; + temp_buffer.resize(mini_batch_size * 4); + + for (int i = 0; i < static_cast(hardware_flags_for_testing.size()); ++i) { + const auto hardware_flags = hardware_flags_for_testing[i]; + Hashing32::HashFixed(hardware_flags, + /*combine_hashes=*/false, num_rows_to_hash, key_length, + keys->data()->GetValues(1) + start_row * key_length, + hashes32[i].data(), temp_buffer.data()); + Hashing64::HashFixed( + /*combine_hashes=*/false, num_rows_to_hash, key_length, + keys->data()->GetValues(1) + start_row * key_length, hashes64[i].data()); + } + + // Verify that all implementations (scalar, SIMD) give the same hashes. + for (int i = 1; i < static_cast(hardware_flags_for_testing.size()); ++i) { + for (int j = 0; j < num_rows_to_hash; ++j) { + ASSERT_EQ(hashes32[i][j], hashes32[0][j]) + << "scalar and simd approaches yielded different 32-bit hashes"; + ASSERT_EQ(hashes64[i][j], hashes64[0][j]) + << "scalar and simd approaches yielded different 64-bit hashes"; + } + } +} + +// Some carefully chosen cases that may cause troubles like GH-39778. +TEST(VectorHash, FixedLengthTailByteSafety) { + // Tow cases of key_length < stripe (16-byte). + HashFixedLengthFrom(/*key_length=*/3, /*num_rows=*/1450, /*start_row=*/1447); + HashFixedLengthFrom(/*key_length=*/5, /*num_rows=*/883, /*start_row=*/858); + // Case of key_length > stripe (16-byte). + HashFixedLengthFrom(/*key_length=*/19, /*num_rows=*/64, /*start_row=*/63); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/light_array.cc b/cpp/src/arrow/compute/light_array.cc index 73ea01a03a8fa..b225e04b05cea 100644 --- a/cpp/src/arrow/compute/light_array.cc +++ b/cpp/src/arrow/compute/light_array.cc @@ -20,6 +20,8 @@ #include #include "arrow/util/bitmap_ops.h" +#include "arrow/util/int_util_overflow.h" +#include "arrow/util/macros.h" namespace arrow { namespace compute { @@ -325,11 +327,10 @@ Status ResizableArrayData::ResizeVaryingLengthBuffer() { column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie(); if (!column_metadata.is_fixed_length) { - int min_new_size = static_cast(reinterpret_cast( - buffers_[kFixedLengthBuffer]->data())[num_rows_]); + int64_t min_new_size = buffers_[kFixedLengthBuffer]->data_as()[num_rows_]; ARROW_DCHECK(var_len_buf_size_ > 0); if (var_len_buf_size_ < min_new_size) { - int new_size = var_len_buf_size_; + int64_t new_size = var_len_buf_size_; while (new_size < min_new_size) { new_size *= 2; } @@ -383,27 +384,22 @@ int ExecBatchBuilder::NumRowsToSkip(const std::shared_ptr& column, KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(column->type).ValueOrDie(); + ARROW_DCHECK(!column_metadata.is_fixed_length || column_metadata.fixed_length > 0); int num_rows_left = num_rows; int num_bytes_skipped = 0; while (num_rows_left > 0 && num_bytes_skipped < num_tail_bytes_to_skip) { + --num_rows_left; + int row_id_removed = row_ids[num_rows_left]; if (column_metadata.is_fixed_length) { - if (column_metadata.fixed_length == 0) { - num_rows_left = std::max(num_rows_left, 8) - 8; - ++num_bytes_skipped; - } else { - --num_rows_left; - num_bytes_skipped += column_metadata.fixed_length; - } + num_bytes_skipped += column_metadata.fixed_length; } else { - --num_rows_left; - int row_id_removed = row_ids[num_rows_left]; const int32_t* offsets = column->GetValues(1); num_bytes_skipped += offsets[row_id_removed + 1] - offsets[row_id_removed]; - // Skip consecutive rows with the same id - while (num_rows_left > 0 && row_id_removed == row_ids[num_rows_left - 1]) { - --num_rows_left; - } + } + // Skip consecutive rows with the same id + while (num_rows_left > 0 && row_id_removed == row_ids[num_rows_left - 1]) { + --num_rows_left; } } @@ -470,12 +466,11 @@ void ExecBatchBuilder::Visit(const std::shared_ptr& column, int num_r if (!metadata.is_fixed_length) { const uint8_t* ptr_base = column->buffers[2]->data(); - const uint32_t* offsets = - reinterpret_cast(column->buffers[1]->data()) + column->offset; + const int32_t* offsets = column->GetValues(1); for (int i = 0; i < num_rows; ++i) { uint16_t row_id = row_ids[i]; const uint8_t* field_ptr = ptr_base + offsets[row_id]; - uint32_t field_length = offsets[row_id + 1] - offsets[row_id]; + int32_t field_length = offsets[row_id + 1] - offsets[row_id]; process_value_fn(i, field_ptr, field_length); } } else { @@ -485,7 +480,7 @@ void ExecBatchBuilder::Visit(const std::shared_ptr& column, int num_r const uint8_t* field_ptr = column->buffers[1]->data() + (column->offset + row_id) * static_cast(metadata.fixed_length); - process_value_fn(i, field_ptr, metadata.fixed_length); + process_value_fn(i, field_ptr, static_cast(metadata.fixed_length)); } } } @@ -516,14 +511,14 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source break; case 1: Visit(source, num_rows_to_append, row_ids, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { target->mutable_data(1)[num_rows_before + i] = *ptr; }); break; case 2: Visit( source, num_rows_to_append, row_ids, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { reinterpret_cast(target->mutable_data(1))[num_rows_before + i] = *reinterpret_cast(ptr); }); @@ -531,7 +526,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source case 4: Visit( source, num_rows_to_append, row_ids, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { reinterpret_cast(target->mutable_data(1))[num_rows_before + i] = *reinterpret_cast(ptr); }); @@ -539,7 +534,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source case 8: Visit( source, num_rows_to_append, row_ids, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { reinterpret_cast(target->mutable_data(1))[num_rows_before + i] = *reinterpret_cast(ptr); }); @@ -549,7 +544,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source num_rows_to_append - NumRowsToSkip(source, num_rows_to_append, row_ids, sizeof(uint64_t)); Visit(source, num_rows_to_process, row_ids, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { uint64_t* dst = reinterpret_cast( target->mutable_data(1) + static_cast(num_bytes) * (num_rows_before + i)); @@ -563,7 +558,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source if (num_rows_to_append > num_rows_to_process) { Visit(source, num_rows_to_append - num_rows_to_process, row_ids + num_rows_to_process, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { uint64_t* dst = reinterpret_cast( target->mutable_data(1) + static_cast(num_bytes) * @@ -580,16 +575,23 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source // Step 1: calculate target offsets // - uint32_t* offsets = reinterpret_cast(target->mutable_data(1)); - uint32_t sum = num_rows_before == 0 ? 0 : offsets[num_rows_before]; + int32_t* offsets = reinterpret_cast(target->mutable_data(1)); + int32_t sum = num_rows_before == 0 ? 0 : offsets[num_rows_before]; Visit(source, num_rows_to_append, row_ids, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { offsets[num_rows_before + i] = num_bytes; }); for (int i = 0; i < num_rows_to_append; ++i) { - uint32_t length = offsets[num_rows_before + i]; + int32_t length = offsets[num_rows_before + i]; offsets[num_rows_before + i] = sum; - sum += length; + int32_t new_sum_maybe_overflow = 0; + if (ARROW_PREDICT_FALSE( + arrow::internal::AddWithOverflow(sum, length, &new_sum_maybe_overflow))) { + return Status::Invalid("Overflow detected in ExecBatchBuilder when appending ", + num_rows_before + i + 1, "-th element of length ", length, + " bytes to current length ", sum, " bytes"); + } + sum = new_sum_maybe_overflow; } offsets[num_rows_before + num_rows_to_append] = sum; @@ -603,7 +605,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source num_rows_to_append - NumRowsToSkip(source, num_rows_to_append, row_ids, sizeof(uint64_t)); Visit(source, num_rows_to_process, row_ids, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { uint64_t* dst = reinterpret_cast(target->mutable_data(2) + offsets[num_rows_before + i]); const uint64_t* src = reinterpret_cast(ptr); @@ -613,7 +615,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr& source } }); Visit(source, num_rows_to_append - num_rows_to_process, row_ids + num_rows_to_process, - [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + [&](int i, const uint8_t* ptr, int32_t num_bytes) { uint64_t* dst = reinterpret_cast( target->mutable_data(2) + offsets[num_rows_before + num_rows_to_process + i]); diff --git a/cpp/src/arrow/compute/light_array.h b/cpp/src/arrow/compute/light_array.h index 84aa86d64bb62..67de71bf56c92 100644 --- a/cpp/src/arrow/compute/light_array.h +++ b/cpp/src/arrow/compute/light_array.h @@ -353,7 +353,7 @@ class ARROW_EXPORT ResizableArrayData { MemoryPool* pool_; int num_rows_; int num_rows_allocated_; - int var_len_buf_size_; + int64_t var_len_buf_size_; static constexpr int kMaxBuffers = 3; std::shared_ptr buffers_[kMaxBuffers]; }; diff --git a/cpp/src/arrow/compute/light_array_test.cc b/cpp/src/arrow/compute/light_array_test.cc index 3ceba43604b28..ecc5f3ad37931 100644 --- a/cpp/src/arrow/compute/light_array_test.cc +++ b/cpp/src/arrow/compute/light_array_test.cc @@ -407,6 +407,70 @@ TEST(ExecBatchBuilder, AppendValuesBeyondLimit) { ASSERT_EQ(0, pool->bytes_allocated()); } +TEST(ExecBatchBuilder, AppendVarLengthBeyondLimit) { + // GH-39332: check appending variable-length data past 2GB. + if constexpr (sizeof(void*) == 4) { + GTEST_SKIP() << "Test only works on 64-bit platforms"; + } + + std::unique_ptr owned_pool = MemoryPool::CreateDefault(); + MemoryPool* pool = owned_pool.get(); + constexpr auto eight_mb = 8 * 1024 * 1024; + constexpr auto eight_mb_minus_one = eight_mb - 1; + // String of size 8mb to repetitively fill the heading multiple of 8mbs of an array + // of int32_max bytes. + std::string str_8mb(eight_mb, 'a'); + // String of size (8mb - 1) to be the last element of an array of int32_max bytes. + std::string str_8mb_minus_1(eight_mb_minus_one, 'b'); + std::shared_ptr values_8mb = ConstantArrayGenerator::String(1, str_8mb); + std::shared_ptr values_8mb_minus_1 = + ConstantArrayGenerator::String(1, str_8mb_minus_1); + + ExecBatch batch_8mb({values_8mb}, 1); + ExecBatch batch_8mb_minus_1({values_8mb_minus_1}, 1); + + auto num_rows = std::numeric_limits::max() / eight_mb; + std::vector body_row_ids(num_rows, 0); + std::vector tail_row_id(1, 0); + + { + // Building an array of (int32_max + 1) = (8mb * num_rows + 8mb) bytes should raise an + // error of overflow. + ExecBatchBuilder builder; + ASSERT_OK(builder.AppendSelected(pool, batch_8mb, num_rows, body_row_ids.data(), + /*num_cols=*/1)); + std::stringstream ss; + ss << "Invalid: Overflow detected in ExecBatchBuilder when appending " << num_rows + 1 + << "-th element of length " << eight_mb << " bytes to current length " + << eight_mb * num_rows << " bytes"; + ASSERT_RAISES_WITH_MESSAGE( + Invalid, ss.str(), + builder.AppendSelected(pool, batch_8mb, 1, tail_row_id.data(), + /*num_cols=*/1)); + } + + { + // Building an array of int32_max = (8mb * num_rows + 8mb - 1) bytes should succeed. + ExecBatchBuilder builder; + ASSERT_OK(builder.AppendSelected(pool, batch_8mb, num_rows, body_row_ids.data(), + /*num_cols=*/1)); + ASSERT_OK(builder.AppendSelected(pool, batch_8mb_minus_1, 1, tail_row_id.data(), + /*num_cols=*/1)); + ExecBatch built = builder.Flush(); + auto datum = built[0]; + ASSERT_TRUE(datum.is_array()); + auto array = datum.array_as(); + ASSERT_EQ(array->length(), num_rows + 1); + for (int i = 0; i < num_rows; ++i) { + ASSERT_EQ(array->GetString(i), str_8mb); + } + ASSERT_EQ(array->GetString(num_rows), str_8mb_minus_1); + ASSERT_NE(0, pool->bytes_allocated()); + } + + ASSERT_EQ(0, pool->bytes_allocated()); +} + TEST(KeyColumnArray, FromExecBatch) { ExecBatch batch = JSONToExecBatch({int64(), boolean()}, "[[1, true], [2, false], [null, null]]"); @@ -474,15 +538,18 @@ TEST(ExecBatchBuilder, AppendBatchesSomeRows) { TEST(ExecBatchBuilder, AppendBatchDupRows) { std::unique_ptr owned_pool = MemoryPool::CreateDefault(); MemoryPool* pool = owned_pool.get(); + // Case of cross-word copying for the last row, which may exceed the buffer boundary. - // This is a simplified case of GH-32570 + // { + // This is a simplified case of GH-32570 // 64-byte data fully occupying one minimal 64-byte aligned memory region. - ExecBatch batch_string = JSONToExecBatch({binary()}, R"([["123456789ABCDEF0"], - ["123456789ABCDEF0"], - ["123456789ABCDEF0"], - ["ABCDEF0"], - ["123456789"]])"); // 9-byte tail row, larger than a word. + ExecBatch batch_string = JSONToExecBatch({binary()}, R"([ + ["123456789ABCDEF0"], + ["123456789ABCDEF0"], + ["123456789ABCDEF0"], + ["ABCDEF0"], + ["123456789"]])"); // 9-byte tail row, larger than a word. ASSERT_EQ(batch_string[0].array()->buffers[1]->capacity(), 64); ASSERT_EQ(batch_string[0].array()->buffers[2]->capacity(), 64); ExecBatchBuilder builder; @@ -494,6 +561,66 @@ TEST(ExecBatchBuilder, AppendBatchDupRows) { ASSERT_EQ(batch_string_appended, built); ASSERT_NE(0, pool->bytes_allocated()); } + + { + // This is a simplified case of GH-39583, using fsb(3) type. + // 63-byte data occupying almost one minimal 64-byte aligned memory region. + ExecBatch batch_fsb = JSONToExecBatch({fixed_size_binary(3)}, R"([ + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["000"], + ["123"]])"); // 3-byte tail row, not aligned to a word. + ASSERT_EQ(batch_fsb[0].array()->buffers[1]->capacity(), 64); + ExecBatchBuilder builder; + uint16_t row_ids[4] = {20, 20, 20, + 20}; // Get the last row 4 times, 3 to skip a word. + ASSERT_OK(builder.AppendSelected(pool, batch_fsb, 4, row_ids, /*num_cols=*/1)); + ExecBatch built = builder.Flush(); + ExecBatch batch_fsb_appended = JSONToExecBatch( + {fixed_size_binary(3)}, R"([["123"], ["123"], ["123"], ["123"]])"); + ASSERT_EQ(batch_fsb_appended, built); + ASSERT_NE(0, pool->bytes_allocated()); + } + + { + // This is a simplified case of GH-39583, using fsb(9) type. + // 63-byte data occupying almost one minimal 64-byte aligned memory region. + ExecBatch batch_fsb = JSONToExecBatch({fixed_size_binary(9)}, R"([ + ["000000000"], + ["000000000"], + ["000000000"], + ["000000000"], + ["000000000"], + ["000000000"], + ["123456789"]])"); // 9-byte tail row, not aligned to a word. + ASSERT_EQ(batch_fsb[0].array()->buffers[1]->capacity(), 64); + ExecBatchBuilder builder; + uint16_t row_ids[2] = {6, 6}; // Get the last row 2 times, 1 to skip a word. + ASSERT_OK(builder.AppendSelected(pool, batch_fsb, 2, row_ids, /*num_cols=*/1)); + ExecBatch built = builder.Flush(); + ExecBatch batch_fsb_appended = + JSONToExecBatch({fixed_size_binary(9)}, R"([["123456789"], ["123456789"]])"); + ASSERT_EQ(batch_fsb_appended, built); + ASSERT_NE(0, pool->bytes_allocated()); + } + ASSERT_EQ(0, pool->bytes_allocated()); } diff --git a/cpp/src/arrow/compute/row/compare_internal.cc b/cpp/src/arrow/compute/row/compare_internal.cc index 7c402e7a2384d..078a8287c71c0 100644 --- a/cpp/src/arrow/compute/row/compare_internal.cc +++ b/cpp/src/arrow/compute/row/compare_internal.cc @@ -208,8 +208,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, // Non-zero length guarantees no underflow int32_t num_loops_less_one = static_cast(bit_util::CeilDiv(length, 8)) - 1; - - uint64_t tail_mask = ~0ULL >> (64 - 8 * (length - num_loops_less_one * 8)); + int32_t num_tail_bytes = length - num_loops_less_one * 8; const uint64_t* key_left_ptr = reinterpret_cast(left_base + irow_left * length); @@ -224,9 +223,11 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, uint64_t key_right = key_right_ptr[i]; result_or |= key_left ^ key_right; } - uint64_t key_left = util::SafeLoad(key_left_ptr + i); - uint64_t key_right = key_right_ptr[i]; - result_or |= tail_mask & (key_left ^ key_right); + uint64_t key_left = 0; + memcpy(&key_left, key_left_ptr + i, num_tail_bytes); + uint64_t key_right = 0; + memcpy(&key_right, key_right_ptr + i, num_tail_bytes); + result_or |= key_left ^ key_right; return result_or == 0 ? 0xff : 0; }); } diff --git a/cpp/src/arrow/compute/row/compare_test.cc b/cpp/src/arrow/compute/row/compare_test.cc new file mode 100644 index 0000000000000..1d8562cd56d3c --- /dev/null +++ b/cpp/src/arrow/compute/row/compare_test.cc @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/row/compare_internal.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { +namespace compute { + +using arrow::bit_util::BytesForBits; +using arrow::internal::CpuInfo; +using arrow::util::MiniBatch; +using arrow::util::TempVectorStack; + +// Specialized case for GH-39577. +TEST(KeyCompare, CompareColumnsToRowsCuriousFSB) { + int fsb_length = 9; + MemoryPool* pool = default_memory_pool(); + TempVectorStack stack; + ASSERT_OK(stack.Init(pool, 8 * MiniBatch::kMiniBatchLength * sizeof(uint64_t))); + + int num_rows = 7; + auto column_right = ArrayFromJSON(fixed_size_binary(fsb_length), R"([ + "000000000", + "111111111", + "222222222", + "333333333", + "444444444", + "555555555", + "666666666"])"); + ExecBatch batch_right({column_right}, num_rows); + + std::vector column_metadatas_right; + ASSERT_OK(ColumnMetadatasFromExecBatch(batch_right, &column_metadatas_right)); + + RowTableMetadata table_metadata_right; + table_metadata_right.FromColumnMetadataVector(column_metadatas_right, sizeof(uint64_t), + sizeof(uint64_t)); + + std::vector column_arrays_right; + ASSERT_OK(ColumnArraysFromExecBatch(batch_right, &column_arrays_right)); + + RowTableImpl row_table; + ASSERT_OK(row_table.Init(pool, table_metadata_right)); + + RowTableEncoder row_encoder; + row_encoder.Init(column_metadatas_right, sizeof(uint64_t), sizeof(uint64_t)); + row_encoder.PrepareEncodeSelected(0, num_rows, column_arrays_right); + + std::vector row_ids_right(num_rows); + std::iota(row_ids_right.begin(), row_ids_right.end(), 0); + ASSERT_OK(row_encoder.EncodeSelected(&row_table, num_rows, row_ids_right.data())); + + auto column_left = ArrayFromJSON(fixed_size_binary(fsb_length), R"([ + "000000000", + "111111111", + "222222222", + "333333333", + "444444444", + "555555555", + "777777777"])"); + ExecBatch batch_left({column_left}, num_rows); + std::vector column_arrays_left; + ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &column_arrays_left)); + + std::vector row_ids_left(num_rows); + std::iota(row_ids_left.begin(), row_ids_left.end(), 0); + + LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack}; + + { + uint32_t num_rows_no_match; + std::vector row_ids_out(num_rows); + KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx, + &num_rows_no_match, row_ids_out.data(), + column_arrays_left, row_table, true, NULLPTR); + ASSERT_EQ(num_rows_no_match, 1); + ASSERT_EQ(row_ids_out[0], 6); + } + + { + std::vector match_bitvector(BytesForBits(num_rows)); + KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx, + NULLPTR, NULLPTR, column_arrays_left, row_table, + true, match_bitvector.data()); + for (int i = 0; i < num_rows; ++i) { + SCOPED_TRACE(i); + ASSERT_EQ(arrow::bit_util::GetBit(match_bitvector.data(), i), i != 6); + } + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/csv/writer_benchmark.cc b/cpp/src/arrow/csv/writer_benchmark.cc index 54c0f50613754..9baa00d48a6d2 100644 --- a/cpp/src/arrow/csv/writer_benchmark.cc +++ b/cpp/src/arrow/csv/writer_benchmark.cc @@ -97,7 +97,7 @@ void BenchmarkWriteCsv(benchmark::State& state, const WriteOptions& options, const RecordBatch& batch) { int64_t total_size = 0; - while (state.KeepRunning()) { + for (auto _ : state) { auto out = io::BufferOutputStream::Create().ValueOrDie(); ABORT_NOT_OK(WriteCSV(batch, options, out.get())); auto buffer = out->Finish().ValueOrDie(); @@ -106,6 +106,7 @@ void BenchmarkWriteCsv(benchmark::State& state, const WriteOptions& options, // byte size of the generated csv dataset state.SetBytesProcessed(total_size); + state.SetItemsProcessed(state.iterations() * batch.num_columns() * batch.num_rows()); state.counters["null_percent"] = static_cast(state.range(0)); } diff --git a/cpp/src/arrow/dataset/file_benchmark.cc b/cpp/src/arrow/dataset/file_benchmark.cc index 8953cbd110643..8aa2ac5a6fa77 100644 --- a/cpp/src/arrow/dataset/file_benchmark.cc +++ b/cpp/src/arrow/dataset/file_benchmark.cc @@ -30,7 +30,12 @@ namespace arrow { namespace dataset { -static std::shared_ptr GetDataset() { +struct SampleDataset { + std::shared_ptr dataset; + int64_t num_fragments; +}; + +static SampleDataset GetDataset() { std::vector files; std::vector paths; for (int a = 0; a < 100; a++) { @@ -50,25 +55,35 @@ static std::shared_ptr GetDataset() { FinishOptions finish_options; finish_options.inspect_options.fragments = 0; EXPECT_OK_AND_ASSIGN(auto dataset, factory->Finish(finish_options)); - return dataset; + return {dataset, static_cast(paths.size())}; } // A benchmark of filtering fragments in a dataset. static void GetAllFragments(benchmark::State& state) { auto dataset = GetDataset(); for (auto _ : state) { - ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto fragments, dataset.dataset->GetFragments()); ABORT_NOT_OK(fragments.Visit([](std::shared_ptr) { return Status::OK(); })); } + state.SetItemsProcessed(state.iterations() * dataset.num_fragments); + state.counters["num_fragments"] = static_cast(dataset.num_fragments); } static void GetFilteredFragments(benchmark::State& state, compute::Expression filter) { auto dataset = GetDataset(); - ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*dataset.dataset->schema())); + int64_t num_filtered_fragments = 0; for (auto _ : state) { - ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments(filter)); - ABORT_NOT_OK(fragments.Visit([](std::shared_ptr) { return Status::OK(); })); + num_filtered_fragments = 0; + ASSERT_OK_AND_ASSIGN(auto fragments, dataset.dataset->GetFragments(filter)); + ABORT_NOT_OK(fragments.Visit([&](std::shared_ptr) { + ++num_filtered_fragments; + return Status::OK(); + })); } + state.SetItemsProcessed(state.iterations() * dataset.num_fragments); + state.counters["num_fragments"] = static_cast(dataset.num_fragments); + state.counters["num_filtered_fragments"] = static_cast(num_filtered_fragments); } using compute::field_ref; diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 0ce08502921f3..140917a2e6341 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -813,11 +813,17 @@ Status ParquetFileFragment::EnsureCompleteMetadata(parquet::arrow::FileReader* r Status ParquetFileFragment::SetMetadata( std::shared_ptr metadata, - std::shared_ptr manifest) { + std::shared_ptr manifest, + std::shared_ptr original_metadata) { DCHECK(row_groups_.has_value()); metadata_ = std::move(metadata); manifest_ = std::move(manifest); + original_metadata_ = original_metadata ? std::move(original_metadata) : metadata_; + // The SchemaDescriptor needs to be owned by a FileMetaData instance, + // because SchemaManifest only stores a raw pointer (GH-39562). + DCHECK_EQ(manifest_->descr, original_metadata_->schema()) + << "SchemaDescriptor should be owned by the original FileMetaData"; statistics_expressions_.resize(row_groups_->size(), compute::literal(true)); statistics_expressions_complete_.resize(manifest_->descr->num_columns(), false); @@ -846,7 +852,8 @@ Result ParquetFileFragment::SplitByRowGroup( parquet_format_.MakeFragment(source_, partition_expression(), physical_schema_, {row_group})); - RETURN_NOT_OK(fragment->SetMetadata(metadata_, manifest_)); + RETURN_NOT_OK(fragment->SetMetadata(metadata_, manifest_, + /*original_metadata=*/original_metadata_)); fragments[i++] = std::move(fragment); } @@ -1106,7 +1113,8 @@ ParquetDatasetFactory::CollectParquetFragments(const Partitioning& partitioning) format_->MakeFragment({path, filesystem_}, std::move(partition_expression), physical_schema_, std::move(row_groups))); - RETURN_NOT_OK(fragment->SetMetadata(metadata_subset, manifest_)); + RETURN_NOT_OK(fragment->SetMetadata(metadata_subset, manifest_, + /*original_metadata=*/metadata_)); fragments[i++] = std::move(fragment); } diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 1e81a34fb3cf0..5141f36385e3f 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -188,7 +188,8 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { std::optional> row_groups); Status SetMetadata(std::shared_ptr metadata, - std::shared_ptr manifest); + std::shared_ptr manifest, + std::shared_ptr original_metadata = {}); // Overridden to opportunistically set metadata since a reader must be opened anyway. Result> ReadPhysicalSchemaImpl() override { @@ -219,6 +220,8 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { std::vector statistics_expressions_complete_; std::shared_ptr metadata_; std::shared_ptr manifest_; + // The FileMetaData that owns the SchemaDescriptor pointed by SchemaManifest. + std::shared_ptr original_metadata_; friend class ParquetFileFormat; friend class ParquetDatasetFactory; diff --git a/cpp/src/arrow/dataset/scanner_benchmark.cc b/cpp/src/arrow/dataset/scanner_benchmark.cc index be953b3555945..287d76418ff16 100644 --- a/cpp/src/arrow/dataset/scanner_benchmark.cc +++ b/cpp/src/arrow/dataset/scanner_benchmark.cc @@ -162,7 +162,11 @@ void ScanOnly( acero::DeclarationToTable(std::move(scan))); ASSERT_GT(collected->num_rows(), 0); - ASSERT_EQ(collected->num_columns(), 2); + if (factory_name == "scan") { + ASSERT_EQ(collected->num_columns(), 6); + } else if (factory_name == "scan2") { + ASSERT_EQ(collected->num_columns(), 2); + } } static constexpr int kScanIdx = 0; diff --git a/cpp/src/arrow/device.cc b/cpp/src/arrow/device.cc index 14d3bac0af1b7..de709923dc44e 100644 --- a/cpp/src/arrow/device.cc +++ b/cpp/src/arrow/device.cc @@ -241,7 +241,11 @@ bool CPUDevice::Equals(const Device& other) const { } std::shared_ptr CPUDevice::memory_manager(MemoryPool* pool) { - return CPUMemoryManager::Make(Instance(), pool); + if (pool == default_memory_pool()) { + return default_cpu_memory_manager(); + } else { + return CPUMemoryManager::Make(Instance(), pool); + } } std::shared_ptr CPUDevice::default_memory_manager() { diff --git a/cpp/src/arrow/filesystem/azurefs.cc b/cpp/src/arrow/filesystem/azurefs.cc index 730adabd48bec..a5179c22190e1 100644 --- a/cpp/src/arrow/filesystem/azurefs.cc +++ b/cpp/src/arrow/filesystem/azurefs.cc @@ -305,31 +305,9 @@ Status ValidateFileLocation(const AzureLocation& location) { return internal::AssertNoTrailingSlash(location.path); } -std::string_view BodyTextView(const Http::RawResponse& raw_response) { - const auto& body = raw_response.GetBody(); -#ifndef NDEBUG - auto& headers = raw_response.GetHeaders(); - auto content_type = headers.find("Content-Type"); - if (content_type != headers.end()) { - DCHECK_EQ(std::string_view{content_type->second}.substr(5), "text/"); - } -#endif - return std::string_view{reinterpret_cast(body.data()), body.size()}; -} - -Status StatusFromErrorResponse(const std::string& url, - const Http::RawResponse& raw_response, - const std::string& context) { - // There isn't an Azure specification that response body on error - // doesn't contain any binary data but we assume it. We hope that - // error response body has useful information for the error. - auto body_text = BodyTextView(raw_response); - return Status::IOError(context, ": ", url, ": ", raw_response.GetReasonPhrase(), " (", - static_cast(raw_response.GetStatusCode()), - "): ", body_text); -} - bool IsContainerNotFound(const Storage::StorageException& e) { + // In some situations, only the ReasonPhrase is set and the + // ErrorCode is empty, so we check both. if (e.ErrorCode == "ContainerNotFound" || e.ReasonPhrase == "The specified container does not exist." || e.ReasonPhrase == "The specified filesystem does not exist.") { @@ -1515,13 +1493,9 @@ class AzureFileSystem::Impl { DCHECK(location.path.empty()); try { auto response = container_client.Delete(); - if (response.Value.Deleted) { - return Status::OK(); - } else { - return StatusFromErrorResponse( - container_client.GetUrl(), *response.RawResponse, - "Failed to delete a container: " + location.container); - } + // Only the "*IfExists" functions ever set Deleted to false. + // All the others either succeed or throw an exception. + DCHECK(response.Value.Deleted); } catch (const Storage::StorageException& exception) { if (IsContainerNotFound(exception)) { return PathNotFound(location); @@ -1530,6 +1504,7 @@ class AzureFileSystem::Impl { "Failed to delete a container: ", location.container, ": ", container_client.GetUrl()); } + return Status::OK(); } /// Deletes contents of a directory and possibly the directory itself @@ -1649,23 +1624,29 @@ class AzureFileSystem::Impl { /// \pre location.container is not empty. /// \pre location.path is not empty. Status DeleteDirOnFileSystem(const DataLake::DataLakeFileSystemClient& adlfs_client, - const AzureLocation& location) { + const AzureLocation& location, bool recursive, + bool require_dir_to_exist) { DCHECK(!location.container.empty()); DCHECK(!location.path.empty()); auto directory_client = adlfs_client.GetDirectoryClient(location.path); - // XXX: should "directory not found" be considered an error? try { - auto response = directory_client.DeleteRecursive(); - if (response.Value.Deleted) { + auto response = + recursive ? directory_client.DeleteRecursive() : directory_client.DeleteEmpty(); + // Only the "*IfExists" functions ever set Deleted to false. + // All the others either succeed or throw an exception. + DCHECK(response.Value.Deleted); + } catch (const Storage::StorageException& exception) { + if (exception.ErrorCode == "FilesystemNotFound" || + exception.ErrorCode == "PathNotFound") { + if (require_dir_to_exist) { + return PathNotFound(location); + } return Status::OK(); - } else { - return StatusFromErrorResponse(directory_client.GetUrl(), *response.RawResponse, - "Failed to delete a directory: " + location.path); } - } catch (const Storage::StorageException& exception) { return ExceptionToStatus(exception, "Failed to delete a directory: ", location.path, ": ", directory_client.GetUrl()); } + return Status::OK(); } /// \pre location.container is not empty. @@ -1855,7 +1836,8 @@ Status AzureFileSystem::DeleteDir(const std::string& path) { return PathNotFound(location); } if (hns_support == HNSSupport::kEnabled) { - return impl_->DeleteDirOnFileSystem(adlfs_client, location); + return impl_->DeleteDirOnFileSystem(adlfs_client, location, /*recursive=*/true, + /*require_dir_to_exist=*/true); } DCHECK_EQ(hns_support, HNSSupport::kDisabled); auto container_client = impl_->GetBlobContainerClient(location.container); diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index 522973bec7231..a1c5250ba66fa 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -282,6 +282,7 @@ TEST(FlightTypes, PollInfo) { std::nullopt}, PollInfo{std::make_unique(info), FlightDescriptor::Command("poll"), 0.1, expiration_time}, + PollInfo{}, }; std::vector reprs = { " " "progress=0.1 expiration_time=2023-06-19 03:14:06.004339000>", + "", }; ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); diff --git a/cpp/src/arrow/flight/serialization_internal.cc b/cpp/src/arrow/flight/serialization_internal.cc index 64a40564afd72..e5a7503a6386b 100644 --- a/cpp/src/arrow/flight/serialization_internal.cc +++ b/cpp/src/arrow/flight/serialization_internal.cc @@ -306,8 +306,10 @@ Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) { // PollInfo Status FromProto(const pb::PollInfo& pb_info, PollInfo* info) { - ARROW_ASSIGN_OR_RAISE(auto flight_info, FromProto(pb_info.info())); - info->info = std::make_unique(std::move(flight_info)); + if (pb_info.has_info()) { + ARROW_ASSIGN_OR_RAISE(auto flight_info, FromProto(pb_info.info())); + info->info = std::make_unique(std::move(flight_info)); + } if (pb_info.has_flight_descriptor()) { FlightDescriptor descriptor; RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &descriptor)); @@ -331,7 +333,9 @@ Status FromProto(const pb::PollInfo& pb_info, PollInfo* info) { } Status ToProto(const PollInfo& info, pb::PollInfo* pb_info) { - RETURN_NOT_OK(ToProto(*info.info, pb_info->mutable_info())); + if (info.info) { + RETURN_NOT_OK(ToProto(*info.info, pb_info->mutable_info())); + } if (info.descriptor) { RETURN_NOT_OK(ToProto(*info.descriptor, pb_info->mutable_flight_descriptor())); } diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 9da83fa8a11f2..1d43c41b69d9f 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -373,7 +373,12 @@ arrow::Result> PollInfo::Deserialize( std::string PollInfo::ToString() const { std::stringstream ss; - ss << "(*other.info) : NULLPTR), descriptor(other.descriptor), progress(other.progress), expiration_time(other.expiration_time) {} + PollInfo(PollInfo&& other) noexcept = default; // NOLINT(runtime/explicit) + ~PollInfo() = default; + PollInfo& operator=(const PollInfo& other) { + info = other.info ? std::make_unique(*other.info) : NULLPTR; + descriptor = other.descriptor; + progress = other.progress; + expiration_time = other.expiration_time; + return *this; + } + PollInfo& operator=(PollInfo&& other) = default; /// \brief Get the wire-format representation of this type. /// diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc index fbcd6f139b6d2..e196dd7bf5389 100644 --- a/cpp/src/arrow/ipc/message.cc +++ b/cpp/src/arrow/ipc/message.cc @@ -607,6 +607,7 @@ class MessageDecoder::MessageDecoderImpl { MemoryPool* pool, bool skip_body) : listener_(std::move(listener)), pool_(pool), + memory_manager_(CPUDevice::memory_manager(pool_)), state_(initial_state), next_required_size_(initial_next_required_size), chunks_(), @@ -822,8 +823,7 @@ class MessageDecoder::MessageDecoderImpl { if (buffer->is_cpu()) { metadata_ = buffer; } else { - ARROW_ASSIGN_OR_RAISE(metadata_, - Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_))); + ARROW_ASSIGN_OR_RAISE(metadata_, Buffer::ViewOrCopy(buffer, memory_manager_)); } return ConsumeMetadata(); } @@ -834,16 +834,15 @@ class MessageDecoder::MessageDecoderImpl { if (chunks_[0]->is_cpu()) { metadata_ = std::move(chunks_[0]); } else { - ARROW_ASSIGN_OR_RAISE( - metadata_, - Buffer::ViewOrCopy(chunks_[0], CPUDevice::memory_manager(pool_))); + ARROW_ASSIGN_OR_RAISE(metadata_, + Buffer::ViewOrCopy(chunks_[0], memory_manager_)); } chunks_.erase(chunks_.begin()); } else { metadata_ = SliceBuffer(chunks_[0], 0, next_required_size_); if (!chunks_[0]->is_cpu()) { - ARROW_ASSIGN_OR_RAISE( - metadata_, Buffer::ViewOrCopy(metadata_, CPUDevice::memory_manager(pool_))); + ARROW_ASSIGN_OR_RAISE(metadata_, + Buffer::ViewOrCopy(metadata_, memory_manager_)); } chunks_[0] = SliceBuffer(chunks_[0], next_required_size_); } @@ -911,8 +910,7 @@ class MessageDecoder::MessageDecoderImpl { if (buffer->is_cpu()) { return util::SafeLoadAs(buffer->data()); } else { - ARROW_ASSIGN_OR_RAISE(auto cpu_buffer, - Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_))); + ARROW_ASSIGN_OR_RAISE(auto cpu_buffer, Buffer::ViewOrCopy(buffer, memory_manager_)); return util::SafeLoadAs(cpu_buffer->data()); } } @@ -924,8 +922,7 @@ class MessageDecoder::MessageDecoderImpl { std::shared_ptr last_chunk; for (auto& chunk : chunks_) { if (!chunk->is_cpu()) { - ARROW_ASSIGN_OR_RAISE( - chunk, Buffer::ViewOrCopy(chunk, CPUDevice::memory_manager(pool_))); + ARROW_ASSIGN_OR_RAISE(chunk, Buffer::ViewOrCopy(chunk, memory_manager_)); } auto data = chunk->data(); auto data_size = chunk->size(); @@ -951,6 +948,7 @@ class MessageDecoder::MessageDecoderImpl { std::shared_ptr listener_; MemoryPool* pool_; + std::shared_ptr memory_manager_; State state_; int64_t next_required_size_; std::vector> chunks_; diff --git a/cpp/src/arrow/testing/generator.cc b/cpp/src/arrow/testing/generator.cc index 36c88c20efe6e..5ea6a541e8922 100644 --- a/cpp/src/arrow/testing/generator.cc +++ b/cpp/src/arrow/testing/generator.cc @@ -38,6 +38,7 @@ #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/string.h" @@ -103,7 +104,13 @@ std::shared_ptr ConstantArrayGenerator::Float64(int64_t size, std::shared_ptr ConstantArrayGenerator::String(int64_t size, std::string value) { - return ConstantArray(size, value); + using BuilderType = typename TypeTraits::BuilderType; + auto type = TypeTraits::type_singleton(); + auto builder_fn = [&](BuilderType* builder) { + DCHECK_OK(builder->Append(std::string_view(value.data()))); + }; + return ArrayFromBuilderVisitor(type, value.size() * size, size, builder_fn) + .ValueOrDie(); } std::shared_ptr ConstantArrayGenerator::Zeroes( diff --git a/cpp/src/arrow/util/bit_stream_utils.h b/cpp/src/arrow/util/bit_stream_utils.h index 2afb2e5193697..811694e43b76c 100644 --- a/cpp/src/arrow/util/bit_stream_utils.h +++ b/cpp/src/arrow/util/bit_stream_utils.h @@ -183,7 +183,7 @@ class BitReader { /// Returns the number of bytes left in the stream, not including the current /// byte (i.e., there may be an additional fraction of a byte). - int bytes_left() { + int bytes_left() const { return max_bytes_ - (byte_offset_ + static_cast(bit_util::BytesForBits(bit_offset_))); } diff --git a/cpp/src/arrow/util/byte_stream_split_internal.h b/cpp/src/arrow/util/byte_stream_split_internal.h index 4bc732ec24313..f70b3991473fa 100644 --- a/cpp/src/arrow/util/byte_stream_split_internal.h +++ b/cpp/src/arrow/util/byte_stream_split_internal.h @@ -26,7 +26,6 @@ #include #ifdef ARROW_HAVE_SSE4_2 -// Enable the SIMD for ByteStreamSplit Encoder/Decoder #define ARROW_HAVE_SIMD_SPLIT #endif // ARROW_HAVE_SSE4_2 @@ -37,17 +36,15 @@ namespace arrow::util::internal { // #if defined(ARROW_HAVE_SSE4_2) -template +template void ByteStreamSplitDecodeSse2(const uint8_t* data, int64_t num_values, int64_t stride, - T* out) { - constexpr size_t kNumStreams = sizeof(T); - static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); - constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U); + uint8_t* out) { + static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams."); + constexpr int kNumStreamsLog2 = (kNumStreams == 8 ? 3 : 2); constexpr int64_t kBlockSize = sizeof(__m128i) * kNumStreams; - const int64_t size = num_values * sizeof(T); + const int64_t size = num_values * kNumStreams; const int64_t num_blocks = size / kBlockSize; - uint8_t* output_data = reinterpret_cast(out); // First handle suffix. // This helps catch if the simd-based processing overflows into the suffix @@ -55,11 +52,11 @@ void ByteStreamSplitDecodeSse2(const uint8_t* data, int64_t num_values, int64_t const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams; for (int64_t i = num_processed_elements; i < num_values; ++i) { uint8_t gathered_byte_data[kNumStreams]; - for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * stride + i; + for (int b = 0; b < kNumStreams; ++b) { + const int64_t byte_index = b * stride + i; gathered_byte_data[b] = data[byte_index]; } - out[i] = arrow::util::SafeLoadAs(&gathered_byte_data[0]); + memcpy(out + i * kNumStreams, gathered_byte_data, kNumStreams); } // The blocks get processed hierarchically using the unpack intrinsics. @@ -67,53 +64,52 @@ void ByteStreamSplitDecodeSse2(const uint8_t* data, int64_t num_values, int64_t // Stage 1: AAAA BBBB CCCC DDDD // Stage 2: ACAC ACAC BDBD BDBD // Stage 3: ABCD ABCD ABCD ABCD - __m128i stage[kNumStreamsLog2 + 1U][kNumStreams]; - constexpr size_t kNumStreamsHalf = kNumStreams / 2U; + __m128i stage[kNumStreamsLog2 + 1][kNumStreams]; + constexpr int kNumStreamsHalf = kNumStreams / 2U; for (int64_t i = 0; i < num_blocks; ++i) { - for (size_t j = 0; j < kNumStreams; ++j) { + for (int j = 0; j < kNumStreams; ++j) { stage[0][j] = _mm_loadu_si128( reinterpret_cast(&data[i * sizeof(__m128i) + j * stride])); } - for (size_t step = 0; step < kNumStreamsLog2; ++step) { - for (size_t j = 0; j < kNumStreamsHalf; ++j) { + for (int step = 0; step < kNumStreamsLog2; ++step) { + for (int j = 0; j < kNumStreamsHalf; ++j) { stage[step + 1U][j * 2] = _mm_unpacklo_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]); stage[step + 1U][j * 2 + 1U] = _mm_unpackhi_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]); } } - for (size_t j = 0; j < kNumStreams; ++j) { - _mm_storeu_si128(reinterpret_cast<__m128i*>( - &output_data[(i * kNumStreams + j) * sizeof(__m128i)]), - stage[kNumStreamsLog2][j]); + for (int j = 0; j < kNumStreams; ++j) { + _mm_storeu_si128( + reinterpret_cast<__m128i*>(out + (i * kNumStreams + j) * sizeof(__m128i)), + stage[kNumStreamsLog2][j]); } } } -template -void ByteStreamSplitEncodeSse2(const uint8_t* raw_values, const size_t num_values, +template +void ByteStreamSplitEncodeSse2(const uint8_t* raw_values, const int64_t num_values, uint8_t* output_buffer_raw) { - constexpr size_t kNumStreams = sizeof(T); - static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); - constexpr size_t kBlockSize = sizeof(__m128i) * kNumStreams; + static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams."); + constexpr int kBlockSize = sizeof(__m128i) * kNumStreams; __m128i stage[3][kNumStreams]; __m128i final_result[kNumStreams]; - const size_t size = num_values * sizeof(T); - const size_t num_blocks = size / kBlockSize; + const int64_t size = num_values * kNumStreams; + const int64_t num_blocks = size / kBlockSize; const __m128i* raw_values_sse = reinterpret_cast(raw_values); __m128i* output_buffer_streams[kNumStreams]; - for (size_t i = 0; i < kNumStreams; ++i) { + for (int i = 0; i < kNumStreams; ++i) { output_buffer_streams[i] = reinterpret_cast<__m128i*>(&output_buffer_raw[num_values * i]); } // First handle suffix. - const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T); - for (size_t i = num_processed_elements; i < num_values; ++i) { - for (size_t j = 0U; j < kNumStreams; ++j) { + const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams; + for (int64_t i = num_processed_elements; i < num_values; ++i) { + for (int j = 0; j < kNumStreams; ++j) { const uint8_t byte_in_value = raw_values[i * kNumStreams + j]; output_buffer_raw[j * num_values + i] = byte_in_value; } @@ -131,48 +127,47 @@ void ByteStreamSplitEncodeSse2(const uint8_t* raw_values, const size_t num_value // 0: AAAA AAAA BBBB BBBB 1: CCCC CCCC DDDD DDDD ... // Step 4: __mm_unpacklo_epi64 and _mm_unpackhi_epi64: // 0: AAAA AAAA AAAA AAAA 1: BBBB BBBB BBBB BBBB ... - for (size_t block_index = 0; block_index < num_blocks; ++block_index) { + for (int64_t block_index = 0; block_index < num_blocks; ++block_index) { // First copy the data to stage 0. - for (size_t i = 0; i < kNumStreams; ++i) { + for (int i = 0; i < kNumStreams; ++i) { stage[0][i] = _mm_loadu_si128(&raw_values_sse[block_index * kNumStreams + i]); } // The shuffling of bytes is performed through the unpack intrinsics. // In my measurements this gives better performance then an implementation // which uses the shuffle intrinsics. - for (size_t stage_lvl = 0; stage_lvl < 2U; ++stage_lvl) { - for (size_t i = 0; i < kNumStreams / 2U; ++i) { + for (int stage_lvl = 0; stage_lvl < 2; ++stage_lvl) { + for (int i = 0; i < kNumStreams / 2; ++i) { stage[stage_lvl + 1][i * 2] = _mm_unpacklo_epi8(stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]); stage[stage_lvl + 1][i * 2 + 1] = _mm_unpackhi_epi8(stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]); } } - if constexpr (kNumStreams == 8U) { + if constexpr (kNumStreams == 8) { // This is the path for double. __m128i tmp[8]; - for (size_t i = 0; i < 4; ++i) { + for (int i = 0; i < 4; ++i) { tmp[i * 2] = _mm_unpacklo_epi32(stage[2][i], stage[2][i + 4]); tmp[i * 2 + 1] = _mm_unpackhi_epi32(stage[2][i], stage[2][i + 4]); } - - for (size_t i = 0; i < 4; ++i) { + for (int i = 0; i < 4; ++i) { final_result[i * 2] = _mm_unpacklo_epi32(tmp[i], tmp[i + 4]); final_result[i * 2 + 1] = _mm_unpackhi_epi32(tmp[i], tmp[i + 4]); } } else { // this is the path for float. __m128i tmp[4]; - for (size_t i = 0; i < 2; ++i) { + for (int i = 0; i < 2; ++i) { tmp[i * 2] = _mm_unpacklo_epi8(stage[2][i * 2], stage[2][i * 2 + 1]); tmp[i * 2 + 1] = _mm_unpackhi_epi8(stage[2][i * 2], stage[2][i * 2 + 1]); } - for (size_t i = 0; i < 2; ++i) { + for (int i = 0; i < 2; ++i) { final_result[i * 2] = _mm_unpacklo_epi64(tmp[i], tmp[i + 2]); final_result[i * 2 + 1] = _mm_unpackhi_epi64(tmp[i], tmp[i + 2]); } } - for (size_t i = 0; i < kNumStreams; ++i) { + for (int i = 0; i < kNumStreams; ++i) { _mm_storeu_si128(&output_buffer_streams[i][block_index], final_result[i]); } } @@ -180,52 +175,50 @@ void ByteStreamSplitEncodeSse2(const uint8_t* raw_values, const size_t num_value #endif // ARROW_HAVE_SSE4_2 #if defined(ARROW_HAVE_AVX2) -template +template void ByteStreamSplitDecodeAvx2(const uint8_t* data, int64_t num_values, int64_t stride, - T* out) { - constexpr size_t kNumStreams = sizeof(T); - static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); - constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U); + uint8_t* out) { + static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams."); + constexpr int kNumStreamsLog2 = (kNumStreams == 8 ? 3 : 2); constexpr int64_t kBlockSize = sizeof(__m256i) * kNumStreams; - const int64_t size = num_values * sizeof(T); + const int64_t size = num_values * kNumStreams; if (size < kBlockSize) // Back to SSE for small size - return ByteStreamSplitDecodeSse2(data, num_values, stride, out); + return ByteStreamSplitDecodeSse2(data, num_values, stride, out); const int64_t num_blocks = size / kBlockSize; - uint8_t* output_data = reinterpret_cast(out); // First handle suffix. const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams; for (int64_t i = num_processed_elements; i < num_values; ++i) { uint8_t gathered_byte_data[kNumStreams]; - for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * stride + i; + for (int b = 0; b < kNumStreams; ++b) { + const int64_t byte_index = b * stride + i; gathered_byte_data[b] = data[byte_index]; } - out[i] = arrow::util::SafeLoadAs(&gathered_byte_data[0]); + memcpy(out + i * kNumStreams, gathered_byte_data, kNumStreams); } // Processed hierarchically using unpack intrinsics, then permute intrinsics. - __m256i stage[kNumStreamsLog2 + 1U][kNumStreams]; + __m256i stage[kNumStreamsLog2 + 1][kNumStreams]; __m256i final_result[kNumStreams]; - constexpr size_t kNumStreamsHalf = kNumStreams / 2U; + constexpr int kNumStreamsHalf = kNumStreams / 2; for (int64_t i = 0; i < num_blocks; ++i) { - for (size_t j = 0; j < kNumStreams; ++j) { + for (int j = 0; j < kNumStreams; ++j) { stage[0][j] = _mm256_loadu_si256( reinterpret_cast(&data[i * sizeof(__m256i) + j * stride])); } - for (size_t step = 0; step < kNumStreamsLog2; ++step) { - for (size_t j = 0; j < kNumStreamsHalf; ++j) { - stage[step + 1U][j * 2] = + for (int step = 0; step < kNumStreamsLog2; ++step) { + for (int j = 0; j < kNumStreamsHalf; ++j) { + stage[step + 1][j * 2] = _mm256_unpacklo_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]); - stage[step + 1U][j * 2 + 1U] = + stage[step + 1][j * 2 + 1] = _mm256_unpackhi_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]); } } - if constexpr (kNumStreams == 8U) { + if constexpr (kNumStreams == 8) { // path for double, 128i index: // {0x00, 0x08}, {0x01, 0x09}, {0x02, 0x0A}, {0x03, 0x0B}, // {0x04, 0x0C}, {0x05, 0x0D}, {0x06, 0x0E}, {0x07, 0x0F}, @@ -258,40 +251,41 @@ void ByteStreamSplitDecodeAvx2(const uint8_t* data, int64_t num_values, int64_t stage[kNumStreamsLog2][3], 0b00110001); } - for (size_t j = 0; j < kNumStreams; ++j) { - _mm256_storeu_si256(reinterpret_cast<__m256i*>( - &output_data[(i * kNumStreams + j) * sizeof(__m256i)]), - final_result[j]); + for (int j = 0; j < kNumStreams; ++j) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out + (i * kNumStreams + j) * sizeof(__m256i)), + final_result[j]); } } } -template -void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const size_t num_values, +template +void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const int64_t num_values, uint8_t* output_buffer_raw) { - constexpr size_t kNumStreams = sizeof(T); - static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); - constexpr size_t kBlockSize = sizeof(__m256i) * kNumStreams; + static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams."); + constexpr int kBlockSize = sizeof(__m256i) * kNumStreams; - if constexpr (kNumStreams == 8U) // Back to SSE, currently no path for double. - return ByteStreamSplitEncodeSse2(raw_values, num_values, output_buffer_raw); + if constexpr (kNumStreams == 8) // Back to SSE, currently no path for double. + return ByteStreamSplitEncodeSse2(raw_values, num_values, + output_buffer_raw); - const size_t size = num_values * sizeof(T); + const int64_t size = num_values * kNumStreams; if (size < kBlockSize) // Back to SSE for small size - return ByteStreamSplitEncodeSse2(raw_values, num_values, output_buffer_raw); - const size_t num_blocks = size / kBlockSize; + return ByteStreamSplitEncodeSse2(raw_values, num_values, + output_buffer_raw); + const int64_t num_blocks = size / kBlockSize; const __m256i* raw_values_simd = reinterpret_cast(raw_values); __m256i* output_buffer_streams[kNumStreams]; - for (size_t i = 0; i < kNumStreams; ++i) { + for (int i = 0; i < kNumStreams; ++i) { output_buffer_streams[i] = reinterpret_cast<__m256i*>(&output_buffer_raw[num_values * i]); } // First handle suffix. - const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T); - for (size_t i = num_processed_elements; i < num_values; ++i) { - for (size_t j = 0U; j < kNumStreams; ++j) { + const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams; + for (int64_t i = num_processed_elements; i < num_values; ++i) { + for (int j = 0; j < kNumStreams; ++j) { const uint8_t byte_in_value = raw_values[i * kNumStreams + j]; output_buffer_raw[j * num_values + i] = byte_in_value; } @@ -301,20 +295,20 @@ void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const size_t num_value // 1. Processed hierarchically to 32i block using the unpack intrinsics. // 2. Pack 128i block using _mm256_permutevar8x32_epi32. // 3. Pack final 256i block with _mm256_permute2x128_si256. - constexpr size_t kNumUnpack = 3U; + constexpr int kNumUnpack = 3; __m256i stage[kNumUnpack + 1][kNumStreams]; static const __m256i kPermuteMask = _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); __m256i permute[kNumStreams]; __m256i final_result[kNumStreams]; - for (size_t block_index = 0; block_index < num_blocks; ++block_index) { - for (size_t i = 0; i < kNumStreams; ++i) { + for (int64_t block_index = 0; block_index < num_blocks; ++block_index) { + for (int i = 0; i < kNumStreams; ++i) { stage[0][i] = _mm256_loadu_si256(&raw_values_simd[block_index * kNumStreams + i]); } - for (size_t stage_lvl = 0; stage_lvl < kNumUnpack; ++stage_lvl) { - for (size_t i = 0; i < kNumStreams / 2U; ++i) { + for (int stage_lvl = 0; stage_lvl < kNumUnpack; ++stage_lvl) { + for (int i = 0; i < kNumStreams / 2; ++i) { stage[stage_lvl + 1][i * 2] = _mm256_unpacklo_epi8(stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]); stage[stage_lvl + 1][i * 2 + 1] = @@ -322,7 +316,7 @@ void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const size_t num_value } } - for (size_t i = 0; i < kNumStreams; ++i) { + for (int i = 0; i < kNumStreams; ++i) { permute[i] = _mm256_permutevar8x32_epi32(stage[kNumUnpack][i], kPermuteMask); } @@ -331,7 +325,7 @@ void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const size_t num_value final_result[2] = _mm256_permute2x128_si256(permute[1], permute[3], 0b00100000); final_result[3] = _mm256_permute2x128_si256(permute[1], permute[3], 0b00110001); - for (size_t i = 0; i < kNumStreams; ++i) { + for (int i = 0; i < kNumStreams; ++i) { _mm256_storeu_si256(&output_buffer_streams[i][block_index], final_result[i]); } } @@ -339,53 +333,51 @@ void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const size_t num_value #endif // ARROW_HAVE_AVX2 #if defined(ARROW_HAVE_AVX512) -template +template void ByteStreamSplitDecodeAvx512(const uint8_t* data, int64_t num_values, int64_t stride, - T* out) { - constexpr size_t kNumStreams = sizeof(T); - static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); - constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U); + uint8_t* out) { + static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams."); + constexpr int kNumStreamsLog2 = (kNumStreams == 8 ? 3 : 2); constexpr int64_t kBlockSize = sizeof(__m512i) * kNumStreams; - const int64_t size = num_values * sizeof(T); + const int64_t size = num_values * kNumStreams; if (size < kBlockSize) // Back to AVX2 for small size - return ByteStreamSplitDecodeAvx2(data, num_values, stride, out); + return ByteStreamSplitDecodeAvx2(data, num_values, stride, out); const int64_t num_blocks = size / kBlockSize; - uint8_t* output_data = reinterpret_cast(out); // First handle suffix. const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams; for (int64_t i = num_processed_elements; i < num_values; ++i) { uint8_t gathered_byte_data[kNumStreams]; - for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * stride + i; + for (int b = 0; b < kNumStreams; ++b) { + const int64_t byte_index = b * stride + i; gathered_byte_data[b] = data[byte_index]; } - out[i] = arrow::util::SafeLoadAs(&gathered_byte_data[0]); + memcpy(out + i * kNumStreams, gathered_byte_data, kNumStreams); } // Processed hierarchically using the unpack, then two shuffles. - __m512i stage[kNumStreamsLog2 + 1U][kNumStreams]; + __m512i stage[kNumStreamsLog2 + 1][kNumStreams]; __m512i shuffle[kNumStreams]; __m512i final_result[kNumStreams]; - constexpr size_t kNumStreamsHalf = kNumStreams / 2U; + constexpr int kNumStreamsHalf = kNumStreams / 2U; for (int64_t i = 0; i < num_blocks; ++i) { - for (size_t j = 0; j < kNumStreams; ++j) { + for (int j = 0; j < kNumStreams; ++j) { stage[0][j] = _mm512_loadu_si512( reinterpret_cast(&data[i * sizeof(__m512i) + j * stride])); } - for (size_t step = 0; step < kNumStreamsLog2; ++step) { - for (size_t j = 0; j < kNumStreamsHalf; ++j) { - stage[step + 1U][j * 2] = + for (int step = 0; step < kNumStreamsLog2; ++step) { + for (int j = 0; j < kNumStreamsHalf; ++j) { + stage[step + 1][j * 2] = _mm512_unpacklo_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]); - stage[step + 1U][j * 2 + 1U] = + stage[step + 1][j * 2 + 1] = _mm512_unpackhi_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]); } } - if constexpr (kNumStreams == 8U) { + if constexpr (kNumStreams == 8) { // path for double, 128i index: // {0x00, 0x04, 0x08, 0x0C}, {0x10, 0x14, 0x18, 0x1C}, // {0x01, 0x05, 0x09, 0x0D}, {0x11, 0x15, 0x19, 0x1D}, @@ -435,49 +427,49 @@ void ByteStreamSplitDecodeAvx512(const uint8_t* data, int64_t num_values, int64_ final_result[3] = _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b11011101); } - for (size_t j = 0; j < kNumStreams; ++j) { - _mm512_storeu_si512(reinterpret_cast<__m512i*>( - &output_data[(i * kNumStreams + j) * sizeof(__m512i)]), - final_result[j]); + for (int j = 0; j < kNumStreams; ++j) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>(out + (i * kNumStreams + j) * sizeof(__m512i)), + final_result[j]); } } } -template -void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const size_t num_values, +template +void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const int64_t num_values, uint8_t* output_buffer_raw) { - constexpr size_t kNumStreams = sizeof(T); - static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams."); - constexpr size_t kBlockSize = sizeof(__m512i) * kNumStreams; + static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams."); + constexpr int kBlockSize = sizeof(__m512i) * kNumStreams; - const size_t size = num_values * sizeof(T); + const int64_t size = num_values * kNumStreams; if (size < kBlockSize) // Back to AVX2 for small size - return ByteStreamSplitEncodeAvx2(raw_values, num_values, output_buffer_raw); + return ByteStreamSplitEncodeAvx2(raw_values, num_values, + output_buffer_raw); - const size_t num_blocks = size / kBlockSize; + const int64_t num_blocks = size / kBlockSize; const __m512i* raw_values_simd = reinterpret_cast(raw_values); __m512i* output_buffer_streams[kNumStreams]; - for (size_t i = 0; i < kNumStreams; ++i) { + for (int i = 0; i < kNumStreams; ++i) { output_buffer_streams[i] = reinterpret_cast<__m512i*>(&output_buffer_raw[num_values * i]); } // First handle suffix. - const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T); - for (size_t i = num_processed_elements; i < num_values; ++i) { - for (size_t j = 0U; j < kNumStreams; ++j) { + const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams; + for (int64_t i = num_processed_elements; i < num_values; ++i) { + for (int j = 0; j < kNumStreams; ++j) { const uint8_t byte_in_value = raw_values[i * kNumStreams + j]; output_buffer_raw[j * num_values + i] = byte_in_value; } } - constexpr size_t KNumUnpack = (kNumStreams == 8U) ? 2U : 3U; + constexpr int KNumUnpack = (kNumStreams == 8) ? 2 : 3; __m512i final_result[kNumStreams]; __m512i unpack[KNumUnpack + 1][kNumStreams]; __m512i permutex[kNumStreams]; __m512i permutex_mask; - if constexpr (kNumStreams == 8U) { + if constexpr (kNumStreams == 8) { // use _mm512_set_epi32, no _mm512_set_epi16 for some old gcc version. permutex_mask = _mm512_set_epi32(0x001F0017, 0x000F0007, 0x001E0016, 0x000E0006, 0x001D0015, 0x000D0005, 0x001C0014, 0x000C0004, @@ -488,13 +480,13 @@ void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const size_t num_val 0x09, 0x05, 0x01, 0x0C, 0x08, 0x04, 0x00); } - for (size_t block_index = 0; block_index < num_blocks; ++block_index) { - for (size_t i = 0; i < kNumStreams; ++i) { + for (int64_t block_index = 0; block_index < num_blocks; ++block_index) { + for (int i = 0; i < kNumStreams; ++i) { unpack[0][i] = _mm512_loadu_si512(&raw_values_simd[block_index * kNumStreams + i]); } - for (size_t unpack_lvl = 0; unpack_lvl < KNumUnpack; ++unpack_lvl) { - for (size_t i = 0; i < kNumStreams / 2U; ++i) { + for (int unpack_lvl = 0; unpack_lvl < KNumUnpack; ++unpack_lvl) { + for (int i = 0; i < kNumStreams / 2; ++i) { unpack[unpack_lvl + 1][i * 2] = _mm512_unpacklo_epi8( unpack[unpack_lvl][i * 2], unpack[unpack_lvl][i * 2 + 1]); unpack[unpack_lvl + 1][i * 2 + 1] = _mm512_unpackhi_epi8( @@ -502,7 +494,7 @@ void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const size_t num_val } } - if constexpr (kNumStreams == 8U) { + if constexpr (kNumStreams == 8) { // path for double // 1. unpack to epi16 block // 2. permutexvar_epi16 to 128i block @@ -511,7 +503,7 @@ void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const size_t num_val // {0x01, 0x05, 0x09, 0x0D}, {0x11, 0x15, 0x19, 0x1D}, // {0x02, 0x06, 0x0A, 0x0E}, {0x12, 0x16, 0x1A, 0x1E}, // {0x03, 0x07, 0x0B, 0x0F}, {0x13, 0x17, 0x1B, 0x1F}, - for (size_t i = 0; i < kNumStreams; ++i) + for (int i = 0; i < kNumStreams; ++i) permutex[i] = _mm512_permutexvar_epi16(permutex_mask, unpack[KNumUnpack][i]); __m512i shuffle[kNumStreams]; @@ -537,7 +529,7 @@ void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const size_t num_val // 1. Processed hierarchically to 32i block using the unpack intrinsics. // 2. Pack 128i block using _mm256_permutevar8x32_epi32. // 3. Pack final 256i block with _mm256_permute2x128_si256. - for (size_t i = 0; i < kNumStreams; ++i) + for (int i = 0; i < kNumStreams; ++i) permutex[i] = _mm512_permutexvar_epi32(permutex_mask, unpack[KNumUnpack][i]); final_result[0] = _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b01000100); @@ -546,7 +538,7 @@ void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const size_t num_val final_result[3] = _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b11101110); } - for (size_t i = 0; i < kNumStreams; ++i) { + for (int i = 0; i < kNumStreams; ++i) { _mm512_storeu_si512(&output_buffer_streams[i][block_index], final_result[i]); } } @@ -554,32 +546,32 @@ void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const size_t num_val #endif // ARROW_HAVE_AVX512 #if defined(ARROW_HAVE_SIMD_SPLIT) -template +template void inline ByteStreamSplitDecodeSimd(const uint8_t* data, int64_t num_values, - int64_t stride, T* out) { + int64_t stride, uint8_t* out) { #if defined(ARROW_HAVE_AVX512) - return ByteStreamSplitDecodeAvx512(data, num_values, stride, out); + return ByteStreamSplitDecodeAvx512(data, num_values, stride, out); #elif defined(ARROW_HAVE_AVX2) - return ByteStreamSplitDecodeAvx2(data, num_values, stride, out); + return ByteStreamSplitDecodeAvx2(data, num_values, stride, out); #elif defined(ARROW_HAVE_SSE4_2) - return ByteStreamSplitDecodeSse2(data, num_values, stride, out); + return ByteStreamSplitDecodeSse2(data, num_values, stride, out); #else #error "ByteStreamSplitDecodeSimd not implemented" #endif } -template +template void inline ByteStreamSplitEncodeSimd(const uint8_t* raw_values, const int64_t num_values, uint8_t* output_buffer_raw) { #if defined(ARROW_HAVE_AVX512) - return ByteStreamSplitEncodeAvx512(raw_values, static_cast(num_values), - output_buffer_raw); + return ByteStreamSplitEncodeAvx512(raw_values, num_values, + output_buffer_raw); #elif defined(ARROW_HAVE_AVX2) - return ByteStreamSplitEncodeAvx2(raw_values, static_cast(num_values), - output_buffer_raw); + return ByteStreamSplitEncodeAvx2(raw_values, num_values, + output_buffer_raw); #elif defined(ARROW_HAVE_SSE4_2) - return ByteStreamSplitEncodeSse2(raw_values, static_cast(num_values), - output_buffer_raw); + return ByteStreamSplitEncodeSse2(raw_values, num_values, + output_buffer_raw); #else #error "ByteStreamSplitEncodeSimd not implemented" #endif @@ -678,10 +670,9 @@ inline void DoMergeStreams(const uint8_t** src_streams, int width, int64_t nvalu } } -template +template void ByteStreamSplitEncodeScalar(const uint8_t* raw_values, const int64_t num_values, uint8_t* output_buffer_raw) { - constexpr int kNumStreams = static_cast(sizeof(T)); std::array dest_streams; for (int stream = 0; stream < kNumStreams; ++stream) { dest_streams[stream] = &output_buffer_raw[stream * num_values]; @@ -689,35 +680,35 @@ void ByteStreamSplitEncodeScalar(const uint8_t* raw_values, const int64_t num_va DoSplitStreams(raw_values, kNumStreams, num_values, dest_streams.data()); } -template +template void ByteStreamSplitDecodeScalar(const uint8_t* data, int64_t num_values, int64_t stride, - T* out) { - constexpr int kNumStreams = static_cast(sizeof(T)); + uint8_t* out) { std::array src_streams; for (int stream = 0; stream < kNumStreams; ++stream) { src_streams[stream] = &data[stream * stride]; } - DoMergeStreams(src_streams.data(), kNumStreams, num_values, - reinterpret_cast(out)); + DoMergeStreams(src_streams.data(), kNumStreams, num_values, out); } -template +template void inline ByteStreamSplitEncode(const uint8_t* raw_values, const int64_t num_values, uint8_t* output_buffer_raw) { #if defined(ARROW_HAVE_SIMD_SPLIT) - return ByteStreamSplitEncodeSimd(raw_values, num_values, output_buffer_raw); + return ByteStreamSplitEncodeSimd(raw_values, num_values, + output_buffer_raw); #else - return ByteStreamSplitEncodeScalar(raw_values, num_values, output_buffer_raw); + return ByteStreamSplitEncodeScalar(raw_values, num_values, + output_buffer_raw); #endif } -template +template void inline ByteStreamSplitDecode(const uint8_t* data, int64_t num_values, int64_t stride, - T* out) { + uint8_t* out) { #if defined(ARROW_HAVE_SIMD_SPLIT) - return ByteStreamSplitDecodeSimd(data, num_values, stride, out); + return ByteStreamSplitDecodeSimd(data, num_values, stride, out); #else - return ByteStreamSplitDecodeScalar(data, num_values, stride, out); + return ByteStreamSplitDecodeScalar(data, num_values, stride, out); #endif } diff --git a/cpp/src/arrow/util/byte_stream_split_test.cc b/cpp/src/arrow/util/byte_stream_split_test.cc index c98f0a086738b..71c6063179ea6 100644 --- a/cpp/src/arrow/util/byte_stream_split_test.cc +++ b/cpp/src/arrow/util/byte_stream_split_test.cc @@ -61,18 +61,30 @@ void ReferenceByteStreamSplitEncode(const uint8_t* src, int width, template class TestByteStreamSplitSpecialized : public ::testing::Test { public: - using EncodeFunc = NamedFunc)>>; - using DecodeFunc = NamedFunc)>>; - static constexpr int kWidth = static_cast(sizeof(T)); + using EncodeFunc = NamedFunc)>>; + using DecodeFunc = NamedFunc)>>; + void SetUp() override { encode_funcs_.push_back({"reference", &ReferenceEncode}); - encode_funcs_.push_back({"scalar", &ByteStreamSplitEncodeScalar}); - decode_funcs_.push_back({"scalar", &ByteStreamSplitDecodeScalar}); + encode_funcs_.push_back({"scalar", &ByteStreamSplitEncodeScalar}); + decode_funcs_.push_back({"scalar", &ByteStreamSplitDecodeScalar}); #if defined(ARROW_HAVE_SIMD_SPLIT) - encode_funcs_.push_back({"simd", &ByteStreamSplitEncodeSimd}); - decode_funcs_.push_back({"simd", &ByteStreamSplitDecodeSimd}); + encode_funcs_.push_back({"simd", &ByteStreamSplitEncodeSimd}); + decode_funcs_.push_back({"simd", &ByteStreamSplitDecodeSimd}); +#endif +#if defined(ARROW_HAVE_SSE4_2) + encode_funcs_.push_back({"sse2", &ByteStreamSplitEncodeSse2}); + decode_funcs_.push_back({"sse2", &ByteStreamSplitDecodeSse2}); +#endif +#if defined(ARROW_HAVE_AVX2) + encode_funcs_.push_back({"avx2", &ByteStreamSplitEncodeAvx2}); + decode_funcs_.push_back({"avx2", &ByteStreamSplitDecodeAvx2}); +#endif +#if defined(ARROW_HAVE_AVX512) + encode_funcs_.push_back({"avx512", &ByteStreamSplitEncodeAvx512}); + decode_funcs_.push_back({"avx512", &ByteStreamSplitDecodeAvx512}); #endif } @@ -92,7 +104,7 @@ class TestByteStreamSplitSpecialized : public ::testing::Test { ARROW_SCOPED_TRACE("decode_func = ", decode_func); decoded.assign(decoded.size(), T{}); decode_func.func(encoded.data(), num_values, /*stride=*/num_values, - decoded.data()); + reinterpret_cast(decoded.data())); ASSERT_EQ(decoded, input); } } @@ -118,7 +130,7 @@ class TestByteStreamSplitSpecialized : public ::testing::Test { while (offset < num_values) { auto chunk_size = std::min(num_values - offset, chunk_size_dist(gen)); decode_func.func(encoded.data() + offset, chunk_size, /*stride=*/num_values, - decoded.data() + offset); + reinterpret_cast(decoded.data() + offset)); offset += chunk_size; } ASSERT_EQ(offset, num_values); diff --git a/cpp/src/parquet/column_reader.cc b/cpp/src/parquet/column_reader.cc index f5d9734aa1e01..ac4627d69c0f6 100644 --- a/cpp/src/parquet/column_reader.cc +++ b/cpp/src/parquet/column_reader.cc @@ -760,7 +760,7 @@ class ColumnReaderImplBase { if (page->encoding() == Encoding::PLAIN_DICTIONARY || page->encoding() == Encoding::PLAIN) { - auto dictionary = MakeTypedDecoder(Encoding::PLAIN, descr_); + auto dictionary = MakeTypedDecoder(Encoding::PLAIN, descr_, pool_); dictionary->SetData(page->num_values(), page->data(), page->size()); // The dictionary is fully decoded during DictionaryDecoder::Init, so the @@ -883,47 +883,21 @@ class ColumnReaderImplBase { current_decoder_ = it->second.get(); } else { switch (encoding) { - case Encoding::PLAIN: { - auto decoder = MakeTypedDecoder(Encoding::PLAIN, descr_); - current_decoder_ = decoder.get(); - decoders_[static_cast(encoding)] = std::move(decoder); - break; - } - case Encoding::BYTE_STREAM_SPLIT: { - auto decoder = MakeTypedDecoder(Encoding::BYTE_STREAM_SPLIT, descr_); - current_decoder_ = decoder.get(); - decoders_[static_cast(encoding)] = std::move(decoder); - break; - } - case Encoding::RLE: { - auto decoder = MakeTypedDecoder(Encoding::RLE, descr_); + case Encoding::PLAIN: + case Encoding::BYTE_STREAM_SPLIT: + case Encoding::RLE: + case Encoding::DELTA_BINARY_PACKED: + case Encoding::DELTA_BYTE_ARRAY: + case Encoding::DELTA_LENGTH_BYTE_ARRAY: { + auto decoder = MakeTypedDecoder(encoding, descr_, pool_); current_decoder_ = decoder.get(); decoders_[static_cast(encoding)] = std::move(decoder); break; } + case Encoding::RLE_DICTIONARY: throw ParquetException("Dictionary page must be before data page."); - case Encoding::DELTA_BINARY_PACKED: { - auto decoder = MakeTypedDecoder(Encoding::DELTA_BINARY_PACKED, descr_); - current_decoder_ = decoder.get(); - decoders_[static_cast(encoding)] = std::move(decoder); - break; - } - case Encoding::DELTA_BYTE_ARRAY: { - auto decoder = MakeTypedDecoder(Encoding::DELTA_BYTE_ARRAY, descr_); - current_decoder_ = decoder.get(); - decoders_[static_cast(encoding)] = std::move(decoder); - break; - } - case Encoding::DELTA_LENGTH_BYTE_ARRAY: { - auto decoder = - MakeTypedDecoder(Encoding::DELTA_LENGTH_BYTE_ARRAY, descr_); - current_decoder_ = decoder.get(); - decoders_[static_cast(encoding)] = std::move(decoder); - break; - } - default: throw ParquetException("Unknown encoding type."); } diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc index 12b2837fbfd1e..23366b2daafd5 100644 --- a/cpp/src/parquet/column_writer.cc +++ b/cpp/src/parquet/column_writer.cc @@ -271,7 +271,12 @@ class SerializedPageWriter : public PageWriter { } int64_t WriteDictionaryPage(const DictionaryPage& page) override { - int64_t uncompressed_size = page.size(); + int64_t uncompressed_size = page.buffer()->size(); + if (uncompressed_size > std::numeric_limits::max()) { + throw ParquetException( + "Uncompressed dictionary page size overflows INT32_MAX. Size:", + uncompressed_size); + } std::shared_ptr compressed_data; if (has_compressor()) { auto buffer = std::static_pointer_cast( @@ -288,6 +293,11 @@ class SerializedPageWriter : public PageWriter { dict_page_header.__set_is_sorted(page.is_sorted()); const uint8_t* output_data_buffer = compressed_data->data(); + if (compressed_data->size() > std::numeric_limits::max()) { + throw ParquetException( + "Compressed dictionary page size overflows INT32_MAX. Size: ", + uncompressed_size); + } int32_t output_data_len = static_cast(compressed_data->size()); if (data_encryptor_.get()) { @@ -371,18 +381,29 @@ class SerializedPageWriter : public PageWriter { const int64_t uncompressed_size = page.uncompressed_size(); std::shared_ptr compressed_data = page.buffer(); const uint8_t* output_data_buffer = compressed_data->data(); - int32_t output_data_len = static_cast(compressed_data->size()); + int64_t output_data_len = compressed_data->size(); + + if (output_data_len > std::numeric_limits::max()) { + throw ParquetException("Compressed data page size overflows INT32_MAX. Size:", + output_data_len); + } if (data_encryptor_.get()) { PARQUET_THROW_NOT_OK(encryption_buffer_->Resize( data_encryptor_->CiphertextSizeDelta() + output_data_len, false)); UpdateEncryption(encryption::kDataPage); - output_data_len = data_encryptor_->Encrypt(compressed_data->data(), output_data_len, + output_data_len = data_encryptor_->Encrypt(compressed_data->data(), + static_cast(output_data_len), encryption_buffer_->mutable_data()); output_data_buffer = encryption_buffer_->data(); } format::PageHeader page_header; + + if (uncompressed_size > std::numeric_limits::max()) { + throw ParquetException("Uncompressed data page size overflows INT32_MAX. Size:", + uncompressed_size); + } page_header.__set_uncompressed_page_size(static_cast(uncompressed_size)); page_header.__set_compressed_page_size(static_cast(output_data_len)); @@ -421,7 +442,7 @@ class SerializedPageWriter : public PageWriter { if (offset_index_builder_ != nullptr) { const int64_t compressed_size = output_data_len + header_size; if (compressed_size > std::numeric_limits::max()) { - throw ParquetException("Compressed page size overflows to INT32_MAX."); + throw ParquetException("Compressed page size overflows INT32_MAX."); } if (!page.first_row_index().has_value()) { throw ParquetException("First row index is not set in data page."); diff --git a/cpp/src/parquet/column_writer_test.cc b/cpp/src/parquet/column_writer_test.cc index 59fc848d7fd57..97421629d2ca6 100644 --- a/cpp/src/parquet/column_writer_test.cc +++ b/cpp/src/parquet/column_writer_test.cc @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. +#include #include #include +#include #include #include "arrow/io/buffered.h" @@ -25,6 +27,7 @@ #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_builders.h" +#include "parquet/column_page.h" #include "parquet/column_reader.h" #include "parquet/column_writer.h" #include "parquet/file_reader.h" @@ -479,6 +482,9 @@ using TestValuesWriterInt64Type = TestPrimitiveWriter; using TestByteArrayValuesWriter = TestPrimitiveWriter; using TestFixedLengthByteArrayValuesWriter = TestPrimitiveWriter; +using ::testing::HasSubstr; +using ::testing::ThrowsMessage; + TYPED_TEST(TestPrimitiveWriter, RequiredPlain) { this->TestRequiredWithEncoding(Encoding::PLAIN); } @@ -889,6 +895,45 @@ TEST_F(TestByteArrayValuesWriter, CheckDefaultStats) { ASSERT_TRUE(this->metadata_is_stats_set()); } +TEST(TestPageWriter, ThrowsOnPagesTooLarge) { + NodePtr item = schema::Int32("item"); // optional item + NodePtr list(GroupNode::Make("b", Repetition::REPEATED, {item}, ConvertedType::LIST)); + NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); // optional list + std::vector fields = {bag}; + NodePtr root = GroupNode::Make("schema", Repetition::REPEATED, fields); + + SchemaDescriptor schema; + schema.Init(root); + + auto sink = CreateOutputStream(); + auto props = WriterProperties::Builder().build(); + + auto metadata = ColumnChunkMetaDataBuilder::Make(props, schema.Column(0)); + std::unique_ptr pager = + PageWriter::Open(sink, Compression::UNCOMPRESSED, metadata.get()); + + uint8_t data; + std::shared_ptr buffer = + std::make_shared(&data, std::numeric_limits::max() + int64_t{1}); + DataPageV1 over_compressed_limit(buffer, /*num_values=*/100, Encoding::BIT_PACKED, + Encoding::BIT_PACKED, Encoding::BIT_PACKED, + /*uncompressed_size=*/100); + EXPECT_THAT([&]() { pager->WriteDataPage(over_compressed_limit); }, + ThrowsMessage(HasSubstr("overflows INT32_MAX"))); + DictionaryPage dictionary_over_compressed_limit(buffer, /*num_values=*/100, + Encoding::PLAIN); + EXPECT_THAT([&]() { pager->WriteDictionaryPage(dictionary_over_compressed_limit); }, + ThrowsMessage(HasSubstr("overflows INT32_MAX"))); + + buffer = std::make_shared(&data, 1); + DataPageV1 over_uncompressed_limit( + buffer, /*num_values=*/100, Encoding::BIT_PACKED, Encoding::BIT_PACKED, + Encoding::BIT_PACKED, + /*uncompressed_size=*/std::numeric_limits::max() + int64_t{1}); + EXPECT_THAT([&]() { pager->WriteDataPage(over_compressed_limit); }, + ThrowsMessage(HasSubstr("overflows INT32_MAX"))); +} + TEST(TestColumnWriter, RepeatedListsUpdateSpacedBug) { // In ARROW-3930 we discovered a bug when writing from Arrow when we had data // that looks like this: diff --git a/cpp/src/parquet/encoding.cc b/cpp/src/parquet/encoding.cc index b07ad6c9fb062..b801b5ab11bb9 100644 --- a/cpp/src/parquet/encoding.cc +++ b/cpp/src/parquet/encoding.cc @@ -850,8 +850,8 @@ std::shared_ptr ByteStreamSplitEncoder::FlushValues() { AllocateBuffer(this->memory_pool(), EstimatedDataEncodedSize()); uint8_t* output_buffer_raw = output_buffer->mutable_data(); const uint8_t* raw_values = sink_.data(); - ::arrow::util::internal::ByteStreamSplitEncode(raw_values, num_values_in_buffer_, - output_buffer_raw); + ::arrow::util::internal::ByteStreamSplitEncode( + raw_values, num_values_in_buffer_, output_buffer_raw); sink_.Reset(); num_values_in_buffer_ = 0; return std::move(output_buffer); @@ -3577,7 +3577,7 @@ class ByteStreamSplitDecoder : public DecoderImpl, virtual public TypedDecoder decode_buffer_; - static constexpr size_t kNumStreams = sizeof(T); + static constexpr int kNumStreams = sizeof(T); }; template @@ -3607,8 +3607,8 @@ int ByteStreamSplitDecoder::Decode(T* buffer, int max_values) { const int num_decoded_previously = num_values_in_buffer_ - num_values_; const uint8_t* data = data_ + num_decoded_previously; - ::arrow::util::internal::ByteStreamSplitDecode(data, values_to_decode, - num_values_in_buffer_, buffer); + ::arrow::util::internal::ByteStreamSplitDecode( + data, values_to_decode, num_values_in_buffer_, reinterpret_cast(buffer)); num_values_ -= values_to_decode; len_ -= sizeof(T) * values_to_decode; return values_to_decode; @@ -3618,7 +3618,7 @@ template int ByteStreamSplitDecoder::DecodeArrow( int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, typename EncodingTraits::Accumulator* builder) { - constexpr int value_size = static_cast(kNumStreams); + constexpr int value_size = kNumStreams; int values_decoded = num_values - null_count; if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { ParquetException::EofException(); @@ -3634,8 +3634,9 @@ int ByteStreamSplitDecoder::DecodeArrow( // Use fast decoding into intermediate buffer. This will also decode // some null values, but it's fast enough that we don't care. T* decode_out = EnsureDecodeBuffer(values_decoded); - ::arrow::util::internal::ByteStreamSplitDecode(data, values_decoded, - num_values_in_buffer_, decode_out); + ::arrow::util::internal::ByteStreamSplitDecode( + data, values_decoded, num_values_in_buffer_, + reinterpret_cast(decode_out)); // XXX If null_count is 0, we could even append in bulk or decode directly into // builder @@ -3648,12 +3649,13 @@ int ByteStreamSplitDecoder::DecodeArrow( [&]() { builder->UnsafeAppendNull(); }); #else + // XXX should operate over runs of 0s / 1s VisitNullBitmapInline( valid_bits, valid_bits_offset, num_values, null_count, [&]() { uint8_t gathered_byte_data[kNumStreams]; - for (size_t b = 0; b < kNumStreams; ++b) { - const size_t byte_index = b * num_values_in_buffer_ + offset; + for (int b = 0; b < kNumStreams; ++b) { + const int64_t byte_index = b * num_values_in_buffer_ + offset; gathered_byte_data[b] = data[byte_index]; } builder->UnsafeAppend(SafeLoadAs(&gathered_byte_data[0])); diff --git a/cpp/src/parquet/encoding_benchmark.cc b/cpp/src/parquet/encoding_benchmark.cc index b5b6cc8d93e03..76c411244b22d 100644 --- a/cpp/src/parquet/encoding_benchmark.cc +++ b/cpp/src/parquet/encoding_benchmark.cc @@ -369,7 +369,8 @@ static void BM_ByteStreamSplitDecode(benchmark::State& state, DecodeFunc&& decod for (auto _ : state) { decode_func(values_raw, static_cast(values.size()), - static_cast(values.size()), output.data()); + static_cast(values.size()), + reinterpret_cast(output.data())); benchmark::ClobberMemory(); } state.SetBytesProcessed(state.iterations() * values.size() * sizeof(T)); @@ -390,22 +391,22 @@ static void BM_ByteStreamSplitEncode(benchmark::State& state, EncodeFunc&& encod static void BM_ByteStreamSplitDecode_Float_Scalar(benchmark::State& state) { BM_ByteStreamSplitDecode( - state, ::arrow::util::internal::ByteStreamSplitDecodeScalar); + state, ::arrow::util::internal::ByteStreamSplitDecodeScalar); } static void BM_ByteStreamSplitDecode_Double_Scalar(benchmark::State& state) { BM_ByteStreamSplitDecode( - state, ::arrow::util::internal::ByteStreamSplitDecodeScalar); + state, ::arrow::util::internal::ByteStreamSplitDecodeScalar); } static void BM_ByteStreamSplitEncode_Float_Scalar(benchmark::State& state) { BM_ByteStreamSplitEncode( - state, ::arrow::util::internal::ByteStreamSplitEncodeScalar); + state, ::arrow::util::internal::ByteStreamSplitEncodeScalar); } static void BM_ByteStreamSplitEncode_Double_Scalar(benchmark::State& state) { BM_ByteStreamSplitEncode( - state, ::arrow::util::internal::ByteStreamSplitEncodeScalar); + state, ::arrow::util::internal::ByteStreamSplitEncodeScalar); } BENCHMARK(BM_ByteStreamSplitDecode_Float_Scalar)->Range(MIN_RANGE, MAX_RANGE); @@ -416,22 +417,22 @@ BENCHMARK(BM_ByteStreamSplitEncode_Double_Scalar)->Range(MIN_RANGE, MAX_RANGE); #if defined(ARROW_HAVE_SSE4_2) static void BM_ByteStreamSplitDecode_Float_Sse2(benchmark::State& state) { BM_ByteStreamSplitDecode( - state, ::arrow::util::internal::ByteStreamSplitDecodeSse2); + state, ::arrow::util::internal::ByteStreamSplitDecodeSse2); } static void BM_ByteStreamSplitDecode_Double_Sse2(benchmark::State& state) { BM_ByteStreamSplitDecode( - state, ::arrow::util::internal::ByteStreamSplitDecodeSse2); + state, ::arrow::util::internal::ByteStreamSplitDecodeSse2); } static void BM_ByteStreamSplitEncode_Float_Sse2(benchmark::State& state) { BM_ByteStreamSplitEncode( - state, ::arrow::util::internal::ByteStreamSplitEncodeSse2); + state, ::arrow::util::internal::ByteStreamSplitEncodeSse2); } static void BM_ByteStreamSplitEncode_Double_Sse2(benchmark::State& state) { BM_ByteStreamSplitEncode( - state, ::arrow::util::internal::ByteStreamSplitEncodeSse2); + state, ::arrow::util::internal::ByteStreamSplitEncodeSse2); } BENCHMARK(BM_ByteStreamSplitDecode_Float_Sse2)->Range(MIN_RANGE, MAX_RANGE); @@ -443,22 +444,22 @@ BENCHMARK(BM_ByteStreamSplitEncode_Double_Sse2)->Range(MIN_RANGE, MAX_RANGE); #if defined(ARROW_HAVE_AVX2) static void BM_ByteStreamSplitDecode_Float_Avx2(benchmark::State& state) { BM_ByteStreamSplitDecode( - state, ::arrow::util::internal::ByteStreamSplitDecodeAvx2); + state, ::arrow::util::internal::ByteStreamSplitDecodeAvx2); } static void BM_ByteStreamSplitDecode_Double_Avx2(benchmark::State& state) { BM_ByteStreamSplitDecode( - state, ::arrow::util::internal::ByteStreamSplitDecodeAvx2); + state, ::arrow::util::internal::ByteStreamSplitDecodeAvx2); } static void BM_ByteStreamSplitEncode_Float_Avx2(benchmark::State& state) { BM_ByteStreamSplitEncode( - state, ::arrow::util::internal::ByteStreamSplitEncodeAvx2); + state, ::arrow::util::internal::ByteStreamSplitEncodeAvx2); } static void BM_ByteStreamSplitEncode_Double_Avx2(benchmark::State& state) { BM_ByteStreamSplitEncode( - state, ::arrow::util::internal::ByteStreamSplitEncodeAvx2); + state, ::arrow::util::internal::ByteStreamSplitEncodeAvx2); } BENCHMARK(BM_ByteStreamSplitDecode_Float_Avx2)->Range(MIN_RANGE, MAX_RANGE); @@ -470,22 +471,22 @@ BENCHMARK(BM_ByteStreamSplitEncode_Double_Avx2)->Range(MIN_RANGE, MAX_RANGE); #if defined(ARROW_HAVE_AVX512) static void BM_ByteStreamSplitDecode_Float_Avx512(benchmark::State& state) { BM_ByteStreamSplitDecode( - state, ::arrow::util::internal::ByteStreamSplitDecodeAvx512); + state, ::arrow::util::internal::ByteStreamSplitDecodeAvx512); } static void BM_ByteStreamSplitDecode_Double_Avx512(benchmark::State& state) { BM_ByteStreamSplitDecode( - state, ::arrow::util::internal::ByteStreamSplitDecodeAvx512); + state, ::arrow::util::internal::ByteStreamSplitDecodeAvx512); } static void BM_ByteStreamSplitEncode_Float_Avx512(benchmark::State& state) { BM_ByteStreamSplitEncode( - state, ::arrow::util::internal::ByteStreamSplitEncodeAvx512); + state, ::arrow::util::internal::ByteStreamSplitEncodeAvx512); } static void BM_ByteStreamSplitEncode_Double_Avx512(benchmark::State& state) { BM_ByteStreamSplitEncode( - state, ::arrow::util::internal::ByteStreamSplitEncodeAvx512); + state, ::arrow::util::internal::ByteStreamSplitEncodeAvx512); } BENCHMARK(BM_ByteStreamSplitDecode_Float_Avx512)->Range(MIN_RANGE, MAX_RANGE); diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index d651ea5db0f18..3f101b5ae3ac6 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -761,7 +761,7 @@ class FileMetaData::FileMetaDataImpl { return metadata_->row_groups[i]; } - void AppendRowGroups(const std::unique_ptr& other) { + void AppendRowGroups(FileMetaDataImpl* other) { std::ostringstream diff_output; if (!schema()->Equals(*other->schema(), &diff_output)) { auto msg = "AppendRowGroups requires equal schemas.\n" + diff_output.str(); @@ -800,6 +800,7 @@ class FileMetaData::FileMetaDataImpl { metadata->schema = metadata_->schema; metadata->row_groups.resize(row_groups.size()); + int i = 0; for (int selected_index : row_groups) { metadata->num_rows += row_group(selected_index).num_rows; @@ -822,7 +823,7 @@ class FileMetaData::FileMetaDataImpl { } void set_file_decryptor(std::shared_ptr file_decryptor) { - file_decryptor_ = file_decryptor; + file_decryptor_ = std::move(file_decryptor); } private: @@ -886,13 +887,14 @@ std::shared_ptr FileMetaData::Make( const void* metadata, uint32_t* metadata_len, std::shared_ptr file_decryptor) { return std::shared_ptr(new FileMetaData( - metadata, metadata_len, default_reader_properties(), file_decryptor)); + metadata, metadata_len, default_reader_properties(), std::move(file_decryptor))); } FileMetaData::FileMetaData(const void* metadata, uint32_t* metadata_len, const ReaderProperties& properties, std::shared_ptr file_decryptor) - : impl_(new FileMetaDataImpl(metadata, metadata_len, properties, file_decryptor)) {} + : impl_(new FileMetaDataImpl(metadata, metadata_len, properties, + std::move(file_decryptor))) {} FileMetaData::FileMetaData() : impl_(new FileMetaDataImpl()) {} @@ -942,7 +944,7 @@ const std::string& FileMetaData::footer_signing_key_metadata() const { void FileMetaData::set_file_decryptor( std::shared_ptr file_decryptor) { - impl_->set_file_decryptor(file_decryptor); + impl_->set_file_decryptor(std::move(file_decryptor)); } ParquetVersion::type FileMetaData::version() const { @@ -975,7 +977,7 @@ const std::shared_ptr& FileMetaData::key_value_metadata( void FileMetaData::set_file_path(const std::string& path) { impl_->set_file_path(path); } void FileMetaData::AppendRowGroups(const FileMetaData& other) { - impl_->AppendRowGroups(other.impl_); + impl_->AppendRowGroups(other.impl_.get()); } std::shared_ptr FileMetaData::Subset( @@ -1839,7 +1841,7 @@ class FileMetaDataBuilder::FileMetaDataBuilderImpl { std::unique_ptr Finish( const std::shared_ptr& key_value_metadata) { int64_t total_rows = 0; - for (auto row_group : row_groups_) { + for (const auto& row_group : row_groups_) { total_rows += row_group.num_rows; } metadata_->__set_num_rows(total_rows); @@ -1858,7 +1860,7 @@ class FileMetaDataBuilder::FileMetaDataBuilderImpl { format::KeyValue kv_pair; kv_pair.__set_key(key_value_metadata_->key(i)); kv_pair.__set_value(key_value_metadata_->value(i)); - metadata_->key_value_metadata.push_back(kv_pair); + metadata_->key_value_metadata.push_back(std::move(kv_pair)); } metadata_->__isset.key_value_metadata = true; } diff --git a/cpp/src/parquet/metadata.h b/cpp/src/parquet/metadata.h index e47c45ff0492a..640b898024346 100644 --- a/cpp/src/parquet/metadata.h +++ b/cpp/src/parquet/metadata.h @@ -306,9 +306,15 @@ class PARQUET_EXPORT FileMetaData { int num_schema_elements() const; /// \brief The total number of rows. + /// + /// If the FileMetaData was obtained by calling `SubSet()`, this is the total + /// number of rows in the selected row groups. int64_t num_rows() const; /// \brief The number of row groups in the file. + /// + /// If the FileMetaData was obtained by calling `SubSet()`, this is the number + /// of selected row groups. int num_row_groups() const; /// \brief Return the RowGroupMetaData of the corresponding row group ordinal. @@ -338,7 +344,7 @@ class PARQUET_EXPORT FileMetaData { /// \brief Size of the original thrift encoded metadata footer. uint32_t size() const; - /// \brief Indicate if all of the FileMetadata's RowGroups can be decompressed. + /// \brief Indicate if all of the FileMetaData's RowGroups can be decompressed. /// /// This will return false if any of the RowGroup's page is compressed with a /// compression format which is not compiled in the current parquet library. diff --git a/cpp/src/parquet/reader_test.cc b/cpp/src/parquet/reader_test.cc index 2c2b62f5d12f6..551f62798e3b5 100644 --- a/cpp/src/parquet/reader_test.cc +++ b/cpp/src/parquet/reader_test.cc @@ -120,11 +120,27 @@ std::string concatenated_gzip_members() { return data_file("concatenated_gzip_members.parquet"); } +std::string byte_stream_split() { return data_file("byte_stream_split.zstd.parquet"); } + +template +std::vector ReadColumnValues(ParquetFileReader* file_reader, int row_group, + int column, int64_t expected_values_read) { + auto column_reader = checked_pointer_cast>( + file_reader->RowGroup(row_group)->Column(column)); + std::vector values(expected_values_read); + int64_t values_read; + auto levels_read = column_reader->ReadBatch(expected_values_read, nullptr, nullptr, + values.data(), &values_read); + EXPECT_EQ(expected_values_read, levels_read); + EXPECT_EQ(expected_values_read, values_read); + return values; +} + // TODO: Assert on definition and repetition levels -template +template void AssertColumnValues(std::shared_ptr> col, int64_t batch_size, int64_t expected_levels_read, - std::vector& expected_values, + const std::vector& expected_values, int64_t expected_values_read) { std::vector values(batch_size); int64_t values_read; @@ -1412,7 +1428,6 @@ TEST_P(TestCodec, LargeFileValues) { // column 0 ("a") auto col = checked_pointer_cast(group->Column(0)); - std::vector values(kNumRows); int64_t values_read; auto levels_read = @@ -1474,6 +1489,38 @@ TEST(TestFileReader, TestOverflowInt16PageOrdinal) { } } +#ifdef ARROW_WITH_ZSTD +TEST(TestByteStreamSplit, FloatIntegrationFile) { + auto file_path = byte_stream_split(); + auto file = ParquetFileReader::OpenFile(file_path); + + const int64_t kNumRows = 300; + + ASSERT_EQ(kNumRows, file->metadata()->num_rows()); + ASSERT_EQ(2, file->metadata()->num_columns()); + ASSERT_EQ(1, file->metadata()->num_row_groups()); + + // column 0 ("f32") + { + auto values = + ReadColumnValues(file.get(), /*row_group=*/0, /*column=*/0, kNumRows); + ASSERT_EQ(values[0], 1.7640524f); + ASSERT_EQ(values[1], 0.4001572f); + ASSERT_EQ(values[kNumRows - 2], -0.39944902f); + ASSERT_EQ(values[kNumRows - 1], 0.37005588f); + } + // column 1 ("f64") + { + auto values = + ReadColumnValues(file.get(), /*row_group=*/0, /*column=*/1, kNumRows); + ASSERT_EQ(values[0], -1.3065268517353166); + ASSERT_EQ(values[1], 1.658130679618188); + ASSERT_EQ(values[kNumRows - 2], -0.9301565025243212); + ASSERT_EQ(values[kNumRows - 1], -0.17858909208732915); + } +} +#endif // ARROW_WITH_ZSTD + struct PageIndexReaderParam { std::vector row_group_indices; std::vector column_indices; diff --git a/cpp/submodules/parquet-testing b/cpp/submodules/parquet-testing index d69d979223e88..4cb3cff24c965 160000 --- a/cpp/submodules/parquet-testing +++ b/cpp/submodules/parquet-testing @@ -1 +1 @@ -Subproject commit d69d979223e883faef9dc6fe3cf573087243c28a +Subproject commit 4cb3cff24c965fb329cdae763eabce47395a68a0 diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index e9df0c8d7566b..2664775c0fbf4 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -115,8 +115,8 @@ ARROW_UTF8PROC_BUILD_VERSION=v2.7.0 ARROW_UTF8PROC_BUILD_SHA256_CHECKSUM=4bb121e297293c0fd55f08f83afab6d35d48f0af4ecc07523ad8ec99aa2b12a1 ARROW_XSIMD_BUILD_VERSION=9.0.1 ARROW_XSIMD_BUILD_SHA256_CHECKSUM=b1bb5f92167fd3a4f25749db0be7e61ed37e0a5d943490f3accdcd2cd2918cc0 -ARROW_ZLIB_BUILD_VERSION=1.2.13 -ARROW_ZLIB_BUILD_SHA256_CHECKSUM=b3a24de97a8fdbc835b9833169501030b8977031bcb54b3b3ac13740f846ab30 +ARROW_ZLIB_BUILD_VERSION=1.3 +ARROW_ZLIB_BUILD_SHA256_CHECKSUM=ff0ba4c292013dbc27530b3a81e1f9a813cd39de01ca5e0f8bf355702efa593e ARROW_ZSTD_BUILD_VERSION=1.5.5 ARROW_ZSTD_BUILD_SHA256_CHECKSUM=9c4396cc829cfae319a6e2615202e82aad41372073482fce286fac78646d3ee4 diff --git a/cpp/vcpkg.json b/cpp/vcpkg.json index c0bf5dce50e32..a0f0aa1008dcd 100644 --- a/cpp/vcpkg.json +++ b/cpp/vcpkg.json @@ -1,6 +1,6 @@ { "name": "arrow", - "version-string": "15.0.0-SNAPSHOT", + "version-string": "16.0.0-SNAPSHOT", "dependencies": [ "abseil", { diff --git a/csharp/Directory.Build.props b/csharp/Directory.Build.props index ae6edda0e2f0e..c759c49b395d8 100644 --- a/csharp/Directory.Build.props +++ b/csharp/Directory.Build.props @@ -27,9 +27,9 @@ Apache Arrow library - Copyright 2016-2019 The Apache Software Foundation + Copyright 2016-2024 The Apache Software Foundation The Apache Software Foundation - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT diff --git a/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj b/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj index aae26273ac282..68c3e47e01902 100644 --- a/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj +++ b/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj @@ -5,7 +5,7 @@ - + diff --git a/csharp/test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj b/csharp/test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj index 7a93d8f92635b..8ed7a93bdcf27 100644 --- a/csharp/test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj +++ b/csharp/test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj @@ -8,7 +8,7 @@ - + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index 7799577535ded..b5e7170a8c31d 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -7,7 +7,7 @@ - + diff --git a/csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj index 972aa178eabe8..a7c52846fd9a4 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj @@ -7,7 +7,7 @@ - + diff --git a/csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj b/csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj index afb636123b37b..d8a92ff756751 100644 --- a/csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj +++ b/csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj @@ -15,7 +15,7 @@ - + all runtime; build; native; contentfiles; analyzers diff --git a/dev/archery/archery/benchmark/runner.py b/dev/archery/archery/benchmark/runner.py index c12c74135e96e..a91989fb95257 100644 --- a/dev/archery/archery/benchmark/runner.py +++ b/dev/archery/archery/benchmark/runner.py @@ -210,6 +210,7 @@ def from_rev_or_path(src, root, rev_or_path, cmake_conf, **kwargs): """ build = None if StaticBenchmarkRunner.is_json_result(rev_or_path): + kwargs.pop('benchmark_extras', None) return StaticBenchmarkRunner.from_json(rev_or_path, **kwargs) elif CMakeBuild.is_build_dir(rev_or_path): build = CMakeBuild.from_path(rev_or_path) diff --git a/dev/archery/archery/bot.py b/dev/archery/archery/bot.py index 68b24dc08d71b..4e5104362254c 100644 --- a/dev/archery/archery/bot.py +++ b/dev/archery/archery/bot.py @@ -280,7 +280,7 @@ def handle_issue_comment(self, command, payload): # https://developer.github.com/v4/enum/commentauthorassociation/ # Checking privileges here enables the bot to respond # without relying on the handler. - allowed_roles = {'OWNER', 'MEMBER', 'CONTRIBUTOR', 'COLLABORATOR'} + allowed_roles = {'OWNER', 'MEMBER', 'COLLABORATOR'} if payload['comment']['author_association'] not in allowed_roles: raise EventError( "Only contributors can submit requests to this bot. " diff --git a/dev/archery/archery/cli.py b/dev/archery/archery/cli.py index 052fe23bfc969..0ad3eee14d1f3 100644 --- a/dev/archery/archery/cli.py +++ b/dev/archery/archery/cli.py @@ -407,7 +407,7 @@ def benchmark_filter_options(cmd): @click.pass_context def benchmark_list(ctx, rev_or_path, src, preserve, output, cmake_extras, java_home, java_options, build_extras, benchmark_extras, - language, **kwargs): + cpp_benchmark_extras, language, **kwargs): """ List benchmark suite. """ with tmpdir(preserve=preserve) as root: @@ -418,7 +418,8 @@ def benchmark_list(ctx, rev_or_path, src, preserve, output, cmake_extras, cmake_extras=cmake_extras, **kwargs) runner_base = CppBenchmarkRunner.from_rev_or_path( - src, root, rev_or_path, conf) + src, root, rev_or_path, conf, + benchmark_extras=cpp_benchmark_extras) elif language == "java": for key in {'cpp_package_prefix', 'cxx_flags', 'cxx', 'cc'}: @@ -546,7 +547,8 @@ def benchmark_run(ctx, rev_or_path, src, preserve, output, cmake_extras, def benchmark_diff(ctx, src, preserve, output, language, cmake_extras, suite_filter, benchmark_filter, repetitions, no_counters, java_home, java_options, build_extras, benchmark_extras, - threshold, contender, baseline, **kwargs): + cpp_benchmark_extras, threshold, contender, baseline, + **kwargs): """Compare (diff) benchmark runs. This command acts like git-diff but for benchmark results. @@ -633,12 +635,14 @@ def benchmark_diff(ctx, src, preserve, output, language, cmake_extras, src, root, contender, conf, repetitions=repetitions, suite_filter=suite_filter, - benchmark_filter=benchmark_filter) + benchmark_filter=benchmark_filter, + benchmark_extras=cpp_benchmark_extras) runner_base = CppBenchmarkRunner.from_rev_or_path( src, root, baseline, conf, repetitions=repetitions, suite_filter=suite_filter, - benchmark_filter=benchmark_filter) + benchmark_filter=benchmark_filter, + benchmark_extras=cpp_benchmark_extras) elif language == "java": for key in {'cpp_package_prefix', 'cxx_flags', 'cxx', 'cc'}: diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py index 032ac13e74ec2..8e7a0bb99f9de 100644 --- a/dev/archery/archery/integration/tester_java.py +++ b/dev/archery/archery/integration/tester_java.py @@ -40,7 +40,6 @@ def load_version_from_pom(): _JAVA_OPTS = [ "-Dio.netty.tryReflectionSetAccessible=true", "-Darrow.struct.conflict.policy=CONFLICT_APPEND", - "--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED", # GH-39113: avoid failures accessing files in `/tmp/hsperfdata_...` "-XX:-UsePerfData", ] @@ -83,13 +82,24 @@ def setup_jpype(): import jpype jar_path = f"{_ARROW_TOOLS_JAR}:{_ARROW_C_DATA_JAR}" # XXX Didn't manage to tone down the logging level here (DEBUG -> INFO) + java_opts = _JAVA_OPTS[:] + proc = subprocess.run( + ['java', '--add-opens'], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True) + if 'Unrecognized option: --add-opens' not in proc.stderr: + # Java 9+ + java_opts.append( + '--add-opens=java.base/java.nio=' + 'org.apache.arrow.memory.core,ALL-UNNAMED') jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path, # This flag is too heavy for IPC and Flight tests "-Darrow.memory.debug.allocator=true", # Reduce internal use of signals by the JVM "-Xrs", - *_JAVA_OPTS) + *java_opts) class _CDataBase: @@ -249,6 +259,8 @@ def __init__(self, *args, **kwargs): self._java_opts.append( '--add-opens=java.base/java.nio=' 'org.apache.arrow.memory.core,ALL-UNNAMED') + self._java_opts.append( + '--add-reads=org.apache.arrow.flight.core=ALL-UNNAMED') def _run(self, arrow_path=None, json_path=None, command='VALIDATE'): cmd = ( diff --git a/dev/archery/setup.py b/dev/archery/setup.py index e2c89ae204bd6..2ecc72e04e8aa 100755 --- a/dev/archery/setup.py +++ b/dev/archery/setup.py @@ -35,6 +35,7 @@ 'setuptools_scm'], 'docker': ['ruamel.yaml', 'python-dotenv'], 'integration': ['cffi'], + 'integration-java': ['jpype1'], 'lint': ['numpydoc==1.1.0', 'autopep8', 'flake8==6.1.0', 'cython-lint', 'cmake_format==0.6.13'], 'numpydoc': ['numpydoc==1.1.0'], diff --git a/dev/release/post-11-bump-versions-test.rb b/dev/release/post-11-bump-versions-test.rb index 4b6933d6102a9..78d9320bfb312 100644 --- a/dev/release/post-11-bump-versions-test.rb +++ b/dev/release/post-11-bump-versions-test.rb @@ -197,6 +197,15 @@ def test_version_post_tag ] if release_type == :major expected_changes += [ + { + path: "docs/source/index.rst", + hunks: [ + [ + "- Go ", + "+ Go ", + ], + ], + }, { path: "r/pkgdown/assets/versions.json", hunks: [ @@ -212,6 +221,15 @@ def test_version_post_tag ], ], }, + { + path: "r/_pkgdown.yml", + hunks: [ + [ + "- [Go](https://pkg.go.dev/github.com/apache/arrow/go/v#{@snapshot_major_version})
", + "+ [Go](https://pkg.go.dev/github.com/apache/arrow/go/v#{@next_major_version})
", + ], + ], + }, ] else expected_changes += [ diff --git a/dev/release/utils-prepare.sh b/dev/release/utils-prepare.sh index 8e4c8a84ae8fd..51367087228a4 100644 --- a/dev/release/utils-prepare.sh +++ b/dev/release/utils-prepare.sh @@ -127,6 +127,7 @@ update_versions() { DESCRIPTION rm -f DESCRIPTION.bak git add DESCRIPTION + # Replace dev version with release version sed -i.bak -E -e \ "/^ com.h2database @@ -85,4 +90,27 @@ + + + jdk11+ + + [11,] + + !m2e.version + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + --add-reads=org.apache.arrow.adapter.jdbc=com.fasterxml.jackson.dataformat.yaml --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED + + + + + + + diff --git a/java/adapter/jdbc/src/main/java/module-info.java b/java/adapter/jdbc/src/main/java/module-info.java new file mode 100644 index 0000000000000..5b59ce768472a --- /dev/null +++ b/java/adapter/jdbc/src/main/java/module-info.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +module org.apache.arrow.adapter.jdbc { + exports org.apache.arrow.adapter.jdbc.consumer; + exports org.apache.arrow.adapter.jdbc; + exports org.apache.arrow.adapter.jdbc.binder; + + requires com.fasterxml.jackson.databind; + requires java.sql; + requires jdk.unsupported; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; +} diff --git a/java/adapter/orc/CMakeLists.txt b/java/adapter/orc/CMakeLists.txt index a9b3a48027937..d29856ff8cd5e 100644 --- a/java/adapter/orc/CMakeLists.txt +++ b/java/adapter/orc/CMakeLists.txt @@ -37,6 +37,11 @@ set_property(TARGET arrow_java_jni_orc PROPERTY OUTPUT_NAME "arrow_orc_jni") target_link_libraries(arrow_java_jni_orc arrow_java_jni_orc_headers jni Arrow::arrow_static) +set(ARROW_JAVA_JNI_ORC_LIBDIR + "${CMAKE_INSTALL_PREFIX}/lib/arrow_orc_jni/${ARROW_JAVA_JNI_ARCH_DIR}") +set(ARROW_JAVA_JNI_ORC_BINDIR + "${CMAKE_INSTALL_PREFIX}/bin/arrow_orc_jni/${ARROW_JAVA_JNI_ARCH_DIR}") + install(TARGETS arrow_java_jni_orc - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + LIBRARY DESTINATION ${ARROW_JAVA_JNI_ORC_LIBDIR} + RUNTIME DESTINATION ${ARROW_JAVA_JNI_ORC_BINDIR}) diff --git a/java/adapter/orc/pom.xml b/java/adapter/orc/pom.xml index a42a458e2072a..79e51470a426e 100644 --- a/java/adapter/orc/pom.xml +++ b/java/adapter/orc/pom.xml @@ -31,6 +31,10 @@ compile ${arrow.vector.classifier} + + org.immutables + value + org.apache.orc orc-core @@ -71,7 +75,7 @@ org.apache.hadoop hadoop-common - 3.3.3 + 3.3.6 test @@ -111,7 +115,7 @@ org.apache.arrow arrow-java-root - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT ../../pom.xml diff --git a/java/adapter/orc/src/main/java/module-info.java b/java/adapter/orc/src/main/java/module-info.java new file mode 100644 index 0000000000000..d18a978e93fa8 --- /dev/null +++ b/java/adapter/orc/src/main/java/module-info.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +open module org.apache.arrow.adapter.orc { + exports org.apache.arrow.adapter.orc; + + requires hadoop.client.api; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; +} diff --git a/java/adapter/orc/src/main/java/org/apache/arrow/adapter/orc/OrcJniUtils.java b/java/adapter/orc/src/main/java/org/apache/arrow/adapter/orc/OrcJniUtils.java index 2701c228709c2..9b599234bdf51 100644 --- a/java/adapter/orc/src/main/java/org/apache/arrow/adapter/orc/OrcJniUtils.java +++ b/java/adapter/orc/src/main/java/org/apache/arrow/adapter/orc/OrcJniUtils.java @@ -39,7 +39,7 @@ static void loadOrcAdapterLibraryFromJar() synchronized (OrcJniUtils.class) { if (!isLoaded) { final String libraryToLoad = - getNormalizedArch() + "/" + System.mapLibraryName(LIBRARY_NAME); + LIBRARY_NAME + "/" + getNormalizedArch() + "/" + System.mapLibraryName(LIBRARY_NAME); final File libraryFile = moveFileFromJarToTemp(System.getProperty("java.io.tmpdir"), libraryToLoad, LIBRARY_NAME); System.load(libraryFile.getAbsolutePath()); diff --git a/java/algorithm/pom.xml b/java/algorithm/pom.xml index 3e32d955ec417..25669010d2d42 100644 --- a/java/algorithm/pom.xml +++ b/java/algorithm/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT arrow-algorithm Arrow Algorithms @@ -42,6 +42,10 @@ arrow-memory-netty test + + org.immutables + value + diff --git a/java/algorithm/src/main/java/module-info.java b/java/algorithm/src/main/java/module-info.java new file mode 100644 index 0000000000000..b347f55aa4d00 --- /dev/null +++ b/java/algorithm/src/main/java/module-info.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +module org.apache.arrow.algorithm { + exports org.apache.arrow.algorithm.search; + exports org.apache.arrow.algorithm.deduplicate; + exports org.apache.arrow.algorithm.dictionary; + exports org.apache.arrow.algorithm.rank; + exports org.apache.arrow.algorithm.misc; + exports org.apache.arrow.algorithm.sort; + + requires jdk.unsupported; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; +} diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java index e62ebdecb1bac..6226921b22ed6 100644 --- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java @@ -20,6 +20,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.arrow.algorithm.sort.VectorValueComparator; import org.apache.arrow.vector.ValueVector; @@ -95,7 +96,7 @@ public int search(V keyVector, int keyIndex) throws ExecutionException, Interrup final int valueCount = vector.getValueCount(); for (int i = 0; i < numThreads; i++) { final int tid = i; - threadPool.submit(() -> { + Future unused = threadPool.submit(() -> { // convert to long to avoid overflow int start = (int) (((long) valueCount) * tid / numThreads); int end = (int) ((long) valueCount) * (tid + 1) / numThreads; @@ -153,7 +154,7 @@ public int search( final int valueCount = vector.getValueCount(); for (int i = 0; i < numThreads; i++) { final int tid = i; - threadPool.submit(() -> { + Future unused = threadPool.submit(() -> { // convert to long to avoid overflow int start = (int) (((long) valueCount) * tid / numThreads); int end = (int) ((long) valueCount) * (tid + 1) / numThreads; diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthOutOfPlaceVectorSorter.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthOutOfPlaceVectorSorter.java index c3b68facfda97..05a4585792dc2 100644 --- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthOutOfPlaceVectorSorter.java +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthOutOfPlaceVectorSorter.java @@ -54,7 +54,7 @@ public void sortOutOfPlace(V srcVector, V dstVector, VectorValueComparator co "Expected capacity %s, actual capacity %s", (srcVector.getValueCount() + 7) / 8, dstValidityBuffer.capacity()); Preconditions.checkArgument( - dstValueBuffer.capacity() >= srcVector.getValueCount() * srcVector.getTypeWidth(), + dstValueBuffer.capacity() >= srcVector.getValueCount() * ((long) srcVector.getTypeWidth()), "Not enough capacity for the data buffer of the dst vector. " + "Expected capacity %s, actual capacity %s", srcVector.getValueCount() * srcVector.getTypeWidth(), dstValueBuffer.capacity()); @@ -73,8 +73,8 @@ public void sortOutOfPlace(V srcVector, V dstVector, VectorValueComparator co } else { BitVectorHelper.setBit(dstValidityBuffer, dstIndex); MemoryUtil.UNSAFE.copyMemory( - srcValueBuffer.memoryAddress() + srcIndex * valueWidth, - dstValueBuffer.memoryAddress() + dstIndex * valueWidth, + srcValueBuffer.memoryAddress() + srcIndex * ((long) valueWidth), + dstValueBuffer.memoryAddress() + dstIndex * ((long) valueWidth), valueWidth); } } diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VariableWidthOutOfPlaceVectorSorter.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VariableWidthOutOfPlaceVectorSorter.java index c60e273e9e851..863b07c348ef2 100644 --- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VariableWidthOutOfPlaceVectorSorter.java +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/VariableWidthOutOfPlaceVectorSorter.java @@ -51,12 +51,12 @@ public void sortOutOfPlace(V srcVector, V dstVector, VectorValueComparator co "Expected capacity %s, actual capacity %s", (srcVector.getValueCount() + 7) / 8, dstValidityBuffer.capacity()); Preconditions.checkArgument( - dstOffsetBuffer.capacity() >= (srcVector.getValueCount() + 1) * BaseVariableWidthVector.OFFSET_WIDTH, + dstOffsetBuffer.capacity() >= (srcVector.getValueCount() + 1) * ((long) BaseVariableWidthVector.OFFSET_WIDTH), "Not enough capacity for the offset buffer of the dst vector. " + "Expected capacity %s, actual capacity %s", (srcVector.getValueCount() + 1) * BaseVariableWidthVector.OFFSET_WIDTH, dstOffsetBuffer.capacity()); long dataSize = srcVector.getOffsetBuffer().getInt( - srcVector.getValueCount() * BaseVariableWidthVector.OFFSET_WIDTH); + srcVector.getValueCount() * ((long) BaseVariableWidthVector.OFFSET_WIDTH)); Preconditions.checkArgument( dstValueBuffer.capacity() >= dataSize, "No enough capacity for the data buffer of the dst vector. " + "Expected capacity %s, actual capacity %s", dataSize, dstValueBuffer.capacity()); @@ -77,15 +77,16 @@ public void sortOutOfPlace(V srcVector, V dstVector, VectorValueComparator co BitVectorHelper.unsetBit(dstValidityBuffer, dstIndex); } else { BitVectorHelper.setBit(dstValidityBuffer, dstIndex); - int srcOffset = srcOffsetBuffer.getInt(srcIndex * BaseVariableWidthVector.OFFSET_WIDTH); - int valueLength = srcOffsetBuffer.getInt((srcIndex + 1) * BaseVariableWidthVector.OFFSET_WIDTH) - srcOffset; + int srcOffset = srcOffsetBuffer.getInt(srcIndex * ((long) BaseVariableWidthVector.OFFSET_WIDTH)); + int valueLength = + srcOffsetBuffer.getInt((srcIndex + 1) * ((long) BaseVariableWidthVector.OFFSET_WIDTH)) - srcOffset; MemoryUtil.UNSAFE.copyMemory( srcValueBuffer.memoryAddress() + srcOffset, dstValueBuffer.memoryAddress() + dstOffset, valueLength); dstOffset += valueLength; } - dstOffsetBuffer.setInt((dstIndex + 1) * BaseVariableWidthVector.OFFSET_WIDTH, dstOffset); + dstOffsetBuffer.setInt((dstIndex + 1) * ((long) BaseVariableWidthVector.OFFSET_WIDTH), dstOffset); } } } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestDeduplicationUtils.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestDeduplicationUtils.java index def83fba7b74a..ac083b84f1611 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestDeduplicationUtils.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestDeduplicationUtils.java @@ -20,6 +20,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -107,7 +109,7 @@ public void testDeduplicateVariableWidth() { for (int i = 0; i < VECTOR_LENGTH; i++) { String str = String.valueOf(i * i); for (int j = 0; j < REPETITION_COUNT; j++) { - origVec.set(i * REPETITION_COUNT + j, str.getBytes()); + origVec.set(i * REPETITION_COUNT + j, str.getBytes(StandardCharsets.UTF_8)); } } @@ -120,7 +122,7 @@ public void testDeduplicateVariableWidth() { assertEquals(VECTOR_LENGTH, dedupVec.getValueCount()); for (int i = 0; i < VECTOR_LENGTH; i++) { - assertArrayEquals(String.valueOf(i * i).getBytes(), dedupVec.get(i)); + assertArrayEquals(String.valueOf(i * i).getBytes(StandardCharsets.UTF_8), dedupVec.get(i)); } DeduplicationUtils.populateRunLengths( diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestVectorRunDeduplicator.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestVectorRunDeduplicator.java index 4bfa6e2555176..788213b162870 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestVectorRunDeduplicator.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/deduplicate/TestVectorRunDeduplicator.java @@ -20,6 +20,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; @@ -104,20 +106,20 @@ public void testDeduplicateVariableWidth() { for (int i = 0; i < VECTOR_LENGTH; i++) { String str = String.valueOf(i * i); for (int j = 0; j < REPETITION_COUNT; j++) { - origVec.set(i * REPETITION_COUNT + j, str.getBytes()); + origVec.set(i * REPETITION_COUNT + j, str.getBytes(StandardCharsets.UTF_8)); } } int distinctCount = deduplicator.getRunCount(); assertEquals(VECTOR_LENGTH, distinctCount); - dedupVec.allocateNew(distinctCount * 10, distinctCount); + dedupVec.allocateNew(distinctCount * 10L, distinctCount); deduplicator.populateDeduplicatedValues(dedupVec); assertEquals(VECTOR_LENGTH, dedupVec.getValueCount()); for (int i = 0; i < VECTOR_LENGTH; i++) { - assertArrayEquals(String.valueOf(i * i).getBytes(), dedupVec.get(i)); + assertArrayEquals(String.valueOf(i * i).getBytes(StandardCharsets.UTF_8), dedupVec.get(i)); } deduplicator.populateRunLengths(lengthVec); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableBasedDictionaryBuilder.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableBasedDictionaryBuilder.java index 0a3314535f234..45c47626b720e 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableBasedDictionaryBuilder.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableBasedDictionaryBuilder.java @@ -21,6 +21,9 @@ import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertEquals; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; @@ -57,16 +60,16 @@ public void testBuildVariableWidthDictionaryWithNull() { dictionary.allocateNew(); // fill data - vec.set(0, "hello".getBytes()); - vec.set(1, "abc".getBytes()); + vec.set(0, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(1, "abc".getBytes(StandardCharsets.UTF_8)); vec.setNull(2); - vec.set(3, "world".getBytes()); - vec.set(4, "12".getBytes()); - vec.set(5, "dictionary".getBytes()); + vec.set(3, "world".getBytes(StandardCharsets.UTF_8)); + vec.set(4, "12".getBytes(StandardCharsets.UTF_8)); + vec.set(5, "dictionary".getBytes(StandardCharsets.UTF_8)); vec.setNull(6); - vec.set(7, "hello".getBytes()); - vec.set(8, "good".getBytes()); - vec.set(9, "abc".getBytes()); + vec.set(7, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(8, "good".getBytes(StandardCharsets.UTF_8)); + vec.set(9, "abc".getBytes(StandardCharsets.UTF_8)); HashTableBasedDictionaryBuilder dictionaryBuilder = new HashTableBasedDictionaryBuilder<>(dictionary, true); @@ -76,13 +79,13 @@ public void testBuildVariableWidthDictionaryWithNull() { assertEquals(7, result); assertEquals(7, dictionary.getValueCount()); - assertEquals("hello", new String(dictionary.get(0))); - assertEquals("abc", new String(dictionary.get(1))); + assertEquals("hello", new String(Objects.requireNonNull(dictionary.get(0)), StandardCharsets.UTF_8)); + assertEquals("abc", new String(Objects.requireNonNull(dictionary.get(1)), StandardCharsets.UTF_8)); assertNull(dictionary.get(2)); - assertEquals("world", new String(dictionary.get(3))); - assertEquals("12", new String(dictionary.get(4))); - assertEquals("dictionary", new String(dictionary.get(5))); - assertEquals("good", new String(dictionary.get(6))); + assertEquals("world", new String(Objects.requireNonNull(dictionary.get(3)), StandardCharsets.UTF_8)); + assertEquals("12", new String(Objects.requireNonNull(dictionary.get(4)), StandardCharsets.UTF_8)); + assertEquals("dictionary", new String(Objects.requireNonNull(dictionary.get(5)), StandardCharsets.UTF_8)); + assertEquals("good", new String(Objects.requireNonNull(dictionary.get(6)), StandardCharsets.UTF_8)); } } @@ -97,16 +100,16 @@ public void testBuildVariableWidthDictionaryWithoutNull() { dictionary.allocateNew(); // fill data - vec.set(0, "hello".getBytes()); - vec.set(1, "abc".getBytes()); + vec.set(0, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(1, "abc".getBytes(StandardCharsets.UTF_8)); vec.setNull(2); - vec.set(3, "world".getBytes()); - vec.set(4, "12".getBytes()); - vec.set(5, "dictionary".getBytes()); + vec.set(3, "world".getBytes(StandardCharsets.UTF_8)); + vec.set(4, "12".getBytes(StandardCharsets.UTF_8)); + vec.set(5, "dictionary".getBytes(StandardCharsets.UTF_8)); vec.setNull(6); - vec.set(7, "hello".getBytes()); - vec.set(8, "good".getBytes()); - vec.set(9, "abc".getBytes()); + vec.set(7, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(8, "good".getBytes(StandardCharsets.UTF_8)); + vec.set(9, "abc".getBytes(StandardCharsets.UTF_8)); HashTableBasedDictionaryBuilder dictionaryBuilder = new HashTableBasedDictionaryBuilder<>(dictionary, false); @@ -116,12 +119,12 @@ public void testBuildVariableWidthDictionaryWithoutNull() { assertEquals(6, result); assertEquals(6, dictionary.getValueCount()); - assertEquals("hello", new String(dictionary.get(0))); - assertEquals("abc", new String(dictionary.get(1))); - assertEquals("world", new String(dictionary.get(2))); - assertEquals("12", new String(dictionary.get(3))); - assertEquals("dictionary", new String(dictionary.get(4))); - assertEquals("good", new String(dictionary.get(5))); + assertEquals("hello", new String(Objects.requireNonNull(dictionary.get(0)), StandardCharsets.UTF_8)); + assertEquals("abc", new String(Objects.requireNonNull(dictionary.get(1)), StandardCharsets.UTF_8)); + assertEquals("world", new String(Objects.requireNonNull(dictionary.get(2)), StandardCharsets.UTF_8)); + assertEquals("12", new String(Objects.requireNonNull(dictionary.get(3)), StandardCharsets.UTF_8)); + assertEquals("dictionary", new String(Objects.requireNonNull(dictionary.get(4)), StandardCharsets.UTF_8)); + assertEquals("good", new String(Objects.requireNonNull(dictionary.get(5)), StandardCharsets.UTF_8)); } } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableDictionaryEncoder.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableDictionaryEncoder.java index dd22ac96fac88..60efbf58bebda 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableDictionaryEncoder.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestHashTableDictionaryEncoder.java @@ -76,7 +76,7 @@ public void testEncodeAndDecode() { dictionary.allocateNew(); for (int i = 0; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); @@ -84,7 +84,7 @@ public void testEncodeAndDecode() { rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); for (int i = 0; i < VECTOR_LENGTH; i++) { int val = (random.nextInt() & Integer.MAX_VALUE) % DICTIONARY_LENGTH; - rawVector.set(i, String.valueOf(val).getBytes()); + rawVector.set(i, String.valueOf(val).getBytes(StandardCharsets.UTF_8)); } rawVector.setValueCount(VECTOR_LENGTH); @@ -98,7 +98,7 @@ public void testEncodeAndDecode() { // verify encoding results assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); for (int i = 0; i < VECTOR_LENGTH; i++) { - assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8)); } // perform decoding @@ -108,7 +108,8 @@ public void testEncodeAndDecode() { // verify decoding results assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); for (int i = 0; i < VECTOR_LENGTH; i++) { - assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8), + decodedVector.get(i)); } } } @@ -126,7 +127,7 @@ public void testEncodeAndDecodeWithNull() { dictionary.setNull(0); for (int i = 1; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); @@ -137,7 +138,7 @@ public void testEncodeAndDecodeWithNull() { rawVector.setNull(i); } else { int val = (random.nextInt() & Integer.MAX_VALUE) % (DICTIONARY_LENGTH - 1) + 1; - rawVector.set(i, String.valueOf(val).getBytes()); + rawVector.set(i, String.valueOf(val).getBytes(StandardCharsets.UTF_8)); } } rawVector.setValueCount(VECTOR_LENGTH); @@ -155,7 +156,7 @@ public void testEncodeAndDecodeWithNull() { if (i % 10 == 0) { assertEquals(0, encodedVector.get(i)); } else { - assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8)); } } @@ -168,7 +169,8 @@ public void testEncodeAndDecodeWithNull() { if (i % 10 == 0) { assertTrue(decodedVector.isNull(i)); } else { - assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8), + decodedVector.get(i)); } } } @@ -185,7 +187,7 @@ public void testEncodeNullWithoutNullInDictionary() { dictionary.allocateNew(); for (int i = 0; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestLinearDictionaryEncoder.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestLinearDictionaryEncoder.java index 104d1b35b0660..a76aedffa308d 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestLinearDictionaryEncoder.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestLinearDictionaryEncoder.java @@ -77,7 +77,7 @@ public void testEncodeAndDecode() { dictionary.allocateNew(); for (int i = 0; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); @@ -85,7 +85,7 @@ public void testEncodeAndDecode() { rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); for (int i = 0; i < VECTOR_LENGTH; i++) { int val = (random.nextInt() & Integer.MAX_VALUE) % DICTIONARY_LENGTH; - rawVector.set(i, String.valueOf(val).getBytes()); + rawVector.set(i, String.valueOf(val).getBytes(StandardCharsets.UTF_8)); } rawVector.setValueCount(VECTOR_LENGTH); @@ -99,7 +99,7 @@ public void testEncodeAndDecode() { // verify encoding results assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); for (int i = 0; i < VECTOR_LENGTH; i++) { - assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8)); } // perform decoding @@ -109,7 +109,8 @@ public void testEncodeAndDecode() { // verify decoding results assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); for (int i = 0; i < VECTOR_LENGTH; i++) { - assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8), + decodedVector.get(i)); } } } @@ -127,7 +128,7 @@ public void testEncodeAndDecodeWithNull() { dictionary.setNull(0); for (int i = 1; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); @@ -138,7 +139,7 @@ public void testEncodeAndDecodeWithNull() { rawVector.setNull(i); } else { int val = (random.nextInt() & Integer.MAX_VALUE) % (DICTIONARY_LENGTH - 1) + 1; - rawVector.set(i, String.valueOf(val).getBytes()); + rawVector.set(i, String.valueOf(val).getBytes(StandardCharsets.UTF_8)); } } rawVector.setValueCount(VECTOR_LENGTH); @@ -156,7 +157,7 @@ public void testEncodeAndDecodeWithNull() { if (i % 10 == 0) { assertEquals(0, encodedVector.get(i)); } else { - assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8)); } } @@ -170,7 +171,8 @@ public void testEncodeAndDecodeWithNull() { if (i % 10 == 0) { assertTrue(decodedVector.isNull(i)); } else { - assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8), + decodedVector.get(i)); } } } @@ -187,7 +189,7 @@ public void testEncodeNullWithoutNullInDictionary() { dictionary.allocateNew(); for (int i = 0; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchDictionaryEncoder.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchDictionaryEncoder.java index a156e987c20ce..e01c2e7905b46 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchDictionaryEncoder.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchDictionaryEncoder.java @@ -78,7 +78,7 @@ public void testEncodeAndDecode() { dictionary.allocateNew(); for (int i = 0; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); @@ -86,7 +86,7 @@ public void testEncodeAndDecode() { rawVector.allocateNew(10 * VECTOR_LENGTH, VECTOR_LENGTH); for (int i = 0; i < VECTOR_LENGTH; i++) { int val = (random.nextInt() & Integer.MAX_VALUE) % DICTIONARY_LENGTH; - rawVector.set(i, String.valueOf(val).getBytes()); + rawVector.set(i, String.valueOf(val).getBytes(StandardCharsets.UTF_8)); } rawVector.setValueCount(VECTOR_LENGTH); @@ -101,7 +101,7 @@ public void testEncodeAndDecode() { // verify encoding results assertEquals(rawVector.getValueCount(), encodedVector.getValueCount()); for (int i = 0; i < VECTOR_LENGTH; i++) { - assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8)); } // perform decoding @@ -111,7 +111,8 @@ public void testEncodeAndDecode() { // verify decoding results assertEquals(encodedVector.getValueCount(), decodedVector.getValueCount()); for (int i = 0; i < VECTOR_LENGTH; i++) { - assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8), + decodedVector.get(i)); } } } @@ -129,7 +130,7 @@ public void testEncodeAndDecodeWithNull() { dictionary.setNull(0); for (int i = 1; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); @@ -140,7 +141,7 @@ public void testEncodeAndDecodeWithNull() { rawVector.setNull(i); } else { int val = (random.nextInt() & Integer.MAX_VALUE) % (DICTIONARY_LENGTH - 1) + 1; - rawVector.set(i, String.valueOf(val).getBytes()); + rawVector.set(i, String.valueOf(val).getBytes(StandardCharsets.UTF_8)); } } rawVector.setValueCount(VECTOR_LENGTH); @@ -159,7 +160,7 @@ public void testEncodeAndDecodeWithNull() { if (i % 10 == 0) { assertEquals(0, encodedVector.get(i)); } else { - assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes()); + assertArrayEquals(rawVector.get(i), String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8)); } } @@ -173,7 +174,8 @@ public void testEncodeAndDecodeWithNull() { if (i % 10 == 0) { assertTrue(decodedVector.isNull(i)); } else { - assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(), decodedVector.get(i)); + assertArrayEquals(String.valueOf(encodedVector.get(i)).getBytes(StandardCharsets.UTF_8), + decodedVector.get(i)); } } } @@ -190,7 +192,7 @@ public void testEncodeNullWithoutNullInDictionary() { dictionary.allocateNew(); for (int i = 0; i < DICTIONARY_LENGTH; i++) { // encode "i" as i - dictionary.setSafe(i, String.valueOf(i).getBytes()); + dictionary.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionary.setValueCount(DICTIONARY_LENGTH); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchTreeBasedDictionaryBuilder.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchTreeBasedDictionaryBuilder.java index d8e9edce83b7f..340b7e67e861f 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchTreeBasedDictionaryBuilder.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/dictionary/TestSearchTreeBasedDictionaryBuilder.java @@ -20,6 +20,9 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + import org.apache.arrow.algorithm.sort.DefaultVectorComparators; import org.apache.arrow.algorithm.sort.VectorValueComparator; import org.apache.arrow.memory.BufferAllocator; @@ -60,16 +63,16 @@ public void testBuildVariableWidthDictionaryWithNull() { sortedDictionary.allocateNew(); // fill data - vec.set(0, "hello".getBytes()); - vec.set(1, "abc".getBytes()); + vec.set(0, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(1, "abc".getBytes(StandardCharsets.UTF_8)); vec.setNull(2); - vec.set(3, "world".getBytes()); - vec.set(4, "12".getBytes()); - vec.set(5, "dictionary".getBytes()); + vec.set(3, "world".getBytes(StandardCharsets.UTF_8)); + vec.set(4, "12".getBytes(StandardCharsets.UTF_8)); + vec.set(5, "dictionary".getBytes(StandardCharsets.UTF_8)); vec.setNull(6); - vec.set(7, "hello".getBytes()); - vec.set(8, "good".getBytes()); - vec.set(9, "abc".getBytes()); + vec.set(7, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(8, "good".getBytes(StandardCharsets.UTF_8)); + vec.set(9, "abc".getBytes(StandardCharsets.UTF_8)); VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); SearchTreeBasedDictionaryBuilder dictionaryBuilder = @@ -83,12 +86,12 @@ public void testBuildVariableWidthDictionaryWithNull() { dictionaryBuilder.populateSortedDictionary(sortedDictionary); assertTrue(sortedDictionary.isNull(0)); - assertEquals("12", new String(sortedDictionary.get(1))); - assertEquals("abc", new String(sortedDictionary.get(2))); - assertEquals("dictionary", new String(sortedDictionary.get(3))); - assertEquals("good", new String(sortedDictionary.get(4))); - assertEquals("hello", new String(sortedDictionary.get(5))); - assertEquals("world", new String(sortedDictionary.get(6))); + assertEquals("12", new String(Objects.requireNonNull(sortedDictionary.get(1)), StandardCharsets.UTF_8)); + assertEquals("abc", new String(Objects.requireNonNull(sortedDictionary.get(2)), StandardCharsets.UTF_8)); + assertEquals("dictionary", new String(Objects.requireNonNull(sortedDictionary.get(3)), StandardCharsets.UTF_8)); + assertEquals("good", new String(Objects.requireNonNull(sortedDictionary.get(4)), StandardCharsets.UTF_8)); + assertEquals("hello", new String(Objects.requireNonNull(sortedDictionary.get(5)), StandardCharsets.UTF_8)); + assertEquals("world", new String(Objects.requireNonNull(sortedDictionary.get(6)), StandardCharsets.UTF_8)); } } @@ -105,16 +108,16 @@ public void testBuildVariableWidthDictionaryWithoutNull() { sortedDictionary.allocateNew(); // fill data - vec.set(0, "hello".getBytes()); - vec.set(1, "abc".getBytes()); + vec.set(0, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(1, "abc".getBytes(StandardCharsets.UTF_8)); vec.setNull(2); - vec.set(3, "world".getBytes()); - vec.set(4, "12".getBytes()); - vec.set(5, "dictionary".getBytes()); + vec.set(3, "world".getBytes(StandardCharsets.UTF_8)); + vec.set(4, "12".getBytes(StandardCharsets.UTF_8)); + vec.set(5, "dictionary".getBytes(StandardCharsets.UTF_8)); vec.setNull(6); - vec.set(7, "hello".getBytes()); - vec.set(8, "good".getBytes()); - vec.set(9, "abc".getBytes()); + vec.set(7, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(8, "good".getBytes(StandardCharsets.UTF_8)); + vec.set(9, "abc".getBytes(StandardCharsets.UTF_8)); VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vec); SearchTreeBasedDictionaryBuilder dictionaryBuilder = @@ -127,12 +130,12 @@ public void testBuildVariableWidthDictionaryWithoutNull() { dictionaryBuilder.populateSortedDictionary(sortedDictionary); - assertEquals("12", new String(sortedDictionary.get(0))); - assertEquals("abc", new String(sortedDictionary.get(1))); - assertEquals("dictionary", new String(sortedDictionary.get(2))); - assertEquals("good", new String(sortedDictionary.get(3))); - assertEquals("hello", new String(sortedDictionary.get(4))); - assertEquals("world", new String(sortedDictionary.get(5))); + assertEquals("12", new String(Objects.requireNonNull(sortedDictionary.get(0)), StandardCharsets.UTF_8)); + assertEquals("abc", new String(Objects.requireNonNull(sortedDictionary.get(1)), StandardCharsets.UTF_8)); + assertEquals("dictionary", new String(Objects.requireNonNull(sortedDictionary.get(2)), StandardCharsets.UTF_8)); + assertEquals("good", new String(Objects.requireNonNull(sortedDictionary.get(3)), StandardCharsets.UTF_8)); + assertEquals("hello", new String(Objects.requireNonNull(sortedDictionary.get(4)), StandardCharsets.UTF_8)); + assertEquals("world", new String(Objects.requireNonNull(sortedDictionary.get(5)), StandardCharsets.UTF_8)); } } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/misc/TestPartialSumUtils.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/misc/TestPartialSumUtils.java index 4e2d5900f8ccc..630dd80b44084 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/misc/TestPartialSumUtils.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/misc/TestPartialSumUtils.java @@ -67,7 +67,7 @@ public void testToPartialSumVector() { // verify results assertEquals(PARTIAL_SUM_VECTOR_LENGTH, partialSum.getValueCount()); for (int i = 0; i < partialSum.getValueCount(); i++) { - assertEquals(i * 3 + sumBase, partialSum.get(i)); + assertEquals(i * 3L + sumBase, partialSum.get(i)); } } } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java index f372a809bab53..0e6627eb4822a 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java @@ -20,6 +20,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.algorithm.sort.DefaultVectorComparators; import org.apache.arrow.algorithm.sort.VectorValueComparator; import org.apache.arrow.memory.BufferAllocator; @@ -89,16 +91,16 @@ public void testVariableWidthRank() { vector.allocateNew(VECTOR_LENGTH * 5, VECTOR_LENGTH); vector.setValueCount(VECTOR_LENGTH); - vector.set(0, String.valueOf(1).getBytes()); - vector.set(1, String.valueOf(5).getBytes()); - vector.set(2, String.valueOf(3).getBytes()); - vector.set(3, String.valueOf(7).getBytes()); - vector.set(4, String.valueOf(9).getBytes()); - vector.set(5, String.valueOf(8).getBytes()); - vector.set(6, String.valueOf(2).getBytes()); - vector.set(7, String.valueOf(0).getBytes()); - vector.set(8, String.valueOf(4).getBytes()); - vector.set(9, String.valueOf(6).getBytes()); + vector.set(0, String.valueOf(1).getBytes(StandardCharsets.UTF_8)); + vector.set(1, String.valueOf(5).getBytes(StandardCharsets.UTF_8)); + vector.set(2, String.valueOf(3).getBytes(StandardCharsets.UTF_8)); + vector.set(3, String.valueOf(7).getBytes(StandardCharsets.UTF_8)); + vector.set(4, String.valueOf(9).getBytes(StandardCharsets.UTF_8)); + vector.set(5, String.valueOf(8).getBytes(StandardCharsets.UTF_8)); + vector.set(6, String.valueOf(2).getBytes(StandardCharsets.UTF_8)); + vector.set(7, String.valueOf(0).getBytes(StandardCharsets.UTF_8)); + vector.set(8, String.valueOf(4).getBytes(StandardCharsets.UTF_8)); + vector.set(9, String.valueOf(6).getBytes(StandardCharsets.UTF_8)); VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vector); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java index 767935aaa4bae..9ccecfa84a73a 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java @@ -19,6 +19,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -130,8 +131,8 @@ public void testParallelStringSearch() throws ExecutionException, InterruptedExc : DefaultVectorComparators.createDefaultComparator(targetVector); for (int i = 0; i < VECTOR_LENGTH; i++) { - targetVector.setSafe(i, String.valueOf(i).getBytes()); - keyVector.setSafe(i, String.valueOf(i * 2).getBytes()); + targetVector.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); + keyVector.setSafe(i, String.valueOf(i * 2).getBytes(StandardCharsets.UTF_8)); } targetVector.setValueCount(VECTOR_LENGTH); keyVector.setValueCount(VECTOR_LENGTH); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java index d7659dc4cfa03..18f4fa0355f4f 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorRangeSearcher.java @@ -81,7 +81,7 @@ public void testGetLowerBounds() { VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(intVector); for (int i = 0; i < maxValue; i++) { int result = VectorRangeSearcher.getFirstMatch(intVector, comparator, intVector, i * repeat); - assertEquals(i * repeat, result); + assertEquals(i * ((long) repeat), result); } } } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java index 2847ddbb8ada6..32fa10bbd98d0 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java @@ -20,6 +20,8 @@ import static org.apache.arrow.vector.complex.BaseRepeatedValueVector.OFFSET_WIDTH; import static org.junit.Assert.assertEquals; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.algorithm.sort.DefaultVectorComparators; import org.apache.arrow.algorithm.sort.VectorValueComparator; import org.apache.arrow.memory.BufferAllocator; @@ -142,7 +144,7 @@ public void testBinarySearchVarChar() { rawVector.set(i, content); } } - negVector.set(0, "abcd".getBytes()); + negVector.set(0, "abcd".getBytes(StandardCharsets.UTF_8)); // do search VectorValueComparator comparator = @@ -181,7 +183,7 @@ public void testLinearSearchVarChar() { rawVector.set(i, content); } } - negVector.set(0, "abcd".getBytes()); + negVector.set(0, "abcd".getBytes(StandardCharsets.UTF_8)); // do search VectorValueComparator comparator = diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestCompositeVectorComparator.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestCompositeVectorComparator.java index cac9933cc0bc2..9624432924b5a 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestCompositeVectorComparator.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestCompositeVectorComparator.java @@ -17,8 +17,10 @@ package org.apache.arrow.algorithm.sort; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.apache.arrow.memory.BufferAllocator; @@ -67,9 +69,9 @@ public void testCompareVectorSchemaRoot() { for (int i = 0; i < vectorLength; i++) { intVec1.set(i, i); - strVec1.set(i, new String("a" + i).getBytes()); + strVec1.set(i, ("a" + i).getBytes(StandardCharsets.UTF_8)); intVec2.set(i, i); - strVec2.set(i, new String("a5").getBytes()); + strVec2.set(i, "a5".getBytes(StandardCharsets.UTF_8)); } VectorValueComparator innerComparator1 = @@ -86,7 +88,7 @@ public void testCompareVectorSchemaRoot() { // verify results // both elements are equal, the result is equal - assertTrue(comparator.compare(5, 5) == 0); + assertEquals(0, comparator.compare(5, 5)); // the first element being equal, the second is smaller, and the result is smaller assertTrue(comparator.compare(1, 1) < 0); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java index 43c634b7647fb..c40854fb17410 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java @@ -65,6 +65,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.jupiter.api.Assertions; /** * Test cases for {@link DefaultVectorComparators}. @@ -258,7 +259,8 @@ public void testCompareUInt2() { vec.allocateNew(10); ValueVectorDataPopulator.setVector( - vec, null, (char) -2, (char) -1, (char) 0, (char) 1, (char) 2, (char) -2, null, + vec, null, (char) (Character.MAX_VALUE - 1), Character.MAX_VALUE, (char) 0, (char) 1, + (char) 2, (char) (Character.MAX_VALUE - 1), null, '\u7FFF', // value for the max 16-byte signed integer '\u8000' // value for the min 16-byte signed integer ); @@ -272,8 +274,8 @@ public void testCompareUInt2() { assertTrue(comparator.compare(1, 3) > 0); assertTrue(comparator.compare(2, 5) > 0); assertTrue(comparator.compare(4, 5) < 0); - assertTrue(comparator.compare(1, 6) == 0); - assertTrue(comparator.compare(0, 7) == 0); + Assertions.assertEquals(0, comparator.compare(1, 6)); + Assertions.assertEquals(0, comparator.compare(0, 7)); assertTrue(comparator.compare(8, 9) < 0); assertTrue(comparator.compare(4, 8) < 0); assertTrue(comparator.compare(5, 9) < 0); diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthSorting.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthSorting.java index ba2a341bf44a0..80c72b4e21a27 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthSorting.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthSorting.java @@ -131,37 +131,37 @@ public static Collection getParameters() { for (boolean inPlace : new boolean[] {true, false}) { params.add(new Object[] { length, nullFrac, inPlace, "TinyIntVector", - (Function) (allocator -> new TinyIntVector("vector", allocator)), + (Function) allocator -> new TinyIntVector("vector", allocator), TestSortingUtil.TINY_INT_GENERATOR }); params.add(new Object[] { length, nullFrac, inPlace, "SmallIntVector", - (Function) (allocator -> new SmallIntVector("vector", allocator)), + (Function) allocator -> new SmallIntVector("vector", allocator), TestSortingUtil.SMALL_INT_GENERATOR }); params.add(new Object[] { length, nullFrac, inPlace, "IntVector", - (Function) (allocator -> new IntVector("vector", allocator)), + (Function) allocator -> new IntVector("vector", allocator), TestSortingUtil.INT_GENERATOR }); params.add(new Object[] { length, nullFrac, inPlace, "BigIntVector", - (Function) (allocator -> new BigIntVector("vector", allocator)), + (Function) allocator -> new BigIntVector("vector", allocator), TestSortingUtil.LONG_GENERATOR }); params.add(new Object[] { length, nullFrac, inPlace, "Float4Vector", - (Function) (allocator -> new Float4Vector("vector", allocator)), + (Function) allocator -> new Float4Vector("vector", allocator), TestSortingUtil.FLOAT_GENERATOR }); params.add(new Object[] { length, nullFrac, inPlace, "Float8Vector", - (Function) (allocator -> new Float8Vector("vector", allocator)), + (Function) allocator -> new Float8Vector("vector", allocator), TestSortingUtil.DOUBLE_GENERATOR }); } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestSortingUtil.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestSortingUtil.java index ea86551061d56..e22b22d4e6757 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestSortingUtil.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestSortingUtil.java @@ -20,6 +20,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.lang.reflect.Array; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Random; import java.util.function.BiConsumer; @@ -122,7 +123,7 @@ static String generateRandomString(int length) { str[i] = (byte) (r % (upper - lower + 1) + lower); } - return new String(str); + return new String(str, StandardCharsets.UTF_8); } /** diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestStableVectorComparator.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestStableVectorComparator.java index 07419359427f9..f2de5d23fce89 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestStableVectorComparator.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestStableVectorComparator.java @@ -20,12 +20,16 @@ import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VarCharVector; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.jupiter.api.Assertions; /** * Test cases for {@link StableVectorComparator}. @@ -51,11 +55,11 @@ public void testCompare() { vec.setValueCount(10); // fill data to sort - vec.set(0, "ba".getBytes()); - vec.set(1, "abc".getBytes()); - vec.set(2, "aa".getBytes()); - vec.set(3, "abc".getBytes()); - vec.set(4, "a".getBytes()); + vec.set(0, "ba".getBytes(StandardCharsets.UTF_8)); + vec.set(1, "abc".getBytes(StandardCharsets.UTF_8)); + vec.set(2, "aa".getBytes(StandardCharsets.UTF_8)); + vec.set(3, "abc".getBytes(StandardCharsets.UTF_8)); + vec.set(4, "a".getBytes(StandardCharsets.UTF_8)); VectorValueComparator comparator = new TestVarCharSorter(); VectorValueComparator stableComparator = new StableVectorComparator<>(comparator); @@ -66,7 +70,7 @@ public void testCompare() { assertTrue(stableComparator.compare(2, 3) < 0); assertTrue(stableComparator.compare(1, 3) < 0); assertTrue(stableComparator.compare(3, 1) > 0); - assertTrue(stableComparator.compare(3, 3) == 0); + Assertions.assertEquals(0, stableComparator.compare(3, 3)); } } @@ -77,16 +81,16 @@ public void testStableSortString() { vec.setValueCount(10); // fill data to sort - vec.set(0, "a".getBytes()); - vec.set(1, "abc".getBytes()); - vec.set(2, "aa".getBytes()); - vec.set(3, "a1".getBytes()); - vec.set(4, "abcdefg".getBytes()); - vec.set(5, "accc".getBytes()); - vec.set(6, "afds".getBytes()); - vec.set(7, "0".getBytes()); - vec.set(8, "01".getBytes()); - vec.set(9, "0c".getBytes()); + vec.set(0, "a".getBytes(StandardCharsets.UTF_8)); + vec.set(1, "abc".getBytes(StandardCharsets.UTF_8)); + vec.set(2, "aa".getBytes(StandardCharsets.UTF_8)); + vec.set(3, "a1".getBytes(StandardCharsets.UTF_8)); + vec.set(4, "abcdefg".getBytes(StandardCharsets.UTF_8)); + vec.set(5, "accc".getBytes(StandardCharsets.UTF_8)); + vec.set(6, "afds".getBytes(StandardCharsets.UTF_8)); + vec.set(7, "0".getBytes(StandardCharsets.UTF_8)); + vec.set(8, "01".getBytes(StandardCharsets.UTF_8)); + vec.set(9, "0c".getBytes(StandardCharsets.UTF_8)); // sort the vector VariableWidthOutOfPlaceVectorSorter sorter = new VariableWidthOutOfPlaceVectorSorter(); @@ -103,16 +107,16 @@ public void testStableSortString() { // verify results // the results are stable - assertEquals("0", new String(sortedVec.get(0))); - assertEquals("01", new String(sortedVec.get(1))); - assertEquals("0c", new String(sortedVec.get(2))); - assertEquals("a", new String(sortedVec.get(3))); - assertEquals("abc", new String(sortedVec.get(4))); - assertEquals("aa", new String(sortedVec.get(5))); - assertEquals("a1", new String(sortedVec.get(6))); - assertEquals("abcdefg", new String(sortedVec.get(7))); - assertEquals("accc", new String(sortedVec.get(8))); - assertEquals("afds", new String(sortedVec.get(9))); + assertEquals("0", new String(Objects.requireNonNull(sortedVec.get(0)), StandardCharsets.UTF_8)); + assertEquals("01", new String(Objects.requireNonNull(sortedVec.get(1)), StandardCharsets.UTF_8)); + assertEquals("0c", new String(Objects.requireNonNull(sortedVec.get(2)), StandardCharsets.UTF_8)); + assertEquals("a", new String(Objects.requireNonNull(sortedVec.get(3)), StandardCharsets.UTF_8)); + assertEquals("abc", new String(Objects.requireNonNull(sortedVec.get(4)), StandardCharsets.UTF_8)); + assertEquals("aa", new String(Objects.requireNonNull(sortedVec.get(5)), StandardCharsets.UTF_8)); + assertEquals("a1", new String(Objects.requireNonNull(sortedVec.get(6)), StandardCharsets.UTF_8)); + assertEquals("abcdefg", new String(Objects.requireNonNull(sortedVec.get(7)), StandardCharsets.UTF_8)); + assertEquals("accc", new String(Objects.requireNonNull(sortedVec.get(8)), StandardCharsets.UTF_8)); + assertEquals("afds", new String(Objects.requireNonNull(sortedVec.get(9)), StandardCharsets.UTF_8)); } } } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthOutOfPlaceVectorSorter.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthOutOfPlaceVectorSorter.java index 8f4e3b8e19426..2486034f1fa32 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthOutOfPlaceVectorSorter.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthOutOfPlaceVectorSorter.java @@ -20,6 +20,9 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BaseVariableWidthVector; @@ -62,16 +65,16 @@ public void testSortString() { vec.setValueCount(10); // fill data to sort - vec.set(0, "hello".getBytes()); - vec.set(1, "abc".getBytes()); + vec.set(0, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(1, "abc".getBytes(StandardCharsets.UTF_8)); vec.setNull(2); - vec.set(3, "world".getBytes()); - vec.set(4, "12".getBytes()); - vec.set(5, "dictionary".getBytes()); + vec.set(3, "world".getBytes(StandardCharsets.UTF_8)); + vec.set(4, "12".getBytes(StandardCharsets.UTF_8)); + vec.set(5, "dictionary".getBytes(StandardCharsets.UTF_8)); vec.setNull(6); - vec.set(7, "hello".getBytes()); - vec.set(8, "good".getBytes()); - vec.set(9, "yes".getBytes()); + vec.set(7, "hello".getBytes(StandardCharsets.UTF_8)); + vec.set(8, "good".getBytes(StandardCharsets.UTF_8)); + vec.set(9, "yes".getBytes(StandardCharsets.UTF_8)); // sort the vector OutOfPlaceVectorSorter sorter = getSorter(); @@ -93,14 +96,14 @@ public void testSortString() { assertTrue(sortedVec.isNull(0)); assertTrue(sortedVec.isNull(1)); - assertEquals("12", new String(sortedVec.get(2))); - assertEquals("abc", new String(sortedVec.get(3))); - assertEquals("dictionary", new String(sortedVec.get(4))); - assertEquals("good", new String(sortedVec.get(5))); - assertEquals("hello", new String(sortedVec.get(6))); - assertEquals("hello", new String(sortedVec.get(7))); - assertEquals("world", new String(sortedVec.get(8))); - assertEquals("yes", new String(sortedVec.get(9))); + assertEquals("12", new String(Objects.requireNonNull(sortedVec.get(2)), StandardCharsets.UTF_8)); + assertEquals("abc", new String(Objects.requireNonNull(sortedVec.get(3)), StandardCharsets.UTF_8)); + assertEquals("dictionary", new String(Objects.requireNonNull(sortedVec.get(4)), StandardCharsets.UTF_8)); + assertEquals("good", new String(Objects.requireNonNull(sortedVec.get(5)), StandardCharsets.UTF_8)); + assertEquals("hello", new String(Objects.requireNonNull(sortedVec.get(6)), StandardCharsets.UTF_8)); + assertEquals("hello", new String(Objects.requireNonNull(sortedVec.get(7)), StandardCharsets.UTF_8)); + assertEquals("world", new String(Objects.requireNonNull(sortedVec.get(8)), StandardCharsets.UTF_8)); + assertEquals("yes", new String(Objects.requireNonNull(sortedVec.get(9)), StandardCharsets.UTF_8)); sortedVec.close(); } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthSorting.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthSorting.java index 068fe8b69a883..7951c39d550d2 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthSorting.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestVariableWidthSorting.java @@ -21,6 +21,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -94,7 +95,7 @@ void sortOutOfPlace() { VectorValueComparator comparator = DefaultVectorComparators.createDefaultComparator(vector); try (V sortedVec = (V) vector.getField().getFieldType().createNewSingleVector("", allocator, null)) { - int dataSize = vector.getOffsetBuffer().getInt(vector.getValueCount() * 4); + int dataSize = vector.getOffsetBuffer().getInt(vector.getValueCount() * 4L); sortedVec.allocateNew(dataSize, vector.getValueCount()); sortedVec.setValueCount(vector.getValueCount()); @@ -113,7 +114,7 @@ public static Collection getParameters() { for (double nullFrac : NULL_FRACTIONS) { params.add(new Object[]{ length, nullFrac, "VarCharVector", - (Function) (allocator -> new VarCharVector("vector", allocator)), + (Function) allocator -> new VarCharVector("vector", allocator), TestSortingUtil.STRING_GENERATOR }); } @@ -130,7 +131,7 @@ public static void verifyResults(V vector, String[] expe if (expected[i] == null) { assertTrue(vector.isNull(i)); } else { - assertArrayEquals(((Text) vector.getObject(i)).getBytes(), expected[i].getBytes()); + assertArrayEquals(((Text) vector.getObject(i)).getBytes(), expected[i].getBytes(StandardCharsets.UTF_8)); } } } @@ -151,8 +152,8 @@ public int compare(String str1, String str2) { return str1 == null ? -1 : 1; } - byte[] bytes1 = str1.getBytes(); - byte[] bytes2 = str2.getBytes(); + byte[] bytes1 = str1.getBytes(StandardCharsets.UTF_8); + byte[] bytes2 = str2.getBytes(StandardCharsets.UTF_8); for (int i = 0; i < bytes1.length && i < bytes2.length; i++) { if (bytes1[i] != bytes2[i]) { diff --git a/java/bom/pom.xml b/java/bom/pom.xml index 5c2ed33dadddf..025632c45a56d 100644 --- a/java/bom/pom.xml +++ b/java/bom/pom.xml @@ -20,7 +20,7 @@ org.apache.arrow arrow-bom - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT pom Arrow Bill of Materials Arrow Bill of Materials @@ -82,11 +82,6 @@ flight-core ${project.version} - - org.apache.arrow - flight-grpc - ${project.version} - org.apache.arrow flight-integration-tests diff --git a/java/c/CMakeLists.txt b/java/c/CMakeLists.txt index 8ff208aaeb010..83909c5e13e1b 100644 --- a/java/c/CMakeLists.txt +++ b/java/c/CMakeLists.txt @@ -30,6 +30,11 @@ add_library(arrow_java_jni_cdata SHARED src/main/cpp/jni_wrapper.cc) set_property(TARGET arrow_java_jni_cdata PROPERTY OUTPUT_NAME "arrow_cdata_jni") target_link_libraries(arrow_java_jni_cdata arrow_java_jni_cdata_headers jni) +set(ARROW_JAVA_JNI_C_LIBDIR + "${CMAKE_INSTALL_PREFIX}/lib/arrow_cdata_jni/${ARROW_JAVA_JNI_ARCH_DIR}") +set(ARROW_JAVA_JNI_C_BINDIR + "${CMAKE_INSTALL_PREFIX}/bin/arrow_cdata_jni/${ARROW_JAVA_JNI_ARCH_DIR}") + install(TARGETS arrow_java_jni_cdata - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + LIBRARY DESTINATION ${ARROW_JAVA_JNI_C_LIBDIR} + RUNTIME DESTINATION ${ARROW_JAVA_JNI_C_BINDIR}) diff --git a/java/c/pom.xml b/java/c/pom.xml index 8fc3f36994d8a..ffd41b62dd674 100644 --- a/java/c/pom.xml +++ b/java/c/pom.xml @@ -13,7 +13,7 @@ arrow-java-root org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT 4.0.0 @@ -48,6 +48,10 @@ org.slf4j slf4j-api + + org.immutables + value + org.apache.arrow arrow-memory-unsafe diff --git a/java/c/src/main/java/module-info.java b/java/c/src/main/java/module-info.java new file mode 100644 index 0000000000000..0a62c9b9875b4 --- /dev/null +++ b/java/c/src/main/java/module-info.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +open module org.apache.arrow.c { + exports org.apache.arrow.c; + exports org.apache.arrow.c.jni; + + requires flatbuffers.java; + requires jdk.unsupported; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; + requires org.slf4j; +} diff --git a/java/c/src/main/java/org/apache/arrow/c/Data.java b/java/c/src/main/java/org/apache/arrow/c/Data.java index a92853b3504f0..c90ce7604d6e7 100644 --- a/java/c/src/main/java/org/apache/arrow/c/Data.java +++ b/java/c/src/main/java/org/apache/arrow/c/Data.java @@ -19,8 +19,6 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.StructVectorLoader; -import org.apache.arrow.vector.StructVectorUnloader; import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; diff --git a/java/c/src/main/java/org/apache/arrow/vector/StructVectorLoader.java b/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java similarity index 98% rename from java/c/src/main/java/org/apache/arrow/vector/StructVectorLoader.java rename to java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java index 4a62be7851ac7..d9afd0189d807 100644 --- a/java/c/src/main/java/org/apache/arrow/vector/StructVectorLoader.java +++ b/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.vector; +package org.apache.arrow.c; import static org.apache.arrow.util.Preconditions.checkArgument; @@ -27,6 +27,8 @@ import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Collections2; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TypeLayout; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.CompressionUtil; diff --git a/java/c/src/main/java/org/apache/arrow/vector/StructVectorUnloader.java b/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java similarity index 97% rename from java/c/src/main/java/org/apache/arrow/vector/StructVectorUnloader.java rename to java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java index e75156cf237bb..aa6d9b4d0f6a7 100644 --- a/java/c/src/main/java/org/apache/arrow/vector/StructVectorUnloader.java +++ b/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.arrow.vector; +package org.apache.arrow.c; import java.util.ArrayList; import java.util.List; import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TypeLayout; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.CompressionUtil; diff --git a/java/c/src/main/java/org/apache/arrow/c/jni/JniLoader.java b/java/c/src/main/java/org/apache/arrow/c/jni/JniLoader.java index e435461349257..ef9f432cf0036 100644 --- a/java/c/src/main/java/org/apache/arrow/c/jni/JniLoader.java +++ b/java/c/src/main/java/org/apache/arrow/c/jni/JniLoader.java @@ -80,7 +80,7 @@ private synchronized void loadRemaining() { private void load(String name) { final String libraryToLoad = - getNormalizedArch() + "/" + System.mapLibraryName(name); + name + "/" + getNormalizedArch() + "/" + System.mapLibraryName(name); try { File temp = File.createTempFile("jnilib-", ".tmp", new File(System.getProperty("java.io.tmpdir"))); temp.deleteOnExit(); diff --git a/java/compression/pom.xml b/java/compression/pom.xml index 9a9f029fee137..dea8c778735a8 100644 --- a/java/compression/pom.xml +++ b/java/compression/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT arrow-compression Arrow Compression @@ -35,6 +35,10 @@ arrow-memory-unsafe test + + org.immutables + value + org.apache.commons commons-compress diff --git a/java/compression/src/main/java/module-info.java b/java/compression/src/main/java/module-info.java new file mode 100644 index 0000000000000..6bf989e4c142e --- /dev/null +++ b/java/compression/src/main/java/module-info.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +module org.apache.arrow.compression { + exports org.apache.arrow.compression; + + requires com.github.luben.zstd_jni; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; + requires org.apache.commons.compress; +} diff --git a/java/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java b/java/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java index 01156fa2b0e0b..5fff4fafd677e 100644 --- a/java/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java +++ b/java/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java @@ -175,7 +175,7 @@ void testCompressVariableWidthBuffers(int vectorLength, CompressionCodec codec) if (i % 10 == 0) { origVec.setNull(i); } else { - origVec.setSafe(i, String.valueOf(i).getBytes()); + origVec.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } } origVec.setValueCount(vectorLength); @@ -199,7 +199,7 @@ void testCompressVariableWidthBuffers(int vectorLength, CompressionCodec codec) if (i % 10 == 0) { assertTrue(newVec.isNull(i)); } else { - assertArrayEquals(String.valueOf(i).getBytes(), newVec.get(i)); + assertArrayEquals(String.valueOf(i).getBytes(StandardCharsets.UTF_8), newVec.get(i)); } } diff --git a/java/dataset/CMakeLists.txt b/java/dataset/CMakeLists.txt index ede3ee7330d21..348850c3be5da 100644 --- a/java/dataset/CMakeLists.txt +++ b/java/dataset/CMakeLists.txt @@ -47,6 +47,12 @@ if(BUILD_TESTING) add_test(NAME arrow-java-jni-dataset-test COMMAND arrow-java-jni-dataset-test) endif() +set(ARROW_JAVA_JNI_DATASET_LIBDIR + "${CMAKE_INSTALL_PREFIX}/lib/arrow_dataset_jni/${ARROW_JAVA_JNI_ARCH_DIR}") + +set(ARROW_JAVA_JNI_DATASET_BINDIR + "${CMAKE_INSTALL_PREFIX}/bin/arrow_dataset_jni/${ARROW_JAVA_JNI_ARCH_DIR}") + install(TARGETS arrow_java_jni_dataset - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + LIBRARY DESTINATION ${ARROW_JAVA_JNI_DATASET_LIBDIR} + RUNTIME DESTINATION ${ARROW_JAVA_JNI_DATASET_BINDIR}) diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index 7d6092743bf4d..8723fafa8dadd 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -15,7 +15,7 @@ arrow-java-root org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT 4.0.0 @@ -47,6 +47,10 @@ arrow-c-data compile + + org.immutables + value + org.apache.arrow arrow-memory-netty @@ -120,7 +124,7 @@ org.apache.orc orc-core - 1.7.6 + 1.9.2 test @@ -161,6 +165,15 @@ + + maven-surefire-plugin + + false + + ${project.basedir}/../../testing/data + + + org.xolstice.maven.plugins protobuf-maven-plugin @@ -182,4 +195,30 @@ + + + jdk11+ + + [11,] + + !m2e.version + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + false + + ${project.basedir}/../../testing/data + + --add-reads=org.apache.arrow.dataset=com.fasterxml.jackson.databind --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED + + + + + + diff --git a/java/dataset/src/main/java/module-info.java b/java/dataset/src/main/java/module-info.java new file mode 100644 index 0000000000000..1672d12ffec69 --- /dev/null +++ b/java/dataset/src/main/java/module-info.java @@ -0,0 +1,29 @@ +/* + + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +open module org.apache.arrow.dataset { + exports org.apache.arrow.dataset.file; + exports org.apache.arrow.dataset.source; + exports org.apache.arrow.dataset.jni; + exports org.apache.arrow.dataset.substrait; + exports org.apache.arrow.dataset.scanner; + + requires org.apache.arrow.c; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/DirectReservationListener.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/DirectReservationListener.java index eb26400cbf882..3922e90335da4 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/DirectReservationListener.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/DirectReservationListener.java @@ -40,7 +40,12 @@ private DirectReservationListener() { methodUnreserve = this.getDeclaredMethodBaseOnJDKVersion(classBits, "unreserveMemory"); methodUnreserve.setAccessible(true); } catch (Exception e) { - throw new RuntimeException(e); + final RuntimeException failure = new RuntimeException( + "Failed to initialize DirectReservationListener. When starting Java you must include " + + "`--add-opens=java.base/java.nio=org.apache.arrow.dataset,org.apache.arrow.memory.core,ALL-UNNAMED` " + + "(See https://arrow.apache.org/docs/java/install.html)", e); + failure.printStackTrace(); + throw failure; } } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniLoader.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniLoader.java index a3b31c73e8540..cf2f8fe29e8ba 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniLoader.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniLoader.java @@ -79,7 +79,7 @@ private synchronized void loadRemaining() { private void load(String name) { final String libraryToLoad = - getNormalizedArch() + "/" + System.mapLibraryName(name); + name + "/" + getNormalizedArch() + "/" + System.mapLibraryName(name); try { File temp = File.createTempFile("jnilib-", ".tmp", new File(System.getProperty("java.io.tmpdir"))); temp.deleteOnExit(); diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java b/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java index 75f905877cd1f..2352a65e8fb62 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java @@ -18,8 +18,8 @@ package org.apache.arrow.dataset; import java.io.File; -import java.nio.file.Path; -import java.nio.file.Paths; +import java.io.InputStream; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -47,7 +47,7 @@ public class ParquetWriteSupport implements AutoCloseable { public ParquetWriteSupport(String schemaName, File outputFolder) throws Exception { - avroSchema = readSchemaFromFile(schemaName); + avroSchema = getSchema(schemaName); path = outputFolder.getPath() + "/" + "generated-" + random.nextLong() + ".parquet"; uri = "file://" + path; writer = AvroParquetWriter @@ -56,10 +56,23 @@ public ParquetWriteSupport(String schemaName, File outputFolder) throws Exceptio .build(); } - private static Schema readSchemaFromFile(String schemaName) throws Exception { - Path schemaPath = Paths.get(ParquetWriteSupport.class.getResource("/").getPath(), - "avroschema", schemaName); - return new org.apache.avro.Schema.Parser().parse(schemaPath.toFile()); + public static Schema getSchema(String schemaName) throws Exception { + try { + // Attempt to use JDK 9 behavior of getting the module then the resource stream from the module. + // Note that this code is caller-sensitive. + Method getModuleMethod = Class.class.getMethod("getModule"); + Object module = getModuleMethod.invoke(ParquetWriteSupport.class); + Method getResourceAsStreamFromModule = module.getClass().getMethod("getResourceAsStream", String.class); + try (InputStream is = (InputStream) getResourceAsStreamFromModule.invoke(module, "/avroschema/" + schemaName)) { + return new Schema.Parser() + .parse(is); + } + } catch (NoSuchMethodException ex) { + // Use JDK8 behavior. + try (InputStream is = ParquetWriteSupport.class.getResourceAsStream("/avroschema/" + schemaName)) { + return new Schema.Parser().parse(is); + } + } } public static ParquetWriteSupport writeTempFile(String schemaName, File outputFolder, diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/TestAllTypes.java b/java/dataset/src/test/java/org/apache/arrow/dataset/TestAllTypes.java index 7be49079e7450..13b247452348d 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/TestAllTypes.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/TestAllTypes.java @@ -32,7 +32,6 @@ import org.apache.arrow.dataset.file.DatasetFileWriter; import org.apache.arrow.dataset.file.FileFormat; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.ArrowTestDataUtil; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateMilliVector; @@ -69,6 +68,7 @@ import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.test.util.ArrowTestDataUtil; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index 8f41d2b65b7d1..0346172f610a6 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -14,7 +14,7 @@ arrow-flight org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT ../pom.xml @@ -86,6 +86,10 @@ com.google.protobuf protobuf-java + + com.google.protobuf + protobuf-java-util + io.grpc grpc-api @@ -95,6 +99,11 @@ grpc-services test + + io.grpc + grpc-inprocess + test + com.fasterxml.jackson.core @@ -108,6 +117,10 @@ javax.annotation javax.annotation-api + + org.immutables + value + com.google.api.grpc @@ -305,4 +318,32 @@ + + + + jdk11+ + + [11,] + + !m2e.version + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + --add-opens=org.apache.arrow.flight.core/org.apache.arrow.flight.perf.impl=protobuf.java --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED + false + + ${project.basedir}/../../../testing/data + + + + + + + + diff --git a/java/flight/flight-core/src/main/java/module-info.java b/java/flight/flight-core/src/main/java/module-info.java new file mode 100644 index 0000000000000..f6bf5b73b0972 --- /dev/null +++ b/java/flight/flight-core/src/main/java/module-info.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +module org.apache.arrow.flight.core { + exports org.apache.arrow.flight; + exports org.apache.arrow.flight.auth; + exports org.apache.arrow.flight.auth2; + exports org.apache.arrow.flight.client; + exports org.apache.arrow.flight.impl; + exports org.apache.arrow.flight.sql.impl; + + requires com.fasterxml.jackson.databind; + requires com.google.common; + requires com.google.errorprone.annotations; + requires io.grpc; + requires io.grpc.internal; + requires io.grpc.netty; + requires io.grpc.protobuf; + requires io.grpc.stub; + requires io.netty.common; + requires io.netty.handler; + requires io.netty.transport; + requires org.apache.arrow.format; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; + requires protobuf.java; + requires org.slf4j; +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index b4ee835dee4a0..46cb282e9f3ce 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -35,8 +35,6 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; -import org.apache.arrow.vector.compression.NoCompressionCodec; -import org.apache.arrow.vector.ipc.message.ArrowBodyCompression; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.IpcOption; @@ -144,7 +142,6 @@ public static HeaderType getHeader(byte b) { private final MessageMetadataResult message; private final ArrowBuf appMetadata; private final List bufs; - private final ArrowBodyCompression bodyCompression; private final boolean tryZeroCopyWrite; public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option) { @@ -155,7 +152,6 @@ public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option bufs = ImmutableList.of(); this.descriptor = descriptor; this.appMetadata = null; - this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION; this.tryZeroCopyWrite = false; } @@ -172,7 +168,6 @@ public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata, boolean tryZer this.bufs = ImmutableList.copyOf(batch.getBuffers()); this.descriptor = null; this.appMetadata = appMetadata; - this.bodyCompression = batch.getBodyCompression(); this.tryZeroCopyWrite = tryZeroCopy; } @@ -186,7 +181,6 @@ public ArrowMessage(ArrowDictionaryBatch batch, IpcOption option) { this.bufs = ImmutableList.copyOf(batch.getDictionary().getBuffers()); this.descriptor = null; this.appMetadata = null; - this.bodyCompression = batch.getDictionary().getBodyCompression(); this.tryZeroCopyWrite = false; } @@ -201,7 +195,6 @@ public ArrowMessage(ArrowBuf appMetadata) { this.bufs = ImmutableList.of(); this.descriptor = null; this.appMetadata = appMetadata; - this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION; this.tryZeroCopyWrite = false; } @@ -212,7 +205,6 @@ public ArrowMessage(FlightDescriptor descriptor) { this.bufs = ImmutableList.of(); this.descriptor = descriptor; this.appMetadata = null; - this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION; this.tryZeroCopyWrite = false; } @@ -227,7 +219,6 @@ private ArrowMessage(FlightDescriptor descriptor, MessageMetadataResult message, this.descriptor = descriptor; this.appMetadata = appMetadata; this.bufs = buf == null ? ImmutableList.of() : ImmutableList.of(buf); - this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION; this.tryZeroCopyWrite = false; } @@ -370,7 +361,7 @@ private static int readRawVarint32(InputStream is) throws IOException { * * @return InputStream */ - private InputStream asInputStream(BufferAllocator allocator) { + private InputStream asInputStream() { if (message == null) { // If we have no IPC message, it's a pure-metadata message final FlightData.Builder builder = FlightData.newBuilder(); @@ -422,7 +413,7 @@ private InputStream asInputStream(BufferAllocator allocator) { // Arrow buffer. This is susceptible to use-after-free, so we subclass CompositeByteBuf // below to tie the Arrow buffer refcnt to the Netty buffer refcnt allBufs.add(Unpooled.wrappedBuffer(b.nioBuffer()).retain()); - size += b.readableBytes(); + size += (int) b.readableBytes(); // [ARROW-4213] These buffers must be aligned to an 8-byte boundary in order to be readable from C++. if (b.readableBytes() % 8 != 0) { int paddingBytes = (int) (8 - (b.readableBytes() % 8)); @@ -543,7 +534,7 @@ public ArrowMessageHolderMarshaller(BufferAllocator allocator) { @Override public InputStream stream(ArrowMessage value) { - return value.asInputStream(allocator); + return value.asInputStream(); } @Override diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CancelFlightInfoResult.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CancelFlightInfoResult.java index eff5afdeeb788..165afdff553df 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CancelFlightInfoResult.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CancelFlightInfoResult.java @@ -105,7 +105,7 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof CancelFlightInfoResult)) { return false; } CancelFlightInfoResult that = (CancelFlightInfoResult) o; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java index 6669ce4655010..6e19d2750cb67 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java @@ -61,7 +61,7 @@ public Iterable getAllByte(String key) { @Override public void insert(String key, String value) { - metadata.put(key, value.getBytes()); + metadata.put(key, value.getBytes(StandardCharsets.UTF_8)); } @Override diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java index dd26d190872ac..93b89e775507e 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight; +import java.nio.charset.StandardCharsets; import java.util.Collection; import java.util.Set; import java.util.stream.Collectors; @@ -46,7 +47,7 @@ public String get(String key) { } if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { - return new String((byte[]) Iterables.get(values, 0)); + return new String((byte[]) Iterables.get(values, 0), StandardCharsets.UTF_8); } return (String) Iterables.get(values, 0); @@ -63,13 +64,14 @@ public byte[] getByte(String key) { return (byte[]) Iterables.get(values, 0); } - return ((String) Iterables.get(values, 0)).getBytes(); + return ((String) Iterables.get(values, 0)).getBytes(StandardCharsets.UTF_8); } @Override public Iterable getAll(String key) { if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { - return this.keysAndValues.get(key).stream().map(o -> new String((byte[]) o)).collect(Collectors.toList()); + return this.keysAndValues.get(key).stream().map(o -> new String((byte[]) o, StandardCharsets.UTF_8)) + .collect(Collectors.toList()); } return (Collection) (Collection) this.keysAndValues.get(key); } @@ -79,7 +81,8 @@ public Iterable getAllByte(String key) { if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { return (Collection) (Collection) this.keysAndValues.get(key); } - return this.keysAndValues.get(key).stream().map(o -> ((String) o).getBytes()).collect(Collectors.toList()); + return this.keysAndValues.get(key).stream().map(o -> ((String) o).getBytes(StandardCharsets.UTF_8)) + .collect(Collectors.toList()); } @Override @@ -105,6 +108,7 @@ public boolean containsKey(String key) { return this.keysAndValues.containsKey(key); } + @Override public String toString() { return this.keysAndValues.toString(); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index 91e3b4d052f39..fc491ebe0df98 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -627,6 +628,7 @@ default boolean isCancelled() { /** * Shut down this client. */ + @Override public void close() throws InterruptedException { channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); allocator.close(); @@ -746,19 +748,24 @@ public FlightClient build() { try { // Linux builder.channelType( - (Class) Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")); - final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") - .newInstance(); + Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel") + .asSubclass(ServerChannel.class)); + final EventLoopGroup elg = + Class.forName("io.netty.channel.epoll.EpollEventLoopGroup").asSubclass(EventLoopGroup.class) + .getDeclaredConstructor().newInstance(); builder.eventLoopGroup(elg); } catch (ClassNotFoundException e) { // BSD builder.channelType( - (Class) Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")); - final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") - .newInstance(); + Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel") + .asSubclass(ServerChannel.class)); + final EventLoopGroup elg = Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") + .asSubclass(EventLoopGroup.class) + .getDeclaredConstructor().newInstance(); builder.eventLoopGroup(elg); } - } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException | + NoSuchMethodException | InvocationTargetException e) { throw new UnsupportedOperationException( "Could not find suitable Netty native transport implementation for domain socket address."); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java index 3eff011d9fe77..1836f2edd94c0 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java @@ -152,7 +152,7 @@ public boolean equals(Object obj) { if (obj == null) { return false; } - if (getClass() != obj.getClass()) { + if (!(obj instanceof FlightDescriptor)) { return false; } FlightDescriptor other = (FlightDescriptor) obj; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java index 1967fe1d91c34..41ead8e1fcddf 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java @@ -33,6 +33,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.Timestamp; +import com.google.protobuf.util.Timestamps; /** * POJO to convert to/from the underlying protobuf FlightEndpoint. @@ -85,11 +86,11 @@ private FlightEndpoint(Ticket ticket, Instant expirationTime, byte[] appMetadata } if (flt.hasExpirationTime()) { this.expirationTime = Instant.ofEpochSecond( - flt.getExpirationTime().getSeconds(), flt.getExpirationTime().getNanos()); + flt.getExpirationTime().getSeconds(), Timestamps.toNanos(flt.getExpirationTime())); } else { this.expirationTime = null; } - this.appMetadata = (flt.getAppMetadata().size() == 0 ? null : flt.getAppMetadata().toByteArray()); + this.appMetadata = (flt.getAppMetadata().isEmpty() ? null : flt.getAppMetadata().toByteArray()); this.ticket = new Ticket(flt.getTicket()); } @@ -163,7 +164,7 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof FlightEndpoint)) { return false; } FlightEndpoint that = (FlightEndpoint) o; diff --git a/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java similarity index 100% rename from java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java rename to java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java index b5279a304c865..39e5f5e3a3ed6 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -249,7 +249,7 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof FlightInfo)) { return false; } FlightInfo that = (FlightInfo) o; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java index 234c9bdcaacc1..d873f7d2828d0 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -21,6 +21,7 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; @@ -134,6 +135,7 @@ public boolean awaitTermination(final long timeout, final TimeUnit unit) throws } /** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */ + @Override public void close() throws InterruptedException { shutdown(); final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS); @@ -146,7 +148,7 @@ public void close() throws InterruptedException { server.shutdownNow(); int count = 0; - while (!server.isTerminated() & count < 30) { + while (!server.isTerminated() && count < 30) { count++; logger.debug("Waiting for termination"); Thread.sleep(100); @@ -216,22 +218,23 @@ public FlightServer build() { try { try { // Linux - builder.channelType( - (Class) Class - .forName("io.netty.channel.epoll.EpollServerDomainSocketChannel")); - final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") - .newInstance(); + builder.channelType(Class + .forName("io.netty.channel.epoll.EpollServerDomainSocketChannel") + .asSubclass(ServerChannel.class)); + final EventLoopGroup elg = Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") + .asSubclass(EventLoopGroup.class).getConstructor().newInstance(); builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg); } catch (ClassNotFoundException e) { // BSD builder.channelType( - (Class) Class - .forName("io.netty.channel.kqueue.KQueueServerDomainSocketChannel")); - final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") - .newInstance(); + Class.forName("io.netty.channel.kqueue.KQueueServerDomainSocketChannel") + .asSubclass(ServerChannel.class)); + final EventLoopGroup elg = Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") + .asSubclass(EventLoopGroup.class).getConstructor().newInstance(); builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg); } - } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException | NoSuchMethodException | + InvocationTargetException e) { throw new UnsupportedOperationException( "Could not find suitable Netty native transport implementation for domain socket address."); } @@ -342,7 +345,8 @@ private void closeInputStreamIfNotNull(InputStream stream) { if (stream != null) { try { stream.close(); - } catch (IOException ignored) { + } catch (IOException expected) { + // stream closes gracefully, doesn't expect an exception. } } } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java index 5231a7aaf76e4..f55b47d2a945b 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java @@ -20,6 +20,7 @@ import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.function.BooleanSupplier; import java.util.function.Consumer; @@ -142,7 +143,7 @@ public void listActions(Flight.Empty request, StreamObserver } private static class GetListener extends OutboundStreamListenerImpl implements ServerStreamListener { - private ServerCallStreamObserver responseObserver; + private final ServerCallStreamObserver serverCallResponseObserver; private final Consumer errorHandler; private Runnable onCancelHandler = null; private Runnable onReadyHandler = null; @@ -152,10 +153,10 @@ public GetListener(ServerCallStreamObserver responseObserver, Cons super(null, responseObserver); this.errorHandler = errorHandler; this.completed = false; - this.responseObserver = responseObserver; - this.responseObserver.setOnCancelHandler(this::onCancel); - this.responseObserver.setOnReadyHandler(this::onReady); - this.responseObserver.disableAutoInboundFlowControl(); + this.serverCallResponseObserver = responseObserver; + this.serverCallResponseObserver.setOnCancelHandler(this::onCancel); + this.serverCallResponseObserver.setOnReadyHandler(this::onReady); + this.serverCallResponseObserver.disableAutoInboundFlowControl(); } private void onCancel() { @@ -183,7 +184,7 @@ public void setOnReadyHandler(Runnable handler) { @Override public boolean isCancelled() { - return responseObserver.isCancelled(); + return serverCallResponseObserver.isCancelled(); } @Override @@ -228,7 +229,7 @@ public StreamObserver doPutCustom(final StreamObserver observer = fs.asObserver(); - executors.submit(() -> { + Future unused = executors.submit(() -> { try { producer.acceptPut(makeContext(responseObserver), fs, ackStream).run(); } catch (Throwable ex) { @@ -277,7 +278,8 @@ public void pollFlightInfo(Flight.FlightDescriptor request, StreamObserver, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get(); + final Map, FlightServerMiddleware> middleware = ServerInterceptorAdapter + .SERVER_MIDDLEWARE_KEY.get(); if (middleware == null || middleware.isEmpty()) { logger.error("Uncaught exception in Flight method body", t); return; @@ -377,7 +379,7 @@ public StreamObserver doExchangeCustom(StreamObserver observer = fs.asObserver(); try { - executors.submit(() -> { + Future unused = executors.submit(() -> { try { producer.doExchange(makeContext(responseObserver), fs, listener); } catch (Exception ex) { @@ -416,8 +418,9 @@ public boolean isCancelled() { } @Override - public T getMiddleware(Key key) { - final Map, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get(); + public T getMiddleware(FlightServerMiddleware.Key key) { + final Map, FlightServerMiddleware> middleware = ServerInterceptorAdapter + .SERVER_MIDDLEWARE_KEY.get(); if (middleware == null) { return null; } @@ -430,8 +433,9 @@ public T getMiddleware(Key key) { } @Override - public Map, FlightServerMiddleware> getMiddleware() { - final Map, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get(); + public Map, FlightServerMiddleware> getMiddleware() { + final Map, FlightServerMiddleware> middleware = + ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get(); if (middleware == null) { return Collections.emptyMap(); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java index ad4ffcbebdec1..7a5a941603ace 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -27,6 +27,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.arrow.flight.ArrowMessage.HeaderType; import org.apache.arrow.flight.grpc.StatusUtils; @@ -56,9 +57,17 @@ */ public class FlightStream implements AutoCloseable { // Use AutoCloseable sentinel objects to simplify logic in #close - private final AutoCloseable DONE = () -> { + private final AutoCloseable DONE = new AutoCloseable() { + @Override + public void close() throws Exception { + + } }; - private final AutoCloseable DONE_EX = () -> { + private final AutoCloseable DONE_EX = new AutoCloseable() { + @Override + public void close() throws Exception { + + } }; private final BufferAllocator allocator; @@ -76,7 +85,7 @@ public class FlightStream implements AutoCloseable { // we don't block forever trying to write to a server that has rejected a call. final CompletableFuture cancelled; - private volatile int pending = 1; + private final AtomicInteger pending = new AtomicInteger(); private volatile VectorSchemaRoot fulfilledRoot; private DictionaryProvider.MapDictionaryProvider dictionaries; private volatile VectorLoader loader; @@ -169,6 +178,7 @@ public FlightDescriptor getDescriptor() { * *

If the stream isn't complete and is cancellable, this method will cancel and drain the stream first. */ + @Override public void close() throws Exception { final List closeables = new ArrayList<>(); Throwable suppressor = null; @@ -227,7 +237,7 @@ public boolean next() { return false; } - pending--; + pending.decrementAndGet(); requestOutstanding(); Object data = queue.take(); @@ -359,9 +369,9 @@ public ArrowBuf getLatestMetadata() { } private synchronized void requestOutstanding() { - if (pending < pendingTarget) { - requestor.request(pendingTarget - pending); - pending = pendingTarget; + if (pending.get() < pendingTarget) { + requestor.request(pendingTarget - pending.get()); + pending.set(pendingTarget); } } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java index 9dba773bf3386..fe192aa0c3f9d 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java @@ -71,7 +71,7 @@ public SocketAddress toSocketAddress() { case LocationSchemes.GRPC_DOMAIN_SOCKET: { try { // This dependency is not available on non-Unix platforms. - return (SocketAddress) Class.forName("io.netty.channel.unix.DomainSocketAddress") + return Class.forName("io.netty.channel.unix.DomainSocketAddress").asSubclass(SocketAddress.class) .getConstructor(String.class) .newInstance(uri.getPath()); } catch (InstantiationException | ClassNotFoundException | InvocationTargetException | @@ -144,7 +144,7 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof Location)) { return false; } Location location = (Location) o; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PollInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PollInfo.java index 2bb3c6db69569..59150d8814cd9 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PollInfo.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PollInfo.java @@ -27,6 +27,7 @@ import org.apache.arrow.flight.impl.Flight; import com.google.protobuf.Timestamp; +import com.google.protobuf.util.Timestamps; /** * A POJO representation of the execution of a long-running query. @@ -57,7 +58,7 @@ public PollInfo(FlightInfo flightInfo, FlightDescriptor flightDescriptor, Double this.flightDescriptor = flt.hasFlightDescriptor() ? new FlightDescriptor(flt.getFlightDescriptor()) : null; this.progress = flt.hasProgress() ? flt.getProgress() : null; this.expirationTime = flt.hasExpirationTime() ? - Instant.ofEpochSecond(flt.getExpirationTime().getSeconds(), flt.getExpirationTime().getNanos()) : + Instant.ofEpochSecond(flt.getExpirationTime().getSeconds(), Timestamps.toNanos(flt.getExpirationTime())) : null; } @@ -133,7 +134,8 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + + if (!(o instanceof PollInfo)) { return false; } PollInfo pollInfo = (PollInfo) o; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java index a93cd087905db..eb2f4af70d781 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java @@ -88,7 +88,7 @@ public boolean equals(Object obj) { if (obj == null) { return false; } - if (getClass() != obj.getClass()) { + if (!(obj instanceof Ticket)) { return false; } Ticket other = (Ticket) obj; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java index ac55872e5b18b..e3ccdc626d71b 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java @@ -20,8 +20,8 @@ import org.apache.arrow.flight.FlightConstants; import io.grpc.Context; +import io.grpc.Metadata; import io.grpc.Metadata.BinaryMarshaller; -import io.grpc.Metadata.Key; import io.grpc.MethodDescriptor; /** @@ -32,7 +32,7 @@ public final class AuthConstants { public static final String HANDSHAKE_DESCRIPTOR_NAME = MethodDescriptor .generateFullMethodName(FlightConstants.SERVICE, "Handshake"); public static final String TOKEN_NAME = "Auth-Token-bin"; - public static final Key TOKEN_KEY = Key.of(TOKEN_NAME, new BinaryMarshaller() { + public static final Metadata.Key TOKEN_KEY = Metadata.Key.of(TOKEN_NAME, new BinaryMarshaller() { @Override public byte[] toBytes(byte[] value) { diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java index 2006e0a2b1241..5eb5863e792d4 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java @@ -55,7 +55,6 @@ public AuthResult authenticate(CallHeaders incomingHeaders) { * Validate the bearer token. * @param bearerToken The bearer token to validate. * @return A successful AuthResult if validation succeeded. - * @throws Exception If the token validation fails. */ protected abstract AuthResult validateBearer(String bearerToken); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java index ae11e52605623..db27aa481ec75 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java @@ -23,7 +23,6 @@ import org.apache.arrow.flight.CallInfo; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightClientMiddleware; -import org.apache.arrow.flight.FlightClientMiddleware.Factory; import org.apache.arrow.flight.FlightMethod; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; @@ -46,9 +45,9 @@ */ public class ClientInterceptorAdapter implements ClientInterceptor { - private final List factories; + private final List factories; - public ClientInterceptorAdapter(List factories) { + public ClientInterceptorAdapter(List factories) { this.factories = factories; } @@ -59,7 +58,7 @@ public ClientCall interceptCall(MethodDescriptor getAll(String key) { - return this.metadata.getAll(Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + return this.metadata.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); } @Override public Iterable getAllByte(String key) { if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { - return this.metadata.getAll(Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + return this.metadata.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); } return StreamSupport.stream(getAll(key).spliterator(), false) .map(String::getBytes).collect(Collectors.toList()); @@ -69,12 +70,12 @@ public Iterable getAllByte(String key) { @Override public void insert(String key, String value) { - this.metadata.put(Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value); + this.metadata.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value); } @Override public void insert(String key, byte[] value) { - this.metadata.put(Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), value); + this.metadata.put(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), value); } @Override @@ -85,13 +86,14 @@ public Set keys() { @Override public boolean containsKey(String key) { if (key.endsWith("-bin")) { - final Key grpcKey = Key.of(key, Metadata.BINARY_BYTE_MARSHALLER); + final Metadata.Key grpcKey = Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER); return this.metadata.containsKey(grpcKey); } - final Key grpcKey = Key.of(key, Metadata.ASCII_STRING_MARSHALLER); + final Metadata.Key grpcKey = Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER); return this.metadata.containsKey(grpcKey); } + @Override public String toString() { return this.metadata.toString(); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java index 9b038b9d49272..70c667df56020 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java @@ -61,7 +61,7 @@ public static class KeyFactory { private final FlightServerMiddleware.Key key; private final FlightServerMiddleware.Factory factory; - public KeyFactory(Key key, Factory factory) { + public KeyFactory(FlightServerMiddleware.Key key, FlightServerMiddleware.Factory factory) { this.key = key; this.factory = factory; } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java index 55e8418642d36..7f0dcf2da3f0d 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight.grpc; +import java.nio.charset.StandardCharsets; import java.util.Iterator; import java.util.Objects; import java.util.function.Function; @@ -171,7 +172,7 @@ private static ErrorFlightMetadata parseTrailers(Metadata trailers) { if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { metadata.insert(key, trailers.get(keyOfBinary(key))); } else { - metadata.insert(key, Objects.requireNonNull(trailers.get(keyOfAscii(key))).getBytes()); + metadata.insert(key, Objects.requireNonNull(trailers.get(keyOfAscii(key))).getBytes(StandardCharsets.UTF_8)); } } return metadata; diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java index 64f70856a3b05..393fa086775ed 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java @@ -23,9 +23,8 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.List; -import java.util.Random; -import org.apache.arrow.util.ArrowTestDataUtil; +import org.apache.arrow.vector.test.util.ArrowTestDataUtil; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.function.Executable; @@ -34,8 +33,6 @@ */ public class FlightTestUtil { - private static final Random RANDOM = new Random(); - public static final String LOCALHOST = "localhost"; static Path getFlightTestDataRoot() { diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java index 07db301309e3d..77c039afd87a0 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java @@ -20,6 +20,7 @@ import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST; import static org.apache.arrow.flight.Location.forGrpcInsecure; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; import java.util.concurrent.ExecutionException; @@ -45,7 +46,7 @@ public class TestApplicationMetadata { // The command used to trigger the test for ARROW-6136. - private static final byte[] COMMAND_ARROW_6136 = "ARROW-6136".getBytes(); + private static final byte[] COMMAND_ARROW_6136 = "ARROW-6136".getBytes(StandardCharsets.UTF_8); // The expected error message. private static final String MESSAGE_ARROW_6136 = "The stream should not be double-closed."; diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java index 7586d50c8e713..596debcf89dd2 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java @@ -22,6 +22,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; @@ -153,7 +154,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l loadData.run(); } else { final ExecutorService service = Executors.newSingleThreadExecutor(); - service.submit(loadData); + Future unused = service.submit(loadData); service.shutdown(); } } @@ -237,7 +238,8 @@ public WaitResult waitForListener(long timeout) { try { Thread.sleep(1); sleepTime.addAndGet(1L); - } catch (InterruptedException ignore) { + } catch (InterruptedException expected) { + // it is expected and no action needed } } return WaitResult.READY; diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index 41b3a4693e579..ae520ee9b991b 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight; + import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST; import static org.apache.arrow.flight.Location.forGrpcInsecure; @@ -114,11 +115,11 @@ public void roundTripInfo() throws Exception { Field.nullable("b", new ArrowType.FixedSizeBinary(32)) ), metadata); final FlightInfo info1 = FlightInfo.builder(schema, FlightDescriptor.path(), Collections.emptyList()) - .setAppMetadata("foo".getBytes()).build(); + .setAppMetadata("foo".getBytes(StandardCharsets.UTF_8)).build(); final FlightInfo info2 = new FlightInfo(schema, FlightDescriptor.command(new byte[2]), Collections.singletonList( FlightEndpoint.builder(new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock")) - .setAppMetadata("bar".getBytes()).build() + .setAppMetadata("bar".getBytes(StandardCharsets.UTF_8)).build() ), 200, 500); final FlightInfo info3 = new FlightInfo(schema, FlightDescriptor.path("a", "b"), Arrays.asList(new FlightEndpoint( @@ -160,7 +161,7 @@ public void roundTripDescriptor() throws Exception { public void getDescriptors() throws Exception { test(c -> { int count = 0; - for (FlightInfo i : c.listFlights(Criteria.ALL)) { + for (FlightInfo unused : c.listFlights(Criteria.ALL)) { count += 1; } Assertions.assertEquals(1, count); @@ -171,7 +172,8 @@ public void getDescriptors() throws Exception { public void getDescriptorsWithCriteria() throws Exception { test(c -> { int count = 0; - for (FlightInfo i : c.listFlights(new Criteria(new byte[]{1}))) { + for (FlightInfo unused : c.listFlights(new Criteria(new byte[]{1}))) { + count += 1; } Assertions.assertEquals(0, count); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java index 8b1a897467d58..41df36c863325 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java @@ -21,6 +21,7 @@ import static org.apache.arrow.flight.Location.forGrpcInsecure; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.Iterator; @@ -87,8 +88,8 @@ public void multipleProperties() { @Test public void binaryProperties() { final FlightCallHeaders headers = new FlightCallHeaders(); - headers.insert("key-bin", "value".getBytes()); - headers.insert("key3-bin", "ëfßæ".getBytes()); + headers.insert("key-bin", "value".getBytes(StandardCharsets.UTF_8)); + headers.insert("key3-bin", "ëfßæ".getBytes(StandardCharsets.UTF_8)); testHeaders(headers); } @@ -96,7 +97,7 @@ public void binaryProperties() { public void mixedProperties() { final FlightCallHeaders headers = new FlightCallHeaders(); headers.insert("key", "value"); - headers.insert("key3-bin", "ëfßæ".getBytes()); + headers.insert("key3-bin", "ëfßæ".getBytes(StandardCharsets.UTF_8)); testHeaders(headers); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java index b3a716ab3cec5..40930131e0ca8 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java @@ -18,7 +18,8 @@ package org.apache.arrow.flight; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; import java.util.TreeSet; @@ -54,7 +55,7 @@ public void testReuseSchema() { Schema newSchema = DictionaryUtils.generateSchema(schema, null, new TreeSet<>()); // assert that no new schema is created. - assertTrue(schema == newSchema); + assertSame(schema, newSchema); } @Test @@ -78,7 +79,7 @@ public void testCreateSchema() { Schema newSchema = DictionaryUtils.generateSchema(schema, dictProvider, dictionaryUsed); // assert that a new schema is created. - assertTrue(schema != newSchema); + assertNotSame(schema, newSchema); // assert the column is converted as expected ArrowType newColType = newSchema.getFields().get(0).getType(); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java index f9db9bfd23a88..b70353df8e9a7 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java @@ -476,7 +476,7 @@ public void doExchange(CallContext context, FlightStream reader, ServerStreamLis } /** Emulate DoGet. */ - private void doGet(CallContext context, FlightStream reader, ServerStreamListener writer) { + private void doGet(CallContext unusedContext, FlightStream unusedReader, ServerStreamListener writer) { try (VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA, allocator)) { writer.start(root); root.allocateNew(); @@ -493,7 +493,7 @@ private void doGet(CallContext context, FlightStream reader, ServerStreamListene } /** Emulate DoPut. */ - private void doPut(CallContext context, FlightStream reader, ServerStreamListener writer) { + private void doPut(CallContext unusedContext, FlightStream reader, ServerStreamListener writer) { int counter = 0; while (reader.next()) { if (!reader.hasRoot()) { @@ -510,7 +510,7 @@ private void doPut(CallContext context, FlightStream reader, ServerStreamListene } /** Exchange metadata without ever exchanging data. */ - private void metadataOnly(CallContext context, FlightStream reader, ServerStreamListener writer) { + private void metadataOnly(CallContext unusedContext, FlightStream reader, ServerStreamListener writer) { final ArrowBuf buf = allocator.buffer(4); buf.writeInt(42); writer.putMetadata(buf); @@ -522,7 +522,7 @@ private void metadataOnly(CallContext context, FlightStream reader, ServerStream } /** Echo the client's response back to it. */ - private void echo(CallContext context, FlightStream reader, ServerStreamListener writer) { + private void echo(CallContext unusedContext, FlightStream reader, ServerStreamListener writer) { VectorSchemaRoot root = null; VectorLoader loader = null; while (reader.next()) { @@ -555,7 +555,7 @@ private void echo(CallContext context, FlightStream reader, ServerStreamListener } /** Accept a set of messages, then return some result. */ - private void transform(CallContext context, FlightStream reader, ServerStreamListener writer) { + private void transform(CallContext unusedContext, FlightStream reader, ServerStreamListener writer) { final Schema schema = reader.getSchema(); for (final Field field : schema.getFields()) { if (!(field.getType() instanceof ArrowType.Int)) { @@ -597,11 +597,11 @@ private void transform(CallContext context, FlightStream reader, ServerStreamLis } /** Immediately cancel the call. */ - private void cancel(CallContext context, FlightStream reader, ServerStreamListener writer) { + private void cancel(CallContext unusedContext, FlightStream unusedReader, ServerStreamListener writer) { writer.error(CallStatus.CANCELLED.withDescription("expected").toRuntimeException()); } - private void error(CallContext context, FlightStream reader, ServerStreamListener writer) { + private void error(CallContext unusedContext, FlightStream reader, ServerStreamListener writer) { VectorSchemaRoot root = null; VectorLoader loader = null; while (reader.next()) { diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java index 1987d98196e9d..4ec7301466228 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java @@ -20,6 +20,8 @@ import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST; import static org.apache.arrow.flight.Location.forGrpcInsecure; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.flight.perf.impl.PerfOuterClass; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -53,7 +55,7 @@ public void testGrpcMetadata() throws Exception { .start(); final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { final CallStatus flightStatus = FlightTestUtil.assertCode(FlightStatusCode.CANCELLED, () -> { - FlightStream stream = client.getStream(new Ticket("abs".getBytes())); + FlightStream stream = client.getStream(new Ticket("abs".getBytes(StandardCharsets.UTF_8))); stream.next(); }); PerfOuterClass.Perf newPerf = null; diff --git a/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java similarity index 98% rename from java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java rename to java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java index 9010f2d4a98f0..2569d2ac2b384 100644 --- a/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java @@ -176,7 +176,7 @@ public void testProxyChannelWithClosedChannel() throws IOException, InterruptedE /** * Private class used for testing purposes that overrides service behavior. */ - private class TestServiceAdapter extends TestServiceGrpc.TestServiceImplBase { + private static class TestServiceAdapter extends TestServiceGrpc.TestServiceImplBase { /** * gRPC service that receives an empty object & returns and empty protobuf object. diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java index 0e4669f29ce43..de1b7750da3bf 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java @@ -21,6 +21,7 @@ import static org.apache.arrow.flight.Location.forGrpcInsecure; import static org.junit.jupiter.api.Assertions.fail; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.Optional; @@ -139,7 +140,7 @@ public void supportsNullSchemas() throws Exception public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { return new FlightInfo(null, descriptor, Collections.emptyList(), - 0, 0, false, IpcOption.DEFAULT, "foo".getBytes()); + 0, 0, false, IpcOption.DEFAULT, "foo".getBytes(StandardCharsets.UTF_8)); } }; @@ -149,7 +150,7 @@ public FlightInfo getFlightInfo(CallContext context, FlightInfo flightInfo = client.getInfo(FlightDescriptor.path("test")); Assertions.assertEquals(Optional.empty(), flightInfo.getSchemaOptional()); Assertions.assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema()); - Assertions.assertArrayEquals(flightInfo.getAppMetadata(), "foo".getBytes()); + Assertions.assertArrayEquals(flightInfo.getAppMetadata(), "foo".getBytes(StandardCharsets.UTF_8)); Exception e = Assertions.assertThrows( FlightRuntimeException.class, diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java index 1feb6afcf8f05..430dc29a7d0c2 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight; +import static com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST; import static org.apache.arrow.flight.Location.forGrpcInsecure; @@ -89,7 +90,7 @@ private static VectorSchemaRoot generateData(BufferAllocator allocator) { final Stream fields = fieldNames .stream() .map(fieldName -> new Field(fieldName, FieldType.nullable(new ArrowType.Int(32, true)), null)); - final Schema schema = new Schema(fields::iterator, null); + final Schema schema = new Schema(fields.collect(toImmutableList()), null); final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); root.allocateNew(); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java index 0e19468d2b409..3bc8f2f90a612 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java @@ -38,15 +38,13 @@ public class TestServerMiddleware { - private static final RuntimeException EXPECTED_EXCEPTION = new RuntimeException("test"); - /** * Make sure errors in DoPut are intercepted. */ @Test public void doPutErrors() { test( - new ErrorProducer(EXPECTED_EXCEPTION), + new ErrorProducer(new RuntimeException("test")), (allocator, client) -> { final FlightDescriptor descriptor = FlightDescriptor.path("test"); try (final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) { @@ -91,7 +89,7 @@ public void doPutCustomCode() { */ @Test public void doPutUncaught() { - test(new ServerErrorProducer(EXPECTED_EXCEPTION), + test(new ServerErrorProducer(new RuntimeException("test")), (allocator, client) -> { final FlightDescriptor descriptor = FlightDescriptor.path("test"); try (final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) { @@ -106,13 +104,13 @@ public void doPutUncaught() { Assertions.assertEquals(FlightStatusCode.OK, status.code()); Assertions.assertNull(status.cause()); Assertions.assertNotNull(err); - Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertEquals("test", err.getMessage()); }); } @Test public void listFlightsUncaught() { - test(new ServerErrorProducer(EXPECTED_EXCEPTION), + test(new ServerErrorProducer(new RuntimeException("test")), (allocator, client) -> client.listFlights(new Criteria(new byte[0])).forEach((action) -> { }), (recorder) -> { final CallStatus status = recorder.statusFuture.get(); @@ -121,13 +119,13 @@ public void listFlightsUncaught() { Assertions.assertEquals(FlightStatusCode.OK, status.code()); Assertions.assertNull(status.cause()); Assertions.assertNotNull(err); - Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertEquals("test", err.getMessage()); }); } @Test public void doActionUncaught() { - test(new ServerErrorProducer(EXPECTED_EXCEPTION), + test(new ServerErrorProducer(new RuntimeException("test")), (allocator, client) -> client.doAction(new Action("test")).forEachRemaining(result -> { }), (recorder) -> { final CallStatus status = recorder.statusFuture.get(); @@ -136,13 +134,13 @@ public void doActionUncaught() { Assertions.assertEquals(FlightStatusCode.OK, status.code()); Assertions.assertNull(status.cause()); Assertions.assertNotNull(err); - Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertEquals("test", err.getMessage()); }); } @Test public void listActionsUncaught() { - test(new ServerErrorProducer(EXPECTED_EXCEPTION), + test(new ServerErrorProducer(new RuntimeException("test")), (allocator, client) -> client.listActions().forEach(result -> { }), (recorder) -> { final CallStatus status = recorder.statusFuture.get(); @@ -151,13 +149,13 @@ public void listActionsUncaught() { Assertions.assertEquals(FlightStatusCode.OK, status.code()); Assertions.assertNull(status.cause()); Assertions.assertNotNull(err); - Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertEquals("test", err.getMessage()); }); } @Test public void getFlightInfoUncaught() { - test(new ServerErrorProducer(EXPECTED_EXCEPTION), + test(new ServerErrorProducer(new RuntimeException("test")), (allocator, client) -> { FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, () -> client.getInfo(FlightDescriptor.path("test"))); }, (recorder) -> { @@ -165,13 +163,13 @@ public void getFlightInfoUncaught() { Assertions.assertNotNull(status); Assertions.assertEquals(FlightStatusCode.INTERNAL, status.code()); Assertions.assertNotNull(status.cause()); - Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), status.cause().getMessage()); + Assertions.assertEquals(new RuntimeException("test").getMessage(), status.cause().getMessage()); }); } @Test public void doGetUncaught() { - test(new ServerErrorProducer(EXPECTED_EXCEPTION), + test(new ServerErrorProducer(new RuntimeException("test")), (allocator, client) -> { try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { while (stream.next()) { @@ -186,7 +184,7 @@ public void doGetUncaught() { Assertions.assertEquals(FlightStatusCode.OK, status.code()); Assertions.assertNull(status.cause()); Assertions.assertNotNull(err); - Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertEquals("test", err.getMessage()); }); } @@ -305,7 +303,7 @@ static class ServerMiddlewarePair { final FlightServerMiddleware.Key key; final FlightServerMiddleware.Factory factory; - ServerMiddlewarePair(Key key, Factory factory) { + ServerMiddlewarePair(FlightServerMiddleware.Key key, FlightServerMiddleware.Factory factory) { this.key = key; this.factory = factory; } @@ -339,7 +337,7 @@ static void test(FlightProducer producer, BiConsumer verify) { final ErrorRecorder.Factory factory = new ErrorRecorder.Factory(); final List> middleware = Collections - .singletonList(new ServerMiddlewarePair<>(Key.of("m"), factory)); + .singletonList(new ServerMiddlewarePair<>(FlightServerMiddleware.Key.of("m"), factory)); test(producer, middleware, (allocator, client) -> { body.accept(allocator, client); try { diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java index 0f6697a8e519c..75bc5f6e61589 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java @@ -27,7 +27,6 @@ import java.util.Iterator; import java.util.function.Consumer; -import org.apache.arrow.flight.FlightClient.Builder; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.junit.jupiter.api.Assertions; @@ -105,7 +104,7 @@ public void connectTlsDisableServerVerification() { }); } - void test(Consumer testFn) { + void test(Consumer testFn) { final FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0); try ( BufferAllocator a = new RootAllocator(Long.MAX_VALUE); @@ -113,7 +112,8 @@ void test(Consumer testFn) { FlightServer s = FlightServer.builder(a, forGrpcInsecure(LOCALHOST, 0), producer) .useTls(certKey.cert, certKey.key) .build().start()) { - final Builder builder = FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, s.getPort())); + final FlightClient.Builder builder = FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, + s.getPort())); testFn.accept(builder); } catch (InterruptedException | IOException e) { throw new RuntimeException(e); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java index 4b8a11870dab6..d34a3a2d3a51e 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java @@ -251,6 +251,7 @@ public ClientCookieMiddleware onCallStarted(CallInfo info) { private void startServerAndClient() throws IOException { final FlightProducer flightProducer = new NoOpFlightProducer() { + @Override public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { listener.onCompleted(); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java index 0ded2f7065f9c..b1e83ea61ed53 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import org.apache.arrow.flight.BackpressureStrategy; import org.apache.arrow.flight.FlightDescriptor; @@ -48,10 +49,7 @@ public class PerformanceTestServer implements AutoCloseable { - private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(PerformanceTestServer.class); - private final FlightServer flightServer; - private final Location location; private final BufferAllocator allocator; private final PerfProducer producer; private final boolean isNonBlocking; @@ -78,7 +76,6 @@ public WaitResult waitForListener(long timeout) { public PerformanceTestServer(BufferAllocator incomingAllocator, Location location, BackpressureStrategy bpStrategy, boolean isNonBlocking) { this.allocator = incomingAllocator.newChildAllocator("perf-server", 0, Long.MAX_VALUE); - this.location = location; this.producer = new PerfProducer(bpStrategy); this.flightServer = FlightServer.builder(this.allocator, location, producer).build(); this.isNonBlocking = isNonBlocking; @@ -110,16 +107,18 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { bpStrategy.register(listener); final Runnable loadData = () -> { - VectorSchemaRoot root = null; + Token token = null; try { - Token token = Token.parseFrom(ticket.getBytes()); - Perf perf = token.getDefinition(); - Schema schema = Schema.deserializeMessage(perf.getSchema().asReadOnlyByteBuffer()); - root = VectorSchemaRoot.create(schema, allocator); - BigIntVector a = (BigIntVector) root.getVector("a"); - BigIntVector b = (BigIntVector) root.getVector("b"); - BigIntVector c = (BigIntVector) root.getVector("c"); - BigIntVector d = (BigIntVector) root.getVector("d"); + token = Token.parseFrom(ticket.getBytes()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + Perf perf = token.getDefinition(); + Schema schema = Schema.deserializeMessage(perf.getSchema().asReadOnlyByteBuffer()); + try ( + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + BigIntVector a = (BigIntVector) root.getVector("a") + ) { listener.setUseZeroCopy(true); listener.start(root); root.allocateNew(); @@ -158,14 +157,6 @@ public void getStream(CallContext context, Ticket ticket, listener.putNext(); } listener.completed(); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException(e); - } finally { - try { - AutoCloseables.close(root); - } catch (Exception e) { - throw new RuntimeException(e); - } } }; @@ -173,7 +164,7 @@ public void getStream(CallContext context, Ticket ticket, loadData.run(); } else { final ExecutorService service = Executors.newSingleThreadExecutor(); - service.submit(loadData); + Future unused = service.submit(loadData); service.shutdown(); } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java index 17c83c205feb0..290e82de36c57 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java @@ -110,14 +110,14 @@ public void throughput() throws Exception { double seconds = r.nanos * 1.0d / 1000 / 1000 / 1000; throughPuts[i] = (r.bytes * 1.0d / 1024 / 1024) / seconds; - System.out.println(String.format( - "Transferred %d records totaling %s bytes at %f MiB/s. %f record/s. %f batch/s.", + System.out.printf( + "Transferred %d records totaling %s bytes at %f MiB/s. %f record/s. %f batch/s.%n", r.rows, r.bytes, throughPuts[i], (r.rows * 1.0d) / seconds, (r.batches * 1.0d) / seconds - )); + ); } } pool.shutdown(); @@ -126,11 +126,11 @@ public void throughput() throws Exception { double average = Arrays.stream(throughPuts).sum() / numRuns; double sqrSum = Arrays.stream(throughPuts).map(val -> val - average).map(val -> val * val).sum(); double stddev = Math.sqrt(sqrSum / numRuns); - System.out.println(String.format("Average throughput: %f MiB/s, standard deviation: %f MiB/s", - average, stddev)); + System.out.printf("Average throughput: %f MiB/s, standard deviation: %f MiB/s%n", + average, stddev); } - private final class Consumer implements Callable { + private static final class Consumer implements Callable { private final FlightClient client; private final Ticket ticket; @@ -157,7 +157,7 @@ public Result call() throws Exception { aSum += a.get(i); } } - r.bytes += rows * 32; + r.bytes += rows * 32L; r.rows += rows; r.aSum = aSum; r.batches++; @@ -173,7 +173,7 @@ public Result call() throws Exception { } - private final class Result { + private static final class Result { private long rows; private long aSum; private long bytes; diff --git a/java/flight/flight-grpc/src/test/protobuf/test.proto b/java/flight/flight-core/src/test/protobuf/test.proto similarity index 100% rename from java/flight/flight-grpc/src/test/protobuf/test.proto rename to java/flight/flight-core/src/test/protobuf/test.proto diff --git a/java/flight/flight-grpc/pom.xml b/java/flight/flight-grpc/pom.xml deleted file mode 100644 index af765f8c436be..0000000000000 --- a/java/flight/flight-grpc/pom.xml +++ /dev/null @@ -1,123 +0,0 @@ - - - - - arrow-flight - org.apache.arrow - 15.0.0-SNAPSHOT - ../pom.xml - - 4.0.0 - - flight-grpc - Arrow Flight GRPC - (Experimental)Contains utility class to expose Flight gRPC service and client - jar - - - 1 - - - - - org.apache.arrow - flight-core - - - io.netty - netty-transport-native-unix-common - - - io.netty - netty-transport-native-kqueue - - - io.netty - netty-transport-native-epoll - - - - - io.grpc - grpc-stub - - - io.grpc - grpc-inprocess - test - - - org.apache.arrow - arrow-memory-core - compile - - - org.apache.arrow - arrow-memory-netty - runtime - - - io.grpc - grpc-protobuf - - - com.google.guava - guava - - - com.google.protobuf - protobuf-java - - - io.grpc - grpc-api - - - - - - - - kr.motd.maven - os-maven-plugin - 1.7.0 - - - - - org.xolstice.maven.plugins - protobuf-maven-plugin - 0.6.1 - - com.google.protobuf:protoc:${dep.protobuf-bom.version}:exe:${os.detected.classifier} - false - grpc-java - io.grpc:protoc-gen-grpc-java:${dep.grpc-bom.version}:exe:${os.detected.classifier} - - - - test - - ${basedir}/src/test/protobuf - ${project.build.directory}/generated-test-sources//protobuf - - - compile - compile-custom - - - - - - - - diff --git a/java/flight/flight-grpc/src/test/resources/logback.xml b/java/flight/flight-grpc/src/test/resources/logback.xml deleted file mode 100644 index 4c54d18a210ff..0000000000000 --- a/java/flight/flight-grpc/src/test/resources/logback.xml +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n - - - - - - - - - diff --git a/java/flight/flight-integration-tests/pom.xml b/java/flight/flight-integration-tests/pom.xml index bb4f6a6b18733..944c624d630a2 100644 --- a/java/flight/flight-integration-tests/pom.xml +++ b/java/flight/flight-integration-tests/pom.xml @@ -15,7 +15,7 @@ arrow-flight org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT ../pom.xml diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java index 2ea9874f3dec3..64b5882c0f50d 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java @@ -98,14 +98,14 @@ private void run(String[] args) throws Exception { Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client); } else { final String inputPath = cmd.getOptionValue("j"); - testStream(allocator, defaultLocation, client, inputPath); + testStream(allocator, client, inputPath); } } catch (InterruptedException e) { throw new RuntimeException(e); } } - private static void testStream(BufferAllocator allocator, Location server, FlightClient client, String inputPath) + private static void testStream(BufferAllocator allocator, FlightClient client, String inputPath) throws IOException { // 1. Read data from JSON and upload to server. FlightDescriptor descriptor = FlightDescriptor.path(inputPath); diff --git a/java/flight/flight-sql-jdbc-core/pom.xml b/java/flight/flight-sql-jdbc-core/pom.xml index 1f20912b9974f..ce1f52e39676e 100644 --- a/java/flight/flight-sql-jdbc-core/pom.xml +++ b/java/flight/flight-sql-jdbc-core/pom.xml @@ -16,7 +16,7 @@ arrow-flight org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT ../pom.xml 4.0.0 diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java index aa1b460fc136a..183e3d5c7b055 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java @@ -43,6 +43,7 @@ import org.apache.calcite.avatica.Meta; import org.apache.calcite.avatica.UnregisteredDriver; + /** * JDBC driver for querying data from an Apache Arrow Flight server. */ @@ -99,6 +100,7 @@ protected String getFactoryClassName(final JdbcVersion jdbcVersion) { } @Override + @SuppressWarnings("StringSplitter") protected DriverVersion createDriverVersion() { if (version == null) { final InputStream flightProperties = this.getClass().getResourceAsStream("/properties/flight.properties"); diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index 382750914992f..d25f03ac27b48 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -255,11 +255,6 @@ private static final class StatementHandleKey { public final String connectionId; public final int id; - StatementHandleKey(String connectionId, int id) { - this.connectionId = connectionId; - this.id = id; - } - StatementHandleKey(StatementHandle statementHandle) { this.connectionId = statementHandle.connectionId; this.id = statementHandle.id; diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java index aea9b75fa6c3f..8d2fe1cc70319 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java @@ -46,59 +46,57 @@ public class ArrowFlightJdbcBaseIntVectorAccessor extends ArrowFlightJdbcAccesso private final MinorType type; private final boolean isUnsigned; - private final int bytesToAllocate; private final Getter getter; private final NumericHolder holder; public ArrowFlightJdbcBaseIntVectorAccessor(UInt1Vector vector, IntSupplier currentRowSupplier, ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { - this(vector, currentRowSupplier, true, UInt1Vector.TYPE_WIDTH, setCursorWasNull); + this(vector, currentRowSupplier, true, setCursorWasNull); } public ArrowFlightJdbcBaseIntVectorAccessor(UInt2Vector vector, IntSupplier currentRowSupplier, ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { - this(vector, currentRowSupplier, true, UInt2Vector.TYPE_WIDTH, setCursorWasNull); + this(vector, currentRowSupplier, true, setCursorWasNull); } public ArrowFlightJdbcBaseIntVectorAccessor(UInt4Vector vector, IntSupplier currentRowSupplier, ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { - this(vector, currentRowSupplier, true, UInt4Vector.TYPE_WIDTH, setCursorWasNull); + this(vector, currentRowSupplier, true, setCursorWasNull); } public ArrowFlightJdbcBaseIntVectorAccessor(UInt8Vector vector, IntSupplier currentRowSupplier, ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { - this(vector, currentRowSupplier, true, UInt8Vector.TYPE_WIDTH, setCursorWasNull); + this(vector, currentRowSupplier, true, setCursorWasNull); } public ArrowFlightJdbcBaseIntVectorAccessor(TinyIntVector vector, IntSupplier currentRowSupplier, ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { - this(vector, currentRowSupplier, false, TinyIntVector.TYPE_WIDTH, setCursorWasNull); + this(vector, currentRowSupplier, false, setCursorWasNull); } public ArrowFlightJdbcBaseIntVectorAccessor(SmallIntVector vector, IntSupplier currentRowSupplier, ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { - this(vector, currentRowSupplier, false, SmallIntVector.TYPE_WIDTH, setCursorWasNull); + this(vector, currentRowSupplier, false, setCursorWasNull); } public ArrowFlightJdbcBaseIntVectorAccessor(IntVector vector, IntSupplier currentRowSupplier, ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { - this(vector, currentRowSupplier, false, IntVector.TYPE_WIDTH, setCursorWasNull); + this(vector, currentRowSupplier, false, setCursorWasNull); } public ArrowFlightJdbcBaseIntVectorAccessor(BigIntVector vector, IntSupplier currentRowSupplier, ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { - this(vector, currentRowSupplier, false, BigIntVector.TYPE_WIDTH, setCursorWasNull); + this(vector, currentRowSupplier, false, setCursorWasNull); } private ArrowFlightJdbcBaseIntVectorAccessor(BaseIntVector vector, IntSupplier currentRowSupplier, - boolean isUnsigned, int bytesToAllocate, - ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + boolean isUnsigned, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { super(currentRowSupplier, setCursorWasNull); this.type = vector.getMinorType(); this.holder = new NumericHolder(); this.getter = createGetter(vector); this.isUnsigned = isUnsigned; - this.bytesToAllocate = bytesToAllocate; } @Override diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java index f55fd12f9a517..67d98c2e69847 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java @@ -32,7 +32,6 @@ public class ArrowFlightJdbcBitVectorAccessor extends ArrowFlightJdbcAccessor { private final BitVector vector; private final NullableBitHolder holder; - private static final int BYTES_T0_ALLOCATE = 1; /** * Constructor for the BitVectorAccessor. diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java index 6237a8b58d68a..e95cf00bc7a21 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java @@ -33,6 +33,7 @@ import org.apache.calcite.avatica.ConnectionConfigImpl; import org.apache.calcite.avatica.ConnectionProperty; + /** * A {@link ConnectionConfig} for the {@link ArrowFlightConnection}. */ diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java index 5ee43ce012e94..c28071490caa6 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java @@ -248,13 +248,13 @@ public PreparedStatement prepareStatement(final String sqlQuery, final int autoG } @Override - public PreparedStatement prepareStatement(final String sqlQuery, final int... columnIndices) + public PreparedStatement prepareStatement(final String sqlQuery, final int[] columnIndices) throws SQLException { return realConnection.prepareStatement(sqlQuery, columnIndices); } @Override - public PreparedStatement prepareStatement(final String sqlQuery, final String... columnNames) + public PreparedStatement prepareStatement(final String sqlQuery, final String[] columnNames) throws SQLException { return realConnection.prepareStatement(sqlQuery, columnNames); } @@ -306,12 +306,12 @@ public Properties getClientInfo() throws SQLException { } @Override - public Array createArrayOf(final String typeName, final Object... elements) throws SQLException { + public Array createArrayOf(final String typeName, final Object[] elements) throws SQLException { return realConnection.createArrayOf(typeName, elements); } @Override - public Struct createStruct(final String typeName, final Object... attributes) + public Struct createStruct(final String typeName, final Object[] attributes) throws SQLException { return realConnection.createStruct(typeName, attributes); } diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java index b21b03340e9f9..843fe0cb89d9f 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java @@ -49,7 +49,6 @@ import org.apache.calcite.avatica.AvaticaParameter; import org.apache.calcite.avatica.ColumnMetaData; import org.apache.calcite.avatica.proto.Common; -import org.apache.calcite.avatica.proto.Common.ColumnMetaData.Builder; /** * Convert objects between Arrow and Avatica. @@ -71,7 +70,7 @@ public static List convertArrowFieldsToColumnMetaDataList(final final Field field = fields.get(index); final ArrowType fieldType = field.getType(); - final Builder builder = Common.ColumnMetaData.newBuilder() + final Common.ColumnMetaData.Builder builder = Common.ColumnMetaData.newBuilder() .setOrdinal(index) .setColumnName(field.getName()) .setLabel(field.getName()); @@ -90,10 +89,10 @@ public static List convertArrowFieldsToColumnMetaDataList(final /** * Set on Column MetaData Builder. * - * @param builder {@link Builder} + * @param builder {@link Common.ColumnMetaData.Builder} * @param metadataMap {@link Map} */ - public static void setOnColumnMetaDataBuilder(final Builder builder, + public static void setOnColumnMetaDataBuilder(final Common.ColumnMetaData.Builder builder, final Map metadataMap) { final FlightSqlColumnMetadata columnMetadata = new FlightSqlColumnMetadata(metadataMap); final String catalogName = columnMetadata.getCatalogName(); diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java index e52251f53918a..64255e2213a1a 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java @@ -21,10 +21,10 @@ import java.net.URLDecoder; import java.util.HashMap; import java.util.Map; - /** * URL Parser for extracting key values from a connection string. */ + public final class UrlParser { private UrlParser() { } @@ -37,6 +37,7 @@ private UrlParser() { * @param url {@link String} * @return {@link Map} */ + @SuppressWarnings("StringSplitter") public static Map parse(String url, String separator) { Map resultMap = new HashMap<>(); if (url != null) { diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java index 90c926612f15a..76f01514c9501 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java @@ -140,7 +140,7 @@ public void testShouldGetResultSetReturnValidResultSetWithOffsets() throws SQLEx Assert.assertEquals((Object) resultSet.getInt(1), dataVector.getObject(count + 3)); count++; } - Assert.assertEquals(count, 5); + Assert.assertEquals(5, count); } } diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java index 784fd5b292b27..e1f64c9dd8732 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java @@ -181,6 +181,8 @@ public void testConnectWithInsensitiveCasePropertyKeys2() throws Exception { /** * Tests whether an exception is thrown upon attempting to connect to a * malformed URI. + * + * @throws SQLException If an error occurs. */ @Test(expected = SQLException.class) public void testShouldThrowExceptionWhenAttemptingToConnectToMalformedUrl() throws SQLException { @@ -194,7 +196,7 @@ public void testShouldThrowExceptionWhenAttemptingToConnectToMalformedUrl() thro * Tests whether an exception is thrown upon attempting to connect to a * malformed URI. * - * @throws Exception If an error occurs. + * @throws SQLException If an error occurs. */ @Test(expected = SQLException.class) public void testShouldThrowExceptionWhenAttemptingToConnectToUrlNoPrefix() throws SQLException { diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionMutualTlsTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionMutualTlsTest.java index cc44cc57be9b3..03f15d77ade11 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionMutualTlsTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionMutualTlsTest.java @@ -50,8 +50,6 @@ public class ConnectionMutualTlsTest { @ClassRule public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; private static final String tlsRootCertsPath; - - private static final String serverMTlsCACertPath; private static final String clientMTlsCertPath; private static final String badClientMTlsCertPath; private static final String clientMTlsKeyPath; @@ -68,8 +66,6 @@ public class ConnectionMutualTlsTest { final File serverMTlsCACert = FlightSqlTestCertificates.exampleCACert(); - serverMTlsCACertPath = serverMTlsCACert.getPath(); - final FlightSqlTestCertificates.CertKeyPair clientMTlsCertKey = FlightSqlTestCertificates.exampleTlsCerts().get(1); diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java index 231371a923a28..0e3e015a04636 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -94,10 +94,6 @@ private static void resultSetNextUntilDone(ResultSet resultSet) throws SQLExcept } } - private static void setMaxRowsLimit(int maxRowsLimit, Statement statement) throws SQLException { - statement.setLargeMaxRows(maxRowsLimit); - } - /** * Tests whether the {@link ArrowFlightJdbcDriver} can run a query successfully. * @@ -411,9 +407,9 @@ public void testPartitionedFlightServer() throws Exception { // Construct the data-only nodes first. FlightProducer firstProducer = new PartitionedFlightSqlProducer.DataOnlyFlightSqlProducer( - new Ticket("first".getBytes()), firstPartition); + new Ticket("first".getBytes(StandardCharsets.UTF_8)), firstPartition); FlightProducer secondProducer = new PartitionedFlightSqlProducer.DataOnlyFlightSqlProducer( - new Ticket("second".getBytes()), secondPartition); + new Ticket("second".getBytes(StandardCharsets.UTF_8)), secondPartition); final FlightServer.Builder firstBuilder = FlightServer.builder( allocator, forGrpcInsecure("localhost", 0), firstProducer); @@ -427,10 +423,10 @@ public void testPartitionedFlightServer() throws Exception { firstServer.start(); secondServer.start(); final FlightEndpoint firstEndpoint = - new FlightEndpoint(new Ticket("first".getBytes()), firstServer.getLocation()); + new FlightEndpoint(new Ticket("first".getBytes(StandardCharsets.UTF_8)), firstServer.getLocation()); final FlightEndpoint secondEndpoint = - new FlightEndpoint(new Ticket("second".getBytes()), secondServer.getLocation()); + new FlightEndpoint(new Ticket("second".getBytes(StandardCharsets.UTF_8)), secondServer.getLocation()); // Finally start the root node. try (final PartitionedFlightSqlProducer rootProducer = new PartitionedFlightSqlProducer( diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java index 099b0122179f1..a9b5c46e01e9b 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java @@ -123,7 +123,7 @@ public void testShouldGetObjectWithBooleanClassReturnGetBoolean() throws SQLExce when(accessor.getObject(Boolean.class)).thenCallRealMethod(); - Assert.assertEquals(accessor.getObject(Boolean.class), true); + Assert.assertEquals(true, accessor.getObject(Boolean.class)); verify(accessor).getBoolean(); } @@ -134,7 +134,7 @@ public void testShouldGetObjectWithBigDecimalClassReturnGetBigDecimal() throws S when(accessor.getObject(BigDecimal.class)).thenCallRealMethod(); - Assert.assertEquals(accessor.getObject(BigDecimal.class), expected); + Assert.assertEquals(expected, accessor.getObject(BigDecimal.class)); verify(accessor).getBigDecimal(); } @@ -145,7 +145,7 @@ public void testShouldGetObjectWithStringClassReturnGetString() throws SQLExcept when(accessor.getObject(String.class)).thenCallRealMethod(); - Assert.assertEquals(accessor.getObject(String.class), expected); + Assert.assertEquals(expected, accessor.getObject(String.class)); verify(accessor).getString(); } @@ -167,7 +167,7 @@ public void testShouldGetObjectWithObjectClassReturnGetObject() throws SQLExcept when(accessor.getObject(Object.class)).thenCallRealMethod(); - Assert.assertEquals(accessor.getObject(Object.class), expected); + Assert.assertEquals(expected, accessor.getObject(Object.class)); verify(accessor).getObject(); } diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java index 38d842724b9c1..e2c17b2f085ae 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java @@ -179,7 +179,7 @@ public void testShouldGetTimestampReturnValidTimestampWithCalendar() throws Exce final Timestamp resultWithoutCalendar = accessor.getTimestamp(null); final Timestamp result = accessor.getTimestamp(calendar); - long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + long offset = (long) timeZone.getOffset(resultWithoutCalendar.getTime()) - timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); @@ -212,7 +212,7 @@ public void testShouldGetDateReturnValidDateWithCalendar() throws Exception { final Date resultWithoutCalendar = accessor.getDate(null); final Date result = accessor.getDate(calendar); - long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + long offset = (long) timeZone.getOffset(resultWithoutCalendar.getTime()) - timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); @@ -245,7 +245,7 @@ public void testShouldGetTimeReturnValidTimeWithCalendar() throws Exception { final Time resultWithoutCalendar = accessor.getTime(null); final Time result = accessor.getTime(calendar); - long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + long offset = (long) timeZone.getOffset(resultWithoutCalendar.getTime()) - timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java index b2eb8f1dbee8f..e958fb60ba41e 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java @@ -114,7 +114,7 @@ public void testShouldGetObjectReturnValidList() throws Exception { accessorIterator.assertAccessorGetter(vector, AbstractArrowFlightJdbcListVectorAccessor::getObject, (accessor, currentRow) -> equalTo( - Arrays.asList(0, (currentRow), (currentRow) * 2, (currentRow) * 3, (currentRow) * 4))); + Arrays.asList(0, currentRow, currentRow * 2, currentRow * 3, currentRow * 4))); } @Test @@ -137,7 +137,7 @@ public void testShouldGetArrayReturnValidArray() throws Exception { Object[] arrayObject = (Object[]) array.getArray(); collector.checkThat(arrayObject, equalTo( - new Object[] {0, currentRow, (currentRow) * 2, (currentRow) * 3, (currentRow) * 4})); + new Object[] {0, currentRow, currentRow * 2, currentRow * 3, currentRow * 4})); }); } @@ -161,7 +161,7 @@ public void testShouldGetArrayReturnValidArrayPassingOffsets() throws Exception Object[] arrayObject = (Object[]) array.getArray(1, 3); collector.checkThat(arrayObject, equalTo( - new Object[] {currentRow, (currentRow) * 2, (currentRow) * 3})); + new Object[] {currentRow, currentRow * 2, currentRow * 3})); }); } diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java index b3c85fc0ab1f3..735fe9f40ba0e 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java @@ -202,8 +202,8 @@ public void testShouldGetObjectWorkWithNestedComplexData() throws SQLException { new ArrowFlightJdbcStructVectorAccessor(rootVector, () -> 0, (boolean wasNull) -> { }); - Assert.assertEquals(accessor.getObject(), expected); - Assert.assertEquals(accessor.getString(), expected.toString()); + Assert.assertEquals(expected, accessor.getObject()); + Assert.assertEquals(expected.toString(), accessor.getString()); } } } diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java index 809d6e8d35386..00537bfa028e9 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java @@ -68,7 +68,7 @@ private void iterate(final CheckedFunction is(arrayToAssert[currentRow] ? result : resultIfFalse)) + (accessor, currentRow) -> is(arrayToAssert[currentRow] ? result : resultIfFalse) ); } diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java index c165bfb7ce336..52a397edab18f 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java @@ -159,7 +159,7 @@ public void addSelectQuery(final String sqlCommand, final Schema schema, * @param updatedRows the number of rows affected. */ public void addUpdateQuery(final String sqlCommand, final long updatedRows) { - addUpdateQuery(sqlCommand, ((flightStream, putResultStreamListener) -> { + addUpdateQuery(sqlCommand, (flightStream, putResultStreamListener) -> { final DoPutUpdateResult result = DoPutUpdateResult.newBuilder().setRecordCount(updatedRows).build(); try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); @@ -171,7 +171,7 @@ public void addUpdateQuery(final String sqlCommand, final long updatedRows) { } finally { putResultStreamListener.onCompleted(); } - })); + }); } /** diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java index a200fc8d39c15..fd8fb57fcafde 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java @@ -18,6 +18,7 @@ package org.apache.arrow.driver.jdbc.utils; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.util.Random; import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; @@ -456,9 +457,9 @@ public VarBinaryVector createVarBinaryVector() { public VarBinaryVector createVarBinaryVector(final String fieldName) { VarBinaryVector valueVector = new VarBinaryVector(fieldName, this.getRootAllocator()); valueVector.allocateNew(3); - valueVector.setSafe(0, (fieldName + "__BINARY_DATA_0001").getBytes()); - valueVector.setSafe(1, (fieldName + "__BINARY_DATA_0002").getBytes()); - valueVector.setSafe(2, (fieldName + "__BINARY_DATA_0003").getBytes()); + valueVector.setSafe(0, (fieldName + "__BINARY_DATA_0001").getBytes(StandardCharsets.UTF_8)); + valueVector.setSafe(1, (fieldName + "__BINARY_DATA_0002").getBytes(StandardCharsets.UTF_8)); + valueVector.setSafe(2, (fieldName + "__BINARY_DATA_0003").getBytes(StandardCharsets.UTF_8)); valueVector.setValueCount(3); return valueVector; @@ -472,9 +473,9 @@ public VarBinaryVector createVarBinaryVector(final String fieldName) { public LargeVarBinaryVector createLargeVarBinaryVector() { LargeVarBinaryVector valueVector = new LargeVarBinaryVector("", this.getRootAllocator()); valueVector.allocateNew(3); - valueVector.setSafe(0, "BINARY_DATA_0001".getBytes()); - valueVector.setSafe(1, "BINARY_DATA_0002".getBytes()); - valueVector.setSafe(2, "BINARY_DATA_0003".getBytes()); + valueVector.setSafe(0, "BINARY_DATA_0001".getBytes(StandardCharsets.UTF_8)); + valueVector.setSafe(1, "BINARY_DATA_0002".getBytes(StandardCharsets.UTF_8)); + valueVector.setSafe(2, "BINARY_DATA_0003".getBytes(StandardCharsets.UTF_8)); valueVector.setValueCount(3); return valueVector; @@ -488,9 +489,9 @@ public LargeVarBinaryVector createLargeVarBinaryVector() { public FixedSizeBinaryVector createFixedSizeBinaryVector() { FixedSizeBinaryVector valueVector = new FixedSizeBinaryVector("", this.getRootAllocator(), 16); valueVector.allocateNew(3); - valueVector.setSafe(0, "BINARY_DATA_0001".getBytes()); - valueVector.setSafe(1, "BINARY_DATA_0002".getBytes()); - valueVector.setSafe(2, "BINARY_DATA_0003".getBytes()); + valueVector.setSafe(0, "BINARY_DATA_0001".getBytes(StandardCharsets.UTF_8)); + valueVector.setSafe(1, "BINARY_DATA_0002".getBytes(StandardCharsets.UTF_8)); + valueVector.setSafe(2, "BINARY_DATA_0003".getBytes(StandardCharsets.UTF_8)); valueVector.setValueCount(3); return valueVector; diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java index f1bd44539ac58..48334dc0f92e2 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java @@ -25,7 +25,7 @@ public class ThrowableAssertionUtils { private ThrowableAssertionUtils() { } - public static void simpleAssertThrowableClass( + public static void simpleAssertThrowableClass( final Class expectedThrowable, final ThrowingRunnable runnable) { try { runnable.run(); diff --git a/java/flight/flight-sql-jdbc-driver/pom.xml b/java/flight/flight-sql-jdbc-driver/pom.xml index 653ee5c192756..28534a9b0badd 100644 --- a/java/flight/flight-sql-jdbc-driver/pom.xml +++ b/java/flight/flight-sql-jdbc-driver/pom.xml @@ -16,7 +16,7 @@ arrow-flight org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT ../pom.xml 4.0.0 @@ -107,7 +107,7 @@ joda-time joda-time - 2.12.5 + 2.12.6 runtime diff --git a/java/flight/flight-sql/pom.xml b/java/flight/flight-sql/pom.xml index 3c7e4b3495e5a..a0598f70b9545 100644 --- a/java/flight/flight-sql/pom.xml +++ b/java/flight/flight-sql/pom.xml @@ -14,7 +14,7 @@ arrow-flight org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT ../pom.xml @@ -50,6 +50,10 @@ org.apache.arrow arrow-memory-core + + org.immutables + value + org.apache.arrow arrow-jdbc @@ -110,4 +114,27 @@ + + + jdk11+ + + [11,] + + !m2e.version + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + --add-reads=org.apache.arrow.flight.sql=org.slf4j --add-reads=org.apache.arrow.flight.core=ALL-UNNAMED --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED + + + + + + + diff --git a/java/flight/flight-sql/src/main/java/module-info.java b/java/flight/flight-sql/src/main/java/module-info.java new file mode 100644 index 0000000000000..5514d5b870afd --- /dev/null +++ b/java/flight/flight-sql/src/main/java/module-info.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +module org.apache.arrow.flight.sql { + exports org.apache.arrow.flight.sql; + exports org.apache.arrow.flight.sql.example; + exports org.apache.arrow.flight.sql.util; + + requires com.google.common; + requires java.sql; + requires org.apache.arrow.flight.core; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; + requires protobuf.java; +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java index e2d79129c1fc9..dbe39ab1d07b4 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -433,7 +433,7 @@ default void cancelFlightInfo(CancelFlightInfoRequest request, CallContext conte * @param info The FlightInfo of the query to cancel. * @param context Per-call context. * @param listener Whether cancellation succeeded. - * @deprecated Prefer {@link #cancelFlightInfo(FlightInfo, CallContext, StreamListener)}. + * @deprecated Prefer {@link #cancelFlightInfo(CancelFlightInfoRequest, CallContext, StreamListener)}. */ @Deprecated default void cancelQuery(FlightInfo info, CallContext context, StreamListener listener) { diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java index 18793f9b905fe..338a60e2ae6df 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java @@ -17,7 +17,6 @@ package org.apache.arrow.flight.sql; -import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.stream.IntStream.range; import static org.apache.arrow.flight.FlightProducer.ServerStreamListener; import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedTransaction; @@ -32,6 +31,7 @@ import java.util.function.ObjIntConsumer; import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlNullOrdering; import org.apache.arrow.flight.sql.impl.FlightSql.SqlOuterJoinsSupportLevel; import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedElementActions; @@ -80,7 +80,7 @@ public class SqlInfoBuilder { * @return a new {@link NullableVarCharHolder} with the provided input data {@code string}. */ public static NullableVarCharHolder getHolderForUtf8(final String string, final ArrowBuf buf) { - final byte[] bytes = string.getBytes(UTF_8); + final byte[] bytes = string.getBytes(StandardCharsets.UTF_8); buf.setBytes(0, bytes); final NullableVarCharHolder holder = new NullableVarCharHolder(); holder.buffer = buf; @@ -502,6 +502,26 @@ public SqlInfoBuilder withSqlQuotedIdentifierCase(final SqlSupportedCaseSensitiv return withBitIntProvider(SqlInfo.SQL_QUOTED_IDENTIFIER_CASE_VALUE, value.getNumber()); } + /** + * Sets a value for {@link SqlInfo#SQL_ALL_TABLES_ARE_SELECTABLE} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_ALL_TABLES_ARE_SELECTABLE} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlAllTablesAreSelectable(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_ALL_TABLES_ARE_SELECTABLE_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_NULL_ORDERING} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_NULL_ORDERING} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlNullOrdering(final SqlNullOrdering value) { + return withBitIntProvider(SqlInfo.SQL_NULL_ORDERING_VALUE, value.getNumber()); + } + /** * Sets a value SqlInf @link SqlInfo#SQL_MAX_BINARY_LITERAL_LENGTH} in the builder. * @@ -572,6 +592,16 @@ public SqlInfoBuilder withSqlMaxColumnsInSelect(final long value) { return withBitIntProvider(SqlInfo.SQL_MAX_COLUMNS_IN_SELECT_VALUE, value); } + /** + * Sets a value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_TABLE} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_TABLE} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxColumnsInTable(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_COLUMNS_IN_TABLE_VALUE, value); + } + /** * Sets a value for {@link SqlInfo#SQL_MAX_CONNECTIONS} in the builder. * @@ -1020,7 +1050,7 @@ private void setDataVarCharListField(final VectorSchemaRoot root, final int inde final int length = values.length; range(0, length) .forEach(i -> onCreateArrowBuf(buf -> { - final byte[] bytes = values[i].getBytes(UTF_8); + final byte[] bytes = values[i].getBytes(StandardCharsets.UTF_8); buf.setBytes(0, bytes); writer.writeVarChar(0, bytes.length, buf); })); diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index 3cc8f4a1c1bd5..1d43728b789f5 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -69,6 +69,7 @@ import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -237,7 +238,10 @@ public FlightSqlExample(final Location location) { SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE : metaData.storesLowerCaseQuotedIdentifiers() ? SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_LOWERCASE : - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UNKNOWN); + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UNKNOWN) + .withSqlAllTablesAreSelectable(true) + .withSqlNullOrdering(SqlNullOrdering.SQL_NULLS_SORTED_AT_END) + .withSqlMaxColumnsInTable(42); } catch (SQLException e) { throw new RuntimeException(e); } @@ -379,6 +383,7 @@ private static int saveToVectors(final Map ve return saveToVectors(vectorToColumnName, data, emptyToNull, alwaysTrue); } + @SuppressWarnings("StringSplitter") private static int saveToVectors(final Map vectorToColumnName, final ResultSet data, boolean emptyToNull, Predicate resultSetPredicate) @@ -509,7 +514,7 @@ private static VectorSchemaRoot getTypeInfoRoot(CommandGetXdbcTypeInfo request, } }; } else { - predicate = (resultSet -> true); + predicate = resultSet -> true; } int rows = saveToVectors(mapper, typeInfo, true, predicate); @@ -682,7 +687,7 @@ public void getStreamPreparedStatement(final CommandPreparedStatementQuery comma public void closePreparedStatement(final ActionClosePreparedStatementRequest request, final CallContext context, final StreamListener listener) { // Running on another thread - executorService.submit(() -> { + Future unused = executorService.submit(() -> { try { preparedStatementLoadingCache.invalidate(request.getPreparedStatementHandle()); } catch (final Exception e) { @@ -771,7 +776,7 @@ public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { // Running on another thread - executorService.submit(() -> { + Future unused = executorService.submit(() -> { try { final ByteString preparedStatementHandle = copyFrom(randomUUID().toString().getBytes(UTF_8)); // Ownership of the connection will be passed to the context. Do NOT close! diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java similarity index 95% rename from java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java rename to java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java index 7635b80ecd0fd..a39736e939f0b 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight; +package org.apache.arrow.flight.sql.test; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; @@ -38,6 +38,15 @@ import java.util.Optional; import java.util.stream.IntStream; +import org.apache.arrow.flight.CancelFlightInfoRequest; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStatusCode; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.RenewFlightEndpointRequest; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; @@ -106,6 +115,12 @@ public static void setUp() throws Exception { .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE), "10.14.2.0 - (1828579)"); GET_SQL_INFO_EXPECTED_RESULTS_MAP .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE), "false"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.SQL_ALL_TABLES_ARE_SELECTABLE_VALUE), "true"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put( + Integer.toString(FlightSql.SqlInfo.SQL_NULL_ORDERING_VALUE), + Integer.toString(FlightSql.SqlNullOrdering.SQL_NULLS_SORTED_AT_END_VALUE)); GET_SQL_INFO_EXPECTED_RESULTS_MAP .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_CATALOG_VALUE), "false"); GET_SQL_INFO_EXPECTED_RESULTS_MAP @@ -122,6 +137,8 @@ public static void setUp() throws Exception { .put( Integer.toString(FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE_VALUE), Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE_VALUE)); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.SQL_MAX_COLUMNS_IN_TABLE_VALUE), "42"); } @AfterAll @@ -135,12 +152,15 @@ private static List> getNonConformingResultsForGetSqlInfo(final Lis FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION, FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION, FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY, + FlightSql.SqlInfo.SQL_ALL_TABLES_ARE_SELECTABLE, + FlightSql.SqlInfo.SQL_NULL_ORDERING, FlightSql.SqlInfo.SQL_DDL_CATALOG, FlightSql.SqlInfo.SQL_DDL_SCHEMA, FlightSql.SqlInfo.SQL_DDL_TABLE, FlightSql.SqlInfo.SQL_IDENTIFIER_CASE, FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR, - FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE); + FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE, + FlightSql.SqlInfo.SQL_MAX_COLUMNS_IN_TABLE); } private static List> getNonConformingResultsForGetSqlInfo( @@ -152,6 +172,7 @@ private static List> getNonConformingResultsForGetSqlInfo( final List result = results.get(index); final String providedName = result.get(0); final String expectedName = Integer.toString(args[index].getNumber()); + System.err.println(expectedName); if (!(GET_SQL_INFO_EXPECTED_RESULTS_MAP.get(providedName).equals(result.get(1)) && providedName.equals(expectedName))) { nonConformingResults.add(result); @@ -603,31 +624,21 @@ public void testGetSqlInfoResultsWithSingleArg() throws Exception { } @Test - public void testGetSqlInfoResultsWithTwoArgs() throws Exception { - final FlightSql.SqlInfo[] args = { - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION}; - final FlightInfo info = sqlClient.getSqlInfo(args); - try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat( - stream.getSchema(), - is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA) - ), - () -> MatcherAssert.assertThat( - getNonConformingResultsForGetSqlInfo(getResults(stream), args), - is(emptyList()) - ) - ); - } - } - - @Test - public void testGetSqlInfoResultsWithThreeArgs() throws Exception { + public void testGetSqlInfoResultsWithManyArgs() throws Exception { final FlightSql.SqlInfo[] args = { FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION, - FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR}; + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY, + FlightSql.SqlInfo.SQL_ALL_TABLES_ARE_SELECTABLE, + FlightSql.SqlInfo.SQL_NULL_ORDERING, + FlightSql.SqlInfo.SQL_DDL_CATALOG, + FlightSql.SqlInfo.SQL_DDL_SCHEMA, + FlightSql.SqlInfo.SQL_DDL_TABLE, + FlightSql.SqlInfo.SQL_IDENTIFIER_CASE, + FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR, + FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE, + FlightSql.SqlInfo.SQL_MAX_COLUMNS_IN_TABLE}; final FlightInfo info = sqlClient.getSqlInfo(args); try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { Assertions.assertAll( @@ -912,7 +923,7 @@ public void testCancelFlightInfo() { FlightInfo info = sqlClient.getSqlInfo(); CancelFlightInfoRequest request = new CancelFlightInfoRequest(info); FlightRuntimeException fre = assertThrows(FlightRuntimeException.class, () -> sqlClient.cancelFlightInfo(request)); - assertEquals(FlightStatusCode.UNIMPLEMENTED, fre.status().code()); + Assertions.assertEquals(FlightStatusCode.UNIMPLEMENTED, fre.status().code()); } @Test diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStreams.java similarity index 96% rename from java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java rename to java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStreams.java index 1dd925eb53add..1dd96f0fd4e9c 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStreams.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight; +package org.apache.arrow.flight.sql.test; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; @@ -28,6 +28,15 @@ import java.util.Collections; import java.util.List; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.sql.BasicFlightSqlProducer; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.FlightSqlProducer; diff --git a/java/flight/pom.xml b/java/flight/pom.xml index 7ddda94f77b49..2f777ab42b756 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -15,7 +15,7 @@ arrow-java-root org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT 4.0.0 @@ -26,7 +26,6 @@ flight-core - flight-grpc flight-sql flight-sql-jdbc-core flight-sql-jdbc-driver @@ -61,7 +60,7 @@ 4.11.0 - 4.11.0 + 5.2.0 diff --git a/java/format/pom.xml b/java/format/pom.xml index 3f581311e20ea..a98edefbeb217 100644 --- a/java/format/pom.xml +++ b/java/format/pom.xml @@ -15,7 +15,7 @@ arrow-java-root org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT arrow-format diff --git a/java/format/src/main/java/module-info.java b/java/format/src/main/java/module-info.java new file mode 100644 index 0000000000000..bda779c91afbc --- /dev/null +++ b/java/format/src/main/java/module-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +module org.apache.arrow.format { + exports org.apache.arrow.flatbuf; + requires transitive flatbuffers.java; +} diff --git a/java/gandiva/CMakeLists.txt b/java/gandiva/CMakeLists.txt index 2aa8d92959e42..369829d7a30d5 100644 --- a/java/gandiva/CMakeLists.txt +++ b/java/gandiva/CMakeLists.txt @@ -84,6 +84,11 @@ if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT) ) endif() +set(ARROW_JAVA_JNI_GANDIVA_LIBDIR + "${CMAKE_INSTALL_PREFIX}/lib/gandiva_jni/${ARROW_JAVA_JNI_ARCH_DIR}") +set(ARROW_JAVA_JNI_GANDIVA_BINDIR + "${CMAKE_INSTALL_PREFIX}/bin/gandiva_jni/${ARROW_JAVA_JNI_ARCH_DIR}") + install(TARGETS arrow_java_jni_gandiva - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + LIBRARY DESTINATION ${ARROW_JAVA_JNI_GANDIVA_LIBDIR} + RUNTIME DESTINATION ${ARROW_JAVA_JNI_GANDIVA_BINDIR}) diff --git a/java/gandiva/pom.xml b/java/gandiva/pom.xml index e837a09ff8330..6337efcf7e348 100644 --- a/java/gandiva/pom.xml +++ b/java/gandiva/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT org.apache.arrow.gandiva @@ -34,6 +34,10 @@ org.apache.arrow arrow-memory-core + + org.immutables + value + org.apache.arrow arrow-memory-netty @@ -92,7 +96,7 @@ org.apache.maven.plugins maven-gpg-plugin - 1.5 + 3.1.0 sign-artifacts diff --git a/java/gandiva/src/main/java/module-info.java b/java/gandiva/src/main/java/module-info.java new file mode 100644 index 0000000000000..533717d91f7f0 --- /dev/null +++ b/java/gandiva/src/main/java/module-info.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +open module org.apache.arrow.gandiva { + exports org.apache.arrow.gandiva.expression; + exports org.apache.arrow.gandiva.exceptions; + exports org.apache.arrow.gandiva.evaluator; + exports org.apache.arrow.gandiva.ipc; + + requires com.google.common; + requires org.apache.arrow.format; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; + requires org.slf4j; + requires protobuf.java; +} \ No newline at end of file diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java index 2528989f3784b..57748e9c8e1af 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java @@ -71,7 +71,7 @@ private static JniLoader setupInstance() throws GandivaException { private static void loadGandivaLibraryFromJar(final String tmpDir) throws IOException, GandivaException { final String libraryToLoad = - getNormalizedArch() + "/" + System.mapLibraryName(LIBRARY_NAME); + LIBRARY_NAME + "/" + getNormalizedArch() + "/" + System.mapLibraryName(LIBRARY_NAME); final File libraryFile = moveFileFromJarToTemp(tmpDir, libraryToLoad, LIBRARY_NAME); System.load(libraryFile.getAbsolutePath()); } diff --git a/java/maven/module-info-compiler-maven-plugin/pom.xml b/java/maven/module-info-compiler-maven-plugin/pom.xml index 46c0d563f4eb9..37d14ad412d88 100644 --- a/java/maven/module-info-compiler-maven-plugin/pom.xml +++ b/java/maven/module-info-compiler-maven-plugin/pom.xml @@ -16,7 +16,7 @@ org.apache.arrow.maven.plugins arrow-maven-plugins - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT module-info-compiler-maven-plugin maven-plugin @@ -80,15 +80,15 @@ maven-plugin-plugin - 3.10.2 + 3.11.0 maven-jar-plugin - 3.0.2 + 3.3.0 maven-install-plugin - 2.5.2 + 3.1.1 maven-deploy-plugin @@ -104,7 +104,7 @@ org.apache.maven.plugins maven-plugin-plugin - 3.10.2 + 3.11.0 true diff --git a/java/maven/pom.xml b/java/maven/pom.xml index 56f3c4c434f64..3a88ec762e19c 100644 --- a/java/maven/pom.xml +++ b/java/maven/pom.xml @@ -17,7 +17,7 @@ --> org.apache.arrow.maven.plugins arrow-maven-plugins - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT Arrow Maven Plugins pom @@ -240,7 +240,7 @@ org.slf4j jcl-over-slf4j - 1.7.5 + 2.0.11 diff --git a/java/memory/memory-core/pom.xml b/java/memory/memory-core/pom.xml index 6e411c0cd5440..2a92d032942c9 100644 --- a/java/memory/memory-core/pom.xml +++ b/java/memory/memory-core/pom.xml @@ -13,7 +13,7 @@ arrow-memory org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT 4.0.0 diff --git a/java/memory/memory-netty-buffer-patch/pom.xml b/java/memory/memory-netty-buffer-patch/pom.xml index 1d4407c638d8a..97b224e9ccc5c 100644 --- a/java/memory/memory-netty-buffer-patch/pom.xml +++ b/java/memory/memory-netty-buffer-patch/pom.xml @@ -15,7 +15,7 @@ arrow-memory org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT 4.0.0 diff --git a/java/memory/memory-netty/pom.xml b/java/memory/memory-netty/pom.xml index 159ab5160c983..9b20e1bde2ae7 100644 --- a/java/memory/memory-netty/pom.xml +++ b/java/memory/memory-netty/pom.xml @@ -13,7 +13,7 @@ arrow-memory org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT 4.0.0 diff --git a/java/memory/memory-unsafe/pom.xml b/java/memory/memory-unsafe/pom.xml index 5ef4e8a9149a5..07a140e594522 100644 --- a/java/memory/memory-unsafe/pom.xml +++ b/java/memory/memory-unsafe/pom.xml @@ -13,7 +13,7 @@ arrow-memory org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT 4.0.0 diff --git a/java/memory/pom.xml b/java/memory/pom.xml index 55fbb90353f34..9e2d612765738 100644 --- a/java/memory/pom.xml +++ b/java/memory/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT arrow-memory Arrow Memory diff --git a/java/performance/pom.xml b/java/performance/pom.xml index 5e0b6c1b54541..a1d53171f549b 100644 --- a/java/performance/pom.xml +++ b/java/performance/pom.xml @@ -14,7 +14,7 @@ arrow-java-root org.apache.arrow - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT arrow-performance jar @@ -187,11 +187,11 @@ maven-install-plugin - 2.5.1 + 3.1.1 maven-jar-plugin - 2.4 + 3.3.0 maven-javadoc-plugin diff --git a/java/performance/src/test/java/org/apache/arrow/adapter/AvroAdapterBenchmarks.java b/java/performance/src/test/java/org/apache/arrow/adapter/AvroAdapterBenchmarks.java index 884647b5af3ea..c07aeffafb10c 100644 --- a/java/performance/src/test/java/org/apache/arrow/adapter/AvroAdapterBenchmarks.java +++ b/java/performance/src/test/java/org/apache/arrow/adapter/AvroAdapterBenchmarks.java @@ -21,10 +21,10 @@ import java.io.ByteArrayOutputStream; import java.util.concurrent.TimeUnit; -import org.apache.arrow.AvroToArrow; -import org.apache.arrow.AvroToArrowConfig; -import org.apache.arrow.AvroToArrowConfigBuilder; -import org.apache.arrow.AvroToArrowVectorIterator; +import org.apache.arrow.adapter.avro.AvroToArrow; +import org.apache.arrow.adapter.avro.AvroToArrowConfig; +import org.apache.arrow.adapter.avro.AvroToArrowConfigBuilder; +import org.apache.arrow.adapter.avro.AvroToArrowVectorIterator; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; diff --git a/java/pom.xml b/java/pom.xml index 042488a5b949a..3e595648ed085 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -20,7 +20,7 @@ org.apache.arrow arrow-java-root - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT pom Apache Arrow Java Root POM @@ -31,13 +31,13 @@ ${project.build.directory}/generated-sources 1.9.0 5.10.1 - 2.0.9 + 2.0.11 33.0.0-jre - 4.1.104.Final + 4.1.106.Final 1.60.0 3.23.1 2.16.0 - 2.7.1 + 3.3.6 23.5.26 1.11.3 @@ -438,14 +438,14 @@ org.immutables value - 2.8.2 + 2.10.0 maven-enforcer-plugin - 3.0.0-M2 + 3.4.1 org.apache.maven.plugins @@ -496,7 +496,7 @@ org.jacoco jacoco-maven-plugin - 0.8.7 + 0.8.11 @@ -1055,7 +1055,6 @@ -DARROW_JAVA_JNI_ENABLE_DEFAULT=OFF -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=lib/${os.detected.arch} -DCMAKE_INSTALL_PREFIX=${arrow.c.jni.dist.dir} ../ @@ -1128,7 +1127,6 @@ -DARROW_SUBSTRAIT=${ARROW_DATASET} -DARROW_USE_CCACHE=ON -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=lib/${os.detected.arch} -DCMAKE_INSTALL_PREFIX=java-dist -DCMAKE_UNITY_BUILD=ON @@ -1169,7 +1167,6 @@ -DARROW_JAVA_JNI_ENABLE_DEFAULT=ON -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=lib/${os.detected.arch} -DCMAKE_INSTALL_PREFIX=${arrow.dataset.jni.dist.dir} -DCMAKE_PREFIX_PATH=${project.basedir}/../java-dist/lib/${os.detected.arch}/cmake -DProtobuf_USE_STATIC_LIBS=ON @@ -1248,7 +1245,6 @@ -DARROW_WITH_ZLIB=ON -DARROW_WITH_ZSTD=ON -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=lib/${os.detected.arch} -DCMAKE_INSTALL_PREFIX=java-dist -DCMAKE_UNITY_BUILD=ON -GNinja @@ -1290,7 +1286,6 @@ -DARROW_JAVA_JNI_ENABLE_DEFAULT=ON -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=lib/${os.detected.arch} -DCMAKE_INSTALL_PREFIX=${arrow.dataset.jni.dist.dir} -DCMAKE_PREFIX_PATH=${project.basedir}/../java-dist/lib/${os.detected.arch}/cmake diff --git a/java/tools/pom.xml b/java/tools/pom.xml index 8df436bac9aef..0688fae1ab78c 100644 --- a/java/tools/pom.xml +++ b/java/tools/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT arrow-tools Arrow Tools @@ -34,6 +34,10 @@ org.apache.arrow arrow-compression + + org.immutables + value + com.google.guava guava diff --git a/java/tools/src/main/java/module-info.java b/java/tools/src/main/java/module-info.java new file mode 100644 index 0000000000000..6b4329eb84f2a --- /dev/null +++ b/java/tools/src/main/java/module-info.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +module org.apache.arrow.tools { + exports org.apache.arrow.tools; + + requires com.fasterxml.jackson.databind; + requires com.google.common; + requires org.apache.arrow.compression; + requires org.apache.arrow.memory.core; + requires org.apache.arrow.vector; + requires org.slf4j; +} diff --git a/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java b/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java index c49b04c855846..1201d0f760524 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java @@ -43,11 +43,9 @@ public class FileRoundtrip { private static final Logger LOGGER = LoggerFactory.getLogger(FileRoundtrip.class); private final Options options; - private final PrintStream out; private final PrintStream err; - FileRoundtrip(PrintStream out, PrintStream err) { - this.out = out; + FileRoundtrip(PrintStream err) { this.err = err; this.options = new Options(); this.options.addOption("i", "in", true, "input file"); @@ -56,7 +54,7 @@ public class FileRoundtrip { } public static void main(String[] args) { - System.exit(new FileRoundtrip(System.out, System.err).run(args)); + System.exit(new FileRoundtrip(System.err).run(args)); } private File validateFile(String type, String fileName) { diff --git a/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java b/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java index 178a0834fa44f..1bc7ead7b73bb 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java @@ -35,7 +35,6 @@ import org.apache.arrow.vector.ipc.ArrowFileReader; import org.apache.arrow.vector.ipc.ArrowFileWriter; import org.apache.arrow.vector.ipc.message.ArrowBlock; -import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; public class ArrowFileTestFixtures { @@ -63,7 +62,6 @@ static void validateOutput(File testOutFile, BufferAllocator allocator) throws E ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - Schema schema = root.getSchema(); for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { if (!arrowReader.loadRecordBatch(rbBlock)) { throw new IOException("Expected to read record batch"); diff --git a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java index 714cb416bf996..9cf893ee5c283 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java @@ -142,7 +142,6 @@ public void basicTest() throws InterruptedException, IOException { Collections.emptyList()); TinyIntVector vector = new TinyIntVector("testField", FieldType.nullable(TINYINT.getType()), alloc); - Schema schema = new Schema(asList(field)); // Try an empty stream, just the header. testEchoServer(serverPort, field, vector, 0); diff --git a/java/tools/src/test/java/org/apache/arrow/tools/TestFileRoundtrip.java b/java/tools/src/test/java/org/apache/arrow/tools/TestFileRoundtrip.java index ddac6f79384d9..a5d6c9658fd4f 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/TestFileRoundtrip.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/TestFileRoundtrip.java @@ -56,7 +56,7 @@ public void test() throws Exception { writeInput(testInFile, allocator); String[] args = {"-i", testInFile.getAbsolutePath(), "-o", testOutFile.getAbsolutePath()}; - int result = new FileRoundtrip(System.out, System.err).run(args); + int result = new FileRoundtrip(System.err).run(args); assertEquals(0, result); validateOutput(testOutFile, allocator); diff --git a/java/vector/pom.xml b/java/vector/pom.xml index 17d8f312a52a5..dc453963b62f6 100644 --- a/java/vector/pom.xml +++ b/java/vector/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 15.0.0-SNAPSHOT + 16.0.0-SNAPSHOT arrow-vector Arrow Vectors @@ -30,6 +30,10 @@ org.apache.arrow arrow-memory-core + + org.immutables + value + com.fasterxml.jackson.core jackson-core @@ -49,7 +53,7 @@ commons-codec commons-codec - 1.15 + 1.16.0 org.apache.arrow diff --git a/java/vector/src/main/codegen/includes/vv_imports.ftl b/java/vector/src/main/codegen/includes/vv_imports.ftl index c9a8820b258b1..f4c72a1a6cbae 100644 --- a/java/vector/src/main/codegen/includes/vv_imports.ftl +++ b/java/vector/src/main/codegen/includes/vv_imports.ftl @@ -48,9 +48,6 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.nio.ByteBuffer; -import java.sql.Date; -import java.sql.Time; -import java.sql.Timestamp; import java.math.BigDecimal; import java.math.BigInteger; import java.time.Duration; diff --git a/java/vector/src/main/java/module-info.java b/java/vector/src/main/java/module-info.java new file mode 100644 index 0000000000000..20f7094715f4d --- /dev/null +++ b/java/vector/src/main/java/module-info.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +module org.apache.arrow.vector { + exports org.apache.arrow.vector; + exports org.apache.arrow.vector.compare; + exports org.apache.arrow.vector.compare.util; + exports org.apache.arrow.vector.complex; + exports org.apache.arrow.vector.complex.impl; + exports org.apache.arrow.vector.complex.reader; + exports org.apache.arrow.vector.complex.writer; + exports org.apache.arrow.vector.compression; + exports org.apache.arrow.vector.dictionary; + exports org.apache.arrow.vector.holders; + exports org.apache.arrow.vector.ipc; + exports org.apache.arrow.vector.ipc.message; + exports org.apache.arrow.vector.table; + exports org.apache.arrow.vector.types; + exports org.apache.arrow.vector.types.pojo; + exports org.apache.arrow.vector.util; + exports org.apache.arrow.vector.validate; + + opens org.apache.arrow.vector.types.pojo to com.fasterxml.jackson.databind; + + requires com.fasterxml.jackson.annotation; + requires com.fasterxml.jackson.core; + requires com.fasterxml.jackson.databind; + requires com.fasterxml.jackson.datatype.jsr310; + requires flatbuffers.java; + requires jdk.unsupported; + requires org.apache.arrow.format; + requires org.apache.arrow.memory.core; + requires org.apache.commons.codec; + requires org.eclipse.collections.impl; + requires org.slf4j; +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java index 90229460111c3..c456c625389ba 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java @@ -110,7 +110,7 @@ public String getName() { */ @Override public long getValidityBufferAddress() { - return (validityBuffer.memoryAddress()); + return validityBuffer.memoryAddress(); } /** @@ -120,7 +120,7 @@ public long getValidityBufferAddress() { */ @Override public long getDataBufferAddress() { - return (valueBuffer.memoryAddress()); + return valueBuffer.memoryAddress(); } /** @@ -298,6 +298,7 @@ public boolean allocateNewSafe() { * @param valueCount the desired number of elements in the vector * @throws org.apache.arrow.memory.OutOfMemoryException on error */ + @Override public void allocateNew(int valueCount) { computeAndCheckBufferSize(valueCount); @@ -521,6 +522,7 @@ public void loadFieldBuffers(ArrowFieldNode fieldNode, List ownBuffers * * @return the inner buffers. */ + @Override public List getFieldBuffers() { List result = new ArrayList<>(2); setReaderAndWriterIndex(); @@ -597,6 +599,7 @@ public TransferPair getTransferPair(BufferAllocator allocator) { * @param allocator allocator for the target vector * @return TransferPair */ + @Override public abstract TransferPair getTransferPair(String ref, BufferAllocator allocator); /** @@ -605,6 +608,7 @@ public TransferPair getTransferPair(BufferAllocator allocator) { * @param allocator allocator for the target vector * @return TransferPair */ + @Override public abstract TransferPair getTransferPair(Field field, BufferAllocator allocator); /** @@ -911,6 +915,7 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { * * @param index position of element */ + @Override public void setNull(int index) { handleSafe(index); // not really needed to set the bit to 0 as long as diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java index a77278138f28c..c239edbcc3c29 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java @@ -228,6 +228,7 @@ private void initOffsetBuffer() { * Reset the vector to initial state. Same as {@link #zeroVector()}. * Note that this method doesn't release any memory. */ + @Override public void reset() { zeroVector(); lastSet = -1; @@ -318,6 +319,7 @@ public void loadFieldBuffers(ArrowFieldNode fieldNode, List ownBuffers * Get the buffers belonging to this vector. * @return the inner buffers. */ + @Override public List getFieldBuffers() { // before flight/IPC, we must bring the vector to a consistent state. // this is because, it is possible that the offset buffers of some trailing values @@ -471,6 +473,7 @@ private void allocateValidityBuffer(final long size) { * Resize the vector to increase the capacity. The internal behavior is to * double the current value capacity. */ + @Override public void reAlloc() { reallocDataBuffer(); reallocValidityAndOffsetBuffers(); @@ -691,6 +694,7 @@ public TransferPair getTransferPair(BufferAllocator allocator) { * @param allocator allocator for the target vector * @return TransferPair */ + @Override public abstract TransferPair getTransferPair(String ref, BufferAllocator allocator); /** @@ -699,6 +703,7 @@ public TransferPair getTransferPair(BufferAllocator allocator) { * @param allocator allocator for the target vector * @return TransferPair */ + @Override public abstract TransferPair getTransferPair(Field field, BufferAllocator allocator); /** @@ -835,6 +840,7 @@ private void splitAndTransferValidityBuffer(int startIndex, int length, * * @return the number of null elements. */ + @Override public int getNullCount() { return BitVectorHelper.getNullCount(validityBuffer, valueCount); } @@ -856,6 +862,7 @@ public boolean isSafe(int index) { * @param index position of element * @return true if element at given index is null */ + @Override public boolean isNull(int index) { return (isSet(index) == 0); } @@ -879,6 +886,7 @@ public int isSet(int index) { * * @return valueCount for the vector */ + @Override public int getValueCount() { return valueCount; } @@ -888,6 +896,7 @@ public int getValueCount() { * * @param valueCount value count */ + @Override public void setValueCount(int valueCount) { assert valueCount >= 0; this.valueCount = valueCount; @@ -1091,6 +1100,7 @@ public void setSafe(int index, ByteBuffer value, int start, int length) { * * @param index position of element */ + @Override public void setNull(int index) { // We need to check and realloc both validity and offset buffer while (index >= getValueCapacity()) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java index 679e5d06c016e..070919c356791 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java @@ -28,15 +28,12 @@ import org.apache.arrow.vector.util.DataSizeRoundingUtil; import org.apache.arrow.vector.util.TransferPair; import org.apache.arrow.vector.util.ValueVectorUtility; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Base class for other Arrow Vector Types. Provides basic functionality around * memory management. */ public abstract class BaseValueVector implements ValueVector { - private static final Logger logger = LoggerFactory.getLogger(BaseValueVector.class); public static final String MAX_ALLOCATION_SIZE_PROPERTY = "arrow.vector.max_allocation_bytes"; public static final long MAX_ALLOCATION_SIZE = Long.getLong(MAX_ALLOCATION_SIZE_PROPERTY, Long.MAX_VALUE); @@ -160,6 +157,7 @@ long computeCombinedBufferSize(int valueCount, int typeWidth) { * * @return Concrete instance of FieldReader by using double-checked locking. */ + @Override public FieldReader getReader() { FieldReader reader = fieldReader; @@ -178,7 +176,7 @@ public FieldReader getReader() { /** * Container for primitive vectors (1 for the validity bit-mask and one to hold the values). */ - class DataAndValidityBuffers { + static class DataAndValidityBuffers { private ArrowBuf dataBuf; private ArrowBuf validityBuf; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java index 46bc9815f037a..4cf495a349f02 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java @@ -247,6 +247,7 @@ private void initOffsetBuffer() { * Reset the vector to initial state. Same as {@link #zeroVector()}. * Note that this method doesn't release any memory. */ + @Override public void reset() { zeroVector(); lastSet = -1; @@ -337,6 +338,7 @@ public void loadFieldBuffers(ArrowFieldNode fieldNode, List ownBuffers * Get the buffers belonging to this vector. * @return the inner buffers. */ + @Override public List getFieldBuffers() { // before flight/IPC, we must bring the vector to a consistent state. // this is because, it is possible that the offset buffers of some trailing values @@ -493,6 +495,7 @@ private void allocateValidityBuffer(final long size) { * Resize the vector to increase the capacity. The internal behavior is to * double the current value capacity. */ + @Override public void reAlloc() { reallocDataBuffer(); reallocValidityAndOffsetBuffers(); @@ -732,6 +735,7 @@ public TransferPair getTransferPair(BufferAllocator allocator) { * @param allocator allocator for the target vector * @return TransferPair */ + @Override public abstract TransferPair getTransferPair(String ref, BufferAllocator allocator); /** @@ -740,6 +744,7 @@ public TransferPair getTransferPair(BufferAllocator allocator) { * @param allocator allocator for the target vector * @return TransferPair */ + @Override public abstract TransferPair getTransferPair(Field field, BufferAllocator allocator); /** @@ -796,7 +801,8 @@ private void splitAndTransferOffsetBuffer(int startIndex, int length, BaseVariab final int dataLength = end - start; if (start == 0) { - final ArrowBuf slicedOffsetBuffer = offsetBuffer.slice(startIndex * OFFSET_WIDTH, (1 + length) * OFFSET_WIDTH); + final ArrowBuf slicedOffsetBuffer = offsetBuffer.slice(startIndex * ((long) OFFSET_WIDTH), + (1 + length) * ((long) OFFSET_WIDTH)); target.offsetBuffer = transferBuffer(slicedOffsetBuffer, target.allocator); } else { target.allocateOffsetBuffer((long) (length + 1) * OFFSET_WIDTH); @@ -883,6 +889,7 @@ private void splitAndTransferValidityBuffer(int startIndex, int length, * * @return the number of null elements. */ + @Override public int getNullCount() { return BitVectorHelper.getNullCount(validityBuffer, valueCount); } @@ -904,6 +911,7 @@ public boolean isSafe(int index) { * @param index position of element * @return true if element at given index is null */ + @Override public boolean isNull(int index) { return (isSet(index) == 0); } @@ -927,6 +935,7 @@ public int isSet(int index) { * * @return valueCount for the vector */ + @Override public int getValueCount() { return valueCount; } @@ -936,6 +945,7 @@ public int getValueCount() { * * @param valueCount value count */ + @Override public void setValueCount(int valueCount) { assert valueCount >= 0; this.valueCount = valueCount; @@ -1016,7 +1026,7 @@ public void setValueLengthSafe(int index, int length) { handleSafe(index, length); fillHoles(index); final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + length); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + length); lastSet = index; } @@ -1119,7 +1129,7 @@ public void set(int index, ByteBuffer value, int start, int length) { fillHoles(index); BitVectorHelper.setBit(validityBuffer, index); final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + length); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + length); valueBuffer.setBytes(startOffset, value, start, length); lastSet = index; } @@ -1140,7 +1150,7 @@ public void setSafe(int index, ByteBuffer value, int start, int length) { fillHoles(index); BitVectorHelper.setBit(validityBuffer, index); final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + length); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + length); valueBuffer.setBytes(startOffset, value, start, length); lastSet = index; } @@ -1150,6 +1160,7 @@ public void setSafe(int index, ByteBuffer value, int start, int length) { * * @param index position of element */ + @Override public void setNull(int index) { // We need to check and realloc both validity and offset buffer while (index >= getValueCapacity()) { @@ -1174,7 +1185,7 @@ public void set(int index, int isSet, int start, int end, ArrowBuf buffer) { fillHoles(index); BitVectorHelper.setValidityBit(validityBuffer, index, isSet); final int startOffset = offsetBuffer.getInt((long) index * OFFSET_WIDTH); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, buffer, start, dataLength); lastSet = index; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java index b0052e7e33009..095d98aa265fe 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BigIntVector.java @@ -129,6 +129,7 @@ public void get(int index, NullableBigIntHolder holder) { * @param index position of element * @return element at given index */ + @Override public Long getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java index 104819147b109..a34df8cf6f68b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java @@ -105,7 +105,7 @@ public MinorType getMinorType() { @Override public void setInitialCapacity(int valueCount) { final int size = getValidityBufferSizeFromCount(valueCount); - if (size * 2 > MAX_ALLOCATION_SIZE) { + if (size * 2L > MAX_ALLOCATION_SIZE) { throw new OversizedAllocationException("Requested amount of memory is more than max allowed"); } lastValueCapacity = valueCount; @@ -149,15 +149,14 @@ public int getBufferSize() { * @param length length of the split. * @param target destination vector */ + @Override public void splitAndTransferTo(int startIndex, int length, BaseFixedWidthVector target) { Preconditions.checkArgument(startIndex >= 0 && length >= 0 && startIndex + length <= valueCount, "Invalid parameters startIndex: %s, length: %s for valueCount: %s", startIndex, length, valueCount); compareTypes(target, "splitAndTransferTo"); target.clear(); - target.validityBuffer = splitAndTransferBuffer(startIndex, length, target, - validityBuffer, target.validityBuffer); - target.valueBuffer = splitAndTransferBuffer(startIndex, length, target, - valueBuffer, target.valueBuffer); + target.validityBuffer = splitAndTransferBuffer(startIndex, length, validityBuffer, target.validityBuffer); + target.valueBuffer = splitAndTransferBuffer(startIndex, length, valueBuffer, target.valueBuffer); target.refreshValueCapacity(); target.setValueCount(length); @@ -166,7 +165,6 @@ public void splitAndTransferTo(int startIndex, int length, BaseFixedWidthVector private ArrowBuf splitAndTransferBuffer( int startIndex, int length, - BaseFixedWidthVector target, ArrowBuf sourceBuffer, ArrowBuf destBuffer) { int firstByteSource = BitVectorHelper.byteIndex(startIndex); @@ -276,11 +274,12 @@ public void get(int index, NullableBitHolder holder) { * @param index position of element * @return element at given index */ + @Override public Boolean getObject(int index) { if (isSet(index) == 0) { return null; } else { - return new Boolean(getBit(index) != 0); + return getBit(index) != 0; } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BufferLayout.java b/java/vector/src/main/java/org/apache/arrow/vector/BufferLayout.java index 09c874e398022..9725693348a48 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BufferLayout.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BufferLayout.java @@ -144,7 +144,7 @@ public boolean equals(Object obj) { if (obj == null) { return false; } - if (getClass() != obj.getClass()) { + if (!(obj instanceof BufferLayout)) { return false; } BufferLayout other = (BufferLayout) obj; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java index c99c5786058b7..13645d3b26004 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DateDayVector.java @@ -131,6 +131,7 @@ public void get(int index, NullableDateDayHolder holder) { * @param index position of element * @return element at given index */ + @Override public Integer getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java index 6ab8ac4eed229..1333fb0adcefa 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DateMilliVector.java @@ -133,6 +133,7 @@ public void get(int index, NullableDateMilliHolder holder) { * @param index position of element * @return element at given index */ + @Override public LocalDateTime getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java index fe650c7d28074..931c4eea0afb1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java @@ -151,6 +151,7 @@ public void get(int index, NullableDecimal256Holder holder) { * @param index position of element * @return element at given index */ + @Override public BigDecimal getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java index 7c3662c86748b..eefcee837f719 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java @@ -150,6 +150,7 @@ public void get(int index, NullableDecimalHolder holder) { * @param index position of element * @return element at given index */ + @Override public BigDecimal getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java index b6abc16194b77..636afef1e9f7b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java @@ -143,6 +143,7 @@ public void get(int index, NullableDurationHolder holder) { * @param index position of element * @return element at given index */ + @Override public Duration getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/Float4Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/Float4Vector.java index 4b56a22f2d0c4..46f9447be2938 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/Float4Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/Float4Vector.java @@ -131,6 +131,7 @@ public void get(int index, NullableFloat4Holder holder) { * @param index position of element * @return element at given index */ + @Override public Float getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/Float8Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/Float8Vector.java index 7e4fae7087ba5..840f9d4ba087b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/Float8Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/Float8Vector.java @@ -131,6 +131,7 @@ public void get(int index, NullableFloat8Holder holder) { * @param index position of element * @return element at given index */ + @Override public Double getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/GenerateSampleData.java b/java/vector/src/main/java/org/apache/arrow/vector/GenerateSampleData.java index efebfd83543d7..6cda18a8a53d3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/GenerateSampleData.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/GenerateSampleData.java @@ -108,8 +108,8 @@ private static void writeTimeStampData(TimeStampVector vector, int valueCount) { } private static void writeDecimalData(DecimalVector vector, int valueCount) { - final BigDecimal even = new BigDecimal(0.0543278923); - final BigDecimal odd = new BigDecimal(2.0543278923); + final BigDecimal even = new BigDecimal("0.0543278923"); + final BigDecimal odd = new BigDecimal("2.0543278923"); for (int i = 0; i < valueCount; i++) { if (i % 2 == 0) { vector.setSafe(i, even); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java index 5c8ef440e8ea4..08ead148af312 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntVector.java @@ -131,6 +131,7 @@ public void get(int index, NullableIntHolder holder) { * @param index position of element * @return element at given index */ + @Override public Integer getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntervalDayVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntervalDayVector.java index 7c0d19baa9a6f..f53eb37138dcb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntervalDayVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntervalDayVector.java @@ -164,6 +164,7 @@ public void get(int index, NullableIntervalDayHolder holder) { * @param index position of element * @return element at given index */ + @Override public Duration getObject(int index) { if (isSet(index) == 0) { return null; @@ -206,23 +207,23 @@ private StringBuilder getAsStringBuilderHelper(int index) { final int days = valueBuffer.getInt(startIndex); int millis = valueBuffer.getInt(startIndex + MILLISECOND_OFFSET); - final int hours = millis / (org.apache.arrow.vector.util.DateUtility.hoursToMillis); - millis = millis % (org.apache.arrow.vector.util.DateUtility.hoursToMillis); + final int hours = millis / org.apache.arrow.vector.util.DateUtility.hoursToMillis; + millis = millis % org.apache.arrow.vector.util.DateUtility.hoursToMillis; - final int minutes = millis / (org.apache.arrow.vector.util.DateUtility.minutesToMillis); - millis = millis % (org.apache.arrow.vector.util.DateUtility.minutesToMillis); + final int minutes = millis / org.apache.arrow.vector.util.DateUtility.minutesToMillis; + millis = millis % org.apache.arrow.vector.util.DateUtility.minutesToMillis; - final int seconds = millis / (org.apache.arrow.vector.util.DateUtility.secondsToMillis); - millis = millis % (org.apache.arrow.vector.util.DateUtility.secondsToMillis); + final int seconds = millis / org.apache.arrow.vector.util.DateUtility.secondsToMillis; + millis = millis % org.apache.arrow.vector.util.DateUtility.secondsToMillis; final String dayString = (Math.abs(days) == 1) ? " day " : " days "; - return (new StringBuilder() + return new StringBuilder() .append(days).append(dayString) .append(hours).append(":") .append(minutes).append(":") .append(seconds).append(".") - .append(millis)); + .append(millis); } /*----------------------------------------------------------------* diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntervalMonthDayNanoVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntervalMonthDayNanoVector.java index fc0aa9d27b1c3..716af6fec9cd8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntervalMonthDayNanoVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntervalMonthDayNanoVector.java @@ -186,6 +186,7 @@ public void get(int index, NullableIntervalMonthDayNanoHolder holder) { * @param index position of element * @return element at given index */ + @Override public PeriodDuration getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/IntervalYearVector.java b/java/vector/src/main/java/org/apache/arrow/vector/IntervalYearVector.java index 7fe572f3ff1f8..c5f384604aa83 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/IntervalYearVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/IntervalYearVector.java @@ -147,6 +147,7 @@ public void get(int index, NullableIntervalYearHolder holder) { * @param index position of element * @return element at given index */ + @Override public Period getObject(int index) { if (isSet(index) == 0) { return null; @@ -181,11 +182,11 @@ private StringBuilder getAsStringBuilderHelper(int index) { final String yearString = (Math.abs(years) == 1) ? " year " : " years "; final String monthString = (Math.abs(months) == 1) ? " month " : " months "; - return (new StringBuilder() + return new StringBuilder() .append(years) .append(yearString) .append(months) - .append(monthString)); + .append(monthString); } /*----------------------------------------------------------------* diff --git a/java/vector/src/main/java/org/apache/arrow/vector/LargeVarBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/LargeVarBinaryVector.java index 0750f68f4f716..8560ba3a68b04 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/LargeVarBinaryVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/LargeVarBinaryVector.java @@ -131,6 +131,7 @@ public void read(int index, ReusableBuffer buffer) { * @param index position of element to get * @return byte array for non-null element, null otherwise */ + @Override public byte[] getObject(int index) { return get(index); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/LargeVarCharVector.java b/java/vector/src/main/java/org/apache/arrow/vector/LargeVarCharVector.java index 6f08fcb81fee1..df424c87488a0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/LargeVarCharVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/LargeVarCharVector.java @@ -121,6 +121,7 @@ public byte[] get(int index) { * @param index position of element to get * @return Text object for non-null element, null otherwise */ + @Override public Text getObject(int index) { assert index >= 0; if (NULL_CHECKING_ENABLED && isSet(index) == 0) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java index 518ee707396ea..37a6fe110401e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/SmallIntVector.java @@ -131,6 +131,7 @@ public void get(int index, NullableSmallIntHolder holder) { * @param index position of element * @return element at given index */ + @Override public Short getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeMicroVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeMicroVector.java index 86738cd221ec4..c463dc36336c8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeMicroVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeMicroVector.java @@ -131,6 +131,7 @@ public void get(int index, NullableTimeMicroHolder holder) { * @param index position of element * @return element at given index */ + @Override public Long getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeMilliVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeMilliVector.java index 480add91097bb..1e745d9b9923b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeMilliVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeMilliVector.java @@ -133,6 +133,7 @@ public void get(int index, NullableTimeMilliHolder holder) { * @param index position of element * @return element at given index */ + @Override public LocalDateTime getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeNanoVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeNanoVector.java index 82609cdc446ed..426e865a5c18b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeNanoVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeNanoVector.java @@ -131,6 +131,7 @@ public void get(int index, NullableTimeNanoHolder holder) { * @param index position of element * @return element at given index */ + @Override public Long getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeSecVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeSecVector.java index 9b7614e55b6e8..c760ed29e04e6 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeSecVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeSecVector.java @@ -131,6 +131,7 @@ public void get(int index, NullableTimeSecHolder holder) { * @param index position of element * @return element at given index */ + @Override public Integer getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroTZVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroTZVector.java index a37b444d1a368..b515f8e2c83c0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroTZVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroTZVector.java @@ -133,6 +133,7 @@ public void get(int index, NullableTimeStampMicroTZHolder holder) { * @param index position of element * @return element at given index */ + @Override public Long getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroVector.java index 88ce27a187ebc..2f65921f22b26 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroVector.java @@ -119,6 +119,7 @@ public void get(int index, NullableTimeStampMicroHolder holder) { * @param index position of element * @return element at given index */ + @Override public LocalDateTime getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliTZVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliTZVector.java index 775594ceea640..d0293099432a9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliTZVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliTZVector.java @@ -133,6 +133,7 @@ public void get(int index, NullableTimeStampMilliTZHolder holder) { * @param index position of element * @return element at given index */ + @Override public Long getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliVector.java index a42773269f8b5..96440fd5ac3f7 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliVector.java @@ -119,6 +119,7 @@ public void get(int index, NullableTimeStampMilliHolder holder) { * @param index position of element * @return element at given index */ + @Override public LocalDateTime getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoTZVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoTZVector.java index af43cf6fc9b64..f93ec9b24c43a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoTZVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoTZVector.java @@ -133,6 +133,7 @@ public void get(int index, NullableTimeStampNanoTZHolder holder) { * @param index position of element * @return element at given index */ + @Override public Long getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoVector.java index 7b02b1c87d3fb..723e62f8d6e02 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoVector.java @@ -119,6 +119,7 @@ public void get(int index, NullableTimeStampNanoHolder holder) { * @param index position of element * @return element at given index */ + @Override public LocalDateTime getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampSecVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampSecVector.java index 1e249140335d2..2de01fd52e457 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampSecVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampSecVector.java @@ -119,6 +119,7 @@ public void get(int index, NullableTimeStampSecHolder holder) { * @param index position of element * @return element at given index */ + @Override public LocalDateTime getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java index 4c4eee1342ff0..e9ea59298d093 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TinyIntVector.java @@ -131,6 +131,7 @@ public void get(int index, NullableTinyIntHolder holder) { * @param index position of element * @return element at given index */ + @Override public Byte getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java b/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java index 60fe2a6a6ee63..ae465418cf2fd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java @@ -55,7 +55,7 @@ public class TypeLayout { /** - * Constructs a new {@TypeLayout} for the given arrowType. + * Constructs a new {@link TypeLayout} for the given arrowType. */ public static TypeLayout getTypeLayout(final ArrowType arrowType) { TypeLayout layout = arrowType.accept(new ArrowTypeVisitor() { @@ -421,6 +421,7 @@ public List getBufferTypes() { return types; } + @Override public String toString() { return bufferLayouts.toString(); } @@ -438,7 +439,7 @@ public boolean equals(Object obj) { if (obj == null) { return false; } - if (getClass() != obj.getClass()) { + if (!(obj instanceof TypeLayout)) { return false; } TypeLayout other = (TypeLayout) obj; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java index 777df3fb1efe7..fcb04eaf08821 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt1Vector.java @@ -136,6 +136,7 @@ public void get(int index, NullableUInt1Holder holder) { * @param index position of element * @return element at given index */ + @Override public Byte getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java index e5b95be191df1..a9708a4faa9a7 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt2Vector.java @@ -127,6 +127,7 @@ public void get(int index, NullableUInt2Holder holder) { * @param index position of element * @return element at given index */ + @Override public Character getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java index bda98b12005ce..f9bed0c013a2a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt4Vector.java @@ -136,6 +136,7 @@ public void get(int index, NullableUInt4Holder holder) { * @param index position of element * @return element at given index */ + @Override public Integer getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java index 5e7c18902f0ae..a3e16b5e30dde 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/UInt8Vector.java @@ -136,6 +136,7 @@ public void get(int index, NullableUInt8Holder holder) { * @param index position of element * @return element at given index */ + @Override public Long getObject(int index) { if (isSet(index) == 0) { return null; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java index 87790c1168bd0..ab67ebad965aa 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java @@ -132,6 +132,7 @@ public void read(int index, ReusableBuffer buffer) { * @param index position of element to get * @return byte array for non-null element, null otherwise */ + @Override public byte[] getObject(int index) { return get(index); } @@ -176,7 +177,7 @@ public void set(int index, VarBinaryHolder holder) { BitVectorHelper.setBit(validityBuffer, index); final int dataLength = holder.end - holder.start; final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, holder.buffer, holder.start, dataLength); lastSet = index; } @@ -196,7 +197,7 @@ public void setSafe(int index, VarBinaryHolder holder) { fillHoles(index); BitVectorHelper.setBit(validityBuffer, index); final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, holder.buffer, holder.start, dataLength); lastSet = index; } @@ -215,10 +216,10 @@ public void set(int index, NullableVarBinaryHolder holder) { final int startOffset = getStartOffset(index); if (holder.isSet != 0) { final int dataLength = holder.end - holder.start; - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, holder.buffer, holder.start, dataLength); } else { - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset); } lastSet = index; } @@ -238,7 +239,7 @@ public void setSafe(int index, NullableVarBinaryHolder holder) { handleSafe(index, dataLength); fillHoles(index); final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, holder.buffer, holder.start, dataLength); } else { fillEmpties(index + 1); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java b/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java index 7350dc99bbda8..c6d5a7090bc6f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java @@ -118,6 +118,7 @@ public byte[] get(int index) { * @param index position of element to get * @return Text object for non-null element, null otherwise */ + @Override public Text getObject(int index) { assert index >= 0; if (NULL_CHECKING_ENABLED && isSet(index) == 0) { @@ -182,7 +183,7 @@ public void set(int index, VarCharHolder holder) { BitVectorHelper.setBit(validityBuffer, index); final int dataLength = holder.end - holder.start; final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, holder.buffer, holder.start, dataLength); lastSet = index; } @@ -203,7 +204,7 @@ public void setSafe(int index, VarCharHolder holder) { BitVectorHelper.setBit(validityBuffer, index); final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, holder.buffer, holder.start, dataLength); lastSet = index; } @@ -222,10 +223,10 @@ public void set(int index, NullableVarCharHolder holder) { final int startOffset = getStartOffset(index); if (holder.isSet != 0) { final int dataLength = holder.end - holder.start; - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, holder.buffer, holder.start, dataLength); } else { - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset); } lastSet = index; } @@ -245,7 +246,7 @@ public void setSafe(int index, NullableVarCharHolder holder) { handleSafe(index, dataLength); fillHoles(index); final int startOffset = getStartOffset(index); - offsetBuffer.setInt((index + 1) * OFFSET_WIDTH, startOffset + dataLength); + offsetBuffer.setInt((index + 1) * ((long) OFFSET_WIDTH), startOffset + dataLength); valueBuffer.setBytes(startOffset, holder.buffer, holder.start, dataLength); } else { fillEmpties(index + 1); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java index 898bfe3d39780..8e6cdb6c45bc5 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java @@ -55,6 +55,7 @@ public void allocateNew() throws OutOfMemoryException { } } + @Override public BufferAllocator getAllocator() { return allocator; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java index 797a5af31f9a4..80efea6cbe39e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Locale; import java.util.stream.Collectors; import org.apache.arrow.memory.ArrowBuf; @@ -56,7 +57,7 @@ public abstract class AbstractStructVector extends AbstractContainerVector { } ConflictPolicy conflictPolicy; try { - conflictPolicy = ConflictPolicy.valueOf(conflictPolicyStr.toUpperCase()); + conflictPolicy = ConflictPolicy.valueOf(conflictPolicyStr.toUpperCase(Locale.ROOT)); } catch (Exception e) { conflictPolicy = ConflictPolicy.CONFLICT_REPLACE; } @@ -172,6 +173,7 @@ public void reAlloc() { * @return resultant {@link org.apache.arrow.vector.ValueVector} * @throws java.lang.IllegalStateException raised if there is a hard schema change */ + @Override public T addOrGet(String childName, FieldType fieldType, Class clazz) { final ValueVector existing = getChild(childName); boolean create = false; @@ -411,7 +413,7 @@ public int getBufferSize() { for (final ValueVector v : vectors.values()) { for (final ArrowBuf buf : v.getBuffers(false)) { - actualBufSize += buf.writerIndex(); + actualBufSize += (int) buf.writerIndex(); } } return actualBufSize; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java index 95deceb4e75ca..8ba2e48dc2fa3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java @@ -54,7 +54,7 @@ public abstract class BaseRepeatedValueVector extends BaseValueVector implements public static final byte OFFSET_WIDTH = 4; protected ArrowBuf offsetBuffer; protected FieldVector vector; - protected final CallBack callBack; + protected final CallBack repeatedCallBack; protected int valueCount; protected long offsetAllocationSizeInBytes = INITIAL_VALUE_ALLOCATION * OFFSET_WIDTH; private final String name; @@ -70,7 +70,7 @@ protected BaseRepeatedValueVector(String name, BufferAllocator allocator, FieldV this.name = name; this.offsetBuffer = allocator.getEmpty(); this.vector = Preconditions.checkNotNull(vector, "data vector cannot be null"); - this.callBack = callBack; + this.repeatedCallBack = callBack; this.valueCount = 0; } @@ -123,7 +123,7 @@ protected void reallocOffsetBuffer() { } newAllocationSize = CommonUtil.nextPowerOfTwo(newAllocationSize); - newAllocationSize = Math.min(newAllocationSize, (long) (OFFSET_WIDTH) * Integer.MAX_VALUE); + newAllocationSize = Math.min(newAllocationSize, (long) OFFSET_WIDTH * Integer.MAX_VALUE); assert newAllocationSize >= 1; if (newAllocationSize > MAX_ALLOCATION_SIZE || newAllocationSize <= offsetBuffer.capacity()) { @@ -157,7 +157,7 @@ public FieldVector getDataVector() { @Override public void setInitialCapacity(int numRecords) { - offsetAllocationSizeInBytes = (numRecords + 1) * OFFSET_WIDTH; + offsetAllocationSizeInBytes = (numRecords + 1L) * OFFSET_WIDTH; if (vector instanceof BaseFixedWidthVector || vector instanceof BaseVariableWidthVector) { vector.setInitialCapacity(numRecords * RepeatedValueVector.DEFAULT_REPEAT_PER_RECORD); } else { @@ -194,7 +194,7 @@ public void setInitialCapacity(int numRecords, double density) { throw new OversizedAllocationException("Requested amount of memory is more than max allowed"); } - offsetAllocationSizeInBytes = (numRecords + 1) * OFFSET_WIDTH; + offsetAllocationSizeInBytes = (numRecords + 1L) * OFFSET_WIDTH; int innerValueCapacity = Math.max((int) (numRecords * density), 1); @@ -222,7 +222,7 @@ public void setInitialCapacity(int numRecords, double density) { * for in this vector across all records. */ public void setInitialTotalCapacity(int numRecords, int totalNumberOfElements) { - offsetAllocationSizeInBytes = (numRecords + 1) * OFFSET_WIDTH; + offsetAllocationSizeInBytes = (numRecords + 1L) * OFFSET_WIDTH; vector.setInitialCapacity(totalNumberOfElements); } @@ -313,13 +313,13 @@ public int size() { public AddOrGetResult addOrGetVector(FieldType fieldType) { boolean created = false; if (vector instanceof NullVector) { - vector = fieldType.createNewSingleVector(defaultDataVectorName, allocator, callBack); + vector = fieldType.createNewSingleVector(defaultDataVectorName, allocator, repeatedCallBack); // returned vector must have the same field created = true; - if (callBack != null && + if (repeatedCallBack != null && // not a schema change if changing from ZeroVector to ZeroVector (fieldType.getType().getTypeID() != ArrowTypeID.Null)) { - callBack.doWork(); + repeatedCallBack.doWork(); } } @@ -355,6 +355,7 @@ public int getInnerValueCountAt(int index) { } /** Return if value at index is null (this implementation is always false). */ + @Override public boolean isNull(int index) { return false; } @@ -376,6 +377,7 @@ public int startNewValue(int index) { } /** Preallocates the number of repeated values. */ + @Override public void setValueCount(int valueCount) { this.valueCount = valueCount; while (valueCount > getOffsetBufferValueCapacity()) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java index 367335436aecd..48b53d7de2e3f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java @@ -234,11 +234,9 @@ public boolean allocateNewSafe() { } finally { if (!success) { clear(); - return false; } } - - return true; + return success; } private void allocateValidityBuffer(final long size) { @@ -257,12 +255,12 @@ public void reAlloc() { private void reallocValidityBuffer() { final int currentBufferCapacity = checkedCastToInt(validityBuffer.capacity()); - long newAllocationSize = currentBufferCapacity * 2; + long newAllocationSize = currentBufferCapacity * 2L; if (newAllocationSize == 0) { if (validityAllocationSizeInBytes > 0) { newAllocationSize = validityAllocationSizeInBytes; } else { - newAllocationSize = getValidityBufferSizeFromCount(INITIAL_VALUE_ALLOCATION) * 2; + newAllocationSize = getValidityBufferSizeFromCount(INITIAL_VALUE_ALLOCATION) * 2L; } } @@ -273,7 +271,7 @@ private void reallocValidityBuffer() { throw new OversizedAllocationException("Unable to expand the buffer"); } - final ArrowBuf newBuf = allocator.buffer((int) newAllocationSize); + final ArrowBuf newBuf = allocator.buffer(newAllocationSize); newBuf.setBytes(0, validityBuffer, 0, currentBufferCapacity); newBuf.setZero(currentBufferCapacity, newBuf.capacity() - currentBufferCapacity); validityBuffer.getReferenceManager().release(1); @@ -468,6 +466,7 @@ public List getObject(int index) { /** * Returns whether the value at index null. */ + @Override public boolean isNull(int index) { return (isSet(index) == 0); } @@ -503,6 +502,7 @@ private int getValidityBufferValueCapacity() { /** * Sets the value at index to null. Reallocates if index is larger than capacity. */ + @Override public void setNull(int index) { while (index >= getValidityBufferValueCapacity()) { reallocValidityBuffer(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java index 312bed6ab3349..b934cbd81db16 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java @@ -194,7 +194,7 @@ public void setInitialCapacity(int numRecords, double density) { throw new OversizedAllocationException("Requested amount of memory is more than max allowed"); } - offsetAllocationSizeInBytes = (numRecords + 1) * OFFSET_WIDTH; + offsetAllocationSizeInBytes = (numRecords + 1L) * OFFSET_WIDTH; int innerValueCapacity = Math.max((int) (numRecords * density), 1); @@ -222,7 +222,7 @@ public void setInitialCapacity(int numRecords, double density) { * for in this vector across all records. */ public void setInitialTotalCapacity(int numRecords, int totalNumberOfElements) { - offsetAllocationSizeInBytes = (numRecords + 1) * OFFSET_WIDTH; + offsetAllocationSizeInBytes = (numRecords + 1L) * OFFSET_WIDTH; vector.setInitialCapacity(totalNumberOfElements); } @@ -332,6 +332,7 @@ public void allocateNew() throws OutOfMemoryException { * * @return false if memory allocation fails, true otherwise. */ + @Override public boolean allocateNewSafe() { boolean success = false; try { @@ -347,7 +348,7 @@ public boolean allocateNewSafe() { } catch (Exception e) { e.printStackTrace(); clear(); - return false; + success = false; } finally { if (!dataAlloc) { clear(); @@ -357,10 +358,9 @@ public boolean allocateNewSafe() { } finally { if (!success) { clear(); - return false; } } - return true; + return success; } private void allocateValidityBuffer(final long size) { @@ -408,7 +408,7 @@ protected void reallocOffsetBuffer() { } newAllocationSize = CommonUtil.nextPowerOfTwo(newAllocationSize); - newAllocationSize = Math.min(newAllocationSize, (long) (OFFSET_WIDTH) * Integer.MAX_VALUE); + newAllocationSize = Math.min(newAllocationSize, (long) OFFSET_WIDTH * Integer.MAX_VALUE); assert newAllocationSize >= 1; if (newAllocationSize > MAX_ALLOCATION_SIZE || newAllocationSize <= offsetBuffer.capacity()) { @@ -425,12 +425,12 @@ protected void reallocOffsetBuffer() { private void reallocValidityBuffer() { final int currentBufferCapacity = checkedCastToInt(validityBuffer.capacity()); - long newAllocationSize = currentBufferCapacity * 2; + long newAllocationSize = currentBufferCapacity * 2L; if (newAllocationSize == 0) { if (validityAllocationSizeInBytes > 0) { newAllocationSize = validityAllocationSizeInBytes; } else { - newAllocationSize = getValidityBufferSizeFromCount(INITIAL_VALUE_ALLOCATION) * 2; + newAllocationSize = getValidityBufferSizeFromCount(INITIAL_VALUE_ALLOCATION) * 2L; } } newAllocationSize = CommonUtil.nextPowerOfTwo(newAllocationSize); @@ -440,7 +440,7 @@ private void reallocValidityBuffer() { throw new OversizedAllocationException("Unable to expand the buffer"); } - final ArrowBuf newBuf = allocator.buffer((int) newAllocationSize); + final ArrowBuf newBuf = allocator.buffer(newAllocationSize); newBuf.setBytes(0, validityBuffer, 0, currentBufferCapacity); newBuf.setZero(currentBufferCapacity, newBuf.capacity() - currentBufferCapacity); validityBuffer.getReferenceManager().release(1); @@ -526,7 +526,7 @@ public TransferPair makeTransferPair(ValueVector target) { @Override public long getValidityBufferAddress() { - return (validityBuffer.memoryAddress()); + return validityBuffer.memoryAddress(); } @Override @@ -536,7 +536,7 @@ public long getDataBufferAddress() { @Override public long getOffsetBufferAddress() { - return (offsetBuffer.memoryAddress()); + return offsetBuffer.memoryAddress(); } @Override @@ -754,6 +754,7 @@ public UnionLargeListReader getReader() { * Initialize the data vector (and execute callback) if it hasn't already been done, * returns the data vector. */ + @Override public AddOrGetResult addOrGetVector(FieldType fieldType) { boolean created = false; if (vector instanceof NullVector) { @@ -988,6 +989,7 @@ public void setNotNull(int index) { * Sets list at index to be null. * @param index position in vector */ + @Override public void setNull(int index) { while (index >= getValidityAndOffsetValueCapacity()) { reallocValidityAndOffsetBuffers(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index e5a83921b3135..5154ac17279c5 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -291,6 +291,7 @@ public void allocateNew() throws OutOfMemoryException { * * @return false if memory allocation fails, true otherwise. */ + @Override public boolean allocateNewSafe() { boolean success = false; try { @@ -303,10 +304,9 @@ public boolean allocateNewSafe() { } finally { if (!success) { clear(); - return false; } } - return true; + return success; } protected void allocateValidityBuffer(final long size) { @@ -336,12 +336,23 @@ protected void reallocValidityAndOffsetBuffers() { private void reallocValidityBuffer() { final int currentBufferCapacity = checkedCastToInt(validityBuffer.capacity()); - long newAllocationSize = currentBufferCapacity * 2; + long newAllocationSize = getNewAllocationSize(currentBufferCapacity); + + final ArrowBuf newBuf = allocator.buffer(newAllocationSize); + newBuf.setBytes(0, validityBuffer, 0, currentBufferCapacity); + newBuf.setZero(currentBufferCapacity, newBuf.capacity() - currentBufferCapacity); + validityBuffer.getReferenceManager().release(1); + validityBuffer = newBuf; + validityAllocationSizeInBytes = (int) newAllocationSize; + } + + private long getNewAllocationSize(int currentBufferCapacity) { + long newAllocationSize = currentBufferCapacity * 2L; if (newAllocationSize == 0) { if (validityAllocationSizeInBytes > 0) { newAllocationSize = validityAllocationSizeInBytes; } else { - newAllocationSize = getValidityBufferSizeFromCount(INITIAL_VALUE_ALLOCATION) * 2; + newAllocationSize = getValidityBufferSizeFromCount(INITIAL_VALUE_ALLOCATION) * 2L; } } newAllocationSize = CommonUtil.nextPowerOfTwo(newAllocationSize); @@ -350,13 +361,7 @@ private void reallocValidityBuffer() { if (newAllocationSize > MAX_ALLOCATION_SIZE) { throw new OversizedAllocationException("Unable to expand the buffer"); } - - final ArrowBuf newBuf = allocator.buffer((int) newAllocationSize); - newBuf.setBytes(0, validityBuffer, 0, currentBufferCapacity); - newBuf.setZero(currentBufferCapacity, newBuf.capacity() - currentBufferCapacity); - validityBuffer.getReferenceManager().release(1); - validityBuffer = newBuf; - validityAllocationSizeInBytes = (int) newAllocationSize; + return newAllocationSize; } /** @@ -425,7 +430,7 @@ public TransferPair makeTransferPair(ValueVector target) { @Override public long getValidityBufferAddress() { - return (validityBuffer.memoryAddress()); + return validityBuffer.memoryAddress(); } @Override @@ -435,7 +440,7 @@ public long getDataBufferAddress() { @Override public long getOffsetBufferAddress() { - return (offsetBuffer.memoryAddress()); + return offsetBuffer.memoryAddress(); } @Override @@ -625,6 +630,7 @@ public UnionListReader getReader() { } /** Initialize the child data vector to field type. */ + @Override public AddOrGetResult addOrGetVector(FieldType fieldType) { AddOrGetResult result = super.addOrGetVector(fieldType); invalidateReader(); @@ -837,6 +843,7 @@ public void setNotNull(int index) { * Sets list at index to be null. * @param index position in vector */ + @Override public void setNull(int index) { while (index >= getValidityAndOffsetValueCapacity()) { reallocValidityAndOffsetBuffers(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java index 27db1574808a3..9d0dc5ca3fd15 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java @@ -495,10 +495,9 @@ public boolean allocateNewSafe() { } finally { if (!success) { clear(); - return false; } } - return true; + return success; } private void allocateValidityBuffer(final long size) { @@ -518,12 +517,23 @@ public void reAlloc() { private void reallocValidityBuffer() { final int currentBufferCapacity = checkedCastToInt(validityBuffer.capacity()); - long newAllocationSize = currentBufferCapacity * 2; + long newAllocationSize = getNewAllocationSize(currentBufferCapacity); + + final ArrowBuf newBuf = allocator.buffer(newAllocationSize); + newBuf.setBytes(0, validityBuffer, 0, currentBufferCapacity); + newBuf.setZero(currentBufferCapacity, newBuf.capacity() - currentBufferCapacity); + validityBuffer.getReferenceManager().release(1); + validityBuffer = newBuf; + validityAllocationSizeInBytes = (int) newAllocationSize; + } + + private long getNewAllocationSize(int currentBufferCapacity) { + long newAllocationSize = currentBufferCapacity * 2L; if (newAllocationSize == 0) { if (validityAllocationSizeInBytes > 0) { newAllocationSize = validityAllocationSizeInBytes; } else { - newAllocationSize = BitVectorHelper.getValidityBufferSize(BaseValueVector.INITIAL_VALUE_ALLOCATION) * 2; + newAllocationSize = BitVectorHelper.getValidityBufferSize(BaseValueVector.INITIAL_VALUE_ALLOCATION) * 2L; } } newAllocationSize = CommonUtil.nextPowerOfTwo(newAllocationSize); @@ -532,13 +542,7 @@ private void reallocValidityBuffer() { if (newAllocationSize > BaseValueVector.MAX_ALLOCATION_SIZE) { throw new OversizedAllocationException("Unable to expand the buffer"); } - - final ArrowBuf newBuf = allocator.buffer((int) newAllocationSize); - newBuf.setBytes(0, validityBuffer, 0, currentBufferCapacity); - newBuf.setZero(currentBufferCapacity, newBuf.capacity() - currentBufferCapacity); - validityBuffer.getReferenceManager().release(1); - validityBuffer = newBuf; - validityAllocationSizeInBytes = (int) newAllocationSize; + return newAllocationSize; } @Override @@ -607,6 +611,7 @@ public void get(int index, ComplexHolder holder) { /** * Return the number of null values in the vector. */ + @Override public int getNullCount() { return BitVectorHelper.getNullCount(validityBuffer, valueCount); } @@ -614,6 +619,7 @@ public int getNullCount() { /** * Returns true if the value at the provided index is null. */ + @Override public boolean isNull(int index) { return isSet(index) == 0; } @@ -643,6 +649,7 @@ public void setIndexDefined(int index) { /** * Marks the value at index as null/not set. */ + @Override public void setNull(int index) { while (index >= getValidityBufferValueCapacity()) { /* realloc the inner buffers if needed */ diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java index c80fcb89d0cc9..028901ee847da 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java @@ -46,6 +46,7 @@ public int getPosition() { return index; } + @Override public void setPosition(int index) { this.index = index; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java index f7be277f592a6..7f724829ef1eb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java @@ -19,6 +19,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; +import java.util.Locale; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.vector.FieldVector; @@ -301,13 +302,15 @@ public boolean isEmptyStruct() { return writer.isEmptyStruct(); } + @Override protected FieldWriter getWriter() { return writer; } private FieldWriter promoteToUnion() { String name = vector.getField().getName(); - TransferPair tp = vector.getTransferPair(vector.getMinorType().name().toLowerCase(), vector.getAllocator()); + TransferPair tp = vector.getTransferPair(vector.getMinorType().name().toLowerCase(Locale.ROOT), + vector.getAllocator()); tp.transfer(); if (parentContainer != null) { // TODO allow dictionaries in complex types diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/StructOrListWriterImpl.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/StructOrListWriterImpl.java index 5c4cd2af98d55..6a217bbc8b547 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/StructOrListWriterImpl.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/StructOrListWriterImpl.java @@ -88,8 +88,9 @@ public StructOrListWriter struct(final String name) { * * @param name Unused. * - * @deprecated use {@link #listOfStruct()} instead. + * @deprecated use {@link #listOfStruct(String)} instead. */ + @Deprecated public StructOrListWriter listoftstruct(final String name) { return listOfStruct(name); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionFixedSizeListReader.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionFixedSizeListReader.java index ece729ae563af..f69fea3bd5779 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionFixedSizeListReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionFixedSizeListReader.java @@ -99,6 +99,7 @@ public boolean next() { } } + @Override public void copyAsValue(ListWriter writer) { ComplexCopier.copy(this, (FieldWriter) writer); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java index faf088b55981d..0f3ba50f2b3a1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java @@ -34,7 +34,6 @@ public class UnionLargeListReader extends AbstractFieldReader { private LargeListVector vector; private ValueVector data; - private long index; private static final long OFFSET_WIDTH = 8L; public UnionLargeListReader(LargeListVector vector) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionListReader.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionListReader.java index 74548bc985f6a..7dadcabdcee88 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionListReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionListReader.java @@ -60,8 +60,8 @@ public void setPosition(int index) { currentOffset = 0; maxOffset = 0; } else { - currentOffset = vector.getOffsetBuffer().getInt(index * OFFSET_WIDTH) - 1; - maxOffset = vector.getOffsetBuffer().getInt((index + 1) * OFFSET_WIDTH); + currentOffset = vector.getOffsetBuffer().getInt(index * (long) OFFSET_WIDTH) - 1; + maxOffset = vector.getOffsetBuffer().getInt((index + 1) * (long) OFFSET_WIDTH); } } @@ -106,6 +106,7 @@ public boolean next() { } } + @Override public void copyAsValue(ListWriter writer) { ComplexCopier.copy(this, (FieldWriter) writer); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java index 6f40e5814b972..5687e4025acee 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java @@ -60,7 +60,7 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof Dictionary)) { return false; } Dictionary that = (Dictionary) o; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/table/package-info.java b/java/vector/src/main/java/org/apache/arrow/vector/table/package-info.java index cdd5093b9f554..b11ada51292f9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/table/package-info.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/table/package-info.java @@ -17,7 +17,7 @@ package org.apache.arrow.vector.table; -/** +/* * Support for Table, an immutable, columnar, tabular data structure based on FieldVectors. * See the Arrow Java documentation for details: Table */ diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/FloatingPointPrecision.java b/java/vector/src/main/java/org/apache/arrow/vector/types/FloatingPointPrecision.java index c52fc1243d99f..85c2532236866 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/FloatingPointPrecision.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/FloatingPointPrecision.java @@ -39,7 +39,7 @@ public enum FloatingPointPrecision { } } - private short flatbufID; + private final short flatbufID; private FloatingPointPrecision(short flatbufID) { this.flatbufID = flatbufID; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/IntervalUnit.java b/java/vector/src/main/java/org/apache/arrow/vector/types/IntervalUnit.java index 1b17240d016b3..d2314ea7cce3c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/IntervalUnit.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/IntervalUnit.java @@ -36,7 +36,7 @@ public enum IntervalUnit { } } - private short flatbufID; + private final short flatbufID; private IntervalUnit(short flatbufID) { this.flatbufID = flatbufID; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java index 8d41b92d867e9..592e18826f09c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java @@ -74,7 +74,7 @@ public String toString() { public boolean equals(Object o) { if (this == o) { return true; - } else if (o == null || getClass() != o.getClass()) { + } else if (!(o instanceof DictionaryEncoding)) { return false; } DictionaryEncoding that = (DictionaryEncoding) o; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java index 54c609d4a104f..d3623618e7a55 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java @@ -37,7 +37,6 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Collections2; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.TypeLayout; import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -93,16 +92,19 @@ private Field( this(name, new FieldType(nullable, type, dictionary, convertMetadata(metadata)), children); } - private Field(String name, FieldType fieldType, List children, TypeLayout typeLayout) { + /** + * Constructs a new Field object. + * + * @param name name of the field + * @param fieldType type of the field + * @param children child fields, if any + */ + public Field(String name, FieldType fieldType, List children) { this.name = name; this.fieldType = checkNotNull(fieldType); this.children = children == null ? Collections.emptyList() : Collections2.toImmutableList(children); } - public Field(String name, FieldType fieldType, List children) { - this(name, fieldType, children, fieldType == null ? null : TypeLayout.getTypeLayout(fieldType.getType())); - } - /** * Construct a new vector of this type using the given allocator. */ @@ -279,7 +281,7 @@ public boolean equals(Object obj) { } Field that = (Field) obj; return Objects.equals(this.name, that.name) && - Objects.equals(this.isNullable(), that.isNullable()) && + this.isNullable() == that.isNullable() && Objects.equals(this.getType(), that.getType()) && Objects.equals(this.getDictionary(), that.getDictionary()) && Objects.equals(this.getMetadata(), that.getMetadata()) && diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java index d5c0d85671fcc..8988993920d79 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java @@ -118,7 +118,7 @@ public boolean equals(Object obj) { return false; } FieldType that = (FieldType) obj; - return Objects.equals(this.isNullable(), that.isNullable()) && + return this.isNullable() == that.isNullable() && Objects.equals(this.getType(), that.getType()) && Objects.equals(this.getDictionary(), that.getDictionary()) && Objects.equals(this.getMetadata(), that.getMetadata()); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Schema.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Schema.java index dcffea0ef5367..392b3c2e2ec73 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Schema.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Schema.java @@ -159,10 +159,10 @@ private Schema(@JsonProperty("fields") Iterable fields, /** * Private constructor to bypass automatic collection copy. - * @param unsafe a ignored argument. Its only purpose is to prevent using the constructor + * @param ignored an ignored argument. Its only purpose is to prevent using the constructor * by accident because of type collisions (List vs Iterable). */ - private Schema(boolean unsafe, List fields, Map metadata) { + private Schema(boolean ignored, List fields, Map metadata) { this.fields = fields; this.metadata = metadata; } @@ -245,13 +245,12 @@ public int getSchema(FlatBufferBuilder builder) { /** * Returns the serialized flatbuffer bytes of the schema wrapped in a message table. - * Use {@link #deserializeMessage() to rebuild the Schema.} + * Use {@link #deserializeMessage(ByteBuffer)} to rebuild the Schema. */ public byte[] serializeAsMessage() { ByteArrayOutputStream out = new ByteArrayOutputStream(); try (WriteChannel channel = new WriteChannel(Channels.newChannel(out))) { - long size = MessageSerializer.serialize( - new WriteChannel(Channels.newChannel(out)), this); + MessageSerializer.serialize(channel, this); return out.toByteArray(); } catch (IOException ex) { throw new RuntimeException(ex); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/MapWithOrdinalImpl.java b/java/vector/src/main/java/org/apache/arrow/vector/util/MapWithOrdinalImpl.java index 7c9c0e9408860..1f18587afdfd1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/MapWithOrdinalImpl.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/MapWithOrdinalImpl.java @@ -27,8 +27,6 @@ import java.util.stream.Collectors; import org.eclipse.collections.impl.map.mutable.primitive.IntObjectHashMap; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * An implementation of map that supports constant time look-up by a generic key or an ordinal. @@ -48,7 +46,6 @@ * @param value type */ public class MapWithOrdinalImpl implements MapWithOrdinal { - private static final Logger logger = LoggerFactory.getLogger(MapWithOrdinalImpl.class); private final Map> primary = new LinkedHashMap<>(); private final IntObjectHashMap secondary = new IntObjectHashMap<>(); @@ -93,10 +90,6 @@ public V put(K key, V value) { return oldPair == null ? null : oldPair.getValue(); } - public boolean put(K key, V value, boolean override) { - return put(key, value) != null; - } - @Override public V remove(Object key) { final Entry oldPair = primary.remove(key); @@ -146,6 +139,7 @@ public Set> entrySet() { * @param id ordinal value for lookup * @return an instance of V */ + @Override public V getByOrdinal(int id) { return secondary.get(id); } @@ -156,6 +150,7 @@ public V getByOrdinal(int id) { * @param key key for ordinal lookup * @return ordinal value corresponding to key if it exists or -1 */ + @Override public int getOrdinal(K key) { Map.Entry pair = primary.get(key); if (pair != null) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/Text.java b/java/vector/src/main/java/org/apache/arrow/vector/util/Text.java index 5f5f5d3bd6d22..95e35ce6938c3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/Text.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/Text.java @@ -210,7 +210,7 @@ public void set(String string) { } /** - * Set to a utf8 byte array. + * Set to an utf8 byte array. * * @param utf8 the byte array to initialize from */ diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ITTestLargeVector.java b/java/vector/src/test/java/org/apache/arrow/vector/ITTestLargeVector.java index 19648dc9e13fb..8596399e7e08c 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ITTestLargeVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ITTestLargeVector.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -133,7 +134,7 @@ public void testLargeDecimalVector() { for (int i = 0; i < vecLength; i++) { ArrowBuf buf = largeVec.get(i); - assertEquals(buf.capacity(), DecimalVector.TYPE_WIDTH); + assertEquals(DecimalVector.TYPE_WIDTH, buf.capacity()); assertEquals(0, buf.getLong(0)); assertEquals(0, buf.getLong(8)); @@ -215,7 +216,7 @@ public void testLargeVarCharVector() { logger.trace("Successfully allocated a vector with capacity " + vecLength); for (int i = 0; i < vecLength; i++) { - largeVec.setSafe(i, strElement.getBytes()); + largeVec.setSafe(i, strElement.getBytes(StandardCharsets.UTF_8)); if ((i + 1) % 10000 == 0) { logger.trace("Successfully written " + (i + 1) + " values"); @@ -228,7 +229,7 @@ public void testLargeVarCharVector() { for (int i = 0; i < vecLength; i++) { byte[] val = largeVec.get(i); - assertEquals(strElement, new String(val)); + assertEquals(strElement, new String(val, StandardCharsets.UTF_8)); if ((i + 1) % 10000 == 0) { logger.trace("Successfully read " + (i + 1) + " values"); @@ -254,7 +255,7 @@ public void testLargeLargeVarCharVector() { logger.trace("Successfully allocated a vector with capacity " + vecLength); for (int i = 0; i < vecLength; i++) { - largeVec.setSafe(i, strElement.getBytes()); + largeVec.setSafe(i, strElement.getBytes(StandardCharsets.UTF_8)); if ((i + 1) % 10000 == 0) { logger.trace("Successfully written " + (i + 1) + " values"); @@ -267,7 +268,7 @@ public void testLargeLargeVarCharVector() { for (int i = 0; i < vecLength; i++) { byte[] val = largeVec.get(i); - assertEquals(strElement, new String(val)); + assertEquals(strElement, new String(val, StandardCharsets.UTF_8)); if ((i + 1) % 10000 == 0) { logger.trace("Successfully read " + (i + 1) + " values"); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java index 96005dc511cab..1da4a4c4914b9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestBitVectorHelper.java @@ -47,7 +47,7 @@ public void testGetNullCount() throws Exception { validityBuffer.setByte(0, 0xFF); count = BitVectorHelper.getNullCount(validityBuffer, 8); - assertEquals(count, 0); + assertEquals(0, count); validityBuffer.close(); // test case 3, 1 null value for 0x7F @@ -55,7 +55,7 @@ public void testGetNullCount() throws Exception { validityBuffer.setByte(0, 0x7F); count = BitVectorHelper.getNullCount(validityBuffer, 8); - assertEquals(count, 1); + assertEquals(1, count); validityBuffer.close(); // test case 4, validity buffer has multiple bytes, 11 items @@ -64,7 +64,7 @@ public void testGetNullCount() throws Exception { validityBuffer.setByte(1, 0b01010101); count = BitVectorHelper.getNullCount(validityBuffer, 11); - assertEquals(count, 5); + assertEquals(5, count); validityBuffer.close(); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestBufferOwnershipTransfer.java b/java/vector/src/test/java/org/apache/arrow/vector/TestBufferOwnershipTransfer.java index 8efadad9b3bf4..056b6bdd2b787 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestBufferOwnershipTransfer.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestBufferOwnershipTransfer.java @@ -21,6 +21,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.ReferenceManager; import org.apache.arrow.memory.RootAllocator; @@ -65,7 +67,7 @@ public void testTransferVariableWidth() { VarCharVector v1 = new VarCharVector("v1", childAllocator1); v1.allocateNew(); - v1.setSafe(4094, "hello world".getBytes(), 0, 11); + v1.setSafe(4094, "hello world".getBytes(StandardCharsets.UTF_8), 0, 11); v1.setValueCount(4001); VarCharVector v2 = new VarCharVector("v2", childAllocator2); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestCopyFrom.java b/java/vector/src/test/java/org/apache/arrow/vector/TestCopyFrom.java index 3786f63c31bb6..97de27bec8237 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestCopyFrom.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestCopyFrom.java @@ -23,9 +23,10 @@ import static org.junit.Assert.assertNull; import java.math.BigDecimal; -import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Period; +import java.util.Objects; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -84,7 +85,7 @@ public void testCopyFromWithNulls() { if (i % 3 == 0) { continue; } - byte[] b = Integer.toString(i).getBytes(); + byte[] b = Integer.toString(i).getBytes(StandardCharsets.UTF_8); vector.setSafe(i, b, 0, b.length); } @@ -156,7 +157,7 @@ public void testCopyFromWithNulls1() { if (i % 3 == 0) { continue; } - byte[] b = Integer.toString(i).getBytes(); + byte[] b = Integer.toString(i).getBytes(StandardCharsets.UTF_8); vector.setSafe(i, b, 0, b.length); } @@ -950,7 +951,7 @@ public void testCopyFromWithNulls13() { assertEquals(0, vector1.getValueCount()); int initialCapacity = vector1.getValueCapacity(); - final double baseValue = 104567897654.876543654; + final double baseValue = 104567897654.87654; final BigDecimal[] decimals = new BigDecimal[4096]; for (int i = 0; i < initialCapacity; i++) { if ((i & 1) == 0) { @@ -1082,13 +1083,13 @@ public void testCopySafeArrow7837() { // to trigger a reallocation of the vector. vc2.setInitialCapacity(/*valueCount*/20, /*density*/0.5); - vc1.setSafe(0, "1234567890".getBytes(Charset.forName("utf-8"))); + vc1.setSafe(0, "1234567890".getBytes(StandardCharsets.UTF_8)); assertFalse(vc1.isNull(0)); - assertEquals(vc1.getObject(0).toString(), "1234567890"); + assertEquals("1234567890", Objects.requireNonNull(vc1.getObject(0)).toString()); vc2.copyFromSafe(0, 0, vc1); assertFalse(vc2.isNull(0)); - assertEquals(vc2.getObject(0).toString(), "1234567890"); + assertEquals("1234567890", Objects.requireNonNull(vc2.getObject(0)).toString()); vc2.copyFromSafe(0, 5, vc1); assertTrue(vc2.isNull(1)); @@ -1096,7 +1097,7 @@ public void testCopySafeArrow7837() { assertTrue(vc2.isNull(3)); assertTrue(vc2.isNull(4)); assertFalse(vc2.isNull(5)); - assertEquals(vc2.getObject(5).toString(), "1234567890"); + assertEquals("1234567890", Objects.requireNonNull(vc2.getObject(5)).toString()); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimal256Vector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimal256Vector.java index b703959d2bb1e..fc5dfc38587a4 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimal256Vector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimal256Vector.java @@ -40,8 +40,8 @@ public class TestDecimal256Vector { static { intValues = new long[60]; for (int i = 0; i < intValues.length / 2; i++) { - intValues[i] = 1 << i + 1; - intValues[2 * i] = -1 * (1 << i + 1); + intValues[i] = 1L << (i + 1); + intValues[2 * i] = -1L * (1 << (i + 1)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java index ba25cbe8b52a0..572f13fea1ed1 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java @@ -40,8 +40,8 @@ public class TestDecimalVector { static { intValues = new long[60]; for (int i = 0; i < intValues.length / 2; i++) { - intValues[i] = 1 << i + 1; - intValues[2 * i] = -1 * (1 << i + 1); + intValues[i] = 1L << (i + 1); + intValues[2 * i] = -1L * (1 << (i + 1)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDenseUnionVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDenseUnionVector.java index 9cb12481612b2..8fd33eb5a8432 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDenseUnionVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDenseUnionVector.java @@ -349,16 +349,16 @@ public void testGetFieldTypeInfo() throws Exception { assertEquals(vector.getField(), field); // Union has 2 child vectors - assertEquals(vector.size(), 2); + assertEquals(2, vector.size()); // Check child field 0 VectorWithOrdinal intChild = vector.getChildVectorWithOrdinal("int"); - assertEquals(intChild.ordinal, 0); + assertEquals(0, intChild.ordinal); assertEquals(intChild.vector.getField(), children.get(0)); // Check child field 1 VectorWithOrdinal varcharChild = vector.getChildVectorWithOrdinal("varchar"); - assertEquals(varcharChild.ordinal, 1); + assertEquals(1, varcharChild.ordinal); assertEquals(varcharChild.vector.getField(), children.get(1)); } @@ -458,8 +458,8 @@ public void testMultipleStructs() { // register relative types byte typeId1 = unionVector.registerNewTypeId(structVector1.getField()); byte typeId2 = unionVector.registerNewTypeId(structVector2.getField()); - assertEquals(typeId1, 0); - assertEquals(typeId2, 1); + assertEquals(0, typeId1); + assertEquals(1, typeId2); // add two struct vectors to union vector unionVector.addVector(typeId1, structVector1); @@ -519,8 +519,8 @@ public void testMultipleVarChars() { byte typeId1 = unionVector.registerNewTypeId(childVector1.getField()); byte typeId2 = unionVector.registerNewTypeId(childVector2.getField()); - assertEquals(typeId1, 0); - assertEquals(typeId2, 1); + assertEquals(0, typeId1); + assertEquals(1, typeId2); while (unionVector.getValueCapacity() < 5) { unionVector.reAlloc(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index 501059733c616..9ffa79470eeb8 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -552,7 +552,7 @@ public void testEncodeWithEncoderInstance() { // now run through the decoder and verify we get the original back try (ValueVector decoded = encoder.decode(encoded)) { assertEquals(vector.getClass(), decoded.getClass()); - assertEquals(vector.getValueCount(), (decoded).getValueCount()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); for (int i = 0; i < 5; i++) { assertEquals(vector.getObject(i), ((VarCharVector) decoded).getObject(i)); } @@ -591,7 +591,7 @@ public void testEncodeMultiVectors() { // now run through the decoder and verify we get the original back try (ValueVector decoded = encoder.decode(encoded)) { assertEquals(vector1.getClass(), decoded.getClass()); - assertEquals(vector1.getValueCount(), (decoded).getValueCount()); + assertEquals(vector1.getValueCount(), decoded.getValueCount()); for (int i = 0; i < 5; i++) { assertEquals(vector1.getObject(i), ((VarCharVector) decoded).getObject(i)); } @@ -611,7 +611,7 @@ public void testEncodeMultiVectors() { // now run through the decoder and verify we get the original back try (ValueVector decoded = encoder.decode(encoded)) { assertEquals(vector2.getClass(), decoded.getClass()); - assertEquals(vector2.getValueCount(), (decoded).getValueCount()); + assertEquals(vector2.getValueCount(), decoded.getValueCount()); for (int i = 0; i < 3; i++) { assertEquals(vector2.getObject(i), ((VarCharVector) decoded).getObject(i)); } @@ -841,7 +841,8 @@ public void testEncodeStructSubFieldWithCertainColumns() { // initialize dictionaries DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); - setVector(dictVector1, "aa".getBytes(), "bb".getBytes(), "cc".getBytes(), "dd".getBytes()); + setVector(dictVector1, "aa".getBytes(StandardCharsets.UTF_8), "bb".getBytes(StandardCharsets.UTF_8), + "cc".getBytes(StandardCharsets.UTF_8), "dd".getBytes(StandardCharsets.UTF_8)); provider.put(new Dictionary(dictVector1, new DictionaryEncoding(1L, false, null))); StructSubfieldEncoder encoder = new StructSubfieldEncoder(allocator, provider); @@ -1049,20 +1050,20 @@ private void testDictionary(Dictionary dictionary, ToIntBiFunction ((UInt2Vector) vector).get(index)); } } @@ -1096,7 +1097,7 @@ public void testDictionaryUInt4() { setVector(dictionaryVector, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"); Dictionary dictionary4 = new Dictionary(dictionaryVector, new DictionaryEncoding(/*id=*/30L, /*ordered=*/false, - /*indexType=*/new ArrowType.Int(/*indexType=*/32, /*isSigned*/false))); + /*indexType=*/new ArrowType.Int(/*bitWidth=*/32, /*isSigned*/false))); testDictionary(dictionary4, (vector, index) -> ((UInt4Vector) vector).get(index)); } } @@ -1107,7 +1108,7 @@ public void testDictionaryUInt8() { setVector(dictionaryVector, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"); Dictionary dictionary8 = new Dictionary(dictionaryVector, new DictionaryEncoding(/*id=*/40L, /*ordered=*/false, - /*indexType=*/new ArrowType.Int(/*indexType=*/64, /*isSigned*/false))); + /*indexType=*/new ArrowType.Int(/*bitWidth=*/64, /*isSigned*/false))); testDictionary(dictionary8, (vector, index) -> (int) ((UInt8Vector) vector).get(index)); } } @@ -1119,13 +1120,13 @@ public void testDictionaryUIntOverflow() { try (VarCharVector dictionaryVector = new VarCharVector("dict vector", allocator)) { dictionaryVector.allocateNew(vecLength * 3, vecLength); for (int i = 0; i < vecLength; i++) { - dictionaryVector.set(i, String.valueOf(i).getBytes()); + dictionaryVector.set(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } dictionaryVector.setValueCount(vecLength); Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(/*id=*/10L, /*ordered=*/false, - /*indexType=*/new ArrowType.Int(/*indexType=*/8, /*isSigned*/false))); + /*indexType=*/new ArrowType.Int(/*bitWidth=*/8, /*isSigned*/false))); try (VarCharVector vector = new VarCharVector("vector", allocator)) { setVector(vector, "255"); @@ -1137,7 +1138,7 @@ public void testDictionaryUIntOverflow() { try (VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dictionary)) { assertEquals(1, decodedVector.getValueCount()); - assertArrayEquals("255".getBytes(), decodedVector.get(0)); + assertArrayEquals("255".getBytes(StandardCharsets.UTF_8), decodedVector.get(0)); } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestFixedSizeListVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestFixedSizeListVector.java index 0023b1dddb8e7..bde6dd491dd71 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestFixedSizeListVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestFixedSizeListVector.java @@ -25,6 +25,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; @@ -61,7 +62,7 @@ public void terminate() throws Exception { @Test public void testIntType() { - try (FixedSizeListVector vector = FixedSizeListVector.empty("list", 2, allocator)) { + try (FixedSizeListVector vector = FixedSizeListVector.empty("list", /*size=*/2, allocator)) { IntVector nested = (IntVector) vector.addOrGetVector(FieldType.nullable(MinorType.INT.getType())).getVector(); vector.allocateNew(); @@ -88,7 +89,7 @@ public void testIntType() { @Test public void testFloatTypeNullable() { - try (FixedSizeListVector vector = FixedSizeListVector.empty("list", 2, allocator)) { + try (FixedSizeListVector vector = FixedSizeListVector.empty("list", /*size=*/2, allocator)) { Float4Vector nested = (Float4Vector) vector.addOrGetVector(FieldType.nullable(MinorType.FLOAT4.getType())) .getVector(); vector.allocateNew(); @@ -235,7 +236,7 @@ public void testTransferPair() { @Test public void testConsistentChildName() throws Exception { - try (FixedSizeListVector listVector = FixedSizeListVector.empty("sourceVector", 2, allocator)) { + try (FixedSizeListVector listVector = FixedSizeListVector.empty("sourceVector", /*size=*/2, allocator)) { String emptyListStr = listVector.getField().toString(); Assert.assertTrue(emptyListStr.contains(ListVector.DATA_VECTOR_NAME)); @@ -251,7 +252,7 @@ public void testUnionFixedSizeListWriterWithNulls() throws Exception { * each list of size 3 and having its data values alternating between null and a non-null. * Read and verify */ - try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*listSize=*/3, allocator)) { + try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*size=*/3, allocator)) { UnionFixedSizeListWriter writer = vector.getWriter(); writer.allocate(); @@ -279,7 +280,7 @@ public void testUnionFixedSizeListWriterWithNulls() throws Exception { @Test public void testUnionFixedSizeListWriter() throws Exception { - try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", 3, allocator)) { + try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", /*size=*/3, allocator)) { UnionFixedSizeListWriter writer1 = vector1.getWriter(); writer1.allocate(); @@ -307,7 +308,7 @@ public void testUnionFixedSizeListWriter() throws Exception { @Test public void testWriteDecimal() throws Exception { - try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*listSize=*/3, allocator)) { + try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*size=*/3, allocator)) { UnionFixedSizeListWriter writer = vector.getWriter(); writer.allocate(); @@ -335,7 +336,7 @@ public void testWriteDecimal() throws Exception { @Test public void testDecimalIndexCheck() throws Exception { - try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*listSize=*/3, allocator)) { + try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*size=*/3, allocator)) { UnionFixedSizeListWriter writer = vector.getWriter(); writer.allocate(); @@ -355,7 +356,7 @@ public void testDecimalIndexCheck() throws Exception { @Test(expected = IllegalStateException.class) public void testWriteIllegalData() throws Exception { - try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", 3, allocator)) { + try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", /*size=*/3, allocator)) { UnionFixedSizeListWriter writer1 = vector1.getWriter(); writer1.allocate(); @@ -378,7 +379,7 @@ public void testWriteIllegalData() throws Exception { @Test public void testSplitAndTransfer() throws Exception { - try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", 3, allocator)) { + try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", /*size=*/3, allocator)) { UnionFixedSizeListWriter writer1 = vector1.getWriter(); writer1.allocate(); @@ -399,9 +400,9 @@ public void testSplitAndTransfer() throws Exception { assertEquals(2, targetVector.getValueCount()); int[] realValue1 = convertListToIntArray(targetVector.getObject(0)); - assertTrue(Arrays.equals(values1, realValue1)); + assertArrayEquals(values1, realValue1); int[] realValue2 = convertListToIntArray(targetVector.getObject(1)); - assertTrue(Arrays.equals(values2, realValue2)); + assertArrayEquals(values2, realValue2); targetVector.clear(); } @@ -409,7 +410,7 @@ public void testSplitAndTransfer() throws Exception { @Test public void testZeroWidthVector() { - try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", 0, allocator)) { + try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", /*size=*/0, allocator)) { UnionFixedSizeListWriter writer1 = vector1.getWriter(); writer1.allocate(); @@ -440,7 +441,7 @@ public void testZeroWidthVector() { @Test public void testVectorWithNulls() { - try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", 4, allocator)) { + try (final FixedSizeListVector vector1 = FixedSizeListVector.empty("vector", /*size=*/4, allocator)) { UnionFixedSizeListWriter writer1 = vector1.getWriter(); writer1.allocate(); @@ -472,7 +473,7 @@ public void testVectorWithNulls() { @Test public void testWriteVarCharHelpers() throws Exception { - try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*listSize=*/4, allocator)) { + try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*size=*/4, allocator)) { UnionFixedSizeListWriter writer = vector.getWriter(); writer.allocate(); @@ -491,7 +492,7 @@ public void testWriteVarCharHelpers() throws Exception { @Test public void testWriteLargeVarCharHelpers() throws Exception { - try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*listSize=*/4, allocator)) { + try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*size=*/4, allocator)) { UnionFixedSizeListWriter writer = vector.getWriter(); writer.allocate(); @@ -510,43 +511,47 @@ public void testWriteLargeVarCharHelpers() throws Exception { @Test public void testWriteVarBinaryHelpers() throws Exception { - try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*listSize=*/4, allocator)) { + try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*size=*/4, allocator)) { UnionFixedSizeListWriter writer = vector.getWriter(); writer.allocate(); writer.startList(); - writer.writeVarBinary("row1,1".getBytes()); - writer.writeVarBinary("row1,2".getBytes(), 0, "row1,2".getBytes().length); - writer.writeVarBinary(ByteBuffer.wrap("row1,3".getBytes())); - writer.writeVarBinary(ByteBuffer.wrap("row1,4".getBytes()), 0, "row1,4".getBytes().length); + writer.writeVarBinary("row1,1".getBytes(StandardCharsets.UTF_8)); + writer.writeVarBinary("row1,2".getBytes(StandardCharsets.UTF_8), 0, + "row1,2".getBytes(StandardCharsets.UTF_8).length); + writer.writeVarBinary(ByteBuffer.wrap("row1,3".getBytes(StandardCharsets.UTF_8))); + writer.writeVarBinary(ByteBuffer.wrap("row1,4".getBytes(StandardCharsets.UTF_8)), 0, + "row1,4".getBytes(StandardCharsets.UTF_8).length); writer.endList(); - assertEquals("row1,1", new String((byte[]) (vector.getObject(0).get(0)))); - assertEquals("row1,2", new String((byte[]) (vector.getObject(0).get(1)))); - assertEquals("row1,3", new String((byte[]) (vector.getObject(0).get(2)))); - assertEquals("row1,4", new String((byte[]) (vector.getObject(0).get(3)))); + assertEquals("row1,1", new String((byte[]) vector.getObject(0).get(0), StandardCharsets.UTF_8)); + assertEquals("row1,2", new String((byte[]) vector.getObject(0).get(1), StandardCharsets.UTF_8)); + assertEquals("row1,3", new String((byte[]) vector.getObject(0).get(2), StandardCharsets.UTF_8)); + assertEquals("row1,4", new String((byte[]) vector.getObject(0).get(3), StandardCharsets.UTF_8)); } } @Test public void testWriteLargeVarBinaryHelpers() throws Exception { - try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*listSize=*/4, allocator)) { + try (final FixedSizeListVector vector = FixedSizeListVector.empty("vector", /*size=*/4, allocator)) { UnionFixedSizeListWriter writer = vector.getWriter(); writer.allocate(); writer.startList(); - writer.writeLargeVarBinary("row1,1".getBytes()); - writer.writeLargeVarBinary("row1,2".getBytes(), 0, "row1,2".getBytes().length); - writer.writeLargeVarBinary(ByteBuffer.wrap("row1,3".getBytes())); - writer.writeLargeVarBinary(ByteBuffer.wrap("row1,4".getBytes()), 0, "row1,4".getBytes().length); + writer.writeLargeVarBinary("row1,1".getBytes(StandardCharsets.UTF_8)); + writer.writeLargeVarBinary("row1,2".getBytes(StandardCharsets.UTF_8), 0, + "row1,2".getBytes(StandardCharsets.UTF_8).length); + writer.writeLargeVarBinary(ByteBuffer.wrap("row1,3".getBytes(StandardCharsets.UTF_8))); + writer.writeLargeVarBinary(ByteBuffer.wrap("row1,4".getBytes(StandardCharsets.UTF_8)), 0, + "row1,4".getBytes(StandardCharsets.UTF_8).length); writer.endList(); - assertEquals("row1,1", new String((byte[]) (vector.getObject(0).get(0)))); - assertEquals("row1,2", new String((byte[]) (vector.getObject(0).get(1)))); - assertEquals("row1,3", new String((byte[]) (vector.getObject(0).get(2)))); - assertEquals("row1,4", new String((byte[]) (vector.getObject(0).get(3)))); + assertEquals("row1,1", new String((byte[]) vector.getObject(0).get(0), StandardCharsets.UTF_8)); + assertEquals("row1,2", new String((byte[]) vector.getObject(0).get(1), StandardCharsets.UTF_8)); + assertEquals("row1,3", new String((byte[]) vector.getObject(0).get(2), StandardCharsets.UTF_8)); + assertEquals("row1,4", new String((byte[]) vector.getObject(0).get(3), StandardCharsets.UTF_8)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java index 993ce0b089769..ffd87c99d508d 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java @@ -102,9 +102,9 @@ public void testCopyFrom() throws Exception { Object result = outVector.getObject(0); ArrayList resultSet = (ArrayList) result; assertEquals(3, resultSet.size()); - assertEquals(new Long(1), resultSet.get(0)); - assertEquals(new Long(2), resultSet.get(1)); - assertEquals(new Long(3), resultSet.get(2)); + assertEquals(Long.valueOf(1), resultSet.get(0)); + assertEquals(Long.valueOf(2), resultSet.get(1)); + assertEquals(Long.valueOf(3), resultSet.get(2)); /* index 1 */ result = outVector.getObject(1); @@ -143,7 +143,7 @@ public void testSetLastSetUsage() throws Exception { assertEquals(-1L, listVector.getLastSet()); int index = 0; - int offset = 0; + int offset; /* write [10, 11, 12] to the list vector at index 0 */ BitVectorHelper.setBit(validityBuffer, index); @@ -222,41 +222,40 @@ public void testSetLastSetUsage() throws Exception { assertEquals(Integer.toString(0), Integer.toString(offset)); Long actual = dataVector.getObject(offset); - assertEquals(new Long(10), actual); + assertEquals(Long.valueOf(10), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(11), actual); + assertEquals(Long.valueOf(11), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(12), actual); + assertEquals(Long.valueOf(12), actual); index++; offset = (int) offsetBuffer.getLong(index * LargeListVector.OFFSET_WIDTH); assertEquals(Integer.toString(3), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(13), actual); + assertEquals(Long.valueOf(13), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(14), actual); + assertEquals(Long.valueOf(14), actual); index++; offset = (int) offsetBuffer.getLong(index * LargeListVector.OFFSET_WIDTH); assertEquals(Integer.toString(5), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(15), actual); + assertEquals(Long.valueOf(15), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(16), actual); + assertEquals(Long.valueOf(16), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(17), actual); + assertEquals(Long.valueOf(17), actual); index++; offset = (int) offsetBuffer.getLong(index * LargeListVector.OFFSET_WIDTH); assertEquals(Integer.toString(8), Integer.toString(offset)); - actual = dataVector.getObject(offset); assertNull(actual); } @@ -323,8 +322,8 @@ public void testSplitAndTransfer() throws Exception { /* check the vector output */ int index = 0; - int offset = 0; - Long actual = null; + int offset; + Long actual; /* index 0 */ assertFalse(listVector.isNull(index)); @@ -332,13 +331,13 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(0), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(10), actual); + assertEquals(Long.valueOf(10), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(11), actual); + assertEquals(Long.valueOf(11), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(12), actual); + assertEquals(Long.valueOf(12), actual); /* index 1 */ index++; @@ -347,10 +346,10 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(3), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(13), actual); + assertEquals(Long.valueOf(13), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(14), actual); + assertEquals(Long.valueOf(14), actual); /* index 2 */ index++; @@ -359,16 +358,16 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(5), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(15), actual); + assertEquals(Long.valueOf(15), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(16), actual); + assertEquals(Long.valueOf(16), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(17), actual); + assertEquals(Long.valueOf(17), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(18), actual); + assertEquals(Long.valueOf(18), actual); /* index 3 */ index++; @@ -377,7 +376,7 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(9), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(19), actual); + assertEquals(Long.valueOf(19), actual); /* index 4 */ index++; @@ -386,16 +385,16 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(10), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(20), actual); + assertEquals(Long.valueOf(20), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(21), actual); + assertEquals(Long.valueOf(21), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(22), actual); + assertEquals(Long.valueOf(22), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(23), actual); + assertEquals(Long.valueOf(23), actual); /* index 5 */ index++; @@ -522,15 +521,15 @@ public void testNestedLargeListVector() throws Exception { assertEquals(4, resultSet.get(1).size()); /* size of second inner list */ list = resultSet.get(0); - assertEquals(new Long(50), list.get(0)); - assertEquals(new Long(100), list.get(1)); - assertEquals(new Long(200), list.get(2)); + assertEquals(Long.valueOf(50), list.get(0)); + assertEquals(Long.valueOf(100), list.get(1)); + assertEquals(Long.valueOf(200), list.get(2)); list = resultSet.get(1); - assertEquals(new Long(75), list.get(0)); - assertEquals(new Long(125), list.get(1)); - assertEquals(new Long(150), list.get(2)); - assertEquals(new Long(175), list.get(3)); + assertEquals(Long.valueOf(75), list.get(0)); + assertEquals(Long.valueOf(125), list.get(1)); + assertEquals(Long.valueOf(150), list.get(2)); + assertEquals(Long.valueOf(175), list.get(3)); /* get listVector value at index 1 -- the value itself is a listvector */ result = listVector.getObject(1); @@ -542,16 +541,16 @@ public void testNestedLargeListVector() throws Exception { assertEquals(3, resultSet.get(2).size()); /* size of third inner list */ list = resultSet.get(0); - assertEquals(new Long(10), list.get(0)); + assertEquals(Long.valueOf(10), list.get(0)); list = resultSet.get(1); - assertEquals(new Long(15), list.get(0)); - assertEquals(new Long(20), list.get(1)); + assertEquals(Long.valueOf(15), list.get(0)); + assertEquals(Long.valueOf(20), list.get(1)); list = resultSet.get(2); - assertEquals(new Long(25), list.get(0)); - assertEquals(new Long(30), list.get(1)); - assertEquals(new Long(35), list.get(2)); + assertEquals(Long.valueOf(25), list.get(0)); + assertEquals(Long.valueOf(30), list.get(1)); + assertEquals(Long.valueOf(35), list.get(2)); /* check underlying bitVector */ assertFalse(listVector.isNull(0)); @@ -656,13 +655,13 @@ public void testNestedLargeListVector2() throws Exception { assertEquals(2, resultSet.get(1).size()); /* size of second inner list */ list = resultSet.get(0); - assertEquals(new Long(50), list.get(0)); - assertEquals(new Long(100), list.get(1)); - assertEquals(new Long(200), list.get(2)); + assertEquals(Long.valueOf(50), list.get(0)); + assertEquals(Long.valueOf(100), list.get(1)); + assertEquals(Long.valueOf(200), list.get(2)); list = resultSet.get(1); - assertEquals(new Long(75), list.get(0)); - assertEquals(new Long(125), list.get(1)); + assertEquals(Long.valueOf(75), list.get(0)); + assertEquals(Long.valueOf(125), list.get(1)); /* get listVector value at index 1 -- the value itself is a listvector */ result = listVector.getObject(1); @@ -673,13 +672,13 @@ public void testNestedLargeListVector2() throws Exception { assertEquals(3, resultSet.get(1).size()); /* size of second inner list */ list = resultSet.get(0); - assertEquals(new Long(15), list.get(0)); - assertEquals(new Long(20), list.get(1)); + assertEquals(Long.valueOf(15), list.get(0)); + assertEquals(Long.valueOf(20), list.get(1)); list = resultSet.get(1); - assertEquals(new Long(25), list.get(0)); - assertEquals(new Long(30), list.get(1)); - assertEquals(new Long(35), list.get(2)); + assertEquals(Long.valueOf(25), list.get(0)); + assertEquals(Long.valueOf(30), list.get(1)); + assertEquals(Long.valueOf(35), list.get(2)); /* check underlying bitVector */ assertFalse(listVector.isNull(0)); @@ -723,15 +722,15 @@ public void testGetBufferAddress() throws Exception { Object result = listVector.getObject(0); ArrayList resultSet = (ArrayList) result; assertEquals(3, resultSet.size()); - assertEquals(new Long(50), resultSet.get(0)); - assertEquals(new Long(100), resultSet.get(1)); - assertEquals(new Long(200), resultSet.get(2)); + assertEquals(Long.valueOf(50), resultSet.get(0)); + assertEquals(Long.valueOf(100), resultSet.get(1)); + assertEquals(Long.valueOf(200), resultSet.get(2)); result = listVector.getObject(1); resultSet = (ArrayList) result; assertEquals(2, resultSet.size()); - assertEquals(new Long(250), resultSet.get(0)); - assertEquals(new Long(300), resultSet.get(1)); + assertEquals(Long.valueOf(250), resultSet.get(0)); + assertEquals(Long.valueOf(300), resultSet.get(1)); List buffers = listVector.getFieldBuffers(); @@ -739,7 +738,7 @@ public void testGetBufferAddress() throws Exception { long offsetAddress = listVector.getOffsetBufferAddress(); try { - long dataAddress = listVector.getDataBufferAddress(); + listVector.getDataBufferAddress(); } catch (UnsupportedOperationException ue) { error = true; } finally { @@ -849,11 +848,11 @@ public void testClearAndReuse() { Object result = vector.getObject(0); ArrayList resultSet = (ArrayList) result; - assertEquals(new Long(7), resultSet.get(0)); + assertEquals(Long.valueOf(7), resultSet.get(0)); result = vector.getObject(1); resultSet = (ArrayList) result; - assertEquals(new Long(8), resultSet.get(0)); + assertEquals(Long.valueOf(8), resultSet.get(0)); // Clear and release the buffers to trigger a realloc when adding next value vector.clear(); @@ -869,11 +868,11 @@ public void testClearAndReuse() { result = vector.getObject(0); resultSet = (ArrayList) result; - assertEquals(new Long(7), resultSet.get(0)); + assertEquals(Long.valueOf(7), resultSet.get(0)); result = vector.getObject(1); resultSet = (ArrayList) result; - assertEquals(new Long(8), resultSet.get(0)); + assertEquals(Long.valueOf(8), resultSet.get(0)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeVarBinaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeVarBinaryVector.java index ecababde8de3a..36607903b01a2 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeVarBinaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeVarBinaryVector.java @@ -22,7 +22,9 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; +import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.Objects; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -61,7 +63,7 @@ public void testSetNullableLargeVarBinaryHolder() { String str = "hello"; try (ArrowBuf buf = allocator.buffer(16)) { - buf.setBytes(0, str.getBytes()); + buf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); binHolder.start = 0; binHolder.end = str.length(); @@ -72,7 +74,7 @@ public void testSetNullableLargeVarBinaryHolder() { // verify results assertTrue(vector.isNull(0)); - assertEquals(str, new String(vector.get(1))); + assertEquals(str, new String(Objects.requireNonNull(vector.get(1)), StandardCharsets.UTF_8)); } } } @@ -90,7 +92,7 @@ public void testSetNullableLargeVarBinaryHolderSafe() { String str = "hello world"; try (ArrowBuf buf = allocator.buffer(16)) { - buf.setBytes(0, str.getBytes()); + buf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); binHolder.start = 0; binHolder.end = str.length(); @@ -100,7 +102,7 @@ public void testSetNullableLargeVarBinaryHolderSafe() { vector.setSafe(1, nullHolder); // verify results - assertEquals(str, new String(vector.get(0))); + assertEquals(str, new String(Objects.requireNonNull(vector.get(0)), StandardCharsets.UTF_8)); assertTrue(vector.isNull(1)); } } @@ -113,18 +115,18 @@ public void testGetBytesRepeatedly() { final String str = "hello world"; final String str2 = "foo"; - vector.setSafe(0, str.getBytes()); - vector.setSafe(1, str2.getBytes()); + vector.setSafe(0, str.getBytes(StandardCharsets.UTF_8)); + vector.setSafe(1, str2.getBytes(StandardCharsets.UTF_8)); // verify results ReusableByteArray reusableByteArray = new ReusableByteArray(); vector.read(0, reusableByteArray); byte[] oldBuffer = reusableByteArray.getBuffer(); - assertArrayEquals(str.getBytes(), Arrays.copyOfRange(reusableByteArray.getBuffer(), + assertArrayEquals(str.getBytes(StandardCharsets.UTF_8), Arrays.copyOfRange(reusableByteArray.getBuffer(), 0, (int) reusableByteArray.getLength())); vector.read(1, reusableByteArray); - assertArrayEquals(str2.getBytes(), Arrays.copyOfRange(reusableByteArray.getBuffer(), + assertArrayEquals(str2.getBytes(StandardCharsets.UTF_8), Arrays.copyOfRange(reusableByteArray.getBuffer(), 0, (int) reusableByteArray.getLength())); // There should not have been any reallocation since the newer value is smaller in length. @@ -137,7 +139,7 @@ public void testGetTransferPairWithField() { try (BufferAllocator childAllocator1 = allocator.newChildAllocator("child1", 1000000, 1000000); LargeVarBinaryVector v1 = new LargeVarBinaryVector("v1", childAllocator1)) { v1.allocateNew(); - v1.setSafe(4094, "hello world".getBytes(), 0, 11); + v1.setSafe(4094, "hello world".getBytes(StandardCharsets.UTF_8), 0, 11); v1.setValueCount(4001); TransferPair tp = v1.getTransferPair(v1.getField(), allocator); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeVarCharVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeVarCharVector.java index 7d074c393648f..62d09da86d652 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeVarCharVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeVarCharVector.java @@ -27,6 +27,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -48,12 +49,12 @@ public class TestLargeVarCharVector { - private static final byte[] STR1 = "AAAAA1".getBytes(); - private static final byte[] STR2 = "BBBBBBBBB2".getBytes(); - private static final byte[] STR3 = "CCCC3".getBytes(); - private static final byte[] STR4 = "DDDDDDDD4".getBytes(); - private static final byte[] STR5 = "EEE5".getBytes(); - private static final byte[] STR6 = "FFFFF6".getBytes(); + private static final byte[] STR1 = "AAAAA1".getBytes(StandardCharsets.UTF_8); + private static final byte[] STR2 = "BBBBBBBBB2".getBytes(StandardCharsets.UTF_8); + private static final byte[] STR3 = "CCCC3".getBytes(StandardCharsets.UTF_8); + private static final byte[] STR4 = "DDDDDDDD4".getBytes(StandardCharsets.UTF_8); + private static final byte[] STR5 = "EEE5".getBytes(StandardCharsets.UTF_8); + private static final byte[] STR6 = "FFFFF6".getBytes(StandardCharsets.UTF_8); private BufferAllocator allocator; @@ -74,7 +75,7 @@ public void testTransfer() { LargeVarCharVector v1 = new LargeVarCharVector("v1", childAllocator1); LargeVarCharVector v2 = new LargeVarCharVector("v2", childAllocator2);) { v1.allocateNew(); - v1.setSafe(4094, "hello world".getBytes(), 0, 11); + v1.setSafe(4094, "hello world".getBytes(StandardCharsets.UTF_8), 0, 11); v1.setValueCount(4001); long memoryBeforeTransfer = childAllocator1.getAllocatedMemory(); @@ -207,12 +208,12 @@ public void testSizeOfValueBuffer() { @Test public void testSetLastSetUsage() { - final byte[] STR1 = "AAAAA1".getBytes(); - final byte[] STR2 = "BBBBBBBBB2".getBytes(); - final byte[] STR3 = "CCCC3".getBytes(); - final byte[] STR4 = "DDDDDDDD4".getBytes(); - final byte[] STR5 = "EEE5".getBytes(); - final byte[] STR6 = "FFFFF6".getBytes(); + final byte[] STR1 = "AAAAA1".getBytes(StandardCharsets.UTF_8); + final byte[] STR2 = "BBBBBBBBB2".getBytes(StandardCharsets.UTF_8); + final byte[] STR3 = "CCCC3".getBytes(StandardCharsets.UTF_8); + final byte[] STR4 = "DDDDDDDD4".getBytes(StandardCharsets.UTF_8); + final byte[] STR5 = "EEE5".getBytes(StandardCharsets.UTF_8); + final byte[] STR6 = "FFFFF6".getBytes(StandardCharsets.UTF_8); try (final LargeVarCharVector vector = new LargeVarCharVector("myvector", allocator)) { vector.allocateNew(1024 * 10, 1024); @@ -353,7 +354,7 @@ public void testSplitAndTransfer() { for (int i = 0; i < length; i++) { final boolean expectedSet = ((start + i) % 3) == 0; if (expectedSet) { - final byte[] expectedValue = compareArray[start + i].getBytes(); + final byte[] expectedValue = compareArray[start + i].getBytes(StandardCharsets.UTF_8); assertFalse(newLargeVarCharVector.isNull(i)); assertArrayEquals(expectedValue, newLargeVarCharVector.get(i)); } else { @@ -367,8 +368,8 @@ public void testSplitAndTransfer() { @Test public void testReallocAfterVectorTransfer() { - final byte[] STR1 = "AAAAA1".getBytes(); - final byte[] STR2 = "BBBBBBBBB2".getBytes(); + final byte[] STR1 = "AAAAA1".getBytes(StandardCharsets.UTF_8); + final byte[] STR2 = "BBBBBBBBB2".getBytes(StandardCharsets.UTF_8); try (final LargeVarCharVector vector = new LargeVarCharVector("vector", allocator)) { /* 4096 values with 10 byte per record */ @@ -675,7 +676,7 @@ public void testSetNullableLargeVarCharHolder() { String str = "hello"; ArrowBuf buf = allocator.buffer(16); - buf.setBytes(0, str.getBytes()); + buf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); stringHolder.start = 0; stringHolder.end = str.length(); @@ -686,7 +687,7 @@ public void testSetNullableLargeVarCharHolder() { // verify results assertTrue(vector.isNull(0)); - assertEquals(str, new String(vector.get(1))); + assertEquals(str, new String(Objects.requireNonNull(vector.get(1)), StandardCharsets.UTF_8)); buf.close(); } @@ -705,7 +706,7 @@ public void testSetNullableLargeVarCharHolderSafe() { String str = "hello world"; ArrowBuf buf = allocator.buffer(16); - buf.setBytes(0, str.getBytes()); + buf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); stringHolder.start = 0; stringHolder.end = str.length(); @@ -715,7 +716,7 @@ public void testSetNullableLargeVarCharHolderSafe() { vector.setSafe(1, nullHolder); // verify results - assertEquals(str, new String(vector.get(0))); + assertEquals(str, new String(Objects.requireNonNull(vector.get(0)), StandardCharsets.UTF_8)); assertTrue(vector.isNull(1)); buf.close(); @@ -743,7 +744,7 @@ public void testLargeVariableWidthVectorNullHashCode() { largeVarChVec.allocateNew(100, 1); largeVarChVec.setValueCount(1); - largeVarChVec.set(0, "abc".getBytes()); + largeVarChVec.set(0, "abc".getBytes(StandardCharsets.UTF_8)); largeVarChVec.setNull(0); assertEquals(0, largeVarChVec.hashCode(0)); @@ -756,7 +757,7 @@ public void testUnloadLargeVariableWidthVector() { largeVarCharVector.allocateNew(5, 2); largeVarCharVector.setValueCount(2); - largeVarCharVector.set(0, "abcd".getBytes()); + largeVarCharVector.set(0, "abcd".getBytes(StandardCharsets.UTF_8)); List bufs = largeVarCharVector.getFieldBuffers(); assertEquals(3, bufs.size()); @@ -821,7 +822,7 @@ public void testGetTransferPairWithField() { try (BufferAllocator childAllocator1 = allocator.newChildAllocator("child1", 1000000, 1000000); LargeVarCharVector v1 = new LargeVarCharVector("v1", childAllocator1)) { v1.allocateNew(); - v1.setSafe(4094, "hello world".getBytes(), 0, 11); + v1.setSafe(4094, "hello world".getBytes(StandardCharsets.UTF_8), 0, 11); v1.setValueCount(4001); TransferPair tp = v1.getTransferPair(v1.getField(), allocator); @@ -835,7 +836,7 @@ public void testGetTransferPairWithField() { private void populateLargeVarcharVector(final LargeVarCharVector vector, int valueCount, String[] values) { for (int i = 0; i < valueCount; i += 3) { final String s = String.format("%010d", i); - vector.set(i, s.getBytes()); + vector.set(i, s.getBytes(StandardCharsets.UTF_8)); if (values != null) { values[i] = s; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java index 278f497b47991..97f2d9fd6def1 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java @@ -107,9 +107,9 @@ public void testCopyFrom() throws Exception { Object result = outVector.getObject(0); ArrayList resultSet = (ArrayList) result; assertEquals(3, resultSet.size()); - assertEquals(new Long(1), (Long) resultSet.get(0)); - assertEquals(new Long(2), (Long) resultSet.get(1)); - assertEquals(new Long(3), (Long) resultSet.get(2)); + assertEquals(Long.valueOf(1), resultSet.get(0)); + assertEquals(Long.valueOf(2), resultSet.get(1)); + assertEquals(Long.valueOf(3), resultSet.get(2)); /* index 1 */ result = outVector.getObject(1); @@ -148,7 +148,7 @@ public void testSetLastSetUsage() throws Exception { assertEquals(Integer.toString(-1), Integer.toString(listVector.getLastSet())); int index = 0; - int offset = 0; + int offset; /* write [10, 11, 12] to the list vector at index 0 */ BitVectorHelper.setBit(validityBuffer, index); @@ -227,36 +227,36 @@ public void testSetLastSetUsage() throws Exception { assertEquals(Integer.toString(0), Integer.toString(offset)); Long actual = dataVector.getObject(offset); - assertEquals(new Long(10), actual); + assertEquals(Long.valueOf(10), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(11), actual); + assertEquals(Long.valueOf(11), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(12), actual); + assertEquals(Long.valueOf(12), actual); index++; offset = offsetBuffer.getInt(index * ListVector.OFFSET_WIDTH); assertEquals(Integer.toString(3), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(13), actual); + assertEquals(Long.valueOf(13), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(14), actual); + assertEquals(Long.valueOf(14), actual); index++; offset = offsetBuffer.getInt(index * ListVector.OFFSET_WIDTH); assertEquals(Integer.toString(5), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(15), actual); + assertEquals(Long.valueOf(15), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(16), actual); + assertEquals(Long.valueOf(16), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(17), actual); + assertEquals(Long.valueOf(17), actual); index++; offset = offsetBuffer.getInt(index * ListVector.OFFSET_WIDTH); @@ -328,8 +328,8 @@ public void testSplitAndTransfer() throws Exception { /* check the vector output */ int index = 0; - int offset = 0; - Long actual = null; + int offset; + Long actual; /* index 0 */ assertFalse(listVector.isNull(index)); @@ -337,13 +337,13 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(0), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(10), actual); + assertEquals(Long.valueOf(10), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(11), actual); + assertEquals(Long.valueOf(11), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(12), actual); + assertEquals(Long.valueOf(12), actual); /* index 1 */ index++; @@ -352,10 +352,10 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(3), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(13), actual); + assertEquals(Long.valueOf(13), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(14), actual); + assertEquals(Long.valueOf(14), actual); /* index 2 */ index++; @@ -364,16 +364,16 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(5), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(15), actual); + assertEquals(Long.valueOf(15), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(16), actual); + assertEquals(Long.valueOf(16), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(17), actual); + assertEquals(Long.valueOf(17), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(18), actual); + assertEquals(Long.valueOf(18), actual); /* index 3 */ index++; @@ -382,7 +382,7 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(9), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(19), actual); + assertEquals(Long.valueOf(19), actual); /* index 4 */ index++; @@ -391,16 +391,16 @@ public void testSplitAndTransfer() throws Exception { assertEquals(Integer.toString(10), Integer.toString(offset)); actual = dataVector.getObject(offset); - assertEquals(new Long(20), actual); + assertEquals(Long.valueOf(20), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(21), actual); + assertEquals(Long.valueOf(21), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(22), actual); + assertEquals(Long.valueOf(22), actual); offset++; actual = dataVector.getObject(offset); - assertEquals(new Long(23), actual); + assertEquals(Long.valueOf(23), actual); /* index 5 */ index++; @@ -527,15 +527,15 @@ public void testNestedListVector() throws Exception { assertEquals(4, resultSet.get(1).size()); /* size of second inner list */ list = resultSet.get(0); - assertEquals(new Long(50), list.get(0)); - assertEquals(new Long(100), list.get(1)); - assertEquals(new Long(200), list.get(2)); + assertEquals(Long.valueOf(50), list.get(0)); + assertEquals(Long.valueOf(100), list.get(1)); + assertEquals(Long.valueOf(200), list.get(2)); list = resultSet.get(1); - assertEquals(new Long(75), list.get(0)); - assertEquals(new Long(125), list.get(1)); - assertEquals(new Long(150), list.get(2)); - assertEquals(new Long(175), list.get(3)); + assertEquals(Long.valueOf(75), list.get(0)); + assertEquals(Long.valueOf(125), list.get(1)); + assertEquals(Long.valueOf(150), list.get(2)); + assertEquals(Long.valueOf(175), list.get(3)); /* get listVector value at index 1 -- the value itself is a listvector */ result = listVector.getObject(1); @@ -547,16 +547,16 @@ public void testNestedListVector() throws Exception { assertEquals(3, resultSet.get(2).size()); /* size of third inner list */ list = resultSet.get(0); - assertEquals(new Long(10), list.get(0)); + assertEquals(Long.valueOf(10), list.get(0)); list = resultSet.get(1); - assertEquals(new Long(15), list.get(0)); - assertEquals(new Long(20), list.get(1)); + assertEquals(Long.valueOf(15), list.get(0)); + assertEquals(Long.valueOf(20), list.get(1)); list = resultSet.get(2); - assertEquals(new Long(25), list.get(0)); - assertEquals(new Long(30), list.get(1)); - assertEquals(new Long(35), list.get(2)); + assertEquals(Long.valueOf(25), list.get(0)); + assertEquals(Long.valueOf(30), list.get(1)); + assertEquals(Long.valueOf(35), list.get(2)); /* check underlying bitVector */ assertFalse(listVector.isNull(0)); @@ -661,13 +661,13 @@ public void testNestedListVector2() throws Exception { assertEquals(2, resultSet.get(1).size()); /* size of second inner list */ list = resultSet.get(0); - assertEquals(new Long(50), list.get(0)); - assertEquals(new Long(100), list.get(1)); - assertEquals(new Long(200), list.get(2)); + assertEquals(Long.valueOf(50), list.get(0)); + assertEquals(Long.valueOf(100), list.get(1)); + assertEquals(Long.valueOf(200), list.get(2)); list = resultSet.get(1); - assertEquals(new Long(75), list.get(0)); - assertEquals(new Long(125), list.get(1)); + assertEquals(Long.valueOf(75), list.get(0)); + assertEquals(Long.valueOf(125), list.get(1)); /* get listVector value at index 1 -- the value itself is a listvector */ result = listVector.getObject(1); @@ -678,13 +678,13 @@ public void testNestedListVector2() throws Exception { assertEquals(3, resultSet.get(1).size()); /* size of second inner list */ list = resultSet.get(0); - assertEquals(new Long(15), list.get(0)); - assertEquals(new Long(20), list.get(1)); + assertEquals(Long.valueOf(15), list.get(0)); + assertEquals(Long.valueOf(20), list.get(1)); list = resultSet.get(1); - assertEquals(new Long(25), list.get(0)); - assertEquals(new Long(30), list.get(1)); - assertEquals(new Long(35), list.get(2)); + assertEquals(Long.valueOf(25), list.get(0)); + assertEquals(Long.valueOf(30), list.get(1)); + assertEquals(Long.valueOf(35), list.get(2)); /* check underlying bitVector */ assertFalse(listVector.isNull(0)); @@ -728,15 +728,15 @@ public void testGetBufferAddress() throws Exception { Object result = listVector.getObject(0); ArrayList resultSet = (ArrayList) result; assertEquals(3, resultSet.size()); - assertEquals(new Long(50), resultSet.get(0)); - assertEquals(new Long(100), resultSet.get(1)); - assertEquals(new Long(200), resultSet.get(2)); + assertEquals(Long.valueOf(50), resultSet.get(0)); + assertEquals(Long.valueOf(100), resultSet.get(1)); + assertEquals(Long.valueOf(200), resultSet.get(2)); result = listVector.getObject(1); resultSet = (ArrayList) result; assertEquals(2, resultSet.size()); - assertEquals(new Long(250), resultSet.get(0)); - assertEquals(new Long(300), resultSet.get(1)); + assertEquals(Long.valueOf(250), resultSet.get(0)); + assertEquals(Long.valueOf(300), resultSet.get(1)); List buffers = listVector.getFieldBuffers(); @@ -744,7 +744,7 @@ public void testGetBufferAddress() throws Exception { long offsetAddress = listVector.getOffsetBufferAddress(); try { - long dataAddress = listVector.getDataBufferAddress(); + listVector.getDataBufferAddress(); } catch (UnsupportedOperationException ue) { error = true; } finally { @@ -777,7 +777,7 @@ public void testSetInitialCapacity() { try (final ListVector vector = ListVector.empty("", allocator)) { vector.addOrGetVector(FieldType.nullable(MinorType.INT.getType())); - /** + /* * use the default multiplier of 5, * 512 * 5 => 2560 * 4 => 10240 bytes => 16KB => 4096 value capacity. */ @@ -792,7 +792,7 @@ public void testSetInitialCapacity() { assertEquals(512, vector.getValueCapacity()); assertTrue(vector.getDataVector().getValueCapacity() >= 512 * 4); - /** + /* * inner value capacity we pass to data vector is 512 * 0.1 => 51 * For an int vector this is 204 bytes of memory for data buffer * and 7 bytes for validity buffer. @@ -805,7 +805,7 @@ public void testSetInitialCapacity() { assertEquals(512, vector.getValueCapacity()); assertTrue(vector.getDataVector().getValueCapacity() >= 51); - /** + /* * inner value capacity we pass to data vector is 512 * 0.01 => 5 * For an int vector this is 20 bytes of memory for data buffer * and 1 byte for validity buffer. @@ -818,7 +818,7 @@ public void testSetInitialCapacity() { assertEquals(512, vector.getValueCapacity()); assertTrue(vector.getDataVector().getValueCapacity() >= 5); - /** + /* * inner value capacity we pass to data vector is 5 * 0.1 => 0 * which is then rounded off to 1. So we pass value count as 1 * to the inner int vector. @@ -854,11 +854,11 @@ public void testClearAndReuse() { Object result = vector.getObject(0); ArrayList resultSet = (ArrayList) result; - assertEquals(new Long(7), resultSet.get(0)); + assertEquals(Long.valueOf(7), resultSet.get(0)); result = vector.getObject(1); resultSet = (ArrayList) result; - assertEquals(new Long(8), resultSet.get(0)); + assertEquals(Long.valueOf(8), resultSet.get(0)); // Clear and release the buffers to trigger a realloc when adding next value vector.clear(); @@ -874,11 +874,11 @@ public void testClearAndReuse() { result = vector.getObject(0); resultSet = (ArrayList) result; - assertEquals(new Long(7), resultSet.get(0)); + assertEquals(Long.valueOf(7), resultSet.get(0)); result = vector.getObject(1); resultSet = (ArrayList) result; - assertEquals(new Long(8), resultSet.get(0)); + assertEquals(Long.valueOf(8), resultSet.get(0)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java index 1db55198e4bb3..43f4c3b536fdc 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java @@ -335,8 +335,8 @@ public void testSplitAndTransfer() throws Exception { /* check the vector output */ int index = 0; - int offset = 0; - Map result = null; + int offset; + Map result; /* index 0 */ assertFalse(mapVector.isNull(index)); @@ -571,18 +571,18 @@ public void testMapWithListValue() throws Exception { assertEquals(1L, getResultKey(resultStruct)); ArrayList list = (ArrayList) getResultValue(resultStruct); assertEquals(3, list.size()); // value is a list with 3 elements - assertEquals(new Long(50), list.get(0)); - assertEquals(new Long(100), list.get(1)); - assertEquals(new Long(200), list.get(2)); + assertEquals(Long.valueOf(50), list.get(0)); + assertEquals(Long.valueOf(100), list.get(1)); + assertEquals(Long.valueOf(200), list.get(2)); // Second Map entry resultStruct = (Map) resultSet.get(1); list = (ArrayList) getResultValue(resultStruct); assertEquals(4, list.size()); // value is a list with 4 elements - assertEquals(new Long(75), list.get(0)); - assertEquals(new Long(125), list.get(1)); - assertEquals(new Long(150), list.get(2)); - assertEquals(new Long(175), list.get(3)); + assertEquals(Long.valueOf(75), list.get(0)); + assertEquals(Long.valueOf(125), list.get(1)); + assertEquals(Long.valueOf(150), list.get(2)); + assertEquals(Long.valueOf(175), list.get(3)); // Get mapVector element at index 1 result = mapVector.getObject(1); @@ -593,24 +593,24 @@ public void testMapWithListValue() throws Exception { assertEquals(3L, getResultKey(resultStruct)); list = (ArrayList) getResultValue(resultStruct); assertEquals(1, list.size()); // value is a list with 1 element - assertEquals(new Long(10), list.get(0)); + assertEquals(Long.valueOf(10), list.get(0)); // Second Map entry resultStruct = (Map) resultSet.get(1); assertEquals(4L, getResultKey(resultStruct)); list = (ArrayList) getResultValue(resultStruct); assertEquals(2, list.size()); // value is a list with 1 element - assertEquals(new Long(15), list.get(0)); - assertEquals(new Long(20), list.get(1)); + assertEquals(Long.valueOf(15), list.get(0)); + assertEquals(Long.valueOf(20), list.get(1)); // Third Map entry resultStruct = (Map) resultSet.get(2); assertEquals(5L, getResultKey(resultStruct)); list = (ArrayList) getResultValue(resultStruct); assertEquals(3, list.size()); // value is a list with 1 element - assertEquals(new Long(25), list.get(0)); - assertEquals(new Long(30), list.get(1)); - assertEquals(new Long(35), list.get(2)); + assertEquals(Long.valueOf(25), list.get(0)); + assertEquals(Long.valueOf(30), list.get(1)); + assertEquals(Long.valueOf(35), list.get(2)); /* check underlying bitVector */ assertFalse(mapVector.isNull(0)); @@ -1012,8 +1012,8 @@ public void testMapWithMapKeyAndMapValue() throws Exception { final ArrowBuf offsetBuffer = mapVector.getOffsetBuffer(); /* mapVector has 2 entries at index 0 and 4 entries at index 1 */ - assertEquals(0, offsetBuffer.getInt(0 * MapVector.OFFSET_WIDTH)); - assertEquals(2, offsetBuffer.getInt(1 * MapVector.OFFSET_WIDTH)); + assertEquals(0, offsetBuffer.getInt(0)); + assertEquals(2, offsetBuffer.getInt(MapVector.OFFSET_WIDTH)); assertEquals(6, offsetBuffer.getInt(2 * MapVector.OFFSET_WIDTH)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java b/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java index 716fa0bde454d..3580a321f01c9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; @@ -57,7 +58,7 @@ public void terminate() throws Exception { private void populateVarcharVector(final VarCharVector vector, int valueCount, String[] compareArray) { for (int i = 0; i < valueCount; i += 3) { final String s = String.format("%010d", i); - vector.set(i, s.getBytes()); + vector.set(i, s.getBytes(StandardCharsets.UTF_8)); if (compareArray != null) { compareArray[i] = s; } @@ -86,7 +87,7 @@ public void test() throws Exception { for (int i = 0; i < length; i++) { final boolean expectedSet = ((start + i) % 3) == 0; if (expectedSet) { - final byte[] expectedValue = compareArray[start + i].getBytes(); + final byte[] expectedValue = compareArray[start + i].getBytes(StandardCharsets.UTF_8); assertFalse(newVarCharVector.isNull(i)); assertArrayEquals(expectedValue, newVarCharVector.get(i)); } else { @@ -141,7 +142,7 @@ public void testTransfer() { for (int i = 0; i < valueCount; i++) { final boolean expectedSet = (i % 3) == 0; if (expectedSet) { - final byte[] expectedValue = compareArray[i].getBytes(); + final byte[] expectedValue = compareArray[i].getBytes(StandardCharsets.UTF_8); assertFalse(newVarCharVector.isNull(i)); assertArrayEquals(expectedValue, newVarCharVector.get(i)); } else { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java index b53171a597681..1b0387feb73ff 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java @@ -404,16 +404,16 @@ public void testGetFieldTypeInfo() throws Exception { assertTrue(vector.getField().equals(field)); // Union has 2 child vectors - assertEquals(vector.size(), 2); + assertEquals(2, vector.size()); // Check child field 0 VectorWithOrdinal intChild = vector.getChildVectorWithOrdinal("int"); - assertEquals(intChild.ordinal, 0); + assertEquals(0, intChild.ordinal); assertEquals(intChild.vector.getField(), children.get(0)); // Check child field 1 VectorWithOrdinal varcharChild = vector.getChildVectorWithOrdinal("varchar"); - assertEquals(varcharChild.ordinal, 1); + assertEquals(1, varcharChild.ordinal); assertEquals(varcharChild.vector.getField(), children.get(1)); } @@ -455,7 +455,7 @@ public void testGetBufferAddress() throws Exception { try { - long offsetAddress = vector.getOffsetBufferAddress(); + vector.getOffsetBufferAddress(); } catch (UnsupportedOperationException ue) { error = true; } finally { @@ -464,7 +464,7 @@ public void testGetBufferAddress() throws Exception { } try { - long dataAddress = vector.getDataBufferAddress(); + vector.getDataBufferAddress(); } catch (UnsupportedOperationException ue) { error = true; } finally { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java index fb96870804441..614aff18d4554 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java @@ -83,7 +83,7 @@ public void init() { allocator = new RootAllocator(Long.MAX_VALUE); } - private static final Charset utf8Charset = Charset.forName("UTF-8"); + private static final Charset utf8Charset = StandardCharsets.UTF_8; private static final byte[] STR1 = "AAAAA1".getBytes(utf8Charset); private static final byte[] STR2 = "BBBBBBBBB2".getBytes(utf8Charset); private static final byte[] STR3 = "CCCC3".getBytes(utf8Charset); @@ -127,10 +127,9 @@ public void testFixedType1() { try (final UInt4Vector vector = new UInt4Vector(EMPTY_SCHEMA_PATH, allocator)) { boolean error = false; - int initialCapacity = 0; vector.allocateNew(1024); - initialCapacity = vector.getValueCapacity(); + int initialCapacity = vector.getValueCapacity(); assertTrue(initialCapacity >= 1024); // Put and set a few values @@ -562,8 +561,6 @@ public void testNullableFixedType1() { assertEquals(103, vector.get(initialCapacity - 2)); assertEquals(104, vector.get(initialCapacity - 1)); - int val = 0; - /* check unset bits/null values */ for (int i = 2, j = 101; i <= 99 || j <= initialCapacity - 3; i++, j++) { if (i <= 99) { @@ -606,8 +603,6 @@ public void testNullableFixedType1() { assertEquals(104, vector.get(initialCapacity - 1)); assertEquals(10000, vector.get(initialCapacity)); - val = 0; - /* check unset bits/null values */ for (int i = 2, j = 101; i < 99 || j < initialCapacity - 3; i++, j++) { if (i <= 99) { @@ -735,7 +730,6 @@ public void testNullableFixedType2() { public void testNullableFixedType3() { // Create a new value vector for 1024 integers try (final IntVector vector = newVector(IntVector.class, EMPTY_SCHEMA_PATH, MinorType.INT, allocator)) { - boolean error = false; int initialCapacity = 1024; /* no memory allocation has happened yet so capacity of underlying buffer should be 0 */ @@ -765,7 +759,6 @@ public void testNullableFixedType3() { } vector.setValueCount(1024); - Field field = vector.getField(); List buffers = vector.getFieldBuffers(); @@ -1105,7 +1098,6 @@ public void testNullableVarType1() { assertEquals(txt, vector.getObject(7)); // Ensure null value throws. - boolean b = false; assertNull(vector.get(8)); } } @@ -1182,18 +1174,18 @@ public void testGetBytesRepeatedly() { final String str = "hello world"; final String str2 = "foo"; - vector.setSafe(0, str.getBytes()); - vector.setSafe(1, str2.getBytes()); + vector.setSafe(0, str.getBytes(StandardCharsets.UTF_8)); + vector.setSafe(1, str2.getBytes(StandardCharsets.UTF_8)); // verify results ReusableByteArray reusableByteArray = new ReusableByteArray(); vector.read(0, reusableByteArray); - assertArrayEquals(str.getBytes(), Arrays.copyOfRange(reusableByteArray.getBuffer(), + assertArrayEquals(str.getBytes(StandardCharsets.UTF_8), Arrays.copyOfRange(reusableByteArray.getBuffer(), 0, (int) reusableByteArray.getLength())); byte[] oldBuffer = reusableByteArray.getBuffer(); vector.read(1, reusableByteArray); - assertArrayEquals(str2.getBytes(), Arrays.copyOfRange(reusableByteArray.getBuffer(), + assertArrayEquals(str2.getBytes(StandardCharsets.UTF_8), Arrays.copyOfRange(reusableByteArray.getBuffer(), 0, (int) reusableByteArray.getLength())); // There should not have been any reallocation since the newer value is smaller in length. @@ -1219,7 +1211,6 @@ public void testGetBytesRepeatedly() { public void testReallocAfterVectorTransfer1() { try (final Float8Vector vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator)) { int initialCapacity = 4096; - boolean error = false; /* use the default capacity; 4096*8 => 32KB */ vector.setInitialCapacity(initialCapacity); @@ -1259,7 +1250,7 @@ public void testReallocAfterVectorTransfer1() { } /* this should trigger a realloc */ - vector.setSafe(capacityAfterRealloc1, baseValue + (double) (capacityAfterRealloc1)); + vector.setSafe(capacityAfterRealloc1, baseValue + (double) capacityAfterRealloc1); assertTrue(vector.getValueCapacity() >= initialCapacity * 4); int capacityAfterRealloc2 = vector.getValueCapacity(); @@ -1301,7 +1292,6 @@ public void testReallocAfterVectorTransfer1() { public void testReallocAfterVectorTransfer2() { try (final Float8Vector vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator)) { int initialCapacity = 4096; - boolean error = false; vector.allocateNew(initialCapacity); assertTrue(vector.getValueCapacity() >= initialCapacity); @@ -1338,7 +1328,7 @@ public void testReallocAfterVectorTransfer2() { } /* this should trigger a realloc */ - vector.setSafe(capacityAfterRealloc1, baseValue + (double) (capacityAfterRealloc1)); + vector.setSafe(capacityAfterRealloc1, baseValue + (double) capacityAfterRealloc1); assertTrue(vector.getValueCapacity() >= initialCapacity * 4); int capacityAfterRealloc2 = vector.getValueCapacity(); @@ -1494,7 +1484,6 @@ public void testReallocAfterVectorTransfer4() { assertTrue(valueCapacity >= 4096); /* populate the vector */ - int baseValue = 1000; for (int i = 0; i < valueCapacity; i++) { if ((i & 1) == 0) { vector.set(i, 1000 + i); @@ -1649,7 +1638,7 @@ public void testFillEmptiesNotOverfill() { int initialCapacity = vector.getValueCapacity(); assertTrue(initialCapacity >= 4095); - vector.setSafe(4094, "hello".getBytes(), 0, 5); + vector.setSafe(4094, "hello".getBytes(StandardCharsets.UTF_8), 0, 5); /* the above set method should NOT have triggered a realloc */ assertEquals(initialCapacity, vector.getValueCapacity()); @@ -1663,7 +1652,7 @@ public void testFillEmptiesNotOverfill() { @Test public void testSetSafeWithArrowBufNoExcessAllocs() { final int numValues = BaseFixedWidthVector.INITIAL_VALUE_ALLOCATION * 2; - final byte[] valueBytes = "hello world".getBytes(); + final byte[] valueBytes = "hello world".getBytes(StandardCharsets.UTF_8); final int valueBytesLength = valueBytes.length; final int isSet = 1; @@ -1720,7 +1709,7 @@ public void testCopyFromWithNulls() { if (i % 3 == 0) { continue; } - byte[] b = Integer.toString(i).getBytes(); + byte[] b = Integer.toString(i).getBytes(StandardCharsets.UTF_8); vector.setSafe(i, b, 0, b.length); } @@ -1781,7 +1770,7 @@ public void testCopyFromWithNulls1() { if (i % 3 == 0) { continue; } - byte[] b = Integer.toString(i).getBytes(); + byte[] b = Integer.toString(i).getBytes(StandardCharsets.UTF_8); vector.setSafe(i, b, 0, b.length); } @@ -2137,7 +2126,7 @@ public void testGetBufferAddress2() { long dataAddress = vector.getDataBufferAddress(); try { - long offsetAddress = vector.getOffsetBufferAddress(); + vector.getOffsetBufferAddress(); } catch (UnsupportedOperationException ue) { error = true; } finally { @@ -2275,7 +2264,7 @@ public void testSetNullableVarCharHolder() { String str = "hello"; ArrowBuf buf = allocator.buffer(16); - buf.setBytes(0, str.getBytes()); + buf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); stringHolder.start = 0; stringHolder.end = str.length(); @@ -2286,7 +2275,7 @@ public void testSetNullableVarCharHolder() { // verify results assertTrue(vector.isNull(0)); - assertEquals(str, new String(vector.get(1))); + assertEquals(str, new String(vector.get(1), StandardCharsets.UTF_8)); buf.close(); } @@ -2305,7 +2294,7 @@ public void testSetNullableVarCharHolderSafe() { String str = "hello world"; ArrowBuf buf = allocator.buffer(16); - buf.setBytes(0, str.getBytes()); + buf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); stringHolder.start = 0; stringHolder.end = str.length(); @@ -2315,7 +2304,7 @@ public void testSetNullableVarCharHolderSafe() { vector.setSafe(1, nullHolder); // verify results - assertEquals(str, new String(vector.get(0))); + assertEquals(str, new String(vector.get(0), StandardCharsets.UTF_8)); assertTrue(vector.isNull(1)); buf.close(); @@ -2335,7 +2324,7 @@ public void testSetNullableVarBinaryHolder() { String str = "hello"; ArrowBuf buf = allocator.buffer(16); - buf.setBytes(0, str.getBytes()); + buf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); binHolder.start = 0; binHolder.end = str.length(); @@ -2346,7 +2335,7 @@ public void testSetNullableVarBinaryHolder() { // verify results assertTrue(vector.isNull(0)); - assertEquals(str, new String(vector.get(1))); + assertEquals(str, new String(vector.get(1), StandardCharsets.UTF_8)); buf.close(); } @@ -2365,7 +2354,7 @@ public void testSetNullableVarBinaryHolderSafe() { String str = "hello world"; ArrowBuf buf = allocator.buffer(16); - buf.setBytes(0, str.getBytes()); + buf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); binHolder.start = 0; binHolder.end = str.length(); @@ -2375,7 +2364,7 @@ public void testSetNullableVarBinaryHolderSafe() { vector.setSafe(1, nullHolder); // verify results - assertEquals(str, new String(vector.get(0))); + assertEquals(str, new String(vector.get(0), StandardCharsets.UTF_8)); assertTrue(vector.isNull(1)); buf.close(); @@ -2431,8 +2420,8 @@ public void testGetPointerVariableWidth() { for (int i = 0; i < sampleData.length; i++) { String str = sampleData[i]; if (str != null) { - vec1.set(i, sampleData[i].getBytes()); - vec2.set(i, sampleData[i].getBytes()); + vec1.set(i, sampleData[i].getBytes(StandardCharsets.UTF_8)); + vec2.set(i, sampleData[i].getBytes(StandardCharsets.UTF_8)); } else { vec1.setNull(i); vec2.setNull(i); @@ -2827,7 +2816,7 @@ public void testVariableWidthVectorNullHashCode() { varChVec.allocateNew(100, 1); varChVec.setValueCount(1); - varChVec.set(0, "abc".getBytes()); + varChVec.set(0, "abc".getBytes(StandardCharsets.UTF_8)); varChVec.setNull(0); assertEquals(0, varChVec.hashCode(0)); @@ -2945,7 +2934,7 @@ public void testUnloadVariableWidthVector() { varCharVector.allocateNew(5, 2); varCharVector.setValueCount(2); - varCharVector.set(0, "abcd".getBytes()); + varCharVector.set(0, "abcd".getBytes(StandardCharsets.UTF_8)); List bufs = varCharVector.getFieldBuffers(); assertEquals(3, bufs.size()); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharListVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharListVector.java index a9b155499f773..bfe489fa5af4e 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharListVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharListVector.java @@ -17,6 +17,8 @@ package org.apache.arrow.vector; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.complex.ListVector; @@ -44,7 +46,7 @@ public void terminate() throws Exception { @Test public void testVarCharListWithNulls() { - byte[] bytes = "a".getBytes(); + byte[] bytes = "a".getBytes(StandardCharsets.UTF_8); try (ListVector vector = new ListVector("VarList", allocator, FieldType.nullable(Types .MinorType.VARCHAR.getType()), null); ArrowBuf tempBuf = allocator.buffer(bytes.length)) { @@ -63,15 +65,15 @@ public void testVarCharListWithNulls() { writer.setPosition(2); writer.startList(); - bytes = "b".getBytes(); + bytes = "b".getBytes(StandardCharsets.UTF_8); tempBuf.setBytes(0, bytes); writer.writeVarChar(0, bytes.length, tempBuf); writer.endList(); writer.setValueCount(2); - Assert.assertTrue(vector.getValueCount() == 2); - Assert.assertTrue(vector.getDataVector().getValueCount() == 2); + Assert.assertEquals(2, vector.getValueCount()); + Assert.assertEquals(2, vector.getDataVector().getValueCount()); } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java index dfc75ec8e34cf..b96f6ab6afedd 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java @@ -143,7 +143,7 @@ public void testFixedWidthVectorAllocation() { totalCapacity = vec2.getValidityBuffer().capacity() + vec2.getDataBuffer().capacity(); // the total capacity must be a power of two - assertEquals(totalCapacity & (totalCapacity - 1), 0); + assertEquals(0, totalCapacity & (totalCapacity - 1)); } } @@ -163,7 +163,7 @@ public void testVariableWidthVectorAllocation() { totalCapacity = vec2.getValidityBuffer().capacity() + vec2.getOffsetBuffer().capacity(); // the total capacity must be a power of two - assertEquals(totalCapacity & (totalCapacity - 1), 0); + assertEquals(0, totalCapacity & (totalCapacity - 1)); } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java index 7d5701ddb765b..9043bd4f8f2d4 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java @@ -324,12 +324,12 @@ public void testVariableRepeatedClearAndSet() throws Exception { vector.allocateNewSafe(); // Initial allocation vector.clear(); // clear vector. - vector.setSafe(0, "hello world".getBytes()); + vector.setSafe(0, "hello world".getBytes(StandardCharsets.UTF_8)); int savedValueCapacity = vector.getValueCapacity(); for (int i = 0; i < 1024; ++i) { vector.clear(); // clear vector. - vector.setSafe(0, "hello world".getBytes()); + vector.setSafe(0, "hello world".getBytes(StandardCharsets.UTF_8)); } // should be deterministic, and not cause a run-away increase in capacity. diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorSchemaRoot.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorSchemaRoot.java index ce3fb2cdf0ea1..207962eb45b85 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorSchemaRoot.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorSchemaRoot.java @@ -61,7 +61,7 @@ public void testResetRowCount() { VectorSchemaRoot vsr = VectorSchemaRoot.of(vec1, vec2); vsr.allocateNew(); - assertEquals(vsr.getRowCount(), 0); + assertEquals(0, vsr.getRowCount()); for (int i = 0; i < size; i++) { vec1.setSafe(i, i % 2); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestTypeEqualsVisitor.java b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestTypeEqualsVisitor.java index c0a3bd89dc18c..62fa0336ea925 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestTypeEqualsVisitor.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestTypeEqualsVisitor.java @@ -20,7 +20,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import java.nio.charset.Charset; import java.util.HashMap; import java.util.Map; @@ -52,11 +51,6 @@ public void init() { allocator = new RootAllocator(Long.MAX_VALUE); } - private static final Charset utf8Charset = Charset.forName("UTF-8"); - private static final byte[] STR1 = "AAAAA1".getBytes(utf8Charset); - private static final byte[] STR2 = "BBBBBBBBB2".getBytes(utf8Charset); - private static final byte[] STR3 = "CCCC3".getBytes(utf8Charset); - @After public void terminate() throws Exception { allocator.close(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 4c8c96a0d74d3..b7fc681c16118 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -24,6 +24,8 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.Objects; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -413,8 +415,8 @@ public void testPromoteLargeVarCharHelpersOnStruct() throws Exception { writer.end(); final LargeVarCharVector uv = v.getChild("c", LargeVarCharVector.class); - assertEquals("foo", uv.getObject(0).toString()); - assertEquals("foo2", uv.getObject(1).toString()); + assertEquals("foo", Objects.requireNonNull(uv.getObject(0)).toString()); + assertEquals("foo2", Objects.requireNonNull(uv.getObject(1)).toString()); } } @@ -433,8 +435,8 @@ public void testPromoteVarCharHelpersOnStruct() throws Exception { writer.end(); final VarCharVector uv = v.getChild("c", VarCharVector.class); - assertEquals("foo", uv.getObject(0).toString()); - assertEquals("foo2", uv.getObject(1).toString()); + assertEquals("foo", Objects.requireNonNull(uv.getObject(0)).toString()); + assertEquals("foo2", Objects.requireNonNull(uv.getObject(1)).toString()); } } @@ -455,8 +457,8 @@ public void testPromoteVarCharHelpersDirect() throws Exception { // The "test" vector in the parent container should have been replaced with a UnionVector. UnionVector promotedVector = container.getChild("test", UnionVector.class); VarCharVector vector = promotedVector.getVarCharVector(); - assertEquals("foo", vector.getObject(0).toString()); - assertEquals("foo2", vector.getObject(1).toString()); + assertEquals("foo", Objects.requireNonNull(vector.getObject(0)).toString()); + assertEquals("foo2", Objects.requireNonNull(vector.getObject(1)).toString()); } } @@ -477,8 +479,8 @@ public void testPromoteLargeVarCharHelpersDirect() throws Exception { // The "test" vector in the parent container should have been replaced with a UnionVector. UnionVector promotedVector = container.getChild("test", UnionVector.class); LargeVarCharVector vector = promotedVector.getLargeVarCharVector(); - assertEquals("foo", vector.getObject(0).toString()); - assertEquals("foo2", vector.getObject(1).toString()); + assertEquals("foo", Objects.requireNonNull(vector.getObject(0)).toString()); + assertEquals("foo2", Objects.requireNonNull(vector.getObject(1)).toString()); } } @@ -491,20 +493,22 @@ public void testPromoteVarBinaryHelpersOnStruct() throws Exception { writer.start(); writer.setPosition(0); - writer.varBinary("c").writeVarBinary("row1".getBytes()); + writer.varBinary("c").writeVarBinary("row1".getBytes(StandardCharsets.UTF_8)); writer.setPosition(1); - writer.varBinary("c").writeVarBinary("row2".getBytes(), 0, "row2".getBytes().length); + writer.varBinary("c").writeVarBinary("row2".getBytes(StandardCharsets.UTF_8), 0, + "row2".getBytes(StandardCharsets.UTF_8).length); writer.setPosition(2); - writer.varBinary("c").writeVarBinary(ByteBuffer.wrap("row3".getBytes())); + writer.varBinary("c").writeVarBinary(ByteBuffer.wrap("row3".getBytes(StandardCharsets.UTF_8))); writer.setPosition(3); - writer.varBinary("c").writeVarBinary(ByteBuffer.wrap("row4".getBytes()), 0, "row4".getBytes().length); + writer.varBinary("c").writeVarBinary(ByteBuffer.wrap("row4".getBytes(StandardCharsets.UTF_8)), 0, + "row4".getBytes(StandardCharsets.UTF_8).length); writer.end(); final VarBinaryVector uv = v.getChild("c", VarBinaryVector.class); - assertEquals("row1", new String(uv.get(0))); - assertEquals("row2", new String(uv.get(1))); - assertEquals("row3", new String(uv.get(2))); - assertEquals("row4", new String(uv.get(3))); + assertEquals("row1", new String(Objects.requireNonNull(uv.get(0)), StandardCharsets.UTF_8)); + assertEquals("row2", new String(Objects.requireNonNull(uv.get(1)), StandardCharsets.UTF_8)); + assertEquals("row3", new String(Objects.requireNonNull(uv.get(2)), StandardCharsets.UTF_8)); + assertEquals("row4", new String(Objects.requireNonNull(uv.get(3)), StandardCharsets.UTF_8)); } } @@ -517,22 +521,24 @@ public void testPromoteVarBinaryHelpersDirect() throws Exception { writer.start(); writer.setPosition(0); - writer.writeVarBinary("row1".getBytes()); + writer.writeVarBinary("row1".getBytes(StandardCharsets.UTF_8)); writer.setPosition(1); - writer.writeVarBinary("row2".getBytes(), 0, "row2".getBytes().length); + writer.writeVarBinary("row2".getBytes(StandardCharsets.UTF_8), 0, + "row2".getBytes(StandardCharsets.UTF_8).length); writer.setPosition(2); - writer.writeVarBinary(ByteBuffer.wrap("row3".getBytes())); + writer.writeVarBinary(ByteBuffer.wrap("row3".getBytes(StandardCharsets.UTF_8))); writer.setPosition(3); - writer.writeVarBinary(ByteBuffer.wrap("row4".getBytes()), 0, "row4".getBytes().length); + writer.writeVarBinary(ByteBuffer.wrap("row4".getBytes(StandardCharsets.UTF_8)), 0, + "row4".getBytes(StandardCharsets.UTF_8).length); writer.end(); // The "test" vector in the parent container should have been replaced with a UnionVector. UnionVector promotedVector = container.getChild("test", UnionVector.class); VarBinaryVector uv = promotedVector.getVarBinaryVector(); - assertEquals("row1", new String(uv.get(0))); - assertEquals("row2", new String(uv.get(1))); - assertEquals("row3", new String(uv.get(2))); - assertEquals("row4", new String(uv.get(3))); + assertEquals("row1", new String(Objects.requireNonNull(uv.get(0)), StandardCharsets.UTF_8)); + assertEquals("row2", new String(Objects.requireNonNull(uv.get(1)), StandardCharsets.UTF_8)); + assertEquals("row3", new String(Objects.requireNonNull(uv.get(2)), StandardCharsets.UTF_8)); + assertEquals("row4", new String(Objects.requireNonNull(uv.get(3)), StandardCharsets.UTF_8)); } } @@ -545,20 +551,22 @@ public void testPromoteLargeVarBinaryHelpersOnStruct() throws Exception { writer.start(); writer.setPosition(0); - writer.largeVarBinary("c").writeLargeVarBinary("row1".getBytes()); + writer.largeVarBinary("c").writeLargeVarBinary("row1".getBytes(StandardCharsets.UTF_8)); writer.setPosition(1); - writer.largeVarBinary("c").writeLargeVarBinary("row2".getBytes(), 0, "row2".getBytes().length); + writer.largeVarBinary("c").writeLargeVarBinary("row2".getBytes(StandardCharsets.UTF_8), 0, + "row2".getBytes(StandardCharsets.UTF_8).length); writer.setPosition(2); - writer.largeVarBinary("c").writeLargeVarBinary(ByteBuffer.wrap("row3".getBytes())); + writer.largeVarBinary("c").writeLargeVarBinary(ByteBuffer.wrap("row3".getBytes(StandardCharsets.UTF_8))); writer.setPosition(3); - writer.largeVarBinary("c").writeLargeVarBinary(ByteBuffer.wrap("row4".getBytes()), 0, "row4".getBytes().length); + writer.largeVarBinary("c").writeLargeVarBinary(ByteBuffer.wrap("row4".getBytes(StandardCharsets.UTF_8)), 0, + "row4".getBytes(StandardCharsets.UTF_8).length); writer.end(); final LargeVarBinaryVector uv = v.getChild("c", LargeVarBinaryVector.class); - assertEquals("row1", new String(uv.get(0))); - assertEquals("row2", new String(uv.get(1))); - assertEquals("row3", new String(uv.get(2))); - assertEquals("row4", new String(uv.get(3))); + assertEquals("row1", new String(Objects.requireNonNull(uv.get(0)), StandardCharsets.UTF_8)); + assertEquals("row2", new String(Objects.requireNonNull(uv.get(1)), StandardCharsets.UTF_8)); + assertEquals("row3", new String(Objects.requireNonNull(uv.get(2)), StandardCharsets.UTF_8)); + assertEquals("row4", new String(Objects.requireNonNull(uv.get(3)), StandardCharsets.UTF_8)); } } @@ -571,22 +579,24 @@ public void testPromoteLargeVarBinaryHelpersDirect() throws Exception { writer.start(); writer.setPosition(0); - writer.writeLargeVarBinary("row1".getBytes()); + writer.writeLargeVarBinary("row1".getBytes(StandardCharsets.UTF_8)); writer.setPosition(1); - writer.writeLargeVarBinary("row2".getBytes(), 0, "row2".getBytes().length); + writer.writeLargeVarBinary("row2".getBytes(StandardCharsets.UTF_8), 0, + "row2".getBytes(StandardCharsets.UTF_8).length); writer.setPosition(2); - writer.writeLargeVarBinary(ByteBuffer.wrap("row3".getBytes())); + writer.writeLargeVarBinary(ByteBuffer.wrap("row3".getBytes(StandardCharsets.UTF_8))); writer.setPosition(3); - writer.writeLargeVarBinary(ByteBuffer.wrap("row4".getBytes()), 0, "row4".getBytes().length); + writer.writeLargeVarBinary(ByteBuffer.wrap("row4".getBytes(StandardCharsets.UTF_8)), 0, + "row4".getBytes(StandardCharsets.UTF_8).length); writer.end(); // The "test" vector in the parent container should have been replaced with a UnionVector. UnionVector promotedVector = container.getChild("test", UnionVector.class); LargeVarBinaryVector uv = promotedVector.getLargeVarBinaryVector(); - assertEquals("row1", new String(uv.get(0))); - assertEquals("row2", new String(uv.get(1))); - assertEquals("row3", new String(uv.get(2))); - assertEquals("row4", new String(uv.get(3))); + assertEquals("row1", new String(Objects.requireNonNull(uv.get(0)), StandardCharsets.UTF_8)); + assertEquals("row2", new String(Objects.requireNonNull(uv.get(1)), StandardCharsets.UTF_8)); + assertEquals("row3", new String(Objects.requireNonNull(uv.get(2)), StandardCharsets.UTF_8)); + assertEquals("row4", new String(Objects.requireNonNull(uv.get(3)), StandardCharsets.UTF_8)); } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index e03ce0c056bf1..19f0ea9d4e392 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -21,6 +21,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.HashSet; @@ -128,9 +129,7 @@ public void simpleNestedTypes() { @Test public void transferPairSchemaChange() { SchemaChangeCallBack callBack1 = new SchemaChangeCallBack(); - SchemaChangeCallBack callBack2 = new SchemaChangeCallBack(); try (NonNullableStructVector parent = populateStructVector(callBack1)) { - TransferPair tp = parent.getTransferPair("newVector", allocator, callBack2); ComplexWriter writer = new ComplexWriterImpl("newWriter", parent); StructWriter rootWriter = writer.rootAsStruct(); @@ -818,7 +817,7 @@ public void promotableWriter() { for (int i = 100; i < 200; i++) { VarCharWriter varCharWriter = rootWriter.varChar("a"); varCharWriter.setPosition(i); - byte[] bytes = Integer.toString(i).getBytes(); + byte[] bytes = Integer.toString(i).getBytes(StandardCharsets.UTF_8); ArrowBuf tempBuf = allocator.buffer(bytes.length); tempBuf.setBytes(0, bytes); varCharWriter.writeVarChar(0, bytes.length, tempBuf); @@ -1719,21 +1718,23 @@ public void structWriterVarBinaryHelpers() { StructWriter rootWriter = writer.rootAsStruct(); rootWriter.start(); rootWriter.setPosition(0); - rootWriter.varBinary("c").writeVarBinary("row1".getBytes()); + rootWriter.varBinary("c").writeVarBinary("row1".getBytes(StandardCharsets.UTF_8)); rootWriter.setPosition(1); - rootWriter.varBinary("c").writeVarBinary("row2".getBytes(), 0, "row2".getBytes().length); + rootWriter.varBinary("c").writeVarBinary("row2".getBytes(StandardCharsets.UTF_8), 0, + "row2".getBytes(StandardCharsets.UTF_8).length); rootWriter.setPosition(2); - rootWriter.varBinary("c").writeVarBinary(ByteBuffer.wrap("row3".getBytes())); + rootWriter.varBinary("c").writeVarBinary(ByteBuffer.wrap("row3".getBytes(StandardCharsets.UTF_8))); rootWriter.setPosition(3); - rootWriter.varBinary("c").writeVarBinary(ByteBuffer.wrap("row4".getBytes()), 0, "row4".getBytes().length); + rootWriter.varBinary("c").writeVarBinary(ByteBuffer.wrap( + "row4".getBytes(StandardCharsets.UTF_8)), 0, "row4".getBytes(StandardCharsets.UTF_8).length); rootWriter.end(); VarBinaryVector uv = parent.getChild("root", StructVector.class).getChild("c", VarBinaryVector.class); - assertEquals("row1", new String(uv.get(0))); - assertEquals("row2", new String(uv.get(1))); - assertEquals("row3", new String(uv.get(2))); - assertEquals("row4", new String(uv.get(3))); + assertEquals("row1", new String(uv.get(0), StandardCharsets.UTF_8)); + assertEquals("row2", new String(uv.get(1), StandardCharsets.UTF_8)); + assertEquals("row3", new String(uv.get(2), StandardCharsets.UTF_8)); + assertEquals("row4", new String(uv.get(3), StandardCharsets.UTF_8)); } } @@ -1744,23 +1745,24 @@ public void structWriterLargeVarBinaryHelpers() { StructWriter rootWriter = writer.rootAsStruct(); rootWriter.start(); rootWriter.setPosition(0); - rootWriter.largeVarBinary("c").writeLargeVarBinary("row1".getBytes()); + rootWriter.largeVarBinary("c").writeLargeVarBinary("row1".getBytes(StandardCharsets.UTF_8)); rootWriter.setPosition(1); - rootWriter.largeVarBinary("c").writeLargeVarBinary("row2".getBytes(), 0, "row2".getBytes().length); + rootWriter.largeVarBinary("c").writeLargeVarBinary("row2".getBytes(StandardCharsets.UTF_8), 0, + "row2".getBytes(StandardCharsets.UTF_8).length); rootWriter.setPosition(2); - rootWriter.largeVarBinary("c").writeLargeVarBinary(ByteBuffer.wrap("row3".getBytes())); + rootWriter.largeVarBinary("c").writeLargeVarBinary(ByteBuffer.wrap("row3".getBytes(StandardCharsets.UTF_8))); rootWriter.setPosition(3); - rootWriter.largeVarBinary("c").writeLargeVarBinary(ByteBuffer.wrap("row4".getBytes()), 0, - "row4".getBytes().length); + rootWriter.largeVarBinary("c").writeLargeVarBinary(ByteBuffer.wrap( + "row4".getBytes(StandardCharsets.UTF_8)), 0, "row4".getBytes(StandardCharsets.UTF_8).length); rootWriter.end(); LargeVarBinaryVector uv = parent.getChild("root", StructVector.class).getChild("c", LargeVarBinaryVector.class); - assertEquals("row1", new String(uv.get(0))); - assertEquals("row2", new String(uv.get(1))); - assertEquals("row3", new String(uv.get(2))); - assertEquals("row4", new String(uv.get(3))); + assertEquals("row1", new String(uv.get(0), StandardCharsets.UTF_8)); + assertEquals("row2", new String(uv.get(1), StandardCharsets.UTF_8)); + assertEquals("row3", new String(uv.get(2), StandardCharsets.UTF_8)); + assertEquals("row4", new String(uv.get(3), StandardCharsets.UTF_8)); } } @@ -1800,16 +1802,18 @@ public void listVarBinaryHelpers() { listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); listWriter.startList(); - listWriter.writeVarBinary("row1".getBytes()); - listWriter.writeVarBinary("row2".getBytes(), 0, "row2".getBytes().length); - listWriter.writeVarBinary(ByteBuffer.wrap("row3".getBytes())); - listWriter.writeVarBinary(ByteBuffer.wrap("row4".getBytes()), 0, "row4".getBytes().length); + listWriter.writeVarBinary("row1".getBytes(StandardCharsets.UTF_8)); + listWriter.writeVarBinary("row2".getBytes(StandardCharsets.UTF_8), 0, + "row2".getBytes(StandardCharsets.UTF_8).length); + listWriter.writeVarBinary(ByteBuffer.wrap("row3".getBytes(StandardCharsets.UTF_8))); + listWriter.writeVarBinary(ByteBuffer.wrap( + "row4".getBytes(StandardCharsets.UTF_8)), 0, "row4".getBytes(StandardCharsets.UTF_8).length); listWriter.endList(); listWriter.setValueCount(1); - assertEquals("row1", new String((byte[]) listVector.getObject(0).get(0))); - assertEquals("row2", new String((byte[]) listVector.getObject(0).get(1))); - assertEquals("row3", new String((byte[]) listVector.getObject(0).get(2))); - assertEquals("row4", new String((byte[]) listVector.getObject(0).get(3))); + assertEquals("row1", new String((byte[]) listVector.getObject(0).get(0), StandardCharsets.UTF_8)); + assertEquals("row2", new String((byte[]) listVector.getObject(0).get(1), StandardCharsets.UTF_8)); + assertEquals("row3", new String((byte[]) listVector.getObject(0).get(2), StandardCharsets.UTF_8)); + assertEquals("row4", new String((byte[]) listVector.getObject(0).get(3), StandardCharsets.UTF_8)); } } @@ -1819,16 +1823,18 @@ public void listLargeVarBinaryHelpers() { listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); listWriter.startList(); - listWriter.writeLargeVarBinary("row1".getBytes()); - listWriter.writeLargeVarBinary("row2".getBytes(), 0, "row2".getBytes().length); - listWriter.writeLargeVarBinary(ByteBuffer.wrap("row3".getBytes())); - listWriter.writeLargeVarBinary(ByteBuffer.wrap("row4".getBytes()), 0, "row4".getBytes().length); + listWriter.writeLargeVarBinary("row1".getBytes(StandardCharsets.UTF_8)); + listWriter.writeLargeVarBinary("row2".getBytes(StandardCharsets.UTF_8), 0, + "row2".getBytes(StandardCharsets.UTF_8).length); + listWriter.writeLargeVarBinary(ByteBuffer.wrap("row3".getBytes(StandardCharsets.UTF_8))); + listWriter.writeLargeVarBinary(ByteBuffer.wrap( + "row4".getBytes(StandardCharsets.UTF_8)), 0, "row4".getBytes(StandardCharsets.UTF_8).length); listWriter.endList(); listWriter.setValueCount(1); - assertEquals("row1", new String((byte[]) listVector.getObject(0).get(0))); - assertEquals("row2", new String((byte[]) listVector.getObject(0).get(1))); - assertEquals("row3", new String((byte[]) listVector.getObject(0).get(2))); - assertEquals("row4", new String((byte[]) listVector.getObject(0).get(3))); + assertEquals("row1", new String((byte[]) listVector.getObject(0).get(0), StandardCharsets.UTF_8)); + assertEquals("row2", new String((byte[]) listVector.getObject(0).get(1), StandardCharsets.UTF_8)); + assertEquals("row3", new String((byte[]) listVector.getObject(0).get(2), StandardCharsets.UTF_8)); + assertEquals("row4", new String((byte[]) listVector.getObject(0).get(3), StandardCharsets.UTF_8)); } } @@ -1847,35 +1853,39 @@ public void unionWithVarCharAndBinaryHelpers() throws Exception { unionWriter.setPosition(3); unionWriter.writeLargeVarChar(new Text("row4")); unionWriter.setPosition(4); - unionWriter.writeVarBinary("row5".getBytes()); + unionWriter.writeVarBinary("row5".getBytes(StandardCharsets.UTF_8)); unionWriter.setPosition(5); - unionWriter.writeVarBinary("row6".getBytes(), 0, "row6".getBytes().length); + unionWriter.writeVarBinary("row6".getBytes(StandardCharsets.UTF_8), 0, + "row6".getBytes(StandardCharsets.UTF_8).length); unionWriter.setPosition(6); - unionWriter.writeVarBinary(ByteBuffer.wrap("row7".getBytes())); + unionWriter.writeVarBinary(ByteBuffer.wrap("row7".getBytes(StandardCharsets.UTF_8))); unionWriter.setPosition(7); - unionWriter.writeVarBinary(ByteBuffer.wrap("row8".getBytes()), 0, "row8".getBytes().length); + unionWriter.writeVarBinary(ByteBuffer.wrap("row8".getBytes(StandardCharsets.UTF_8)), 0, + "row8".getBytes(StandardCharsets.UTF_8).length); unionWriter.setPosition(8); - unionWriter.writeLargeVarBinary("row9".getBytes()); + unionWriter.writeLargeVarBinary("row9".getBytes(StandardCharsets.UTF_8)); unionWriter.setPosition(9); - unionWriter.writeLargeVarBinary("row10".getBytes(), 0, "row10".getBytes().length); + unionWriter.writeLargeVarBinary("row10".getBytes(StandardCharsets.UTF_8), 0, + "row10".getBytes(StandardCharsets.UTF_8).length); unionWriter.setPosition(10); - unionWriter.writeLargeVarBinary(ByteBuffer.wrap("row11".getBytes())); + unionWriter.writeLargeVarBinary(ByteBuffer.wrap("row11".getBytes(StandardCharsets.UTF_8))); unionWriter.setPosition(11); - unionWriter.writeLargeVarBinary(ByteBuffer.wrap("row12".getBytes()), 0, "row12".getBytes().length); + unionWriter.writeLargeVarBinary(ByteBuffer.wrap( + "row12".getBytes(StandardCharsets.UTF_8)), 0, "row12".getBytes(StandardCharsets.UTF_8).length); unionWriter.end(); - assertEquals("row1", new String(vector.getVarCharVector().get(0))); - assertEquals("row2", new String(vector.getVarCharVector().get(1))); - assertEquals("row3", new String(vector.getLargeVarCharVector().get(2))); - assertEquals("row4", new String(vector.getLargeVarCharVector().get(3))); - assertEquals("row5", new String(vector.getVarBinaryVector().get(4))); - assertEquals("row6", new String(vector.getVarBinaryVector().get(5))); - assertEquals("row7", new String(vector.getVarBinaryVector().get(6))); - assertEquals("row8", new String(vector.getVarBinaryVector().get(7))); - assertEquals("row9", new String(vector.getLargeVarBinaryVector().get(8))); - assertEquals("row10", new String(vector.getLargeVarBinaryVector().get(9))); - assertEquals("row11", new String(vector.getLargeVarBinaryVector().get(10))); - assertEquals("row12", new String(vector.getLargeVarBinaryVector().get(11))); + assertEquals("row1", new String(vector.getVarCharVector().get(0), StandardCharsets.UTF_8)); + assertEquals("row2", new String(vector.getVarCharVector().get(1), StandardCharsets.UTF_8)); + assertEquals("row3", new String(vector.getLargeVarCharVector().get(2), StandardCharsets.UTF_8)); + assertEquals("row4", new String(vector.getLargeVarCharVector().get(3), StandardCharsets.UTF_8)); + assertEquals("row5", new String(vector.getVarBinaryVector().get(4), StandardCharsets.UTF_8)); + assertEquals("row6", new String(vector.getVarBinaryVector().get(5), StandardCharsets.UTF_8)); + assertEquals("row7", new String(vector.getVarBinaryVector().get(6), StandardCharsets.UTF_8)); + assertEquals("row8", new String(vector.getVarBinaryVector().get(7), StandardCharsets.UTF_8)); + assertEquals("row9", new String(vector.getLargeVarBinaryVector().get(8), StandardCharsets.UTF_8)); + assertEquals("row10", new String(vector.getLargeVarBinaryVector().get(9), StandardCharsets.UTF_8)); + assertEquals("row11", new String(vector.getLargeVarBinaryVector().get(10), StandardCharsets.UTF_8)); + assertEquals("row12", new String(vector.getLargeVarBinaryVector().get(11), StandardCharsets.UTF_8)); } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java index 8663c0c49990d..de9187edb667e 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java @@ -694,19 +694,19 @@ protected void writeBatchData(ArrowWriter writer, IntVector vector, VectorSchema protected void validateBatchData(ArrowReader reader, IntVector vector) throws IOException { reader.loadNextBatch(); - assertEquals(vector.getValueCount(), 5); + assertEquals(5, vector.getValueCount()); assertTrue(vector.isNull(0)); - assertEquals(vector.get(1), 1); - assertEquals(vector.get(2), 2); + assertEquals(1, vector.get(1)); + assertEquals(2, vector.get(2)); assertTrue(vector.isNull(3)); - assertEquals(vector.get(4), 1); + assertEquals(1, vector.get(4)); reader.loadNextBatch(); - assertEquals(vector.getValueCount(), 3); + assertEquals(3, vector.getValueCount()); assertTrue(vector.isNull(0)); - assertEquals(vector.get(1), 1); - assertEquals(vector.get(2), 2); + assertEquals(1, vector.get(1)); + assertEquals(2, vector.get(2)); } protected VectorSchemaRoot writeMapData(BufferAllocator bufferAllocator) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java index 9348cd3a66708..145bdd588e945 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java @@ -79,8 +79,8 @@ public void testStreamZeroLengthBatch() throws IOException { VectorSchemaRoot root = reader.getVectorSchemaRoot(); IntVector vector = (IntVector) root.getFieldVectors().get(0); reader.loadNextBatch(); - assertEquals(vector.getValueCount(), 0); - assertEquals(root.getRowCount(), 0); + assertEquals(0, vector.getValueCount()); + assertEquals(0, root.getRowCount()); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java index 0aa49d9daa0da..bd5bd4feabbd4 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java @@ -476,7 +476,7 @@ public void testRoundtripEmptyVector() throws Exception { assertEquals(schema, readSchema); try (final VectorSchemaRoot data = reader.read()) { assertNotNull(data); - assertEquals(data.getRowCount(), 0); + assertEquals(0, data.getRowCount()); } assertNull(reader.read()); } @@ -496,7 +496,7 @@ public void testRoundtripEmptyVector() throws Exception { assertEquals(schema, readSchema); try (final VectorSchemaRoot data = reader.read()) { assertNotNull(data); - assertEquals(data.getRowCount(), 0); + assertEquals(0, data.getRowCount()); } assertNull(reader.read()); } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.java index 6aa7a0c6df5c3..ac95121eb73f2 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.java @@ -27,6 +27,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collection; import java.util.Map; @@ -138,7 +139,7 @@ private void readData( VarCharVector dictVector = (VarCharVector) dictionary.getVector(); assertEquals(expectedDictItems.length, dictVector.getValueCount()); for (int i = 0; i < dictVector.getValueCount(); i++) { - assertArrayEquals(expectedDictItems[i].getBytes(), dictVector.get(i)); + assertArrayEquals(expectedDictItems[i].getBytes(StandardCharsets.UTF_8), dictVector.get(i)); } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/message/TestMessageMetadataResult.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/message/TestMessageMetadataResult.java index ee5361547a0b9..0505a18484b54 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/message/TestMessageMetadataResult.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/message/TestMessageMetadataResult.java @@ -30,7 +30,7 @@ public void getMessageLength_returnsConstructValue() { // This API is used by spark. MessageMetadataResult result = new MessageMetadataResult(1, ByteBuffer.allocate(0), new org.apache.arrow.flatbuf.Message()); - assertEquals(result.getMessageLength(), 1); + assertEquals(1, result.getMessageLength()); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/table/BaseTableTest.java b/java/vector/src/test/java/org/apache/arrow/vector/table/BaseTableTest.java index 78f2ee51b8912..1b7f984992ada 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/table/BaseTableTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/table/BaseTableTest.java @@ -28,8 +28,10 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -282,8 +284,8 @@ void testDecode() { VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); dictionaryVector.allocateNew(2); - dictionaryVector.set(0, "one".getBytes()); - dictionaryVector.set(1, "two".getBytes()); + dictionaryVector.set(0, "one".getBytes(StandardCharsets.UTF_8)); + dictionaryVector.set(1, "two".getBytes(StandardCharsets.UTF_8)); dictionaryVector.setValueCount(2); Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); @@ -297,8 +299,8 @@ void testDecode() { try (Table t = new Table(vectorList, vectorList.get(0).getValueCount(), provider)) { VarCharVector v = (VarCharVector) t.decode(encoded.getName(), 1L); assertNotNull(v); - assertEquals("one", new String(v.get(0))); - assertEquals("two", new String(v.get(1))); + assertEquals("one", new String(Objects.requireNonNull(v.get(0)), StandardCharsets.UTF_8)); + assertEquals("two", new String(Objects.requireNonNull(v.get(1)), StandardCharsets.UTF_8)); } } @@ -319,8 +321,8 @@ private DictionaryProvider getDictionary() { VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); dictionaryVector.allocateNew(2); - dictionaryVector.set(0, "one".getBytes()); - dictionaryVector.set(1, "two".getBytes()); + dictionaryVector.set(0, "one".getBytes(StandardCharsets.UTF_8)); + dictionaryVector.set(1, "two".getBytes(StandardCharsets.UTF_8)); dictionaryVector.setValueCount(2); Dictionary dictionary = new Dictionary(dictionaryVector, encoding); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/table/RowTest.java b/java/vector/src/test/java/org/apache/arrow/vector/table/RowTest.java index eb50e866b19f0..3e6a096104d44 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/table/RowTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/table/RowTest.java @@ -650,8 +650,8 @@ void getVarChar() { c.setPosition(1); assertEquals(c.getVarCharObj(1), "two"); assertEquals(c.getVarCharObj(1), c.getVarCharObj(VARCHAR_VECTOR_NAME_1)); - assertArrayEquals("two".getBytes(), c.getVarChar(VARCHAR_VECTOR_NAME_1)); - assertArrayEquals("two".getBytes(), c.getVarChar(1)); + assertArrayEquals("two".getBytes(StandardCharsets.UTF_8), c.getVarChar(VARCHAR_VECTOR_NAME_1)); + assertArrayEquals("two".getBytes(StandardCharsets.UTF_8), c.getVarChar(1)); } } @@ -661,7 +661,7 @@ void getVarBinary() { try (Table t = new Table(vectorList)) { Row c = t.immutableRow(); c.setPosition(1); - assertArrayEquals(c.getVarBinary(1), "two".getBytes()); + assertArrayEquals(c.getVarBinary(1), "two".getBytes(StandardCharsets.UTF_8)); assertArrayEquals(c.getVarBinary(1), c.getVarBinary(VARBINARY_VECTOR_NAME_1)); } } @@ -672,7 +672,7 @@ void getLargeVarBinary() { try (Table t = new Table(vectorList)) { Row c = t.immutableRow(); c.setPosition(1); - assertArrayEquals(c.getLargeVarBinary(1), "two".getBytes()); + assertArrayEquals(c.getLargeVarBinary(1), "two".getBytes(StandardCharsets.UTF_8)); assertArrayEquals(c.getLargeVarBinary(1), c.getLargeVarBinary(VARBINARY_VECTOR_NAME_1)); } } @@ -685,8 +685,8 @@ void getLargeVarChar() { c.setPosition(1); assertEquals(c.getLargeVarCharObj(1), "two"); assertEquals(c.getLargeVarCharObj(1), c.getLargeVarCharObj(VARCHAR_VECTOR_NAME_1)); - assertArrayEquals("two".getBytes(), c.getLargeVarChar(VARCHAR_VECTOR_NAME_1)); - assertArrayEquals("two".getBytes(), c.getLargeVarChar(1)); + assertArrayEquals("two".getBytes(StandardCharsets.UTF_8), c.getLargeVarChar(VARCHAR_VECTOR_NAME_1)); + assertArrayEquals("two".getBytes(StandardCharsets.UTF_8), c.getLargeVarChar(1)); } } @@ -696,7 +696,7 @@ void getFixedBinary() { try (Table t = new Table(vectorList)) { Row c = t.immutableRow(); c.setPosition(1); - assertArrayEquals(c.getFixedSizeBinary(1), "two".getBytes()); + assertArrayEquals(c.getFixedSizeBinary(1), "two".getBytes(StandardCharsets.UTF_8)); assertArrayEquals(c.getFixedSizeBinary(1), c.getFixedSizeBinary(FIXEDBINARY_VECTOR_NAME_1)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/table/TestUtils.java b/java/vector/src/test/java/org/apache/arrow/vector/table/TestUtils.java index cb0b7b8eb6b87..c0b3bfdf73220 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/table/TestUtils.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/table/TestUtils.java @@ -20,6 +20,7 @@ import static org.apache.arrow.vector.complex.BaseRepeatedValueVector.OFFSET_WIDTH; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -117,8 +118,8 @@ static List intPlusVarcharColumns(BufferAllocator allocator) { IntVector v1 = getSimpleIntVector(allocator); VarCharVector v2 = new VarCharVector(VARCHAR_VECTOR_NAME_1, allocator); v2.allocateNew(2); - v2.set(0, "one".getBytes()); - v2.set(1, "two".getBytes()); + v2.set(0, "one".getBytes(StandardCharsets.UTF_8)); + v2.set(1, "two".getBytes(StandardCharsets.UTF_8)); v2.setValueCount(2); vectorList.add(v1); vectorList.add(v2); @@ -134,8 +135,8 @@ static List intPlusLargeVarcharColumns(BufferAllocator allocator) { IntVector v1 = getSimpleIntVector(allocator); LargeVarCharVector v2 = new LargeVarCharVector(VARCHAR_VECTOR_NAME_1, allocator); v2.allocateNew(2); - v2.set(0, "one".getBytes()); - v2.set(1, "two".getBytes()); + v2.set(0, "one".getBytes(StandardCharsets.UTF_8)); + v2.set(1, "two".getBytes(StandardCharsets.UTF_8)); v2.setValueCount(2); vectorList.add(v1); vectorList.add(v2); @@ -152,8 +153,8 @@ static List intPlusVarBinaryColumns(BufferAllocator allocator) { IntVector v1 = getSimpleIntVector(allocator); VarBinaryVector v2 = new VarBinaryVector(VARBINARY_VECTOR_NAME_1, allocator); v2.allocateNew(2); - v2.set(0, "one".getBytes()); - v2.set(1, "two".getBytes()); + v2.set(0, "one".getBytes(StandardCharsets.UTF_8)); + v2.set(1, "two".getBytes(StandardCharsets.UTF_8)); v2.setValueCount(2); vectorList.add(v1); vectorList.add(v2); @@ -170,8 +171,8 @@ static List intPlusLargeVarBinaryColumns(BufferAllocator allocator) IntVector v1 = getSimpleIntVector(allocator); LargeVarBinaryVector v2 = new LargeVarBinaryVector(VARBINARY_VECTOR_NAME_1, allocator); v2.allocateNew(2); - v2.set(0, "one".getBytes()); - v2.set(1, "two".getBytes()); + v2.set(0, "one".getBytes(StandardCharsets.UTF_8)); + v2.set(1, "two".getBytes(StandardCharsets.UTF_8)); v2.setValueCount(2); vectorList.add(v1); vectorList.add(v2); @@ -188,8 +189,8 @@ static List intPlusFixedBinaryColumns(BufferAllocator allocator) { IntVector v1 = getSimpleIntVector(allocator); FixedSizeBinaryVector v2 = new FixedSizeBinaryVector(FIXEDBINARY_VECTOR_NAME_1, allocator, 3); v2.allocateNew(2); - v2.set(0, "one".getBytes()); - v2.set(1, "two".getBytes()); + v2.set(0, "one".getBytes(StandardCharsets.UTF_8)); + v2.set(1, "two".getBytes(StandardCharsets.UTF_8)); v2.setValueCount(2); vectorList.add(v1); vectorList.add(v2); diff --git a/java/vector/src/test/java/org/apache/arrow/util/ArrowTestDataUtil.java b/java/vector/src/test/java/org/apache/arrow/vector/test/util/ArrowTestDataUtil.java similarity index 97% rename from java/vector/src/test/java/org/apache/arrow/util/ArrowTestDataUtil.java rename to java/vector/src/test/java/org/apache/arrow/vector/test/util/ArrowTestDataUtil.java index 120c0adc884ed..901a09e313f59 100644 --- a/java/vector/src/test/java/org/apache/arrow/util/ArrowTestDataUtil.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/test/util/ArrowTestDataUtil.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.util; +package org.apache.arrow.vector.test.util; import java.nio.file.Path; import java.nio.file.Paths; diff --git a/java/vector/src/test/java/org/apache/arrow/vector/testing/TestValueVectorPopulator.java b/java/vector/src/test/java/org/apache/arrow/vector/testing/TestValueVectorPopulator.java index 74257c45ca887..3c075c9293079 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/testing/TestValueVectorPopulator.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/testing/TestValueVectorPopulator.java @@ -20,6 +20,8 @@ import static junit.framework.TestCase.assertTrue; import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; @@ -204,13 +206,14 @@ public void testPopulateFixedSizeBinaryVector() { if (i % 2 == 0) { vector1.setNull(i); } else { - vector1.set(i, ("test" + i).getBytes()); + vector1.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); } } vector1.setValueCount(10); - setVector(vector2, null, "test1".getBytes(), null, "test3".getBytes(), null, "test5".getBytes(), null, - "test7".getBytes(), null, "test9".getBytes()); + setVector(vector2, null, "test1".getBytes(StandardCharsets.UTF_8), null, + "test3".getBytes(StandardCharsets.UTF_8), null, "test5".getBytes(StandardCharsets.UTF_8), null, + "test7".getBytes(StandardCharsets.UTF_8), null, "test9".getBytes(StandardCharsets.UTF_8)); assertTrue(VectorEqualsVisitor.vectorEquals(vector1, vector2)); } } @@ -571,13 +574,14 @@ public void testPopulateVarBinaryVector() { if (i % 2 == 0) { vector1.setNull(i); } else { - vector1.set(i, ("test" + i).getBytes()); + vector1.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); } } vector1.setValueCount(10); - setVector(vector2, null, "test1".getBytes(), null, "test3".getBytes(), null, "test5".getBytes(), null, - "test7".getBytes(), null, "test9".getBytes()); + setVector(vector2, null, "test1".getBytes(StandardCharsets.UTF_8), null, + "test3".getBytes(StandardCharsets.UTF_8), null, "test5".getBytes(StandardCharsets.UTF_8), null, + "test7".getBytes(StandardCharsets.UTF_8), null, "test9".getBytes(StandardCharsets.UTF_8)); assertTrue(VectorEqualsVisitor.vectorEquals(vector1, vector2)); } } @@ -592,7 +596,7 @@ public void testPopulateVarCharVector() { if (i % 2 == 0) { vector1.setNull(i); } else { - vector1.set(i, ("test" + i).getBytes()); + vector1.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); } } vector1.setValueCount(10); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java index 084350410a4f5..872b2f3934b07 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java @@ -221,7 +221,7 @@ public void roundtripLocation() throws IOException { final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0); Assert.assertTrue(deserialized instanceof LocationVector); - Assert.assertEquals(deserialized.getName(), "location"); + Assert.assertEquals("location", deserialized.getName()); StructVector deserStruct = (StructVector) deserialized.getUnderlyingVector(); Assert.assertNotNull(deserStruct.getChild("Latitude")); Assert.assertNotNull(deserStruct.getChild("Longitude")); @@ -273,7 +273,7 @@ public void testVectorCompare() { // Test out vector appender VectorBatchAppender.batchAppend(a1, a2, bb); - assertEquals(a1.getValueCount(), 6); + assertEquals(6, a1.getValueCount()); validateVisitor.visit(a1, null); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/util/TestElementAddressableVectorIterator.java b/java/vector/src/test/java/org/apache/arrow/vector/util/TestElementAddressableVectorIterator.java index 419872225e16f..1c8281c85981b 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/util/TestElementAddressableVectorIterator.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/util/TestElementAddressableVectorIterator.java @@ -20,6 +20,8 @@ import static junit.framework.TestCase.assertNull; import static org.junit.Assert.assertEquals; +import java.nio.charset.StandardCharsets; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.util.ArrowBufPointer; @@ -98,7 +100,7 @@ public void testIterateVarCharVector() { if (i == 0) { strVector.setNull(i); } else { - strVector.set(i, String.valueOf(i).getBytes()); + strVector.set(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); } } @@ -125,7 +127,7 @@ public void testIterateVarCharVector() { assertEquals(expected.length(), pt.getLength()); pt.getBuf().getBytes(pt.getOffset(), actual); - assertEquals(expected, new String(actual)); + assertEquals(expected, new String(actual, StandardCharsets.UTF_8)); } index += 1; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/util/TestReusableByteArray.java b/java/vector/src/test/java/org/apache/arrow/vector/util/TestReusableByteArray.java index b11aa5638d651..f562e63b4bf8d 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/util/TestReusableByteArray.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/util/TestReusableByteArray.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Base64; @@ -54,25 +55,27 @@ public void testSetByteArrayRepeatedly() { ReusableByteArray byteArray = new ReusableByteArray(); try (ArrowBuf workingBuf = allocator.buffer(100)) { final String str = "test"; - workingBuf.setBytes(0, str.getBytes()); - byteArray.set(workingBuf, 0, str.getBytes().length); - assertEquals(str.getBytes().length, byteArray.getLength()); - assertArrayEquals(str.getBytes(), Arrays.copyOfRange(byteArray.getBuffer(), 0, (int) byteArray.getLength())); - assertEquals(Base64.getEncoder().encodeToString(str.getBytes()), byteArray.toString()); - assertEquals(new ReusableByteArray(str.getBytes()), byteArray); - assertEquals(new ReusableByteArray(str.getBytes()).hashCode(), byteArray.hashCode()); + workingBuf.setBytes(0, str.getBytes(StandardCharsets.UTF_8)); + byteArray.set(workingBuf, 0, str.getBytes(StandardCharsets.UTF_8).length); + assertEquals(str.getBytes(StandardCharsets.UTF_8).length, byteArray.getLength()); + assertArrayEquals(str.getBytes(StandardCharsets.UTF_8), Arrays.copyOfRange(byteArray.getBuffer(), 0, + (int) byteArray.getLength())); + assertEquals(Base64.getEncoder().encodeToString(str.getBytes(StandardCharsets.UTF_8)), byteArray.toString()); + assertEquals(new ReusableByteArray(str.getBytes(StandardCharsets.UTF_8)), byteArray); + assertEquals(new ReusableByteArray(str.getBytes(StandardCharsets.UTF_8)).hashCode(), byteArray.hashCode()); // Test a longer string. Should require reallocation. final String str2 = "test_longer"; byte[] oldBuffer = byteArray.getBuffer(); workingBuf.clear(); - workingBuf.setBytes(0, str2.getBytes()); - byteArray.set(workingBuf, 0, str2.getBytes().length); - assertEquals(str2.getBytes().length, byteArray.getLength()); - assertArrayEquals(str2.getBytes(), Arrays.copyOfRange(byteArray.getBuffer(), 0, (int) byteArray.getLength())); - assertEquals(Base64.getEncoder().encodeToString(str2.getBytes()), byteArray.toString()); - assertEquals(new ReusableByteArray(str2.getBytes()), byteArray); - assertEquals(new ReusableByteArray(str2.getBytes()).hashCode(), byteArray.hashCode()); + workingBuf.setBytes(0, str2.getBytes(StandardCharsets.UTF_8)); + byteArray.set(workingBuf, 0, str2.getBytes(StandardCharsets.UTF_8).length); + assertEquals(str2.getBytes(StandardCharsets.UTF_8).length, byteArray.getLength()); + assertArrayEquals(str2.getBytes(StandardCharsets.UTF_8), Arrays.copyOfRange(byteArray.getBuffer(), 0, + (int) byteArray.getLength())); + assertEquals(Base64.getEncoder().encodeToString(str2.getBytes(StandardCharsets.UTF_8)), byteArray.toString()); + assertEquals(new ReusableByteArray(str2.getBytes(StandardCharsets.UTF_8)), byteArray); + assertEquals(new ReusableByteArray(str2.getBytes(StandardCharsets.UTF_8)).hashCode(), byteArray.hashCode()); // Verify reallocation needed. assertNotSame(oldBuffer, byteArray.getBuffer()); @@ -82,13 +85,14 @@ public void testSetByteArrayRepeatedly() { final String str3 = "short"; oldBuffer = byteArray.getBuffer(); workingBuf.clear(); - workingBuf.setBytes(0, str3.getBytes()); - byteArray.set(workingBuf, 0, str3.getBytes().length); - assertEquals(str3.getBytes().length, byteArray.getLength()); - assertArrayEquals(str3.getBytes(), Arrays.copyOfRange(byteArray.getBuffer(), 0, (int) byteArray.getLength())); - assertEquals(Base64.getEncoder().encodeToString(str3.getBytes()), byteArray.toString()); - assertEquals(new ReusableByteArray(str3.getBytes()), byteArray); - assertEquals(new ReusableByteArray(str3.getBytes()).hashCode(), byteArray.hashCode()); + workingBuf.setBytes(0, str3.getBytes(StandardCharsets.UTF_8)); + byteArray.set(workingBuf, 0, str3.getBytes(StandardCharsets.UTF_8).length); + assertEquals(str3.getBytes(StandardCharsets.UTF_8).length, byteArray.getLength()); + assertArrayEquals(str3.getBytes(StandardCharsets.UTF_8), Arrays.copyOfRange(byteArray.getBuffer(), 0, + (int) byteArray.getLength())); + assertEquals(Base64.getEncoder().encodeToString(str3.getBytes(StandardCharsets.UTF_8)), byteArray.toString()); + assertEquals(new ReusableByteArray(str3.getBytes(StandardCharsets.UTF_8)), byteArray); + assertEquals(new ReusableByteArray(str3.getBytes(StandardCharsets.UTF_8)).hashCode(), byteArray.hashCode()); // Verify reallocation was not needed. assertSame(oldBuffer, byteArray.getBuffer()); diff --git a/java/vector/src/test/java/org/apache/arrow/util/TestSchemaUtil.java b/java/vector/src/test/java/org/apache/arrow/vector/util/TestSchemaUtil.java similarity index 98% rename from java/vector/src/test/java/org/apache/arrow/util/TestSchemaUtil.java rename to java/vector/src/test/java/org/apache/arrow/vector/util/TestSchemaUtil.java index cefff83823289..52b6584086832 100644 --- a/java/vector/src/test/java/org/apache/arrow/util/TestSchemaUtil.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/util/TestSchemaUtil.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.util; +package org.apache.arrow.vector.util; import static java.util.Arrays.asList; import static org.junit.Assert.assertEquals; diff --git a/java/vector/src/test/java/org/apache/arrow/vector/util/TestVectorAppender.java b/java/vector/src/test/java/org/apache/arrow/vector/util/TestVectorAppender.java index ab36ea2fd2129..93e7535947536 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/util/TestVectorAppender.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/util/TestVectorAppender.java @@ -437,8 +437,6 @@ public void testAppendStructVector() { delta.accept(appender, null); assertEquals(length1 + length2, target.getValueCount()); - IntVector child1 = (IntVector) target.getVectorById(0); - VarCharVector child2 = (VarCharVector) target.getVectorById(1); try (IntVector expected1 = new IntVector("expected1", allocator); VarCharVector expected2 = new VarCharVector("expected2", allocator)) { diff --git a/js/package.json b/js/package.json index d1346eb37c9ed..57f9267afa3a8 100644 --- a/js/package.json +++ b/js/package.json @@ -121,5 +121,5 @@ "engines": { "node": ">=12.0" }, - "version": "15.0.0-SNAPSHOT" + "version": "16.0.0-SNAPSHOT" } diff --git a/js/src/builder/buffer.ts b/js/src/builder/buffer.ts index 18c6dcda738b9..ad1c06b0d9f0f 100644 --- a/js/src/builder/buffer.ts +++ b/js/src/builder/buffer.ts @@ -27,30 +27,17 @@ function roundLengthUpToNearest64Bytes(len: number, BPE: number) { /** @ignore */ function resizeArray(arr: T, len = 0): T { - // TODO: remove when https://github.com/microsoft/TypeScript/issues/54636 is fixed - const buffer = arr.buffer as ArrayBufferLike & { resizable: boolean; resize: (byteLength: number) => void; maxByteLength: number }; - const byteLength = len * arr.BYTES_PER_ELEMENT; - if (buffer.resizable && byteLength <= buffer.maxByteLength) { - buffer.resize(byteLength); - return arr; - } - - // Fallback for non-resizable buffers return arr.length >= len ? arr.subarray(0, len) as T : memcpy(new (arr.constructor as any)(len), arr, 0); } -/** @ignore */ -export const SAFE_ARRAY_SIZE = 2 ** 32 - 1; - /** @ignore */ export class BufferBuilder { constructor(bufferType: ArrayCtor, initialSize = 0, stride = 1) { this.length = Math.ceil(initialSize / stride); - // TODO: remove as any when https://github.com/microsoft/TypeScript/issues/54636 is fixed - this.buffer = new bufferType(new (ArrayBuffer as any)(this.length * bufferType.BYTES_PER_ELEMENT, { maxByteLength: SAFE_ARRAY_SIZE })) as T; + this.buffer = new bufferType(this.length) as T; this.stride = stride; this.BYTES_PER_ELEMENT = bufferType.BYTES_PER_ELEMENT; this.ArrayType = bufferType; @@ -94,8 +81,7 @@ export class BufferBuilder { } public clear() { this.length = 0; - // TODO: remove as any when https://github.com/microsoft/TypeScript/issues/54636 is fixed - this.buffer = new this.ArrayType(new (ArrayBuffer as any)(0, { maxByteLength: SAFE_ARRAY_SIZE })) as T; + this.buffer = new this.ArrayType() as T; return this; } protected _resize(newLength: number) { diff --git a/matlab/CMakeLists.txt b/matlab/CMakeLists.txt index 47d2acd613f8b..206ecb318b3cc 100644 --- a/matlab/CMakeLists.txt +++ b/matlab/CMakeLists.txt @@ -94,7 +94,7 @@ endfunction() set(CMAKE_CXX_STANDARD 17) -set(MLARROW_VERSION "15.0.0-SNAPSHOT") +set(MLARROW_VERSION "16.0.0-SNAPSHOT") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" MLARROW_BASE_VERSION "${MLARROW_VERSION}") project(mlarrow VERSION "${MLARROW_BASE_VERSION}") diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 2df1e67b9f4c7..54a5b99e058a5 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -21,7 +21,7 @@ cmake_minimum_required(VERSION 3.16) project(pyarrow) -set(PYARROW_VERSION "15.0.0-SNAPSHOT") +set(PYARROW_VERSION "16.0.0-SNAPSHOT") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" PYARROW_BASE_VERSION "${PYARROW_VERSION}") # Running from a Python sdist tarball diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 5c2d22aef1895..1416f5f4346d9 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -66,8 +66,7 @@ cdef shared_ptr[CDataType] _ndarray_to_type(object values, dtype = values.dtype if type is None and dtype != object: - with nogil: - check_status(NumPyDtypeToArrow(dtype, &c_type)) + c_type = GetResultValue(NumPyDtypeToArrow(dtype)) if type is not None: c_type = type.sp_type diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index 9301a5fee5ade..f83753ac57d47 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -292,8 +292,8 @@ def _ensure_partitioning(scheme): elif isinstance(scheme, (Partitioning, PartitioningFactory)): pass else: - ValueError("Expected Partitioning or PartitioningFactory, got {}" - .format(type(scheme))) + raise ValueError("Expected Partitioning or PartitioningFactory, got {}" + .format(type(scheme))) return scheme diff --git a/python/pyarrow/includes/libarrow_python.pxd b/python/pyarrow/includes/libarrow_python.pxd index e3179062a1e52..906f0b7d28e59 100644 --- a/python/pyarrow/includes/libarrow_python.pxd +++ b/python/pyarrow/includes/libarrow_python.pxd @@ -73,7 +73,7 @@ cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil: object obj, object mask, const PyConversionOptions& options, CMemoryPool* pool) - CStatus NumPyDtypeToArrow(object dtype, shared_ptr[CDataType]* type) + CResult[shared_ptr[CDataType]] NumPyDtypeToArrow(object dtype) CStatus NdarrayToArrow(CMemoryPool* pool, object ao, object mo, c_bool from_pandas, diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index 39dee85492400..61e6318e29c24 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -967,20 +967,9 @@ def _extract_index_level(table, result_table, field_name, # The serialized index column was removed by the user return result_table, None, None - pd = _pandas_api.pd - col = table.column(i) - values = col.to_pandas(types_mapper=types_mapper).values - - if hasattr(values, 'flags') and not values.flags.writeable: - # ARROW-1054: in pandas 0.19.2, factorize will reject - # non-writeable arrays when calling MultiIndex.from_arrays - values = values.copy() - - if isinstance(col.type, pa.lib.TimestampType) and col.type.tz is not None: - index_level = make_tz_aware(pd.Series(values, copy=False), col.type.tz) - else: - index_level = pd.Series(values, dtype=values.dtype, copy=False) + index_level = col.to_pandas(types_mapper=types_mapper) + index_level.name = None result_table = result_table.remove_column( result_table.schema.get_field_index(field_name) ) diff --git a/python/pyarrow/src/arrow/python/common.cc b/python/pyarrow/src/arrow/python/common.cc index 6fe2ed4dae321..2f44a9122f024 100644 --- a/python/pyarrow/src/arrow/python/common.cc +++ b/python/pyarrow/src/arrow/python/common.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "arrow/memory_pool.h" @@ -90,9 +91,15 @@ class PythonErrorDetail : public StatusDetail { std::string ToString() const override { // This is simple enough not to need the GIL - const auto ty = reinterpret_cast(exc_type_.obj()); - // XXX Should we also print traceback? - return std::string("Python exception: ") + ty->tp_name; + Result result = FormatImpl(); + + if (result.ok()) { + return result.ValueOrDie(); + } else { + // Fallback to just the exception type + const auto ty = reinterpret_cast(exc_type_.obj()); + return std::string("Python exception: ") + ty->tp_name; + } } void RestorePyError() const { @@ -131,6 +138,42 @@ class PythonErrorDetail : public StatusDetail { } protected: + Result FormatImpl() const { + PyAcquireGIL lock; + + // Use traceback.format_exception() + OwnedRef traceback_module; + RETURN_NOT_OK(internal::ImportModule("traceback", &traceback_module)); + + OwnedRef fmt_exception; + RETURN_NOT_OK(internal::ImportFromModule(traceback_module.obj(), "format_exception", + &fmt_exception)); + + OwnedRef formatted; + formatted.reset(PyObject_CallFunctionObjArgs(fmt_exception.obj(), exc_type_.obj(), + exc_value_.obj(), exc_traceback_.obj(), + NULL)); + RETURN_IF_PYERROR(); + + std::stringstream ss; + ss << "Python exception: "; + Py_ssize_t num_lines = PySequence_Length(formatted.obj()); + RETURN_IF_PYERROR(); + + for (Py_ssize_t i = 0; i < num_lines; ++i) { + Py_ssize_t line_size; + + PyObject* line = PySequence_GetItem(formatted.obj(), i); + RETURN_IF_PYERROR(); + + const char* data = PyUnicode_AsUTF8AndSize(line, &line_size); + RETURN_IF_PYERROR(); + + ss << std::string_view(data, line_size); + } + return ss.str(); + } + PythonErrorDetail() = default; OwnedRefNoGIL exc_type_, exc_value_, exc_traceback_; diff --git a/python/pyarrow/src/arrow/python/inference.cc b/python/pyarrow/src/arrow/python/inference.cc index 9537aec574470..10116f9afad69 100644 --- a/python/pyarrow/src/arrow/python/inference.cc +++ b/python/pyarrow/src/arrow/python/inference.cc @@ -468,10 +468,7 @@ class TypeInferrer { if (numpy_dtype_count_ > 0) { // All NumPy scalars and Nones/nulls if (numpy_dtype_count_ + none_count_ == total_count_) { - std::shared_ptr type; - RETURN_NOT_OK(NumPyDtypeToArrow(numpy_unifier_.current_dtype(), &type)); - *out = type; - return Status::OK(); + return NumPyDtypeToArrow(numpy_unifier_.current_dtype()).Value(out); } // The "bad path": data contains a mix of NumPy scalars and diff --git a/python/pyarrow/src/arrow/python/numpy_convert.cc b/python/pyarrow/src/arrow/python/numpy_convert.cc index 49706807644d2..dfee88c092e65 100644 --- a/python/pyarrow/src/arrow/python/numpy_convert.cc +++ b/python/pyarrow/src/arrow/python/numpy_convert.cc @@ -59,12 +59,11 @@ NumPyBuffer::~NumPyBuffer() { #define TO_ARROW_TYPE_CASE(NPY_NAME, FACTORY) \ case NPY_##NPY_NAME: \ - *out = FACTORY(); \ - break; + return FACTORY(); namespace { -Status GetTensorType(PyObject* dtype, std::shared_ptr* out) { +Result> GetTensorType(PyObject* dtype) { if (!PyObject_TypeCheck(dtype, &PyArrayDescr_Type)) { return Status::TypeError("Did not pass numpy.dtype object"); } @@ -84,11 +83,8 @@ Status GetTensorType(PyObject* dtype, std::shared_ptr* out) { TO_ARROW_TYPE_CASE(FLOAT16, float16); TO_ARROW_TYPE_CASE(FLOAT32, float32); TO_ARROW_TYPE_CASE(FLOAT64, float64); - default: { - return Status::NotImplemented("Unsupported numpy type ", descr->type_num); - } } - return Status::OK(); + return Status::NotImplemented("Unsupported numpy type ", descr->type_num); } Status GetNumPyType(const DataType& type, int* type_num) { @@ -120,15 +116,21 @@ Status GetNumPyType(const DataType& type, int* type_num) { } // namespace -Status NumPyDtypeToArrow(PyObject* dtype, std::shared_ptr* out) { +Result> NumPyScalarToArrowDataType(PyObject* scalar) { + PyArray_Descr* descr = PyArray_DescrFromScalar(scalar); + OwnedRef descr_ref(reinterpret_cast(descr)); + return NumPyDtypeToArrow(descr); +} + +Result> NumPyDtypeToArrow(PyObject* dtype) { if (!PyObject_TypeCheck(dtype, &PyArrayDescr_Type)) { return Status::TypeError("Did not pass numpy.dtype object"); } PyArray_Descr* descr = reinterpret_cast(dtype); - return NumPyDtypeToArrow(descr, out); + return NumPyDtypeToArrow(descr); } -Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr* out) { +Result> NumPyDtypeToArrow(PyArray_Descr* descr) { int type_num = fix_numpy_type_num(descr->type_num); switch (type_num) { @@ -151,20 +153,15 @@ Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr* out) { reinterpret_cast(descr->c_metadata); switch (date_dtype->meta.base) { case NPY_FR_s: - *out = timestamp(TimeUnit::SECOND); - break; + return timestamp(TimeUnit::SECOND); case NPY_FR_ms: - *out = timestamp(TimeUnit::MILLI); - break; + return timestamp(TimeUnit::MILLI); case NPY_FR_us: - *out = timestamp(TimeUnit::MICRO); - break; + return timestamp(TimeUnit::MICRO); case NPY_FR_ns: - *out = timestamp(TimeUnit::NANO); - break; + return timestamp(TimeUnit::NANO); case NPY_FR_D: - *out = date32(); - break; + return date32(); case NPY_FR_GENERIC: return Status::NotImplemented("Unbound or generic datetime64 time unit"); default: @@ -176,29 +173,22 @@ Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr* out) { reinterpret_cast(descr->c_metadata); switch (timedelta_dtype->meta.base) { case NPY_FR_s: - *out = duration(TimeUnit::SECOND); - break; + return duration(TimeUnit::SECOND); case NPY_FR_ms: - *out = duration(TimeUnit::MILLI); - break; + return duration(TimeUnit::MILLI); case NPY_FR_us: - *out = duration(TimeUnit::MICRO); - break; + return duration(TimeUnit::MICRO); case NPY_FR_ns: - *out = duration(TimeUnit::NANO); - break; + return duration(TimeUnit::NANO); case NPY_FR_GENERIC: return Status::NotImplemented("Unbound or generic timedelta64 time unit"); default: return Status::NotImplemented("Unsupported timedelta64 time unit"); } } break; - default: { - return Status::NotImplemented("Unsupported numpy type ", descr->type_num); - } } - return Status::OK(); + return Status::NotImplemented("Unsupported numpy type ", descr->type_num); } #undef TO_ARROW_TYPE_CASE @@ -230,9 +220,8 @@ Status NdarrayToTensor(MemoryPool* pool, PyObject* ao, strides[i] = array_strides[i]; } - std::shared_ptr type; - RETURN_NOT_OK( - GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray)), &type)); + ARROW_ASSIGN_OR_RAISE( + auto type, GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray)))); *out = std::make_shared(type, data, shape, strides, dim_names); return Status::OK(); } @@ -435,9 +424,9 @@ Status NdarraysToSparseCOOTensor(MemoryPool* pool, PyObject* data_ao, PyObject* PyArrayObject* ndarray_data = reinterpret_cast(data_ao); std::shared_ptr data = std::make_shared(data_ao); - std::shared_ptr type_data; - RETURN_NOT_OK(GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray_data)), - &type_data)); + ARROW_ASSIGN_OR_RAISE( + auto type_data, + GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray_data)))); std::shared_ptr coords; RETURN_NOT_OK(NdarrayToTensor(pool, coords_ao, {}, &coords)); @@ -462,9 +451,9 @@ Status NdarraysToSparseCSXMatrix(MemoryPool* pool, PyObject* data_ao, PyObject* PyArrayObject* ndarray_data = reinterpret_cast(data_ao); std::shared_ptr data = std::make_shared(data_ao); - std::shared_ptr type_data; - RETURN_NOT_OK(GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray_data)), - &type_data)); + ARROW_ASSIGN_OR_RAISE( + auto type_data, + GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray_data)))); std::shared_ptr indptr, indices; RETURN_NOT_OK(NdarrayToTensor(pool, indptr_ao, {}, &indptr)); @@ -491,9 +480,9 @@ Status NdarraysToSparseCSFTensor(MemoryPool* pool, PyObject* data_ao, PyObject* const int ndim = static_cast(shape.size()); PyArrayObject* ndarray_data = reinterpret_cast(data_ao); std::shared_ptr data = std::make_shared(data_ao); - std::shared_ptr type_data; - RETURN_NOT_OK(GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray_data)), - &type_data)); + ARROW_ASSIGN_OR_RAISE( + auto type_data, + GetTensorType(reinterpret_cast(PyArray_DESCR(ndarray_data)))); std::vector> indptr(ndim - 1); std::vector> indices(ndim); diff --git a/python/pyarrow/src/arrow/python/numpy_convert.h b/python/pyarrow/src/arrow/python/numpy_convert.h index 10451077a221d..2d1086e135528 100644 --- a/python/pyarrow/src/arrow/python/numpy_convert.h +++ b/python/pyarrow/src/arrow/python/numpy_convert.h @@ -49,9 +49,11 @@ class ARROW_PYTHON_EXPORT NumPyBuffer : public Buffer { }; ARROW_PYTHON_EXPORT -Status NumPyDtypeToArrow(PyObject* dtype, std::shared_ptr* out); +Result> NumPyDtypeToArrow(PyObject* dtype); ARROW_PYTHON_EXPORT -Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr* out); +Result> NumPyDtypeToArrow(PyArray_Descr* descr); +ARROW_PYTHON_EXPORT +Result> NumPyScalarToArrowDataType(PyObject* scalar); ARROW_PYTHON_EXPORT Status NdarrayToTensor(MemoryPool* pool, PyObject* ao, const std::vector& dim_names, diff --git a/python/pyarrow/src/arrow/python/numpy_to_arrow.cc b/python/pyarrow/src/arrow/python/numpy_to_arrow.cc index 2727ce32f4494..8903df31be826 100644 --- a/python/pyarrow/src/arrow/python/numpy_to_arrow.cc +++ b/python/pyarrow/src/arrow/python/numpy_to_arrow.cc @@ -462,8 +462,7 @@ template inline Status NumPyConverter::ConvertData(std::shared_ptr* data) { RETURN_NOT_OK(PrepareInputData(data)); - std::shared_ptr input_type; - RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast(dtype_), &input_type)); + ARROW_ASSIGN_OR_RAISE(auto input_type, NumPyDtypeToArrow(dtype_)); if (!input_type->Equals(*type_)) { RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_, type_, @@ -490,7 +489,7 @@ inline Status NumPyConverter::ConvertData(std::shared_ptr* d Status s = StaticCastBuffer(**data, length_, pool_, data); RETURN_NOT_OK(s); } else { - RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast(dtype_), &input_type)); + ARROW_ASSIGN_OR_RAISE(input_type, NumPyDtypeToArrow(dtype_)); if (!input_type->Equals(*type_)) { // The null bitmap was already computed in VisitNative() RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_, @@ -498,7 +497,7 @@ inline Status NumPyConverter::ConvertData(std::shared_ptr* d } } } else { - RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast(dtype_), &input_type)); + ARROW_ASSIGN_OR_RAISE(input_type, NumPyDtypeToArrow(dtype_)); if (!input_type->Equals(*type_)) { RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_, type_, cast_options_, pool_, data)); @@ -531,7 +530,7 @@ inline Status NumPyConverter::ConvertData(std::shared_ptr* d } *data = std::move(result); } else { - RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast(dtype_), &input_type)); + ARROW_ASSIGN_OR_RAISE(input_type, NumPyDtypeToArrow(dtype_)); if (!input_type->Equals(*type_)) { // The null bitmap was already computed in VisitNative() RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_, @@ -539,7 +538,7 @@ inline Status NumPyConverter::ConvertData(std::shared_ptr* d } } } else { - RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast(dtype_), &input_type)); + ARROW_ASSIGN_OR_RAISE(input_type, NumPyDtypeToArrow(dtype_)); if (!input_type->Equals(*type_)) { RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_, type_, cast_options_, pool_, data)); diff --git a/python/pyarrow/src/arrow/python/python_test.cc b/python/pyarrow/src/arrow/python/python_test.cc index 01ab8a3038099..746bf410911f9 100644 --- a/python/pyarrow/src/arrow/python/python_test.cc +++ b/python/pyarrow/src/arrow/python/python_test.cc @@ -174,10 +174,14 @@ Status TestOwnedRefNoGILMoves() { } } -std::string FormatPythonException(const std::string& exc_class_name) { +std::string FormatPythonException(const std::string& exc_class_name, + const std::string& exc_value) { std::stringstream ss; ss << "Python exception: "; ss << exc_class_name; + ss << ": "; + ss << exc_value; + ss << "\n"; return ss.str(); } @@ -205,7 +209,8 @@ Status TestCheckPyErrorStatus() { } PyErr_SetString(PyExc_TypeError, "some error"); - ASSERT_OK(check_error(st, "some error", FormatPythonException("TypeError"))); + ASSERT_OK( + check_error(st, "some error", FormatPythonException("TypeError", "some error"))); ASSERT_TRUE(st.IsTypeError()); PyErr_SetString(PyExc_ValueError, "some error"); @@ -223,7 +228,8 @@ Status TestCheckPyErrorStatus() { } PyErr_SetString(PyExc_NotImplementedError, "some error"); - ASSERT_OK(check_error(st, "some error", FormatPythonException("NotImplementedError"))); + ASSERT_OK(check_error(st, "some error", + FormatPythonException("NotImplementedError", "some error"))); ASSERT_TRUE(st.IsNotImplemented()); // No override if a specific status code is given @@ -246,7 +252,8 @@ Status TestCheckPyErrorStatusNoGIL() { lock.release(); ASSERT_TRUE(st.IsUnknownError()); ASSERT_EQ(st.message(), "zzzt"); - ASSERT_EQ(st.detail()->ToString(), FormatPythonException("ZeroDivisionError")); + ASSERT_EQ(st.detail()->ToString(), + FormatPythonException("ZeroDivisionError", "zzzt")); return Status::OK(); } } @@ -257,7 +264,7 @@ Status TestRestorePyErrorBasics() { ASSERT_FALSE(PyErr_Occurred()); ASSERT_TRUE(st.IsUnknownError()); ASSERT_EQ(st.message(), "zzzt"); - ASSERT_EQ(st.detail()->ToString(), FormatPythonException("ZeroDivisionError")); + ASSERT_EQ(st.detail()->ToString(), FormatPythonException("ZeroDivisionError", "zzzt")); RestorePyError(st); ASSERT_TRUE(PyErr_Occurred()); diff --git a/python/pyarrow/src/arrow/python/python_to_arrow.cc b/python/pyarrow/src/arrow/python/python_to_arrow.cc index 23b92598e321e..d1d94ac17a13e 100644 --- a/python/pyarrow/src/arrow/python/python_to_arrow.cc +++ b/python/pyarrow/src/arrow/python/python_to_arrow.cc @@ -386,8 +386,7 @@ class PyValue { } } else if (PyArray_CheckAnyScalarExact(obj)) { // validate that the numpy scalar has np.datetime64 dtype - std::shared_ptr numpy_type; - RETURN_NOT_OK(NumPyDtypeToArrow(PyArray_DescrFromScalar(obj), &numpy_type)); + ARROW_ASSIGN_OR_RAISE(auto numpy_type, NumPyScalarToArrowDataType(obj)); if (!numpy_type->Equals(*type)) { return Status::NotImplemented("Expected np.datetime64 but got: ", numpy_type->ToString()); @@ -466,8 +465,7 @@ class PyValue { } } else if (PyArray_CheckAnyScalarExact(obj)) { // validate that the numpy scalar has np.datetime64 dtype - std::shared_ptr numpy_type; - RETURN_NOT_OK(NumPyDtypeToArrow(PyArray_DescrFromScalar(obj), &numpy_type)); + ARROW_ASSIGN_OR_RAISE(auto numpy_type, NumPyScalarToArrowDataType(obj)); if (!numpy_type->Equals(*type)) { return Status::NotImplemented("Expected np.timedelta64 but got: ", numpy_type->ToString()); diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index d98c93e1c049b..3c450d61a7659 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -5202,7 +5202,17 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None): raise ValueError( "The 'names' argument is not valid when passing a dictionary") return Table.from_pydict(data, schema=schema, metadata=metadata) + elif _pandas_api.is_data_frame(data): + if names is not None or metadata is not None: + raise ValueError( + "The 'names' and 'metadata' arguments are not valid when " + "passing a pandas DataFrame") + return Table.from_pandas(data, schema=schema, nthreads=nthreads) elif hasattr(data, "__arrow_c_stream__"): + if names is not None or metadata is not None: + raise ValueError( + "The 'names' and 'metadata' arguments are not valid when " + "using Arrow PyCapsule Interface") if schema is not None: requested = schema.__arrow_c_schema__() else: @@ -5216,14 +5226,12 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None): table = table.cast(schema) return table elif hasattr(data, "__arrow_c_array__"): - batch = record_batch(data, schema) - return Table.from_batches([batch]) - elif _pandas_api.is_data_frame(data): if names is not None or metadata is not None: raise ValueError( "The 'names' and 'metadata' arguments are not valid when " - "passing a pandas DataFrame") - return Table.from_pandas(data, schema=schema, nthreads=nthreads) + "using Arrow PyCapsule Interface") + batch = record_batch(data, schema) + return Table.from_batches([batch]) else: raise TypeError( "Expected pandas DataFrame, python dictionary or list of arrays") diff --git a/python/pyarrow/tests/parquet/test_datetime.py b/python/pyarrow/tests/parquet/test_datetime.py index 6a9cbd4f73d4f..0896eb37e6473 100644 --- a/python/pyarrow/tests/parquet/test_datetime.py +++ b/python/pyarrow/tests/parquet/test_datetime.py @@ -116,7 +116,7 @@ def test_coerce_timestamps(tempdir): df_expected = df.copy() for i, x in enumerate(df_expected['datetime64']): if isinstance(x, np.ndarray): - df_expected['datetime64'][i] = x.astype('M8[us]') + df_expected.loc[i, 'datetime64'] = x.astype('M8[us]') tm.assert_frame_equal(df_expected, df_read) @@ -429,7 +429,7 @@ def test_noncoerced_nanoseconds_written_without_exception(tempdir): # nanosecond timestamps by default n = 9 df = pd.DataFrame({'x': range(n)}, - index=pd.date_range('2017-01-01', freq='1n', periods=n)) + index=pd.date_range('2017-01-01', freq='ns', periods=n)) tb = pa.Table.from_pandas(df) filename = tempdir / 'written.parquet' diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index a9c17cc100cb4..ff81b06440f03 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -414,6 +414,40 @@ def test_export_import_batch_reader(reader_factory): pa.RecordBatchReader._import_from_c(ptr_stream) +@needs_cffi +def test_export_import_exception_reader(): + # See: https://github.com/apache/arrow/issues/37164 + c_stream = ffi.new("struct ArrowArrayStream*") + ptr_stream = int(ffi.cast("uintptr_t", c_stream)) + + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + def gen(): + if True: + try: + raise ValueError('foo') + except ValueError as e: + raise NotImplementedError('bar') from e + else: + yield from make_batches() + + original = pa.RecordBatchReader.from_batches(make_schema(), gen()) + original._export_to_c(ptr_stream) + + reader = pa.RecordBatchReader._import_from_c(ptr_stream) + with pytest.raises(OSError) as exc_info: + reader.read_next_batch() + + # inner *and* outer exception should be present + assert 'ValueError: foo' in str(exc_info.value) + assert 'NotImplementedError: bar' in str(exc_info.value) + # Stacktrace containing line of the raise statement + assert 'raise ValueError(\'foo\')' in str(exc_info.value) + + assert pa.total_allocated_bytes() == old_allocated + + @needs_cffi def test_imported_batch_reader_error(): c_stream = ffi.new("struct ArrowArrayStream*") diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 7c5a134d330ac..4b58dc65bae9b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -561,7 +561,8 @@ def test_slice_compatibility(): def test_binary_slice_compatibility(): - arr = pa.array([b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"]) + data = [b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"] + arr = pa.array(data) for start, stop, step in itertools.product(range(-6, 6), range(-6, 6), range(-3, 4)): @@ -574,6 +575,13 @@ def test_binary_slice_compatibility(): assert expected.equals(result) # Positional options assert pc.binary_slice(arr, start, stop, step) == result + # Fixed size binary input / output + for item in data: + fsb_scalar = pa.scalar(item, type=pa.binary(len(item))) + expected = item[start:stop:step] + actual = pc.binary_slice(fsb_scalar, start, stop, step) + assert actual.type == pa.binary(len(expected)) + assert actual.as_py() == expected def test_split_pattern(): @@ -2255,6 +2263,19 @@ def test_extract_datetime_components(): _check_datetime_components(timestamps, timezone) +@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) +def test_iso_calendar_longer_array(unit): + # https://github.com/apache/arrow/issues/38655 + # ensure correct result for array length > 32 + arr = pa.array([datetime.datetime(2022, 1, 2, 9)]*50, pa.timestamp(unit)) + result = pc.iso_calendar(arr) + expected = pa.StructArray.from_arrays( + [[2021]*50, [52]*50, [7]*50], + names=['iso_year', 'iso_week', 'iso_day_of_week'] + ) + assert result.equals(expected) + + @pytest.mark.pandas @pytest.mark.skipif(sys.platform == "win32" and not util.windows_has_tzdata(), reason="Timezone database is not installed on Windows") @@ -2352,10 +2373,10 @@ def _check_temporal_rounding(ts, values, unit): unit_shorthand = { "nanosecond": "ns", "microsecond": "us", - "millisecond": "L", + "millisecond": "ms", "second": "s", "minute": "min", - "hour": "H", + "hour": "h", "day": "D" } greater_unit = { @@ -2363,7 +2384,7 @@ def _check_temporal_rounding(ts, values, unit): "microsecond": "ms", "millisecond": "s", "second": "min", - "minute": "H", + "minute": "h", "hour": "d", } ta = pa.array(ts) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index ae2146c0bdaee..a4838d63a6b0b 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -178,12 +178,14 @@ def multisourcefs(request): # simply split the dataframe into four chunks to construct a data source # from each chunk into its own directory - df_a, df_b, df_c, df_d = np.array_split(df, 4) + n = len(df) + df_a, df_b, df_c, df_d = [df.iloc[i:i+n//4] for i in range(0, n, n//4)] # create a directory containing a flat sequence of parquet files without # any partitioning involved mockfs.create_dir('plain') - for i, chunk in enumerate(np.array_split(df_a, 10)): + n = len(df_a) + for i, chunk in enumerate([df_a.iloc[i:i+n//10] for i in range(0, n, n//10)]): path = 'plain/chunk-{}.parquet'.format(i) with mockfs.open_output_stream(path) as out: pq.write_table(_table_from_pandas(chunk), out) @@ -699,6 +701,17 @@ def test_partitioning(): load_back_table = load_back.to_table() assert load_back_table.equals(table) + # test invalid partitioning input + with tempfile.TemporaryDirectory() as tempdir: + partitioning = ds.DirectoryPartitioning(partitioning_schema) + ds.write_dataset(table, tempdir, + format='ipc', partitioning=partitioning) + load_back = None + with pytest.raises(ValueError, + match="Expected Partitioning or PartitioningFactory"): + load_back = ds.dataset(tempdir, format='ipc', partitioning=int(0)) + assert load_back is None + def test_partitioning_pickling(pickle_module): schema = pa.schema([ diff --git a/python/pyarrow/tests/test_pandas.py b/python/pyarrow/tests/test_pandas.py index d15ee82d5dbf1..8106219057efe 100644 --- a/python/pyarrow/tests/test_pandas.py +++ b/python/pyarrow/tests/test_pandas.py @@ -113,6 +113,10 @@ def _check_pandas_roundtrip(df, expected=None, use_threads=False, if expected is None: expected = df + for col in expected.columns: + if expected[col].dtype == 'object': + expected[col] = expected[col].replace({np.nan: None}) + with warnings.catch_warnings(): warnings.filterwarnings( "ignore", "elementwise comparison failed", DeprecationWarning) @@ -152,6 +156,9 @@ def _check_array_roundtrip(values, expected=None, mask=None, expected = pd.Series(values).copy() expected[mask.copy()] = None + if expected.dtype == 'object': + expected = expected.replace({np.nan: None}) + tm.assert_series_equal(pd.Series(result), expected, check_names=False) @@ -478,7 +485,7 @@ def test_mixed_column_names(self): preserve_index=True) def test_binary_column_name(self): - if Version("2.0.0") <= Version(pd.__version__) < Version("2.3.0"): + if Version("2.0.0") <= Version(pd.__version__) < Version("3.0.0"): # TODO: regression in pandas, hopefully fixed in next version # https://issues.apache.org/jira/browse/ARROW-18394 # https://github.com/pandas-dev/pandas/issues/50127 @@ -3108,7 +3115,7 @@ def _fully_loaded_dataframe_example(): @pytest.mark.parametrize('columns', ([b'foo'], ['foo'])) def test_roundtrip_with_bytes_unicode(columns): - if Version("2.0.0") <= Version(pd.__version__) < Version("2.3.0"): + if Version("2.0.0") <= Version(pd.__version__) < Version("3.0.0"): # TODO: regression in pandas, hopefully fixed in next version # https://issues.apache.org/jira/browse/ARROW-18394 # https://github.com/pandas-dev/pandas/issues/50127 @@ -3491,7 +3498,7 @@ def test_table_from_pandas_schema_field_order_metadata(): # ensure that a different field order in specified schema doesn't # mangle metadata df = pd.DataFrame({ - "datetime": pd.date_range("2020-01-01T00:00:00Z", freq="H", periods=2), + "datetime": pd.date_range("2020-01-01T00:00:00Z", freq="h", periods=2), "float": np.random.randn(2) }) @@ -4181,8 +4188,6 @@ def _Int64Dtype__from_arrow__(self, array): def test_convert_to_extension_array(monkeypatch): - import pandas.core.internals as _int - # table converted from dataframe with extension types (so pandas_metadata # has this information) df = pd.DataFrame( @@ -4193,16 +4198,15 @@ def test_convert_to_extension_array(monkeypatch): # Int64Dtype is recognized -> convert to extension block by default # for a proper roundtrip result = table.to_pandas() - assert not isinstance(_get_mgr(result).blocks[0], _int.ExtensionBlock) assert _get_mgr(result).blocks[0].values.dtype == np.dtype("int64") - assert isinstance(_get_mgr(result).blocks[1], _int.ExtensionBlock) + assert _get_mgr(result).blocks[1].values.dtype == pd.Int64Dtype() tm.assert_frame_equal(result, df) # test with missing values df2 = pd.DataFrame({'a': pd.array([1, 2, None], dtype='Int64')}) table2 = pa.table(df2) result = table2.to_pandas() - assert isinstance(_get_mgr(result).blocks[0], _int.ExtensionBlock) + assert _get_mgr(result).blocks[0].values.dtype == pd.Int64Dtype() tm.assert_frame_equal(result, df2) # monkeypatch pandas Int64Dtype to *not* have the protocol method @@ -4215,7 +4219,7 @@ def test_convert_to_extension_array(monkeypatch): # Int64Dtype has no __from_arrow__ -> use normal conversion result = table.to_pandas() assert len(_get_mgr(result).blocks) == 1 - assert not isinstance(_get_mgr(result).blocks[0], _int.ExtensionBlock) + assert _get_mgr(result).blocks[0].values.dtype == np.dtype("int64") class MyCustomIntegerType(pa.ExtensionType): @@ -4233,8 +4237,6 @@ def to_pandas_dtype(self): def test_conversion_extensiontype_to_extensionarray(monkeypatch): # converting extension type to linked pandas ExtensionDtype/Array - import pandas.core.internals as _int - storage = pa.array([1, 2, 3, 4], pa.int64()) arr = pa.ExtensionArray.from_storage(MyCustomIntegerType(), storage) table = pa.table({'a': arr}) @@ -4242,12 +4244,12 @@ def test_conversion_extensiontype_to_extensionarray(monkeypatch): # extension type points to Int64Dtype, which knows how to create a # pandas ExtensionArray result = arr.to_pandas() - assert isinstance(_get_mgr(result).blocks[0], _int.ExtensionBlock) + assert _get_mgr(result).blocks[0].values.dtype == pd.Int64Dtype() expected = pd.Series([1, 2, 3, 4], dtype='Int64') tm.assert_series_equal(result, expected) result = table.to_pandas() - assert isinstance(_get_mgr(result).blocks[0], _int.ExtensionBlock) + assert _get_mgr(result).blocks[0].values.dtype == pd.Int64Dtype() expected = pd.DataFrame({'a': pd.array([1, 2, 3, 4], dtype='Int64')}) tm.assert_frame_equal(result, expected) @@ -4261,7 +4263,7 @@ def test_conversion_extensiontype_to_extensionarray(monkeypatch): pd.core.arrays.integer.NumericDtype, "__from_arrow__") result = arr.to_pandas() - assert not isinstance(_get_mgr(result).blocks[0], _int.ExtensionBlock) + assert _get_mgr(result).blocks[0].values.dtype == np.dtype("int64") expected = pd.Series([1, 2, 3, 4]) tm.assert_series_equal(result, expected) @@ -4312,10 +4314,14 @@ def test_array_to_pandas(): def test_roundtrip_empty_table_with_extension_dtype_index(): df = pd.DataFrame(index=pd.interval_range(start=0, end=3)) table = pa.table(df) - table.to_pandas().index == pd.Index([{'left': 0, 'right': 1}, - {'left': 1, 'right': 2}, - {'left': 2, 'right': 3}], - dtype='object') + if Version(pd.__version__) > Version("1.0"): + tm.assert_index_equal(table.to_pandas().index, df.index) + else: + tm.assert_index_equal(table.to_pandas().index, + pd.Index([{'left': 0, 'right': 1}, + {'left': 1, 'right': 2}, + {'left': 2, 'right': 3}], + dtype='object')) @pytest.mark.parametrize("index", ["a", ["a", "b"]]) diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 912ee39f7d712..b6dc53d633543 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -5140,12 +5140,8 @@ def from_numpy_dtype(object dtype): >>> pa.from_numpy_dtype(np.str_) DataType(string) """ - cdef shared_ptr[CDataType] c_type dtype = np.dtype(dtype) - with nogil: - check_status(NumPyDtypeToArrow(dtype, &c_type)) - - return pyarrow_wrap_data_type(c_type) + return pyarrow_wrap_data_type(GetResultValue(NumPyDtypeToArrow(dtype))) def is_boolean_value(object obj): diff --git a/python/requirements-test.txt b/python/requirements-test.txt index 9f07e5c57bd09..b3ba5d852b968 100644 --- a/python/requirements-test.txt +++ b/python/requirements-test.txt @@ -1,6 +1,6 @@ cffi hypothesis pandas -pytest +pytest<8 pytest-lazy-fixture pytz diff --git a/python/requirements-wheel-test.txt b/python/requirements-wheel-test.txt index 516ec0fccc2e9..c74a8ca6908a7 100644 --- a/python/requirements-wheel-test.txt +++ b/python/requirements-wheel-test.txt @@ -1,7 +1,7 @@ cffi cython hypothesis -pytest +pytest<8 pytest-lazy-fixture pytz tzdata; sys_platform == 'win32' diff --git a/python/setup.py b/python/setup.py index b1c825d84d5a9..098d75a3186af 100755 --- a/python/setup.py +++ b/python/setup.py @@ -407,7 +407,7 @@ def get_outputs(self): # If the event of not running from a git clone (e.g. from a git archive # or a Python sdist), see if we can set the version number ourselves -default_version = '15.0.0-SNAPSHOT' +default_version = '16.0.0-SNAPSHOT' if (not os.path.exists('../.git') and not os.environ.get('SETUPTOOLS_SCM_PRETEND_VERSION')): os.environ['SETUPTOOLS_SCM_PRETEND_VERSION'] = \ diff --git a/r/DESCRIPTION b/r/DESCRIPTION index b290a75f932d5..21cc4dec902d2 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -1,6 +1,6 @@ Package: arrow Title: Integration to 'Apache' 'Arrow' -Version: 14.0.2.9000 +Version: 15.0.0.9000 Authors@R: c( person("Neal", "Richardson", email = "neal.p.richardson@gmail.com", role = c("aut")), person("Ian", "Cook", email = "ianmcook@gmail.com", role = c("aut")), @@ -27,7 +27,8 @@ URL: https://github.com/apache/arrow/, https://arrow.apache.org/docs/r/ BugReports: https://github.com/apache/arrow/issues Encoding: UTF-8 Language: en-US -SystemRequirements: C++17; for AWS S3 support on Linux, libcurl and openssl (optional) +SystemRequirements: C++17; for AWS S3 support on Linux, libcurl and openssl (optional); + cmake >= 3.16 (build-time only, and only for full source build) Biarch: true Imports: assertthat, diff --git a/r/NEWS.md b/r/NEWS.md index 1744e6e96e936..58c82c5128b82 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -17,7 +17,38 @@ under the License. --> -# arrow 14.0.2.9000 +# arrow 15.0.0.9000 + +# arrow 15.0.0 + +## New features + +* Bindings for `base::prod` have been added so you can now use it in your dplyr + pipelines (i.e., `tbl |> summarize(prod(col))`) without having to pull the + data into R (@m-muecke, #38601). +* Calling `dimnames` or `colnames` on `Dataset` objects now returns a useful + result rather than just `NULL` (#38377). +* The `code()` method on Schema objects now takes an optional `namespace` + argument which, when `TRUE`, prefixes names with `arrow::` which makes + the output more portable (@orgadish, #38144). + +## Minor improvements and fixes + +* Don't download cmake when ARROW_OFFLINE_BUILD=true and update `SystemRequirements` (#39602). +* Fallback to source build gracefully if binary download fails (#39587). +* An error is now thrown instead of warning and pulling the data into R when any + of `sub`, `gsub`, `stringr::str_replace`, `stringr::str_replace_all` are + passed a length > 1 vector of values in `pattern` (@abfleishman, #39219). +* Missing documentation was added to `?open_dataset` documenting how to use the + ND-JSON support added in arrow 13.0.0 (@Divyansh200102, #38258). +* To make debugging problems easier when using arrow with AWS S3 + (e.g., `s3_bucket`, `S3FileSystem`), the debug log level for S3 can be set + with the `AWS_S3_LOG_LEVEL` environment variable. + See `?S3FileSystem` for more information. (#38267) +* Using arrow with duckdb (i.e., `to_duckdb()`) no longer results in warnings + when quitting your R session. (#38495) +* A large number of minor spelling mistakes were fixed (@jsoref, #38929, #38257) +* The developer documentation has been updated to match changes made in recent releases (#38220) # arrow 14.0.2 diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R index 492729df8c12a..2042f800142b7 100644 --- a/r/R/dplyr-funcs-doc.R +++ b/r/R/dplyr-funcs-doc.R @@ -21,7 +21,7 @@ #' #' The `arrow` package contains methods for 37 `dplyr` table functions, many of #' which are "verbs" that do transformations to one or more tables. -#' The package also has mappings of 211 R functions to the corresponding +#' The package also has mappings of 212 R functions to the corresponding #' functions in the Arrow compute library. These allow you to write code inside #' of `dplyr` methods that call R functions, including many in packages like #' `stringr` and `lubridate`, and they will get translated to Arrow and run @@ -83,7 +83,7 @@ #' Functions can be called either as `pkg::fun()` or just `fun()`, i.e. both #' `str_sub()` and `stringr::str_sub()` work. #' -#' In addition to these functions, you can call any of Arrow's 254 compute +#' In addition to these functions, you can call any of Arrow's 262 compute #' functions directly. Arrow has many functions that don't map to an existing R #' function. In other cases where there is an R function mapping, you can still #' call the Arrow function directly if you don't want the adaptations that the R diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml index 84111e599c457..e9513b8c16b26 100644 --- a/r/_pkgdown.yml +++ b/r/_pkgdown.yml @@ -76,7 +76,7 @@ home: [C GLib](https://arrow.apache.org/docs/c_glib)
[C++](https://arrow.apache.org/docs/cpp)
[C#](https://github.com/apache/arrow/blob/main/csharp/README.md)
- [Go](https://pkg.go.dev/github.com/apache/arrow/go)
+ [Go](https://pkg.go.dev/github.com/apache/arrow/go/v16)
[Java](https://arrow.apache.org/docs/java)
[JavaScript](https://arrow.apache.org/docs/js)
[Julia](https://github.com/apache/arrow-julia/blob/main/README.md)
diff --git a/r/configure b/r/configure index 029fc004dfc4c..0882ee6719c4b 100755 --- a/r/configure +++ b/r/configure @@ -73,7 +73,7 @@ FORCE_BUNDLED_BUILD=`echo $FORCE_BUNDLED_BUILD | tr '[:upper:]' '[:lower:]'` ARROW_USE_PKG_CONFIG=`echo $ARROW_USE_PKG_CONFIG | tr '[:upper:]' '[:lower:]'` # Just used in testing: whether or not it is ok to download dependencies (in the # bundled build) -TEST_OFFLINE_BUILD=`echo $TEST_OFFLINE_BUILD | tr '[:upper:]' '[:lower:]'` +ARROW_OFFLINE_BUILD=`echo $ARROW_OFFLINE_BUILD | tr '[:upper:]' '[:lower:]'` VERSION=`grep '^Version' DESCRIPTION | sed s/Version:\ //` UNAME=`uname -s` diff --git a/r/inst/NOTICE.txt b/r/inst/NOTICE.txt index a609791374c28..2089c6fb20358 100644 --- a/r/inst/NOTICE.txt +++ b/r/inst/NOTICE.txt @@ -1,5 +1,5 @@ Apache Arrow -Copyright 2016-2019 The Apache Software Foundation +Copyright 2016-2024 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/r/inst/build_arrow_static.sh b/r/inst/build_arrow_static.sh index 9c9fadea4757b..d28cbcb08fbec 100755 --- a/r/inst/build_arrow_static.sh +++ b/r/inst/build_arrow_static.sh @@ -74,6 +74,8 @@ ${CMAKE} -DARROW_BOOST_USE_SHARED=OFF \ -DARROW_DATASET=${ARROW_DATASET:-ON} \ -DARROW_DEPENDENCY_SOURCE=${ARROW_DEPENDENCY_SOURCE:-AUTO} \ -DAWSSDK_SOURCE=${AWSSDK_SOURCE:-} \ + -DBoost_SOURCE=${Boost_SOURCE:-} \ + -Dlz4_SOURCE=${lz4_SOURCE:-} \ -DARROW_FILESYSTEM=ON \ -DARROW_GCS=${ARROW_GCS:-$ARROW_DEFAULT_PARAM} \ -DARROW_JEMALLOC=${ARROW_JEMALLOC:-$ARROW_DEFAULT_PARAM} \ diff --git a/r/man/acero.Rd b/r/man/acero.Rd index 12afdc23138ac..365795d9fc65c 100644 --- a/r/man/acero.Rd +++ b/r/man/acero.Rd @@ -9,7 +9,7 @@ \description{ The \code{arrow} package contains methods for 37 \code{dplyr} table functions, many of which are "verbs" that do transformations to one or more tables. -The package also has mappings of 211 R functions to the corresponding +The package also has mappings of 212 R functions to the corresponding functions in the Arrow compute library. These allow you to write code inside of \code{dplyr} methods that call R functions, including many in packages like \code{stringr} and \code{lubridate}, and they will get translated to Arrow and run @@ -71,7 +71,7 @@ can assume that the function works in Acero just as it does in R. Functions can be called either as \code{pkg::fun()} or just \code{fun()}, i.e. both \code{str_sub()} and \code{stringr::str_sub()} work. -In addition to these functions, you can call any of Arrow's 254 compute +In addition to these functions, you can call any of Arrow's 262 compute functions directly. Arrow has many functions that don't map to an existing R function. In other cases where there is an R function mapping, you can still call the Arrow function directly if you don't want the adaptations that the R diff --git a/r/pkgdown/assets/versions.json b/r/pkgdown/assets/versions.json index 35a1ef3b5ecb3..0b7f9884f9b6f 100644 --- a/r/pkgdown/assets/versions.json +++ b/r/pkgdown/assets/versions.json @@ -1,12 +1,16 @@ [ { - "name": "14.0.2.9000 (dev)", + "name": "15.0.0.9000 (dev)", "version": "dev/" }, { - "name": "14.0.2 (release)", + "name": "15.0.0 (release)", "version": "" }, + { + "name": "14.0.2", + "version": "14.0/" + }, { "name": "13.0.0.1", "version": "13.0/" diff --git a/r/tools/nixlibs.R b/r/tools/nixlibs.R index fe8de284b16b0..17c6ab0a8078b 100644 --- a/r/tools/nixlibs.R +++ b/r/tools/nixlibs.R @@ -79,6 +79,10 @@ find_latest_nightly <- function(description_version, } try_download <- function(from_url, to_file, hush = quietly) { + if (!download_ok) { + # Don't even try + return(FALSE) + } # We download some fairly large files, so ensure the timeout is set appropriately. # This assumes a static library size of 100 MB (generous) and a download speed # of .3 MB/s (slow). This is to anticipate slower user connections or load on @@ -96,18 +100,7 @@ try_download <- function(from_url, to_file, hush = quietly) { !inherits(status, "try-error") && status == 0 } -download_binary <- function(lib) { - libfile <- paste0("arrow-", VERSION, ".zip") - binary_url <- paste0(arrow_repo, "bin/", lib, "/arrow-", VERSION, ".zip") - if (try_download(binary_url, libfile)) { - lg("Successfully retrieved libarrow (%s)", lib) - } else { - lg( - "Downloading libarrow failed for version %s (%s)\n at %s", - VERSION, lib, binary_url - ) - libfile <- NULL - } +validate_checksum <- function(binary_url, libfile, hush = quietly) { # Explicitly setting the env var to "false" will skip checksum validation # e.g. in case the included checksums are stale. skip_checksum <- env_is("ARROW_R_ENFORCE_CHECKSUM", "false") @@ -116,33 +109,66 @@ download_binary <- function(lib) { # validate binary checksum for CRAN release only if (!skip_checksum && dir.exists(checksum_path) && is_release || enforce_checksum) { + # Munge the path to the correct sha file which we include during the + # release process checksum_file <- sub(".+/bin/(.+\\.zip)", "\\1\\.sha512", binary_url) checksum_file <- file.path(checksum_path, checksum_file) - checksum_cmd <- "shasum" - checksum_args <- c("--status", "-a", "512", "-c", checksum_file) - - # shasum is not available on all linux versions - status_shasum <- try( - suppressWarnings( - system2("shasum", args = c("--help"), stdout = FALSE, stderr = FALSE) - ), - silent = TRUE - ) - if (inherits(status_shasum, "try-error") || is.integer(status_shasum) && status_shasum != 0) { - checksum_cmd <- "sha512sum" - checksum_args <- c("--status", "-c", checksum_file) + # Try `shasum`, and if that doesn't work, fall back to `sha512sum` if not found + # system2 doesn't generate an R error, so we can't use a tryCatch to + # move from shasum to sha512sum. + # The warnings from system2 if it fails pop up later in the log and thus are + # more confusing than they are helpful (so we suppress them) + checksum_ok <- suppressWarnings(system2( + "shasum", + args = c("--status", "-a", "512", "-c", checksum_file), + stdout = ifelse(quietly, FALSE, ""), + stderr = ifelse(quietly, FALSE, "") + )) == 0 + + if (!checksum_ok) { + checksum_ok <- suppressWarnings(system2( + "sha512sum", + args = c("--status", "-c", checksum_file), + stdout = ifelse(quietly, FALSE, ""), + stderr = ifelse(quietly, FALSE, "") + )) == 0 } - checksum_ok <- system2(checksum_cmd, args = checksum_args) - - if (checksum_ok != 0) { - lg("Checksum validation failed for libarrow: %s/%s", lib, libfile) - unlink(libfile) - libfile <- NULL + if (checksum_ok) { + lg("Checksum validated successfully for libarrow") } else { - lg("Checksum validated successfully for libarrow: %s/%s", lib, libfile) + lg("Checksum validation failed for libarrow") + unlink(libfile) } + } else { + checksum_ok <- TRUE + } + + # Return whether the checksum was successful + checksum_ok +} + +download_binary <- function(lib) { + libfile <- paste0("arrow-", VERSION, ".zip") + binary_url <- paste0(arrow_repo, "bin/", lib, "/arrow-", VERSION, ".zip") + if (try_download(binary_url, libfile) && validate_checksum(binary_url, libfile)) { + lg("Successfully retrieved libarrow (%s)", lib) + } else { + # If the download or checksum fail, we will set libfile to NULL this will + # normally result in a source build after this. + # TODO: should we condense these together and only call them when verbose? + lg( + "Unable to retrieve libarrow for version %s (%s)", + VERSION, lib + ) + if (!quietly) { + lg( + "Attempted to download the libarrow binary from: %s", + binary_url + ) + } + libfile <- NULL } libfile @@ -464,7 +490,7 @@ env_vars_as_string <- function(env_var_list) { stopifnot( length(env_var_list) == length(names(env_var_list)), all(grepl("^[^0-9]", names(env_var_list))), - all(grepl("^[A-Z0-9_]+$", names(env_var_list))), + all(grepl("^[a-zA-Z0-9_]+$", names(env_var_list))), !any(grepl("'", env_var_list, fixed = TRUE)) ) env_var_string <- paste0(names(env_var_list), "='", env_var_list, "'", collapse = " ") @@ -496,7 +522,7 @@ build_libarrow <- function(src_dir, dst_dir) { Sys.setenv(MAKEFLAGS = makeflags) } if (!quietly) { - lg("Building with MAKEFLAGS=", makeflags) + lg("Building with MAKEFLAGS=%s", makeflags) } # Check for libarrow build dependencies: # * cmake @@ -539,6 +565,19 @@ build_libarrow <- function(src_dir, dst_dir) { env_var_list <- c(env_var_list, ARROW_DEPENDENCY_SOURCE = "BUNDLED") } + # On macOS, if not otherwise set, let's override Boost_SOURCE to be bundled + # Necessary due to #39590 for CRAN + if (on_macos) { + # Using lowercase (e.g. Boost_SOURCE) to match the cmake args we use already. + deps_to_bundle <- c("Boost", "lz4") + for (dep_to_bundle in deps_to_bundle) { + env_var <- paste0(dep_to_bundle, "_SOURCE") + if (Sys.getenv(env_var) == "") { + env_var_list <- c(env_var_list, setNames("BUNDLED", env_var)) + } + } + } + env_var_list <- with_cloud_support(env_var_list) # turn_off_all_optional_features() needs to happen after @@ -595,7 +634,6 @@ ensure_cmake <- function(cmake_minimum_required = "3.16") { if (is.null(cmake)) { # If not found, download it - lg("cmake", .indent = "****") CMAKE_VERSION <- Sys.getenv("CMAKE_VERSION", "3.26.4") if (on_macos) { postfix <- "-macos-universal.tar.gz" @@ -642,10 +680,7 @@ ensure_cmake <- function(cmake_minimum_required = "3.16") { bin_dir, "/cmake" ) - } else { - # Show which one we found - # Full source builds will always show "cmake" in the logs - lg("cmake: %s", cmake, .indent = "****") + lg("cmake %s", CMAKE_VERSION, .indent = "****") } cmake } @@ -653,6 +688,8 @@ ensure_cmake <- function(cmake_minimum_required = "3.16") { find_cmake <- function(paths = c( Sys.getenv("CMAKE"), Sys.which("cmake"), + # CRAN has it here, not on PATH + if (on_macos) "/Applications/CMake.app/Contents/bin/cmake", Sys.which("cmake3") ), version_required = "3.16") { @@ -660,10 +697,25 @@ find_cmake <- function(paths = c( # version_required should be a string or packageVersion; numeric version # can be misleading (e.g. 3.10 is actually 3.1) for (path in paths) { - if (nzchar(path) && cmake_version(path) >= version_required) { + if (nzchar(path) && file.exists(path)) { # Sys.which() returns a named vector, but that plays badly with c() later names(path) <- NULL - return(path) + found_version <- cmake_version(path) + if (found_version >= version_required) { + # Show which one we found + lg("cmake %s: %s", found_version, path, .indent = "****") + # Stop searching here + return(path) + } else { + # Keep trying + lg("Not using cmake found at %s", path, .indent = "****") + if (found_version > 0) { + lg("Version >= %s required; found %s", version_required, found_version, .indent = "*****") + } else { + # If cmake_version() couldn't determine version, it returns 0 + lg("Could not determine version; >= %s required", version_required, .indent = "*****") + } + } } } # If none found, return NULL @@ -854,27 +906,13 @@ on_windows <- tolower(Sys.info()[["sysname"]]) == "windows" # For local debugging, set ARROW_R_DEV=TRUE to make this script print more quietly <- !env_is("ARROW_R_DEV", "true") -not_cran <- env_is("NOT_CRAN", "true") - -if (is_release) { - VERSION <- VERSION[1, 1:3] - arrow_repo <- paste0(getOption("arrow.repo", sprintf("https://apache.jfrog.io/artifactory/arrow/r/%s", VERSION)), "/libarrow/") -} else { - not_cran <- TRUE - arrow_repo <- paste0(getOption("arrow.dev_repo", "https://nightlies.apache.org/arrow/r"), "/libarrow/") -} - -if (!is_release && !test_mode) { - VERSION <- find_latest_nightly(VERSION) -} - # To collect dirs to rm on exit, use cleanup() to add dirs # we reset it to avoid errors on reruns in the same session. options(.arrow.cleanup = character()) on.exit(unlink(getOption(".arrow.cleanup"), recursive = TRUE), add = TRUE) -# enable full featured builds for macOS in case of CRAN source builds. -if (not_cran || on_macos) { +not_cran <- env_is("NOT_CRAN", "true") +if (not_cran) { # Set more eager defaults if (env_is("LIBARROW_BINARY", "")) { Sys.setenv(LIBARROW_BINARY = "true") @@ -889,13 +927,38 @@ if (not_cran || on_macos) { # and don't fall back to a full source build build_ok <- !env_is("LIBARROW_BUILD", "false") -# Check if we're authorized to download (not asked an offline build). -# (Note that cmake will still be downloaded if necessary -# https://arrow.apache.org/docs/developers/cpp/building.html#offline-builds) -download_ok <- !test_mode && !env_is("TEST_OFFLINE_BUILD", "true") +# Check if we're authorized to download +download_ok <- !test_mode && !env_is("ARROW_OFFLINE_BUILD", "true") +if (!download_ok) { + lg("Dependency downloading disabled. Unset ARROW_OFFLINE_BUILD to enable", .indent = "***") +} +# If not forbidden from downloading, check if we are offline and turn off downloading. +# The default libarrow source build will download its source dependencies and fail +# if they can't be retrieved. +# But, don't do this if the user has requested a binary or a non-minimal build: +# we should error rather than silently succeeding with a minimal build. +if (download_ok && Sys.getenv("LIBARROW_BINARY") %in% c("false", "") && !env_is("LIBARROW_MINIMAL", "false")) { + download_ok <- try_download("https://apache.jfrog.io/artifactory/arrow/r/", tempfile()) + if (!download_ok) { + lg("Network connection not available", .indent = "***") + } +} download_libarrow_ok <- download_ok && !env_is("LIBARROW_DOWNLOAD", "false") +# Set binary repos +if (is_release) { + VERSION <- VERSION[1, 1:3] + arrow_repo <- paste0(getOption("arrow.repo", sprintf("https://apache.jfrog.io/artifactory/arrow/r/%s", VERSION)), "/libarrow/") +} else { + arrow_repo <- paste0(getOption("arrow.dev_repo", "https://nightlies.apache.org/arrow/r"), "/libarrow/") +} + +# If we're on a dev version, look for the most recent libarrow binary version +if (download_libarrow_ok && !is_release && !test_mode) { + VERSION <- find_latest_nightly(VERSION) +} + # This "tools/thirdparty_dependencies" path, within the tar file, might exist if # create_package_with_all_dependencies() was run, or if someone has created it # manually before running make build. diff --git a/r/vignettes/developers/setup.Rmd b/r/vignettes/developers/setup.Rmd index 8e7cff7410473..4c1eab1e6972f 100644 --- a/r/vignettes/developers/setup.Rmd +++ b/r/vignettes/developers/setup.Rmd @@ -280,12 +280,11 @@ withr::with_makevars(list(CPPFLAGS = "", LDFLAGS = ""), remotes::install_github( * See the user-facing [article on installation](../install.html) for a large number of environment variables that determine how the build works and what features get built. -* `TEST_OFFLINE_BUILD`: When set to `true`, the build script will not download - prebuilt the C++ library binary. +* `ARROW_OFFLINE_BUILD`: When set to `true`, the build script will not download + prebuilt the C++ library binary or, if needed, `cmake`. It will turn off any features that require a download, unless they're available in `ARROW_THIRDPARTY_DEPENDENCY_DIR` or the `tools/thirdparty_download/` subfolder. `create_package_with_all_dependencies()` creates that subfolder. - Regardless of this flag's value, `cmake` will be downloaded if it's unavailable. # Troubleshooting diff --git a/ruby/red-arrow-cuda/lib/arrow-cuda/version.rb b/ruby/red-arrow-cuda/lib/arrow-cuda/version.rb index 8551b647cb86f..816751fcba8ff 100644 --- a/ruby/red-arrow-cuda/lib/arrow-cuda/version.rb +++ b/ruby/red-arrow-cuda/lib/arrow-cuda/version.rb @@ -16,7 +16,7 @@ # under the License. module ArrowCUDA - VERSION = "15.0.0-SNAPSHOT" + VERSION = "16.0.0-SNAPSHOT" module Version numbers, TAG = VERSION.split("-") diff --git a/ruby/red-arrow-dataset/lib/arrow-dataset/version.rb b/ruby/red-arrow-dataset/lib/arrow-dataset/version.rb index acfdd675687be..e391493e15974 100644 --- a/ruby/red-arrow-dataset/lib/arrow-dataset/version.rb +++ b/ruby/red-arrow-dataset/lib/arrow-dataset/version.rb @@ -16,7 +16,7 @@ # under the License. module ArrowDataset - VERSION = "15.0.0-SNAPSHOT" + VERSION = "16.0.0-SNAPSHOT" module Version numbers, TAG = VERSION.split("-") diff --git a/ruby/red-arrow-flight-sql/lib/arrow-flight-sql/version.rb b/ruby/red-arrow-flight-sql/lib/arrow-flight-sql/version.rb index 3354678e30032..d90751be80cb0 100644 --- a/ruby/red-arrow-flight-sql/lib/arrow-flight-sql/version.rb +++ b/ruby/red-arrow-flight-sql/lib/arrow-flight-sql/version.rb @@ -16,7 +16,7 @@ # under the License. module ArrowFlightSQL - VERSION = "15.0.0-SNAPSHOT" + VERSION = "16.0.0-SNAPSHOT" module Version numbers, TAG = VERSION.split("-") diff --git a/ruby/red-arrow-flight/lib/arrow-flight/version.rb b/ruby/red-arrow-flight/lib/arrow-flight/version.rb index f2141a68432e5..6c2d676809f8f 100644 --- a/ruby/red-arrow-flight/lib/arrow-flight/version.rb +++ b/ruby/red-arrow-flight/lib/arrow-flight/version.rb @@ -16,7 +16,7 @@ # under the License. module ArrowFlight - VERSION = "15.0.0-SNAPSHOT" + VERSION = "16.0.0-SNAPSHOT" module Version numbers, TAG = VERSION.split("-") diff --git a/ruby/red-arrow/lib/arrow/version.rb b/ruby/red-arrow/lib/arrow/version.rb index 235a7df75d672..2b1c14e389116 100644 --- a/ruby/red-arrow/lib/arrow/version.rb +++ b/ruby/red-arrow/lib/arrow/version.rb @@ -16,7 +16,7 @@ # under the License. module Arrow - VERSION = "15.0.0-SNAPSHOT" + VERSION = "16.0.0-SNAPSHOT" module Version numbers, TAG = VERSION.split("-") diff --git a/ruby/red-gandiva/lib/gandiva/version.rb b/ruby/red-gandiva/lib/gandiva/version.rb index 6a1835f0e50e8..0a20a520194b0 100644 --- a/ruby/red-gandiva/lib/gandiva/version.rb +++ b/ruby/red-gandiva/lib/gandiva/version.rb @@ -16,7 +16,7 @@ # under the License. module Gandiva - VERSION = "15.0.0-SNAPSHOT" + VERSION = "16.0.0-SNAPSHOT" module Version numbers, TAG = VERSION.split("-") diff --git a/ruby/red-parquet/lib/parquet/version.rb b/ruby/red-parquet/lib/parquet/version.rb index c5a945c4e4297..cd61c772a3285 100644 --- a/ruby/red-parquet/lib/parquet/version.rb +++ b/ruby/red-parquet/lib/parquet/version.rb @@ -16,7 +16,7 @@ # under the License. module Parquet - VERSION = "15.0.0-SNAPSHOT" + VERSION = "16.0.0-SNAPSHOT" module Version numbers, TAG = VERSION.split("-") diff --git a/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift b/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift index 32728dc7eeaa4..b78f0ccd74997 100644 --- a/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift +++ b/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift @@ -36,12 +36,12 @@ public class ArrowArrayBuilder> public func finish() throws -> ArrowArray { let buffers = self.bufferBuilder.finish() - let arrowData = try ArrowData(self.type, buffers: buffers, nullCount: self.nullCount, stride: self.getStride()) + let arrowData = try ArrowData(self.type, buffers: buffers, nullCount: self.nullCount) return U(arrowData) } public func getStride() -> Int { - MemoryLayout.stride + return self.type.getStride() } } @@ -73,20 +73,12 @@ public class Date32ArrayBuilder: ArrowArrayBuilder Int { - MemoryLayout.stride - } } public class Date64ArrayBuilder: ArrowArrayBuilder { fileprivate convenience init() throws { try self.init(ArrowType(ArrowType.ArrowDate64)) } - - public override func getStride() -> Int { - MemoryLayout.stride - } } public class Time32ArrayBuilder: ArrowArrayBuilder, Time32Array> { diff --git a/swift/Arrow/Sources/Arrow/ArrowData.swift b/swift/Arrow/Sources/Arrow/ArrowData.swift index 60281a8d24133..93986b5955bd8 100644 --- a/swift/Arrow/Sources/Arrow/ArrowData.swift +++ b/swift/Arrow/Sources/Arrow/ArrowData.swift @@ -24,7 +24,7 @@ public class ArrowData { public let length: UInt public let stride: Int - init(_ arrowType: ArrowType, buffers: [ArrowBuffer], nullCount: UInt, stride: Int) throws { + init(_ arrowType: ArrowType, buffers: [ArrowBuffer], nullCount: UInt) throws { let infoType = arrowType.info switch infoType { case let .primitiveInfo(typeId): @@ -45,7 +45,7 @@ public class ArrowData { self.buffers = buffers self.nullCount = nullCount self.length = buffers[1].length - self.stride = stride + self.stride = arrowType.getStride() } public func isNull(_ at: UInt) -> Bool { diff --git a/swift/Arrow/Sources/Arrow/ArrowReader.swift b/swift/Arrow/Sources/Arrow/ArrowReader.swift index d9dc1bdb470e6..237f22dc979e3 100644 --- a/swift/Arrow/Sources/Arrow/ArrowReader.swift +++ b/swift/Arrow/Sources/Arrow/ArrowReader.swift @@ -57,15 +57,17 @@ public class ArrowReader { private func loadPrimitiveData(_ loadInfo: DataLoadInfo) -> Result { do { let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)! + let nullLength = UInt(ceil(Double(node.length) / 8)) try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex) let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)! let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData, - length: UInt(node.nullCount), messageOffset: loadInfo.messageOffset) + length: nullLength, messageOffset: loadInfo.messageOffset) try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1) let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)! let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData, length: UInt(node.length), messageOffset: loadInfo.messageOffset) - return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowValueBuffer]) + return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowValueBuffer], + nullCount: UInt(node.nullCount)) } catch let error as ArrowError { return .failure(error) } catch { @@ -76,10 +78,11 @@ public class ArrowReader { private func loadVariableData(_ loadInfo: DataLoadInfo) -> Result { let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)! do { + let nullLength = UInt(ceil(Double(node.length) / 8)) try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex) let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)! let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData, - length: UInt(node.nullCount), messageOffset: loadInfo.messageOffset) + length: nullLength, messageOffset: loadInfo.messageOffset) try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1) let offsetBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)! let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: loadInfo.fileData, @@ -88,7 +91,8 @@ public class ArrowReader { let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 2)! let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData, length: UInt(node.length), messageOffset: loadInfo.messageOffset) - return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer]) + return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer], + nullCount: UInt(node.nullCount)) } catch let error as ArrowError { return .failure(error) } catch { diff --git a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift index fa52160478f24..fb4a13b766f10 100644 --- a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift +++ b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift @@ -18,10 +18,11 @@ import FlatBuffers import Foundation -private func makeBinaryHolder(_ buffers: [ArrowBuffer]) -> Result { +private func makeBinaryHolder(_ buffers: [ArrowBuffer], + nullCount: UInt) -> Result { do { - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBinary), buffers: buffers, - nullCount: buffers[0].length, stride: MemoryLayout.stride) + let arrowType = ArrowType(ArrowType.ArrowBinary) + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(BinaryArray(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -30,10 +31,11 @@ private func makeBinaryHolder(_ buffers: [ArrowBuffer]) -> Result Result { +private func makeStringHolder(_ buffers: [ArrowBuffer], + nullCount: UInt) -> Result { do { - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers, - nullCount: buffers[0].length, stride: MemoryLayout.stride) + let arrowType = ArrowType(ArrowType.ArrowString) + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(StringArray(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -42,31 +44,17 @@ private func makeStringHolder(_ buffers: [ArrowBuffer]) -> Result Result { - switch floatType.precision { - case .single: - return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat) - case .double: - return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble) - default: - return .failure(.unknownType("Float precision \(floatType.precision) currently not supported")) - } -} - -private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date, - buffers: [ArrowBuffer] +private func makeDateHolder(_ field: ArrowField, + buffers: [ArrowBuffer], + nullCount: UInt ) -> Result { do { - if dateType.unit == .day { - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers, - nullCount: buffers[0].length, stride: MemoryLayout.stride) + if field.type.id == .date32 { + let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(Date32Array(arrowData))) } - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers, - nullCount: buffers[0].length, stride: MemoryLayout.stride) + let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(Date64Array(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -75,21 +63,26 @@ private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date, } } -private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time, - buffers: [ArrowBuffer] +private func makeTimeHolder(_ field: ArrowField, + buffers: [ArrowBuffer], + nullCount: UInt ) -> Result { do { - if timeType.unit == .second || timeType.unit == .millisecond { - let arrowUnit: ArrowTime32Unit = timeType.unit == .second ? .seconds : .milliseconds - let arrowData = try ArrowData(ArrowTypeTime32(arrowUnit), buffers: buffers, - nullCount: buffers[0].length, stride: MemoryLayout.stride) - return .success(ArrowArrayHolder(FixedArray(arrowData))) + if field.type.id == .time32 { + if let arrowType = field.type as? ArrowTypeTime32 { + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) + return .success(ArrowArrayHolder(FixedArray(arrowData))) + } else { + return .failure(.invalid("Incorrect field type for time: \(field.type)")) + } } - let arrowUnit: ArrowTime64Unit = timeType.unit == .microsecond ? .microseconds : .nanoseconds - let arrowData = try ArrowData(ArrowTypeTime64(arrowUnit), buffers: buffers, - nullCount: buffers[0].length, stride: MemoryLayout.stride) - return .success(ArrowArrayHolder(FixedArray(arrowData))) + if let arrowType = field.type as? ArrowTypeTime64 { + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) + return .success(ArrowArrayHolder(FixedArray(arrowData))) + } else { + return .failure(.invalid("Incorrect field type for time: \(field.type)")) + } } catch let error as ArrowError { return .failure(error) } catch { @@ -97,10 +90,11 @@ private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time, } } -private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result { +private func makeBoolHolder(_ buffers: [ArrowBuffer], + nullCount: UInt) -> Result { do { - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBool), buffers: buffers, - nullCount: buffers[0].length, stride: MemoryLayout.stride) + let arrowType = ArrowType(ArrowType.ArrowBool) + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(BoolArray(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -110,12 +104,11 @@ private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result( - _: T.Type, buffers: [ArrowBuffer], - arrowType: ArrowType.Info + _: T.Type, field: ArrowField, buffers: [ArrowBuffer], + nullCount: UInt ) -> Result { do { - let arrowData = try ArrowData(ArrowType(arrowType), buffers: buffers, - nullCount: buffers[0].length, stride: MemoryLayout.stride) + let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(FixedArray(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -124,58 +117,56 @@ private func makeFixedHolder( } } -func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity +func makeArrayHolder( _ field: org_apache_arrow_flatbuf_Field, - buffers: [ArrowBuffer] + buffers: [ArrowBuffer], + nullCount: UInt ) -> Result { - let type = field.typeType - switch type { - case .int: - let intType = field.type(type: org_apache_arrow_flatbuf_Int.self)! - let bitWidth = intType.bitWidth - if bitWidth == 8 { - if intType.isSigned { - return makeFixedHolder(Int8.self, buffers: buffers, arrowType: ArrowType.ArrowInt8) - } else { - return makeFixedHolder(UInt8.self, buffers: buffers, arrowType: ArrowType.ArrowUInt8) - } - } else if bitWidth == 16 { - if intType.isSigned { - return makeFixedHolder(Int16.self, buffers: buffers, arrowType: ArrowType.ArrowInt16) - } else { - return makeFixedHolder(UInt16.self, buffers: buffers, arrowType: ArrowType.ArrowUInt16) - } - } else if bitWidth == 32 { - if intType.isSigned { - return makeFixedHolder(Int32.self, buffers: buffers, arrowType: ArrowType.ArrowInt32) - } else { - return makeFixedHolder(UInt32.self, buffers: buffers, arrowType: ArrowType.ArrowUInt32) - } - } else if bitWidth == 64 { - if intType.isSigned { - return makeFixedHolder(Int64.self, buffers: buffers, arrowType: ArrowType.ArrowInt64) - } else { - return makeFixedHolder(UInt64.self, buffers: buffers, arrowType: ArrowType.ArrowUInt64) - } - } - return .failure(.unknownType("Int width \(bitWidth) currently not supported")) - case .bool: - return makeBoolHolder(buffers) - case .floatingpoint: - let floatType = field.type(type: org_apache_arrow_flatbuf_FloatingPoint.self)! - return makeFloatHolder(floatType, buffers: buffers) - case .utf8: - return makeStringHolder(buffers) + let arrowField = fromProto(field: field) + return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount) +} + +func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity + _ field: ArrowField, + buffers: [ArrowBuffer], + nullCount: UInt +) -> Result { + let typeId = field.type.id + switch typeId { + case .int8: + return makeFixedHolder(Int8.self, field: field, buffers: buffers, nullCount: nullCount) + case .uint8: + return makeFixedHolder(UInt8.self, field: field, buffers: buffers, nullCount: nullCount) + case .int16: + return makeFixedHolder(Int16.self, field: field, buffers: buffers, nullCount: nullCount) + case .uint16: + return makeFixedHolder(UInt16.self, field: field, buffers: buffers, nullCount: nullCount) + case .int32: + return makeFixedHolder(Int32.self, field: field, buffers: buffers, nullCount: nullCount) + case .uint32: + return makeFixedHolder(UInt32.self, field: field, buffers: buffers, nullCount: nullCount) + case .int64: + return makeFixedHolder(Int64.self, field: field, buffers: buffers, nullCount: nullCount) + case .uint64: + return makeFixedHolder(UInt64.self, field: field, buffers: buffers, nullCount: nullCount) + case .boolean: + return makeBoolHolder(buffers, nullCount: nullCount) + case .float: + return makeFixedHolder(Float.self, field: field, buffers: buffers, nullCount: nullCount) + case .double: + return makeFixedHolder(Double.self, field: field, buffers: buffers, nullCount: nullCount) + case .string: + return makeStringHolder(buffers, nullCount: nullCount) case .binary: - return makeBinaryHolder(buffers) - case .date: - let dateType = field.type(type: org_apache_arrow_flatbuf_Date.self)! - return makeDateHolder(dateType, buffers: buffers) - case .time: - let timeType = field.type(type: org_apache_arrow_flatbuf_Time.self)! - return makeTimeHolder(timeType, buffers: buffers) + return makeBinaryHolder(buffers, nullCount: nullCount) + case .date32: + return makeDateHolder(field, buffers: buffers, nullCount: nullCount) + case .time32: + return makeTimeHolder(field, buffers: buffers, nullCount: nullCount) + case .time64: + return makeTimeHolder(field, buffers: buffers, nullCount: nullCount) default: - return .failure(.unknownType("Type \(type) currently not supported")) + return .failure(.unknownType("Type \(typeId) currently not supported")) } } diff --git a/swift/Arrow/Sources/Arrow/ArrowType.swift b/swift/Arrow/Sources/Arrow/ArrowType.swift index e63647d0797ee..f5a869f7cdaff 100644 --- a/swift/Arrow/Sources/Arrow/ArrowType.swift +++ b/swift/Arrow/Sources/Arrow/ArrowType.swift @@ -19,6 +19,8 @@ import Foundation public typealias Time32 = Int32 public typealias Time64 = Int64 +public typealias Date32 = Int32 +public typealias Date64 = Int64 func FlatBuffersVersion_23_1_4() { // swiftlint:disable:this identifier_name } @@ -165,6 +167,48 @@ public class ArrowType { return ArrowType.ArrowUnknown } } + + public func getStride( // swiftlint:disable:this cyclomatic_complexity + ) -> Int { + switch self.id { + case .int8: + return MemoryLayout.stride + case .int16: + return MemoryLayout.stride + case .int32: + return MemoryLayout.stride + case .int64: + return MemoryLayout.stride + case .uint8: + return MemoryLayout.stride + case .uint16: + return MemoryLayout.stride + case .uint32: + return MemoryLayout.stride + case .uint64: + return MemoryLayout.stride + case .float: + return MemoryLayout.stride + case .double: + return MemoryLayout.stride + case .boolean: + return MemoryLayout.stride + case .date32: + return MemoryLayout.stride + case .date64: + return MemoryLayout.stride + case .time32: + return MemoryLayout.stride + case .time64: + return MemoryLayout.stride + case .binary: + return MemoryLayout.stride + case .string: + return MemoryLayout.stride + default: + fatalError("Stride requested for unknown type: \(self)") + } + } } extension ArrowType.Info: Equatable { diff --git a/swift/Arrow/Sources/Arrow/ProtoUtil.swift b/swift/Arrow/Sources/Arrow/ProtoUtil.swift new file mode 100644 index 0000000000000..f7fd725fe1140 --- /dev/null +++ b/swift/Arrow/Sources/Arrow/ProtoUtil.swift @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation + +func fromProto( // swiftlint:disable:this cyclomatic_complexity + field: org_apache_arrow_flatbuf_Field +) -> ArrowField { + let type = field.typeType + var arrowType = ArrowType(ArrowType.ArrowUnknown) + switch type { + case .int: + let intType = field.type(type: org_apache_arrow_flatbuf_Int.self)! + let bitWidth = intType.bitWidth + if bitWidth == 8 { + arrowType = ArrowType(intType.isSigned ? ArrowType.ArrowInt8 : ArrowType.ArrowUInt8) + } else if bitWidth == 16 { + arrowType = ArrowType(intType.isSigned ? ArrowType.ArrowInt16 : ArrowType.ArrowUInt16) + } else if bitWidth == 32 { + arrowType = ArrowType(intType.isSigned ? ArrowType.ArrowInt32 : ArrowType.ArrowUInt32) + } else if bitWidth == 64 { + arrowType = ArrowType(intType.isSigned ? ArrowType.ArrowInt64 : ArrowType.ArrowUInt64) + } + case .bool: + arrowType = ArrowType(ArrowType.ArrowBool) + case .floatingpoint: + let floatType = field.type(type: org_apache_arrow_flatbuf_FloatingPoint.self)! + if floatType.precision == .single { + arrowType = ArrowType(ArrowType.ArrowFloat) + } else if floatType.precision == .double { + arrowType = ArrowType(ArrowType.ArrowDouble) + } + case .utf8: + arrowType = ArrowType(ArrowType.ArrowString) + case .binary: + arrowType = ArrowType(ArrowType.ArrowBinary) + case .date: + let dateType = field.type(type: org_apache_arrow_flatbuf_Date.self)! + if dateType.unit == .day { + arrowType = ArrowType(ArrowType.ArrowDate32) + } else { + arrowType = ArrowType(ArrowType.ArrowDate64) + } + case .time: + let timeType = field.type(type: org_apache_arrow_flatbuf_Time.self)! + if timeType.unit == .second || timeType.unit == .millisecond { + let arrowUnit: ArrowTime32Unit = timeType.unit == .second ? .seconds : .milliseconds + arrowType = ArrowTypeTime32(arrowUnit) + } else { + let arrowUnit: ArrowTime64Unit = timeType.unit == .microsecond ? .microseconds : .nanoseconds + arrowType = ArrowTypeTime64(arrowUnit) + } + default: + arrowType = ArrowType(ArrowType.ArrowUnknown) + } + + return ArrowField(field.name ?? "", type: arrowType, isNullable: field.nullable) +} diff --git a/swift/Arrow/Tests/ArrowTests/ArrayTests.swift b/swift/Arrow/Tests/ArrowTests/ArrayTests.swift index 069dbfc88f3ac..f5bfa0506e62f 100644 --- a/swift/Arrow/Tests/ArrowTests/ArrayTests.swift +++ b/swift/Arrow/Tests/ArrowTests/ArrayTests.swift @@ -211,4 +211,38 @@ final class ArrayTests: XCTestCase { XCTAssertEqual(microArray[1], 20000) XCTAssertEqual(microArray[2], 987654321) } + + func checkHolderForType(_ checkType: ArrowType) throws { + let buffers = [ArrowBuffer(length: 0, capacity: 0, + rawPointer: UnsafeMutableRawPointer.allocate(byteCount: 0, alignment: .zero)), + ArrowBuffer(length: 0, capacity: 0, + rawPointer: UnsafeMutableRawPointer.allocate(byteCount: 0, alignment: .zero))] + let field = ArrowField("", type: checkType, isNullable: true) + switch makeArrayHolder(field, buffers: buffers, nullCount: 0) { + case .success(let holder): + XCTAssertEqual(holder.type.id, checkType.id) + case .failure(let err): + throw err + } + } + + func testArrayHolders() throws { + try checkHolderForType(ArrowType(ArrowType.ArrowInt8)) + try checkHolderForType(ArrowType(ArrowType.ArrowUInt8)) + try checkHolderForType(ArrowType(ArrowType.ArrowInt16)) + try checkHolderForType(ArrowType(ArrowType.ArrowUInt16)) + try checkHolderForType(ArrowType(ArrowType.ArrowInt32)) + try checkHolderForType(ArrowType(ArrowType.ArrowUInt32)) + try checkHolderForType(ArrowType(ArrowType.ArrowInt64)) + try checkHolderForType(ArrowType(ArrowType.ArrowUInt64)) + try checkHolderForType(ArrowTypeTime32(.seconds)) + try checkHolderForType(ArrowTypeTime32(.milliseconds)) + try checkHolderForType(ArrowTypeTime64(.microseconds)) + try checkHolderForType(ArrowTypeTime64(.nanoseconds)) + try checkHolderForType(ArrowType(ArrowType.ArrowBinary)) + try checkHolderForType(ArrowType(ArrowType.ArrowFloat)) + try checkHolderForType(ArrowType(ArrowType.ArrowDouble)) + try checkHolderForType(ArrowType(ArrowType.ArrowBool)) + try checkHolderForType(ArrowType(ArrowType.ArrowString)) + } } diff --git a/swift/Arrow/Tests/ArrowTests/IPCTests.swift b/swift/Arrow/Tests/ArrowTests/IPCTests.swift index 59cad94ef4da5..103c3b24c7b93 100644 --- a/swift/Arrow/Tests/ArrowTests/IPCTests.swift +++ b/swift/Arrow/Tests/ArrowTests/IPCTests.swift @@ -64,14 +64,16 @@ func makeSchema() -> ArrowSchema { return schemaBuilder.addField("col1", type: ArrowType(ArrowType.ArrowUInt8), isNullable: true) .addField("col2", type: ArrowType(ArrowType.ArrowString), isNullable: false) .addField("col3", type: ArrowType(ArrowType.ArrowDate32), isNullable: false) + .addField("col4", type: ArrowType(ArrowType.ArrowInt32), isNullable: false) + .addField("col5", type: ArrowType(ArrowType.ArrowFloat), isNullable: false) .finish() } func makeRecordBatch() throws -> RecordBatch { let uint8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() uint8Builder.append(10) - uint8Builder.append(22) - uint8Builder.append(33) + uint8Builder.append(nil) + uint8Builder.append(nil) uint8Builder.append(44) let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder() stringBuilder.append("test10") @@ -85,13 +87,28 @@ func makeRecordBatch() throws -> RecordBatch { date32Builder.append(date2) date32Builder.append(date1) date32Builder.append(date2) - let intHolder = ArrowArrayHolder(try uint8Builder.finish()) + let int32Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + int32Builder.append(1) + int32Builder.append(2) + int32Builder.append(3) + int32Builder.append(4) + let floatBuilder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + floatBuilder.append(211.112) + floatBuilder.append(322.223) + floatBuilder.append(433.334) + floatBuilder.append(544.445) + + let uint8Holder = ArrowArrayHolder(try uint8Builder.finish()) let stringHolder = ArrowArrayHolder(try stringBuilder.finish()) let date32Holder = ArrowArrayHolder(try date32Builder.finish()) + let int32Holder = ArrowArrayHolder(try int32Builder.finish()) + let floatHolder = ArrowArrayHolder(try floatBuilder.finish()) let result = RecordBatch.Builder() - .addColumn("col1", arrowArray: intHolder) + .addColumn("col1", arrowArray: uint8Holder) .addColumn("col2", arrowArray: stringHolder) .addColumn("col3", arrowArray: date32Holder) + .addColumn("col4", arrowArray: int32Holder) + .addColumn("col5", arrowArray: floatHolder) .finish() switch result { case .success(let recordBatch): @@ -182,15 +199,20 @@ final class IPCFileReaderTests: XCTestCase { XCTAssertEqual(recordBatches.count, 1) for recordBatch in recordBatches { XCTAssertEqual(recordBatch.length, 4) - XCTAssertEqual(recordBatch.columns.count, 3) - XCTAssertEqual(recordBatch.schema.fields.count, 3) + XCTAssertEqual(recordBatch.columns.count, 5) + XCTAssertEqual(recordBatch.schema.fields.count, 5) XCTAssertEqual(recordBatch.schema.fields[0].name, "col1") XCTAssertEqual(recordBatch.schema.fields[0].type.info, ArrowType.ArrowUInt8) XCTAssertEqual(recordBatch.schema.fields[1].name, "col2") XCTAssertEqual(recordBatch.schema.fields[1].type.info, ArrowType.ArrowString) XCTAssertEqual(recordBatch.schema.fields[2].name, "col3") XCTAssertEqual(recordBatch.schema.fields[2].type.info, ArrowType.ArrowDate32) + XCTAssertEqual(recordBatch.schema.fields[3].name, "col4") + XCTAssertEqual(recordBatch.schema.fields[3].type.info, ArrowType.ArrowInt32) + XCTAssertEqual(recordBatch.schema.fields[4].name, "col5") + XCTAssertEqual(recordBatch.schema.fields[4].type.info, ArrowType.ArrowFloat) let columns = recordBatch.columns + XCTAssertEqual(columns[0].nullCount, 2) let dateVal = "\((columns[2].array as! AsString).asString(0))" // swiftlint:disable:this force_cast XCTAssertEqual(dateVal, "2014-09-10 00:00:00 +0000") @@ -227,13 +249,17 @@ final class IPCFileReaderTests: XCTestCase { case .success(let result): XCTAssertNotNil(result.schema) let schema = result.schema! - XCTAssertEqual(schema.fields.count, 3) + XCTAssertEqual(schema.fields.count, 5) XCTAssertEqual(schema.fields[0].name, "col1") XCTAssertEqual(schema.fields[0].type.info, ArrowType.ArrowUInt8) XCTAssertEqual(schema.fields[1].name, "col2") XCTAssertEqual(schema.fields[1].type.info, ArrowType.ArrowString) XCTAssertEqual(schema.fields[2].name, "col3") XCTAssertEqual(schema.fields[2].type.info, ArrowType.ArrowDate32) + XCTAssertEqual(schema.fields[3].name, "col4") + XCTAssertEqual(schema.fields[3].type.info, ArrowType.ArrowInt32) + XCTAssertEqual(schema.fields[4].name, "col5") + XCTAssertEqual(schema.fields[4].type.info, ArrowType.ArrowFloat) case.failure(let error): throw error } diff --git a/swift/Arrow/Tests/ArrowTests/RecordBatchTests.swift b/swift/Arrow/Tests/ArrowTests/RecordBatchTests.swift index ab6cad1b5e409..8820f1cdb1a91 100644 --- a/swift/Arrow/Tests/ArrowTests/RecordBatchTests.swift +++ b/swift/Arrow/Tests/ArrowTests/RecordBatchTests.swift @@ -23,9 +23,11 @@ final class RecordBatchTests: XCTestCase { let uint8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() uint8Builder.append(10) uint8Builder.append(22) + uint8Builder.append(nil) let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder() stringBuilder.append("test10") stringBuilder.append("test22") + stringBuilder.append("test33") let intHolder = ArrowArrayHolder(try uint8Builder.finish()) let stringHolder = ArrowArrayHolder(try stringBuilder.finish()) @@ -39,15 +41,16 @@ final class RecordBatchTests: XCTestCase { XCTAssertEqual(schema.fields.count, 2) XCTAssertEqual(schema.fields[0].name, "col1") XCTAssertEqual(schema.fields[0].type.info, ArrowType.ArrowUInt8) - XCTAssertEqual(schema.fields[0].isNullable, false) + XCTAssertEqual(schema.fields[0].isNullable, true) XCTAssertEqual(schema.fields[1].name, "col2") XCTAssertEqual(schema.fields[1].type.info, ArrowType.ArrowString) XCTAssertEqual(schema.fields[1].isNullable, false) XCTAssertEqual(recordBatch.columns.count, 2) let col1: ArrowArray = recordBatch.data(for: 0) let col2: ArrowArray = recordBatch.data(for: 1) - XCTAssertEqual(col1.length, 2) - XCTAssertEqual(col2.length, 2) + XCTAssertEqual(col1.length, 3) + XCTAssertEqual(col2.length, 3) + XCTAssertEqual(col1.nullCount, 1) case .failure(let error): throw error }