diff --git a/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java index d5d4059d18d..2a11b24b3a8 100644 --- a/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ArrowColumnVectorTest.java @@ -21,7 +21,6 @@ import java.nio.ByteBuffer; import java.util.ArrayList; -import ai.rapids.cudf.HostColumnVector.BasicType; import ai.rapids.cudf.HostColumnVector.ListType; import ai.rapids.cudf.HostColumnVector.StructType; @@ -40,7 +39,7 @@ import org.junit.jupiter.api.Test; -import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/java/src/test/java/ai/rapids/cudf/AssertUtils.java b/java/src/test/java/ai/rapids/cudf/AssertUtils.java new file mode 100644 index 00000000000..184e7dd0c57 --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/AssertUtils.java @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** Utility methods for asserting in unit tests */ +public class AssertUtils { + + /** + * Checks and asserts that passed in columns match + * @param expect The expected result column + * @param cv The input column + */ + public static void assertColumnsAreEqual(ColumnView expect, ColumnView cv) { + assertColumnsAreEqual(expect, cv, "unnamed"); + } + + /** + * Checks and asserts that passed in columns match + * @param expected The expected result column + * @param cv The input column + * @param colName The name of the column + */ + public static void assertColumnsAreEqual(ColumnView expected, ColumnView cv, String colName) { + assertPartialColumnsAreEqual(expected, 0, expected.getRowCount(), cv, colName, true, false); + } + + /** + * Checks and asserts that passed in host columns match + * @param expected The expected result host column + * @param cv The input host column + * @param colName The name of the host column + */ + public static void assertColumnsAreEqual(HostColumnVector expected, HostColumnVector cv, String colName) { + assertPartialColumnsAreEqual(expected, 0, expected.getRowCount(), cv, colName, true, false); + } + + /** + * Checks and asserts that passed in Struct columns match + * @param expected The expected result Struct column + * @param cv The input Struct column + */ + public static void assertStructColumnsAreEqual(ColumnView expected, ColumnView cv) { + assertPartialStructColumnsAreEqual(expected, 0, expected.getRowCount(), cv, "unnamed", true, false); + } + + /** + * Checks and asserts that passed in Struct columns match + * @param expected The expected result Struct column + * @param rowOffset The row number to look from + * @param length The number of rows to consider + * @param cv The input Struct column + * @param colName The name of the column + * @param enableNullCountCheck Whether to check for nulls in the Struct column + * @param enableNullabilityCheck Whether the table have a validity mask + */ + public static void assertPartialStructColumnsAreEqual(ColumnView expected, long rowOffset, long length, + ColumnView cv, String colName, boolean enableNullCountCheck, boolean enableNullabilityCheck) { + try (HostColumnVector hostExpected = expected.copyToHost(); + HostColumnVector hostcv = cv.copyToHost()) { + assertPartialColumnsAreEqual(hostExpected, rowOffset, length, hostcv, colName, enableNullCountCheck, enableNullabilityCheck); + } + } + + /** + * Checks and asserts that passed in columns match + * @param expected The expected result column + * @param cv The input column + * @param colName The name of the column + * @param enableNullCheck Whether to check for nulls in the column + * @param enableNullabilityCheck Whether the table have a validity mask + */ + public static void assertPartialColumnsAreEqual(ColumnView expected, long rowOffset, long length, + ColumnView cv, String colName, boolean enableNullCheck, boolean enableNullabilityCheck) { + try (HostColumnVector hostExpected = expected.copyToHost(); + HostColumnVector hostcv = cv.copyToHost()) { + assertPartialColumnsAreEqual(hostExpected, rowOffset, length, hostcv, colName, enableNullCheck, enableNullabilityCheck); + } + } + + /** + * Checks and asserts that passed in host columns match + * @param expected The expected result host column + * @param rowOffset start row index + * @param length number of rows from starting offset + * @param cv The input host column + * @param colName The name of the host column + * @param enableNullCountCheck Whether to check for nulls in the host column + */ + public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, long rowOffset, long length, + HostColumnVectorCore cv, String colName, boolean enableNullCountCheck, boolean enableNullabilityCheck) { + assertEquals(expected.getType(), cv.getType(), "Type For Column " + colName); + assertEquals(length, cv.getRowCount(), "Row Count For Column " + colName); + assertEquals(expected.getNumChildren(), cv.getNumChildren(), "Child Count for Column " + colName); + if (enableNullCountCheck) { + assertEquals(expected.getNullCount(), cv.getNullCount(), "Null Count For Column " + colName); + } else { + // TODO add in a proper check when null counts are supported by serializing a partitioned column + } + if (enableNullabilityCheck) { + assertEquals(expected.hasValidityVector(), cv.hasValidityVector(), "Column nullability is different than expected"); + } + DType type = expected.getType(); + for (long expectedRow = rowOffset; expectedRow < (rowOffset + length); expectedRow++) { + long tableRow = expectedRow - rowOffset; + assertEquals(expected.isNull(expectedRow), cv.isNull(tableRow), + "NULL for Column " + colName + " Row " + tableRow); + if (!expected.isNull(expectedRow)) { + switch (type.typeId) { + case BOOL8: // fall through + case INT8: // fall through + case UINT8: + assertEquals(expected.getByte(expectedRow), cv.getByte(tableRow), + "Column " + colName + " Row " + tableRow); + break; + case INT16: // fall through + case UINT16: + assertEquals(expected.getShort(expectedRow), cv.getShort(tableRow), + "Column " + colName + " Row " + tableRow); + break; + case INT32: // fall through + case UINT32: // fall through + case TIMESTAMP_DAYS: + case DURATION_DAYS: + case DECIMAL32: + assertEquals(expected.getInt(expectedRow), cv.getInt(tableRow), + "Column " + colName + " Row " + tableRow); + break; + case INT64: // fall through + case UINT64: // fall through + case DURATION_MICROSECONDS: // fall through + case DURATION_MILLISECONDS: // fall through + case DURATION_NANOSECONDS: // fall through + case DURATION_SECONDS: // fall through + case TIMESTAMP_MICROSECONDS: // fall through + case TIMESTAMP_MILLISECONDS: // fall through + case TIMESTAMP_NANOSECONDS: // fall through + case TIMESTAMP_SECONDS: + case DECIMAL64: + assertEquals(expected.getLong(expectedRow), cv.getLong(tableRow), + "Column " + colName + " Row " + tableRow); + break; + case DECIMAL128: + assertEquals(expected.getBigDecimal(expectedRow), cv.getBigDecimal(tableRow), + "Column " + colName + " Row " + tableRow); + break; + case FLOAT32: + CudfTestBase.assertEqualsWithinPercentage(expected.getFloat(expectedRow), cv.getFloat(tableRow), 0.0001, + "Column " + colName + " Row " + tableRow); + break; + case FLOAT64: + CudfTestBase.assertEqualsWithinPercentage(expected.getDouble(expectedRow), cv.getDouble(tableRow), 0.0001, + "Column " + colName + " Row " + tableRow); + break; + case STRING: + assertArrayEquals(expected.getUTF8(expectedRow), cv.getUTF8(tableRow), + "Column " + colName + " Row " + tableRow); + break; + case LIST: + HostMemoryBuffer expectedOffsets = expected.getOffsets(); + HostMemoryBuffer cvOffsets = cv.getOffsets(); + int expectedChildRows = expectedOffsets.getInt((expectedRow + 1) * 4) - + expectedOffsets.getInt(expectedRow * 4); + int cvChildRows = cvOffsets.getInt((tableRow + 1) * 4) - + cvOffsets.getInt(tableRow * 4); + assertEquals(expectedChildRows, cvChildRows, "Child row count for Column " + + colName + " Row " + tableRow); + break; + case STRUCT: + // parent column only has validity which was checked above + break; + default: + throw new IllegalArgumentException(type + " is not supported yet"); + } + } + } + + if (type.isNestedType()) { + switch (type.typeId) { + case LIST: + int expectedChildRowOffset = 0; + int numChildRows = 0; + if (length > 0) { + HostMemoryBuffer expectedOffsets = expected.getOffsets(); + HostMemoryBuffer cvOffsets = cv.getOffsets(); + expectedChildRowOffset = expectedOffsets.getInt(rowOffset * 4); + numChildRows = expectedOffsets.getInt((rowOffset + length) * 4) - + expectedChildRowOffset; + } + assertPartialColumnsAreEqual(expected.getNestedChildren().get(0), expectedChildRowOffset, + numChildRows, cv.getNestedChildren().get(0), colName + " list child", + enableNullCountCheck, enableNullabilityCheck); + break; + case STRUCT: + List expectedChildren = expected.getNestedChildren(); + List cvChildren = cv.getNestedChildren(); + for (int i = 0; i < expectedChildren.size(); i++) { + HostColumnVectorCore expectedChild = expectedChildren.get(i); + HostColumnVectorCore cvChild = cvChildren.get(i); + String childName = colName + " child " + i; + assertEquals(length, cvChild.getRowCount(), "Row Count for Column " + colName); + assertPartialColumnsAreEqual(expectedChild, rowOffset, length, cvChild, + colName, enableNullCountCheck, enableNullabilityCheck); + } + break; + default: + throw new IllegalArgumentException(type + " is not supported yet"); + } + } + } + + /** + * Checks and asserts that the two tables from a given rowindex match based on a provided schema. + * @param expected the expected result table + * @param rowOffset the row number to start checking from + * @param length the number of rows to check + * @param table the input table to compare against expected + * @param enableNullCheck whether to check for nulls or not + * @param enableNullabilityCheck whether the table have a validity mask + */ + public static void assertPartialTablesAreEqual(Table expected, long rowOffset, long length, Table table, + boolean enableNullCheck, boolean enableNullabilityCheck) { + assertEquals(expected.getNumberOfColumns(), table.getNumberOfColumns()); + assertEquals(length, table.getRowCount(), "ROW COUNT"); + for (int col = 0; col < expected.getNumberOfColumns(); col++) { + ColumnVector expect = expected.getColumn(col); + ColumnVector cv = table.getColumn(col); + String name = String.valueOf(col); + if (rowOffset != 0 || length != expected.getRowCount()) { + name = name + " PART " + rowOffset + "-" + (rowOffset + length - 1); + } + assertPartialColumnsAreEqual(expect, rowOffset, length, cv, name, enableNullCheck, enableNullabilityCheck); + } + } + + /** + * Checks and asserts that the two tables match + * @param expected the expected result table + * @param table the input table to compare against expected + */ + public static void assertTablesAreEqual(Table expected, Table table) { + assertPartialTablesAreEqual(expected, 0, expected.getRowCount(), table, true, false); + } + + public static void assertTableTypes(DType[] expectedTypes, Table t) { + int len = t.getNumberOfColumns(); + assertEquals(expectedTypes.length, len); + for (int i = 0; i < len; i++) { + ColumnVector vec = t.getColumn(i); + DType type = vec.getType(); + assertEquals(expectedTypes[i], type, "Types don't match at " + i); + } + } +} diff --git a/java/src/test/java/ai/rapids/cudf/BinaryOpTest.java b/java/src/test/java/ai/rapids/cudf/BinaryOpTest.java index 894861b8c44..0ca997d3c80 100644 --- a/java/src/test/java/ai/rapids/cudf/BinaryOpTest.java +++ b/java/src/test/java/ai/rapids/cudf/BinaryOpTest.java @@ -27,7 +27,7 @@ import java.util.Arrays; import java.util.stream.IntStream; -import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static ai.rapids.cudf.TestUtils.*; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/java/src/test/java/ai/rapids/cudf/ByteColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ByteColumnVectorTest.java index 878fa7e4516..a26dbec4907 100644 --- a/java/src/test/java/ai/rapids/cudf/ByteColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ByteColumnVectorTest.java @@ -127,9 +127,9 @@ public void testCastToByte() { ColumnVector expected1 = ColumnVector.fromBytes((byte)4, (byte)3, (byte)8); ColumnVector expected2 = ColumnVector.fromBytes((byte)100); ColumnVector expected3 = ColumnVector.fromBytes((byte)-23)) { - TableTest.assertColumnsAreEqual(expected1, byteColumnVector1); - TableTest.assertColumnsAreEqual(expected2, byteColumnVector2); - TableTest.assertColumnsAreEqual(expected3, byteColumnVector3); + AssertUtils.assertColumnsAreEqual(expected1, byteColumnVector1); + AssertUtils.assertColumnsAreEqual(expected2, byteColumnVector2); + AssertUtils.assertColumnsAreEqual(expected3, byteColumnVector3); } } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index a582541a0d4..2f79b47c64f 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -34,8 +34,10 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertStructColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertTablesAreEqual; import static ai.rapids.cudf.QuantileMethod.*; -import static ai.rapids.cudf.TableTest.*; import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -86,8 +88,8 @@ void testTransformVector() { ColumnVector cv1 = cv.transform(ptx, true); ColumnVector cv2 = cv.transform(cuda, false); ColumnVector expected = ColumnVector.fromBoxedInts(2*2-2, 3*3-3, null, 4*4-4)) { - TableTest.assertColumnsAreEqual(expected, cv1); - TableTest.assertColumnsAreEqual(expected, cv2); + assertColumnsAreEqual(expected, cv1); + assertColumnsAreEqual(expected, cv2); } } @@ -252,7 +254,7 @@ void testStringCreation() { try (ColumnVector cv = ColumnVector.fromStrings("d", "sd", "sde", null, "END"); HostColumnVector host = cv.copyToHost(); ColumnVector backAgain = host.copyToDevice()) { - TableTest.assertColumnsAreEqual(cv, backAgain); + assertColumnsAreEqual(cv, backAgain); } } @@ -265,7 +267,7 @@ void testUTF8StringCreation() { null, "END".getBytes(StandardCharsets.UTF_8)); ColumnVector expected = ColumnVector.fromStrings("d", "sd", "sde", null, "END")) { - TableTest.assertColumnsAreEqual(expected, cv); + assertColumnsAreEqual(expected, cv); } } @@ -299,7 +301,7 @@ void testConcatNoNulls() { ColumnVector v2 = ColumnVector.fromInts(8, 9); ColumnVector v = ColumnVector.concatenate(v0, v1, v2); ColumnVector expected = ColumnVector.fromInts(1, 2, 3, 4, 5, 6, 7, 8, 9)) { - TableTest.assertColumnsAreEqual(expected, v); + assertColumnsAreEqual(expected, v); } } @@ -310,7 +312,7 @@ void testConcatWithNulls() { ColumnVector v2 = ColumnVector.fromBoxedDoubles(null, 9.0); ColumnVector v = ColumnVector.concatenate(v0, v1, v2); ColumnVector expected = ColumnVector.fromBoxedDoubles(1., 2., 3., 4., 5., 6., 7., null, 9.)) { - TableTest.assertColumnsAreEqual(expected, v); + assertColumnsAreEqual(expected, v); } } @@ -1882,13 +1884,13 @@ void testSubvector() { try (ColumnVector vec = ColumnVector.fromBoxedInts(1, 2, 3, null, 5); ColumnVector expected = ColumnVector.fromBoxedInts(2, 3, null, 5); ColumnVector found = vec.subVector(1, 5)) { - TableTest.assertColumnsAreEqual(expected, found); + assertColumnsAreEqual(expected, found); } try (ColumnVector vec = ColumnVector.fromStrings("1", "2", "3", null, "5"); ColumnVector expected = ColumnVector.fromStrings("2", "3", null, "5"); ColumnVector found = vec.subVector(1, 5)) { - TableTest.assertColumnsAreEqual(expected, found); + assertColumnsAreEqual(expected, found); } } @@ -2014,7 +2016,7 @@ void testTrimStringsWhiteSpace() { try (ColumnVector cv = ColumnVector.fromStrings(" 123", "123 ", null, " 123 ", "\t\t123\n\n"); ColumnVector trimmed = cv.strip(); ColumnVector expected = ColumnVector.fromStrings("123", "123", null, "123", "123")) { - TableTest.assertColumnsAreEqual(expected, trimmed); + assertColumnsAreEqual(expected, trimmed); } } @@ -2024,7 +2026,7 @@ void testTrimStrings() { Scalar one = Scalar.fromString(" 1"); ColumnVector trimmed = cv.strip(one); ColumnVector expected = ColumnVector.fromStrings("23", "23", null, "23", "\t\t123\n\n")) { - TableTest.assertColumnsAreEqual(expected, trimmed); + assertColumnsAreEqual(expected, trimmed); } } @@ -2033,7 +2035,7 @@ void testLeftTrimStringsWhiteSpace() { try (ColumnVector cv = ColumnVector.fromStrings(" 123", "123 ", null, " 123 ", "\t\t123\n\n"); ColumnVector trimmed = cv.lstrip(); ColumnVector expected = ColumnVector.fromStrings("123", "123 ", null, "123 ", "123\n\n")) { - TableTest.assertColumnsAreEqual(expected, trimmed); + assertColumnsAreEqual(expected, trimmed); } } @@ -2043,7 +2045,7 @@ void testLeftTrimStrings() { Scalar one = Scalar.fromString(" 1"); ColumnVector trimmed = cv.lstrip(one); ColumnVector expected = ColumnVector.fromStrings("23", "23 ", null, "231", "\t\t123\n\n")) { - TableTest.assertColumnsAreEqual(expected, trimmed); + assertColumnsAreEqual(expected, trimmed); } } @@ -2052,7 +2054,7 @@ void testRightTrimStringsWhiteSpace() { try (ColumnVector cv = ColumnVector.fromStrings(" 123", "123 ", null, " 123 ", "\t\t123\n\n"); ColumnVector trimmed = cv.rstrip(); ColumnVector expected = ColumnVector.fromStrings(" 123", "123", null, " 123", "\t\t123")) { - TableTest.assertColumnsAreEqual(expected, trimmed); + assertColumnsAreEqual(expected, trimmed); } } @@ -2062,7 +2064,7 @@ void testRightTrimStrings() { Scalar one = Scalar.fromString(" 1"); ColumnVector trimmed = cv.rstrip(one); ColumnVector expected = ColumnVector.fromStrings("123", "123", null, "123", "\t\t123\n\n")) { - TableTest.assertColumnsAreEqual(expected, trimmed); + assertColumnsAreEqual(expected, trimmed); } } @@ -2108,7 +2110,7 @@ void testCountElements() { Arrays.asList(1, 2, 3), Arrays.asList(1, 2, 3, 4)); ColumnVector lengths = cv.countElements(); ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 2, 3, 4)) { - TableTest.assertColumnsAreEqual(expected, lengths); + assertColumnsAreEqual(expected, lengths); } } @@ -2117,7 +2119,7 @@ void testStringLengths() { try (ColumnVector cv = ColumnVector.fromStrings("1", "12", null, "123", "1234"); ColumnVector lengths = cv.getCharLengths(); ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 3, 4)) { - TableTest.assertColumnsAreEqual(expected, lengths); + assertColumnsAreEqual(expected, lengths); } } @@ -2126,7 +2128,7 @@ void testGetByteCount() { try (ColumnVector cv = ColumnVector.fromStrings("1", "12", "123", null, "1234"); ColumnVector byteLengthVector = cv.getByteCount(); ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, 3, null, 4)) { - TableTest.assertColumnsAreEqual(expected, byteLengthVector); + assertColumnsAreEqual(expected, byteLengthVector); } } diff --git a/java/src/test/java/ai/rapids/cudf/IfElseTest.java b/java/src/test/java/ai/rapids/cudf/IfElseTest.java index 86ddcc23416..a078befdf40 100644 --- a/java/src/test/java/ai/rapids/cudf/IfElseTest.java +++ b/java/src/test/java/ai/rapids/cudf/IfElseTest.java @@ -25,7 +25,7 @@ import java.util.stream.Stream; -import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.assertThrows; public class IfElseTest extends CudfTestBase { diff --git a/java/src/test/java/ai/rapids/cudf/IntColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/IntColumnVectorTest.java index dd03c4de69e..2fb8164534b 100644 --- a/java/src/test/java/ai/rapids/cudf/IntColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/IntColumnVectorTest.java @@ -117,8 +117,8 @@ public void testCastToInt() { ColumnVector expected1 = ColumnVector.fromInts(4, 3, 8); ColumnVector intColumnVector2 = shortColumnVector.asInts(); ColumnVector expected2 = ColumnVector.fromInts(100)) { - TableTest.assertColumnsAreEqual(expected1, intColumnVector1); - TableTest.assertColumnsAreEqual(expected2, intColumnVector2); + AssertUtils.assertColumnsAreEqual(expected1, intColumnVector1); + AssertUtils.assertColumnsAreEqual(expected2, intColumnVector2); } } diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java index 0889363c2d0..86c340bb321 100644 --- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java +++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java @@ -29,7 +29,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; -import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.*; public class ScalarTest extends CudfTestBase { diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 4512a08430c..f61a9b0f902 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -57,6 +57,11 @@ import java.util.stream.Collectors; import static ai.rapids.cudf.ColumnWriterOptions.mapColumn; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertPartialColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertPartialTablesAreEqual; +import static ai.rapids.cudf.AssertUtils.assertTableTypes; +import static ai.rapids.cudf.AssertUtils.assertTablesAreEqual; import static ai.rapids.cudf.ParquetWriterOptions.listBuilder; import static ai.rapids.cudf.ParquetWriterOptions.structBuilder; import static ai.rapids.cudf.Table.TestBuilder; @@ -94,242 +99,6 @@ public class TableTest extends CudfTestBase { "8|118.2|128\n" + "9|119.8|129").getBytes(StandardCharsets.UTF_8); - /** - * Checks and asserts that passed in columns match - * @param expect The expected result column - * @param cv The input column - */ - public static void assertColumnsAreEqual(ColumnView expect, ColumnView cv) { - assertColumnsAreEqual(expect, cv, "unnamed"); - } - - /** - * Checks and asserts that passed in columns match - * @param expected The expected result column - * @param cv The input column - * @param colName The name of the column - */ - public static void assertColumnsAreEqual(ColumnView expected, ColumnView cv, String colName) { - assertPartialColumnsAreEqual(expected, 0, expected.getRowCount(), cv, colName, true, false); - } - - /** - * Checks and asserts that passed in host columns match - * @param expected The expected result host column - * @param cv The input host column - * @param colName The name of the host column - */ - public static void assertColumnsAreEqual(HostColumnVector expected, HostColumnVector cv, String colName) { - assertPartialColumnsAreEqual(expected, 0, expected.getRowCount(), cv, colName, true, false); - } - - /** - * Checks and asserts that passed in Struct columns match - * @param expected The expected result Struct column - * @param cv The input Struct column - */ - public static void assertStructColumnsAreEqual(ColumnView expected, ColumnView cv) { - assertPartialStructColumnsAreEqual(expected, 0, expected.getRowCount(), cv, "unnamed", true, false); - } - - /** - * Checks and asserts that passed in Struct columns match - * @param expected The expected result Struct column - * @param rowOffset The row number to look from - * @param length The number of rows to consider - * @param cv The input Struct column - * @param colName The name of the column - * @param enableNullCountCheck Whether to check for nulls in the Struct column - * @param enableNullabilityCheck Whether the table have a validity mask - */ - public static void assertPartialStructColumnsAreEqual(ColumnView expected, long rowOffset, long length, - ColumnView cv, String colName, boolean enableNullCountCheck, boolean enableNullabilityCheck) { - try (HostColumnVector hostExpected = expected.copyToHost(); - HostColumnVector hostcv = cv.copyToHost()) { - assertPartialColumnsAreEqual(hostExpected, rowOffset, length, hostcv, colName, enableNullCountCheck, enableNullabilityCheck); - } - } - - /** - * Checks and asserts that passed in columns match - * @param expected The expected result column - * @param cv The input column - * @param colName The name of the column - * @param enableNullCheck Whether to check for nulls in the column - * @param enableNullabilityCheck Whether the table have a validity mask - */ - public static void assertPartialColumnsAreEqual(ColumnView expected, long rowOffset, long length, - ColumnView cv, String colName, boolean enableNullCheck, boolean enableNullabilityCheck) { - try (HostColumnVector hostExpected = expected.copyToHost(); - HostColumnVector hostcv = cv.copyToHost()) { - assertPartialColumnsAreEqual(hostExpected, rowOffset, length, hostcv, colName, enableNullCheck, enableNullabilityCheck); - } - } - - /** - * Checks and asserts that passed in host columns match - * @param expected The expected result host column - * @param rowOffset start row index - * @param length number of rows from starting offset - * @param cv The input host column - * @param colName The name of the host column - * @param enableNullCountCheck Whether to check for nulls in the host column - */ - public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, long rowOffset, long length, - HostColumnVectorCore cv, String colName, boolean enableNullCountCheck, boolean enableNullabilityCheck) { - assertEquals(expected.getType(), cv.getType(), "Type For Column " + colName); - assertEquals(length, cv.getRowCount(), "Row Count For Column " + colName); - assertEquals(expected.getNumChildren(), cv.getNumChildren(), "Child Count for Column " + colName); - if (enableNullCountCheck) { - assertEquals(expected.getNullCount(), cv.getNullCount(), "Null Count For Column " + colName); - } else { - // TODO add in a proper check when null counts are supported by serializing a partitioned column - } - if (enableNullabilityCheck) { - assertEquals(expected.hasValidityVector(), cv.hasValidityVector(), "Column nullability is different than expected"); - } - DType type = expected.getType(); - for (long expectedRow = rowOffset; expectedRow < (rowOffset + length); expectedRow++) { - long tableRow = expectedRow - rowOffset; - assertEquals(expected.isNull(expectedRow), cv.isNull(tableRow), - "NULL for Column " + colName + " Row " + tableRow); - if (!expected.isNull(expectedRow)) { - switch (type.typeId) { - case BOOL8: // fall through - case INT8: // fall through - case UINT8: - assertEquals(expected.getByte(expectedRow), cv.getByte(tableRow), - "Column " + colName + " Row " + tableRow); - break; - case INT16: // fall through - case UINT16: - assertEquals(expected.getShort(expectedRow), cv.getShort(tableRow), - "Column " + colName + " Row " + tableRow); - break; - case INT32: // fall through - case UINT32: // fall through - case TIMESTAMP_DAYS: - case DURATION_DAYS: - case DECIMAL32: - assertEquals(expected.getInt(expectedRow), cv.getInt(tableRow), - "Column " + colName + " Row " + tableRow); - break; - case INT64: // fall through - case UINT64: // fall through - case DURATION_MICROSECONDS: // fall through - case DURATION_MILLISECONDS: // fall through - case DURATION_NANOSECONDS: // fall through - case DURATION_SECONDS: // fall through - case TIMESTAMP_MICROSECONDS: // fall through - case TIMESTAMP_MILLISECONDS: // fall through - case TIMESTAMP_NANOSECONDS: // fall through - case TIMESTAMP_SECONDS: - case DECIMAL64: - assertEquals(expected.getLong(expectedRow), cv.getLong(tableRow), - "Column " + colName + " Row " + tableRow); - break; - case DECIMAL128: - assertEquals(expected.getBigDecimal(expectedRow), cv.getBigDecimal(tableRow), - "Column " + colName + " Row " + tableRow); - break; - case FLOAT32: - assertEqualsWithinPercentage(expected.getFloat(expectedRow), cv.getFloat(tableRow), 0.0001, - "Column " + colName + " Row " + tableRow); - break; - case FLOAT64: - assertEqualsWithinPercentage(expected.getDouble(expectedRow), cv.getDouble(tableRow), 0.0001, - "Column " + colName + " Row " + tableRow); - break; - case STRING: - assertArrayEquals(expected.getUTF8(expectedRow), cv.getUTF8(tableRow), - "Column " + colName + " Row " + tableRow); - break; - case LIST: - HostMemoryBuffer expectedOffsets = expected.getOffsets(); - HostMemoryBuffer cvOffsets = cv.getOffsets(); - int expectedChildRows = expectedOffsets.getInt((expectedRow + 1) * 4) - - expectedOffsets.getInt(expectedRow * 4); - int cvChildRows = cvOffsets.getInt((tableRow + 1) * 4) - - cvOffsets.getInt(tableRow * 4); - assertEquals(expectedChildRows, cvChildRows, "Child row count for Column " + - colName + " Row " + tableRow); - break; - case STRUCT: - // parent column only has validity which was checked above - break; - default: - throw new IllegalArgumentException(type + " is not supported yet"); - } - } - } - - if (type.isNestedType()) { - switch (type.typeId) { - case LIST: - int expectedChildRowOffset = 0; - int numChildRows = 0; - if (length > 0) { - HostMemoryBuffer expectedOffsets = expected.getOffsets(); - HostMemoryBuffer cvOffsets = cv.getOffsets(); - expectedChildRowOffset = expectedOffsets.getInt(rowOffset * 4); - numChildRows = expectedOffsets.getInt((rowOffset + length) * 4) - - expectedChildRowOffset; - } - assertPartialColumnsAreEqual(expected.getNestedChildren().get(0), expectedChildRowOffset, - numChildRows, cv.getNestedChildren().get(0), colName + " list child", - enableNullCountCheck, enableNullabilityCheck); - break; - case STRUCT: - List expectedChildren = expected.getNestedChildren(); - List cvChildren = cv.getNestedChildren(); - for (int i = 0; i < expectedChildren.size(); i++) { - HostColumnVectorCore expectedChild = expectedChildren.get(i); - HostColumnVectorCore cvChild = cvChildren.get(i); - String childName = colName + " child " + i; - assertEquals(length, cvChild.getRowCount(), "Row Count for Column " + colName); - assertPartialColumnsAreEqual(expectedChild, rowOffset, length, cvChild, - colName, enableNullCountCheck, enableNullabilityCheck); - } - break; - default: - throw new IllegalArgumentException(type + " is not supported yet"); - } - } - } - - /** - * Checks and asserts that the two tables from a given rowindex match based on a provided schema. - * @param expected the expected result table - * @param rowOffset the row number to start checking from - * @param length the number of rows to check - * @param table the input table to compare against expected - * @param enableNullCheck whether to check for nulls or not - * @param enableNullabilityCheck whether the table have a validity mask - */ - public static void assertPartialTablesAreEqual(Table expected, long rowOffset, long length, Table table, - boolean enableNullCheck, boolean enableNullabilityCheck) { - assertEquals(expected.getNumberOfColumns(), table.getNumberOfColumns()); - assertEquals(length, table.getRowCount(), "ROW COUNT"); - for (int col = 0; col < expected.getNumberOfColumns(); col++) { - ColumnVector expect = expected.getColumn(col); - ColumnVector cv = table.getColumn(col); - String name = String.valueOf(col); - if (rowOffset != 0 || length != expected.getRowCount()) { - name = name + " PART " + rowOffset + "-" + (rowOffset + length - 1); - } - assertPartialColumnsAreEqual(expect, rowOffset, length, cv, name, enableNullCheck, enableNullabilityCheck); - } - } - - /** - * Checks and asserts that the two tables match - * @param expected the expected result table - * @param table the input table to compare against expected - */ - public static void assertTablesAreEqual(Table expected, Table table) { - assertPartialTablesAreEqual(expected, 0, expected.getRowCount(), table, true, false); - } - void assertTablesHaveSameValues(HashMap[] expectedTable, Table table) { assertEquals(expectedTable.length, table.getNumberOfColumns()); int numCols = table.getNumberOfColumns(); @@ -358,16 +127,6 @@ void assertTablesHaveSameValues(HashMap[] expectedTable, Table } } - public static void assertTableTypes(DType[] expectedTypes, Table t) { - int len = t.getNumberOfColumns(); - assertEquals(expectedTypes.length, len); - for (int i = 0; i < len; i++) { - ColumnVector vec = t.getColumn(i); - DType type = vec.getType(); - assertEquals(expectedTypes[i], type, "Types don't match at " + i); - } - } - @Test void testMergeSimple() { try (Table table1 = new Table.TestBuilder() diff --git a/java/src/test/java/ai/rapids/cudf/TimestampColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/TimestampColumnVectorTest.java index 8bf1370a0f7..9a929cec98d 100644 --- a/java/src/test/java/ai/rapids/cudf/TimestampColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/TimestampColumnVectorTest.java @@ -22,7 +22,7 @@ import java.util.function.Function; -import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.assertEquals; public class TimestampColumnVectorTest extends CudfTestBase { diff --git a/java/src/test/java/ai/rapids/cudf/UnaryOpTest.java b/java/src/test/java/ai/rapids/cudf/UnaryOpTest.java index 76970e8bf76..7fcb7cbd85b 100644 --- a/java/src/test/java/ai/rapids/cudf/UnaryOpTest.java +++ b/java/src/test/java/ai/rapids/cudf/UnaryOpTest.java @@ -22,7 +22,7 @@ import ai.rapids.cudf.HostColumnVector.Builder; import org.junit.jupiter.api.Test; -import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; public class UnaryOpTest extends CudfTestBase { private static final Double[] DOUBLES_1 = new Double[]{1.0, 10.0, -100.1, 5.3, 50.0, 100.0, null, Double.NaN, Double.POSITIVE_INFINITY, 1/9.0, Double.NEGATIVE_INFINITY, 500.0, -500.0}; diff --git a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java index 2fb6792b409..e50da0a4d4d 100644 --- a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java @@ -36,7 +36,7 @@ import java.util.function.Function; import java.util.stream.Stream; -import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; public class CompiledExpressionTest extends CudfTestBase { @Test