Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HiveHash support for nested types #9

Merged
merged 16 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions src/main/cpp/src/hive_hash.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class hive_device_row_hasher {
HIVE_INIT_HASH,
cuda::proclaim_return_type<hive_hash_value_t>(
[row_index, nulls = this->_check_nulls] __device__(auto hash, auto const& column) {
auto cur_hash = cudf::type_dispatcher<cudf::experimental::dispatch_void_if_nested>(
auto cur_hash = cudf::type_dispatcher(
column.type(), element_hasher_adapter{nulls}, column, row_index);
return HIVE_HASH_FACTOR * hash + cur_hash;
}));
Expand Down Expand Up @@ -210,11 +210,72 @@ class hive_device_row_hasher {
return this->hash_functor.template operator()<T>(col, row_index);
}

struct StackElement{
cudf::column_device_view col; // current column
int child_idx; // index of the child column to process next
int factor_exp; // factor exponent for the current column
//factor_exp = parent.factor_exp + parent.child_num - 1 - parent.child_idx

__device__ StackElement() = delete;
__device__ StackElement(cudf::column_device_view col, int factor_exp) : col(col), child_idx(0), factor_exp(factor_exp) {}
};

typedef StackElement* StackElementPtr;
template <typename T, CUDF_ENABLE_IF(cudf::is_nested<T>())>
__device__ hive_hash_value_t operator()(cudf::column_device_view const& col,
cudf::size_type row_index) const noexcept
{
CUDF_UNREACHABLE("Nested type is not supported");
hive_hash_value_t ret = HIVE_INIT_HASH;

This comment was marked as resolved.

cudf::column_device_view curr_col = col.slice(row_index, 1);
// column_device_view default constructor is deleted, can not allocate column_device_view array directly
// use byte array to wrapper StackElement list
ustcfy marked this conversation as resolved.
Show resolved Hide resolved
constexpr int len_of_8_StackElement = 8 * sizeof(StackElement);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is 8 used here ?

uint8_t stack_wrapper[len_of_8_StackElement];
StackElementPtr stack = reinterpret_cast<StackElementPtr>(stack_wrapper);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused on creating a stack of StackElement ?
Does the below not work in C++ ?

StackElement statck[8];

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused on creating a stack of StackElement ? Does the below not work in C++ ?

StackElement statck[8];

Because if I define an array StackElement stack[8]; without initialization, each element of the array will be constructed using the default constructor of StackElement. However, the default constructor for the member cudf::column_device_view in StackElement is deleted.

Copy link

@firestarman firestarman Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, But you can restore the default constructor. Any reason it has to be deleted?
uint8_t stack_wrapper[len_of_8_StackElement] will also initialize 8 empty StackElement, the same as what StackElement statck[8] will do.

Copy link
Collaborator

@res-life res-life Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason it has to be deleted?

Do not know why cuDF marks default constructor as deleted.

column_device_view()                          = delete;

uint8_t stack_wrapper[len_of_8_StackElement] will also initialize 8 empty StackElement

For a array like A_Class[8], it will first allocate 8 * sizeof(A_Class) memory, then call default constructor to initilize 8 instences. uint8_t stack_wrapper[len_of_8_StackElement] will skip the call of default constructor to initilize.

Here is a workaround to skip the call of default constructor.

      uint8_t stack_wrapper[len_of_8_StackElement];
      StackElementPtr stack = reinterpret_cast<StackElementPtr>(stack_wrapper);

int stack_size = 0;

stack[stack_size++] = StackElement(curr_col, 0);
//depth first search
while (stack_size > 0) {
StackElementPtr element = &stack[stack_size - 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
StackElement const& element = stack[stack_size - 1];

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: StackElement const& element = stack[stack_size - 1];

But I need to modify the child_idx member of the element.

curr_col = element->col;

if (curr_col.type().id() == cudf::type_id::STRUCT) {
if (element->child_idx == curr_col.num_child_columns()) {
// All child columns processed, pop the stack
stack_size--;
} else {
// Push child column to stack
stack[stack_size++] = StackElement(cudf::detail::structs_column_device_view(curr_col).get_sliced_child(element->child_idx), element->factor_exp + curr_col.num_child_columns() - 1 - element->child_idx);
Copy link
Collaborator

@res-life res-life Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add throw exception if stack_size > 8

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we throw an exception in a CUDA kernel ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See an example as below, a mcro is used to handle this kind of case. And CUDF_UNREACHABLE

    template <typename T, CUDF_ENABLE_IF(not cudf::column_device_view::has_element_accessor<T>())>
    __device__ hash_value_type operator()(cudf::column_device_view const&,
                                          cudf::size_type,
                                          Nullate const,
                                          hash_value_type const) const noexcept
    {
      CUDF_UNREACHABLE("Unsupported type for xxhash64");
    }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throw an exception in CPU code before call into this kernel.

element->child_idx++;
}
} else if (curr_col.type().id() == cudf::type_id::LIST) {
//lists_column_device_view has a different interface from structs_column_device_view
curr_col = cudf::detail::lists_column_device_view(curr_col).get_sliced_child();
if (element->child_idx == curr_col.size()) {
stack_size--;
} else {
stack[stack_size++] = StackElement(curr_col.slice(element->child_idx, 1), element->factor_exp + curr_col.size() - element->child_idx - 1);
Copy link
Collaborator

@res-life res-life Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throw an exception in CPU code before call into this kernel.

element->child_idx++;
}
} else { // Process primitive type
hive_hash_value_t cur_hash = cudf::detail::accumulate(
thrust::counting_iterator(0),
thrust::counting_iterator(curr_col.size()),
HIVE_INIT_HASH,
[curr_col, hasher = this->hash_functor] __device__(auto hash, auto element_index) {
return HIVE_HASH_FACTOR * hash + cudf::type_dispatcher<cudf::experimental::dispatch_void_if_nested>(
curr_col.type(), hasher, curr_col, element_index);
});
//ret += cur_hash * (HIVE_HASH_FACTOR ^ element->factor_exp);
for(int i = 0; i < element->factor_exp; i++) {
cur_hash *= HIVE_HASH_FACTOR;
}
ret += cur_hash;
stack_size--;
}
}
return ret;
}

private:
Expand Down
1 change: 0 additions & 1 deletion src/main/java/com/nvidia/spark/rapids/jni/Hash.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ public static ColumnVector hiveHash(ColumnView columns[]) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isNestedType() : "Unsupported column type Nested";
columnViews[i] = columns[i].getNativeView();
}
return new ColumnVector(hiveHash(columnViews));
Expand Down
219 changes: 219 additions & 0 deletions src/test/java/com/nvidia/spark/rapids/jni/HashTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.rapids.cudf.ColumnView;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector.*;
import ai.rapids.cudf.Scalar;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
Expand Down Expand Up @@ -510,4 +511,222 @@ void testHiveHashMixed() {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testHiveHashStruct() {
Copy link
Collaborator

@res-life res-life Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please test all the types, especially DICTIONARY32, TIMESTAMP_MILLISECONDS, DECIMAL32 types.

 switch (dtype.id()) {
    case type_id::INT8:
      return f.template operator()<typename IdTypeMap<type_id::INT8>::type>(
        std::forward<Ts>(args)...);
    case type_id::INT16:
      return f.template operator()<typename IdTypeMap<type_id::INT16>::type>(
        std::forward<Ts>(args)...);
    case type_id::INT32:
      return f.template operator()<typename IdTypeMap<type_id::INT32>::type>(
        std::forward<Ts>(args)...);
    case type_id::INT64:
      return f.template operator()<typename IdTypeMap<type_id::INT64>::type>(
        std::forward<Ts>(args)...);
    case type_id::UINT8:
      return f.template operator()<typename IdTypeMap<type_id::UINT8>::type>(
        std::forward<Ts>(args)...);
    case type_id::UINT16:
      return f.template operator()<typename IdTypeMap<type_id::UINT16>::type>(
        std::forward<Ts>(args)...);
    case type_id::UINT32:
      return f.template operator()<typename IdTypeMap<type_id::UINT32>::type>(
        std::forward<Ts>(args)...);
    case type_id::UINT64:
      return f.template operator()<typename IdTypeMap<type_id::UINT64>::type>(
        std::forward<Ts>(args)...);
    case type_id::FLOAT32:
      return f.template operator()<typename IdTypeMap<type_id::FLOAT32>::type>(
        std::forward<Ts>(args)...);
    case type_id::FLOAT64:
      return f.template operator()<typename IdTypeMap<type_id::FLOAT64>::type>(
        std::forward<Ts>(args)...);
    case type_id::BOOL8:
      return f.template operator()<typename IdTypeMap<type_id::BOOL8>::type>(
        std::forward<Ts>(args)...);
    case type_id::TIMESTAMP_DAYS:
      return f.template operator()<typename IdTypeMap<type_id::TIMESTAMP_DAYS>::type>(
        std::forward<Ts>(args)...);
    case type_id::TIMESTAMP_SECONDS:
      return f.template operator()<typename IdTypeMap<type_id::TIMESTAMP_SECONDS>::type>(
        std::forward<Ts>(args)...);
    case type_id::TIMESTAMP_MILLISECONDS:
      return f.template operator()<typename IdTypeMap<type_id::TIMESTAMP_MILLISECONDS>::type>(
        std::forward<Ts>(args)...);
    case type_id::TIMESTAMP_MICROSECONDS:
      return f.template operator()<typename IdTypeMap<type_id::TIMESTAMP_MICROSECONDS>::type>(
        std::forward<Ts>(args)...);
    case type_id::TIMESTAMP_NANOSECONDS:
      return f.template operator()<typename IdTypeMap<type_id::TIMESTAMP_NANOSECONDS>::type>(
        std::forward<Ts>(args)...);
    case type_id::DURATION_DAYS:
      return f.template operator()<typename IdTypeMap<type_id::DURATION_DAYS>::type>(
        std::forward<Ts>(args)...);
    case type_id::DURATION_SECONDS:
      return f.template operator()<typename IdTypeMap<type_id::DURATION_SECONDS>::type>(
        std::forward<Ts>(args)...);
    case type_id::DURATION_MILLISECONDS:
      return f.template operator()<typename IdTypeMap<type_id::DURATION_MILLISECONDS>::type>(
        std::forward<Ts>(args)...);
    case type_id::DURATION_MICROSECONDS:
      return f.template operator()<typename IdTypeMap<type_id::DURATION_MICROSECONDS>::type>(
        std::forward<Ts>(args)...);
    case type_id::DURATION_NANOSECONDS:
      return f.template operator()<typename IdTypeMap<type_id::DURATION_NANOSECONDS>::type>(
        std::forward<Ts>(args)...);
    case type_id::DICTIONARY32:
      return f.template operator()<typename IdTypeMap<type_id::DICTIONARY32>::type>(
        std::forward<Ts>(args)...);
    case type_id::STRING:
      return f.template operator()<typename IdTypeMap<type_id::STRING>::type>(
        std::forward<Ts>(args)...);
    case type_id::LIST:
      return f.template operator()<typename IdTypeMap<type_id::LIST>::type>(
        std::forward<Ts>(args)...);
    case type_id::DECIMAL32:
      return f.template operator()<typename IdTypeMap<type_id::DECIMAL32>::type>(
        std::forward<Ts>(args)...);
    case type_id::DECIMAL64:
      return f.template operator()<typename IdTypeMap<type_id::DECIMAL64>::type>(
        std::forward<Ts>(args)...);
    case type_id::DECIMAL128:
      return f.template operator()<typename IdTypeMap<type_id::DECIMAL128>::type>(
        std::forward<Ts>(args)...);
    case type_id::STRUCT:
      return f.template operator()<typename IdTypeMap<type_id::STRUCT>::type>(
        std::forward<Ts>(args)...);

Copy link
Author

@ustcfy ustcfy Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@res-life

  • Do you mean, for example, to add fields like TIMESTAMP_MILLISECONDS and DECIMAL32 inside struct? However, it seems that the Hivehash for these two types has not been implemented yet.

  • As for DICTIONARY, I will consider adding support if the cases for struct and list are resolved.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean, for example, to add fields like TIMESTAMP_MILLISECONDS and DECIMAL32 inside struct? However, it seems that the Hivehash for these two types has not been implemented yet.

If Hivehash does not support these two types, we should try to find a way to implement it, or we should fallback in the Rapids plugin. I assume Spark supports these two types.

As for DICTIONARY, I will consider adding support if the cases for struct and list are resolved.
Please help test.

  • DICTIONARY for basic types like int, long
  • DICTIONARY for nested types, maybe cuDF can not generate this kind of data.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More primitive types support can be a follow-up. This PR focuses on nested types support.

try (ColumnVector strings = ColumnVector.fromStrings(
"a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721",
"This is a long string (greater than 128 bytes/char string) case to test this " +
"hash function. Just want an abnormal case here to see if any error may happen when" +
"doing the hive hashing",
null, null);
ColumnVector integers = ColumnVector.fromBoxedInts(
0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null);
ColumnVector doubles = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0,
POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
ColumnVector floats = ColumnVector.fromBoxedFloats(0f, 100f, -100f,
NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnVector bools = ColumnVector.fromBoxedBooleans(
true, false, null, false, true, null);
ColumnView structs = ColumnView.makeStructView(strings, integers, doubles, floats, bools);
ColumnVector result = Hash.hiveHash(new ColumnView[]{structs});
ColumnVector expected = Hash.hiveHash(new ColumnVector[]{strings}).mul(Scalar.fromInt(923521)) // 923521 = 31 ^ 4
.add(Hash.hiveHash(new ColumnVector[]{integers}).mul(Scalar.fromInt(29791))) // 29791 = 31 ^ 3
.add(Hash.hiveHash(new ColumnVector[]{doubles}).mul(Scalar.fromInt(961))) // 961 = 31 ^ 2
.add(Hash.hiveHash(new ColumnVector[]{floats}).mul(Scalar.fromInt(31)))
.add(Hash.hiveHash(new ColumnVector[]{bools}));) {
assertColumnsAreEqual(expected, result);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test cases are using the same interface to get both expected and actual results.
Please leverage the results in testHiveHashStrings, testHiveHashFloats...
Test case will like:

struct_data = {1.0, 2.3, 5}
actual = Hash.hiveHash(...)
expected = 37 * (37 * hash(1.0) + hash(2.3)) + hash(5)
assert_equals(expected, actual)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better that we can use Spark hive hash interface to get the expected result.
Try to introduce the corresponding Java dependency with test scope.

    <dependency>
      <groupId>xx</groupId>
      <artifactId>xx</artifactId>
      <scope>test</scope>
    </dependency>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not suggest do this in JNI repo. We can hardcode the expected results just as what other tests do.
But we can test them against Spark hive hash in Rapids tests.

}
}

@Test
void testHiveHashNestedStruct() {
try (ColumnVector strings = ColumnVector.fromStrings(
"a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721",
"This is a long string (greater than 128 bytes/char string) case to test this " +
"hash function. Just want an abnormal case here to see if any error may happen when" +
"doing the hive hashing",
null, null);
ColumnVector integers = ColumnVector.fromBoxedInts(
0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null);
ColumnVector doubles = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0,
POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
ColumnVector floats = ColumnVector.fromBoxedFloats(0f, 100f, -100f,
NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnVector bools = ColumnVector.fromBoxedBooleans(
true, false, null, false, true, null);
ColumnView structs1 = ColumnView.makeStructView(strings, integers);
ColumnView structs2 = ColumnView.makeStructView(structs1, doubles);
ColumnView structs3 = ColumnView.makeStructView(bools);
ColumnView structs = ColumnView.makeStructView(structs2, floats, structs3);
ColumnVector result = Hash.hiveHash(new ColumnView[]{structs});
ColumnVector expected = Hash.hiveHash(new ColumnView[]{structs2}).mul(Scalar.fromInt(961)) // 31 ^ 2
.add(Hash.hiveHash(new ColumnVector[]{floats}).mul(Scalar.fromInt(31))) // 31 ^ 1
.add(Hash.hiveHash(new ColumnView[]{structs3}));) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testHiveHashLists() {
try (ColumnVector stringListCV = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.STRING)),
Arrays.asList(null, "a"),
Arrays.asList("B\n", ""),
Arrays.asList("dE\"\u0100\t\u0101", " \ud720\ud721"),
Collections.singletonList("This is a long string (greater than 128 bytes/char string) case to test this " +
"hash function. Just want an abnormal case here to see if any error may happen when" +
"doing the hive hashing"),
Collections.singletonList(""),
null);
ColumnVector strings1 = ColumnVector.fromStrings(
null, "B\n", "dE\"\u0100\t\u0101",
"This is a long string (greater than 128 bytes/char string) case to test this " +
"hash function. Just want an abnormal case here to see if any error may happen when" +
"doing the hive hashing", "", null);
ColumnVector strings2 = ColumnVector.fromStrings(
"a", "", " \ud720\ud721", null, null, null);
ColumnVector stringResult = Hash.hiveHash(new ColumnView[]{stringListCV});
ColumnVector stringExpected = Hash.hiveHash(new ColumnVector[]{strings1}).mul(ColumnVector.fromBoxedInts(31, 31, 31, 1, 1, 0))
.add(Hash.hiveHash(new ColumnVector[]{strings2}));
ColumnVector intListCV = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.INT32)),
Collections.singletonList(null),
Arrays.asList(0, -2, 3),
Collections.singletonList(Integer.MAX_VALUE),
Arrays.asList(5, -6, null),
Collections.singletonList(Integer.MIN_VALUE),
null);
ColumnVector integers1 = ColumnVector.fromBoxedInts(null, 0, Integer.MAX_VALUE, 5, Integer.MIN_VALUE, null);
ColumnVector integers2 = ColumnVector.fromBoxedInts(null, -2, null, -6, null, null);
ColumnVector integers3 = ColumnVector.fromBoxedInts(null, 3, null, null, null, null);
ColumnVector intExpected = Hash.hiveHash(new ColumnVector[]{integers1}).mul(ColumnVector.fromBoxedInts(1, 961, 1, 961, 1, 0))
.add(Hash.hiveHash(new ColumnVector[]{integers2}).mul(ColumnVector.fromBoxedInts(0, 31, 0, 31, 0, 0)))
.add(Hash.hiveHash(new ColumnVector[]{integers3}));
ColumnVector intResult = Hash.hiveHash(new ColumnVector[]{intListCV});) {
assertColumnsAreEqual(stringExpected, stringResult);
assertColumnsAreEqual(intExpected, intResult);
}
}

@Test
void testHiveHashNestedLists() {
try (ColumnVector nestedStringListCV = ColumnVector.fromLists(
new ListType(true, new ListType(true, new BasicType(true, DType.STRING))),
Arrays.asList(null, Arrays.asList("a", null)),
Arrays.asList(Arrays.asList("B\n", "")),
Arrays.asList(Collections.singletonList("dE\"\u0100\t\u0101"), Collections.singletonList(" \ud720\ud721")),
Arrays.asList(Collections.singletonList("This is a long string (greater than 128 bytes/char string) case to test this " +
"hash function. Just want an abnormal case here to see if any error may happen when" +
"doing the hive hashing"), null),
Arrays.asList(Collections.singletonList(""), Collections.singletonList(null)),
null);
ColumnVector stringListCV1 = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.STRING)),
null,
Arrays.asList("B\n", ""),
Collections.singletonList("dE\"\u0100\t\u0101"),
Collections.singletonList("This is a long string (greater than 128 bytes/char string) case to test this " +
"hash function. Just want an abnormal case here to see if any error may happen when" +
"doing the hive hashing"),
Collections.singletonList(""),
null);
ColumnVector stringListCV2 = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.STRING)),
Arrays.asList("a", null),
null,
Collections.singletonList(" \ud720\ud721"),
null,
Collections.singletonList(null),
null);
ColumnVector stringExpected = Hash.hiveHash(new ColumnVector[]{stringListCV1}).mul(ColumnVector.fromBoxedInts(31, 1, 31, 31, 31, 0))
.add(Hash.hiveHash(new ColumnVector[]{stringListCV2}));
ColumnVector stringResult = Hash.hiveHash(new ColumnView[]{nestedStringListCV});
ColumnVector nestedIntListCV = ColumnVector.fromLists(
new ListType(true, new ListType(true, new BasicType(true, DType.INT32))),
Arrays.asList(Arrays.asList(null, null), null),
Arrays.asList(Collections.singletonList(0), Collections.singletonList(-2), Collections.singletonList(3)),
Arrays.asList(null, Collections.singletonList(Integer.MAX_VALUE)),
Arrays.asList(Collections.singletonList(5), Collections.singletonList(-6), null),
Arrays.asList(Collections.singletonList(Integer.MIN_VALUE),null, Collections.singletonList(null)),
null);
ColumnVector intListCV1 = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.INT32)),
Arrays.asList(null, null),
Collections.singletonList(0),
null,
Collections.singletonList(5),
Collections.singletonList(Integer.MIN_VALUE),
null);
ColumnVector intListCV2 = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.INT32)),
null,
Collections.singletonList(-2),
Collections.singletonList(Integer.MAX_VALUE),
Collections.singletonList(-6),
null,
null);
ColumnVector intListCV3 = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.INT32)),
null,
Collections.singletonList(3),
null,
null,
Collections.singletonList(null),
null);
ColumnVector intExpected = Hash.hiveHash(new ColumnVector[]{intListCV1}).mul(ColumnVector.fromBoxedInts(31, 961, 31, 961, 961, 0))
.add(Hash.hiveHash(new ColumnVector[]{intListCV2}).mul(ColumnVector.fromBoxedInts(1, 31, 1, 31, 31, 0)))
.add(Hash.hiveHash(new ColumnVector[]{intListCV3}));
ColumnVector intResult = Hash.hiveHash(new ColumnVector[]{nestedIntListCV});) {
assertColumnsAreEqual(stringExpected, stringResult);
assertColumnsAreEqual(intExpected, intResult);
}
}

@Test
void testHiveHashNestedType() {
Copy link

@firestarman firestarman Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to have a test to verify the case when the max nested depth is larger than 8.
Also missing tests for the type of list of struct.

try (ColumnVector stringListCV = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.STRING)),
Arrays.asList(null, "a"),
Arrays.asList("B\n", ""),
Arrays.asList("dE\"\u0100\t\u0101", " \ud720\ud721"),
Collections.singletonList("This is a long string (greater than 128 bytes/char string) case to test this " +
"hash function. Just want an abnormal case here to see if any error may happen when" +
"doing the hive hashing"),
Collections.singletonList(""),
null);
ColumnVector strings1 = ColumnVector.fromStrings(
null, "B\n", "dE\"\u0100\t\u0101",
"This is a long string (greater than 128 bytes/char string) case to test this " +
"hash function. Just want an abnormal case here to see if any error may happen when" +
"doing the hive hashing", "", null);
ColumnVector strings2 = ColumnVector.fromStrings(
"a", "", " \ud720\ud721", null, null, null);
ColumnView stringStruct = ColumnView.makeStructView(strings1, strings2);
ColumnVector stringExpected = Hash.hiveHash(new ColumnView[]{stringStruct});
ColumnVector stringResult = Hash.hiveHash(new ColumnView[]{strings1}).mul(Scalar.fromInt(31))
.add(Hash.hiveHash(new ColumnView[]{strings2}));
ColumnVector intListCV = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.INT32)),
Collections.singletonList(null),
Arrays.asList(0, -2, 3),
Collections.singletonList(Integer.MAX_VALUE),
Arrays.asList(5, -6, null),
Collections.singletonList(Integer.MIN_VALUE),
null);
ColumnVector doubles = ColumnVector.fromBoxedDoubles(
0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
ColumnVector floats = ColumnVector.fromBoxedFloats(
0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnView structCV = ColumnView.makeStructView(intListCV, stringListCV, doubles, floats);
ColumnVector nestedExpected = Hash.hiveHash(new ColumnView[]{intListCV}).mul(Scalar.fromInt(29791)) // 31 ^ 3
.add(Hash.hiveHash(new ColumnView[]{stringListCV}).mul(Scalar.fromInt(961))) // 31 ^ 2
.add(Hash.hiveHash(new ColumnVector[]{doubles}).mul(Scalar.fromInt(31)))
.add(Hash.hiveHash(new ColumnVector[]{floats}));
ColumnVector nestedResult =
Hash.hiveHash(new ColumnView[]{structCV})) {
assertColumnsAreEqual(stringExpected, stringResult);
assertColumnsAreEqual(nestedExpected, nestedResult);
}
}
}