From 1f5bfcc436a603c8cbe6a2a55253d93d895be68f Mon Sep 17 00:00:00 2001 From: Firestarman Date: Fri, 14 May 2021 17:20:30 +0800 Subject: [PATCH 1/3] Create column from UTF8 String Signed-off-by: Firestarman --- .../java/ai/rapids/cudf/ColumnVector.java | 10 +++++ .../java/ai/rapids/cudf/HostColumnVector.java | 39 ++++++++++++++++++- .../java/ai/rapids/cudf/ColumnVectorTest.java | 14 +++++++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index adf0f317340..0b572d5dd3b 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -1275,6 +1275,16 @@ public static ColumnVector fromStrings(String... values) { } } + /** + * Create a new string vector from the given values. This API + * supports inline nulls. + */ + public static ColumnVector fromUTF8StringsBytes(byte[]... values) { + try (HostColumnVector host = HostColumnVector.fromUTF8StringsBytes(values)) { + return host.copyToDevice(); + } + } + /** * Create a new vector from the given values. This API supports inline nulls, * but is much slower than building from primitive array of unscaledValues. diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 846bcb3b635..44370b658c6 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -29,6 +29,7 @@ import java.util.Objects; import java.util.Optional; import java.util.StringJoiner; +import java.util.function.BiConsumer; import java.util.function.Consumer; /** @@ -577,6 +578,40 @@ public static HostColumnVector fromStrings(String... values) { }); } + /** + * Create a new string vector from the given values. This API + * supports inline nulls. + */ + public static HostColumnVector fromUTF8StringsBytes(byte[]... values) { + int rows = values.length; + long nullCount = 0; + long bufferSize = 0; + // How many bytes do we need to hold the data. + for (byte[] s: values) { + if (s == null) { + nullCount++; + } else { + bufferSize += s.length; + } + } + + BiConsumer appendUTF8 = nullCount == 0 ? + (b, s) -> b.appendUTF8String(s) : + (b, s) -> { + if (s == null) { + b.appendNull(); + } else { + b.appendUTF8String(s); + } + }; + + return build(rows, bufferSize, (b) -> { + for (byte[] s: values) { + appendUTF8.accept(b, s); + } + }); + } + /** * Create a new vector from the given values. This API supports inline nulls, * but is much slower than building from primitive array of unscaledValues. @@ -1085,9 +1120,11 @@ private void appendChildOrNull(ColumnBuilder childBuilder, Object listElement) { } else if (listElement instanceof BigDecimal) { childBuilder.append((BigDecimal) listElement); } else if (listElement instanceof List) { - childBuilder.append((List) listElement); + childBuilder.append((List) listElement); } else if (listElement instanceof StructData) { childBuilder.append((StructData) listElement); + } else if (listElement instanceof byte[]) { + childBuilder.appendUTF8String((byte[]) listElement); } else { throw new IllegalStateException("Unexpected element type: " + listElement.getClass()); } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index ba10590fe34..bdd26927832 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -29,6 +29,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -176,6 +177,19 @@ void testStringCreation() { } } + @Test + void testUTF8StringCreation() { + try (ColumnVector cv = ColumnVector.fromUTF8StringsBytes( + "d".getBytes(StandardCharsets.UTF_8), + "sd".getBytes(StandardCharsets.UTF_8), + "sde".getBytes(StandardCharsets.UTF_8), + null, + "END".getBytes(StandardCharsets.UTF_8)); + ColumnVector expected = ColumnVector.fromStrings("d", "sd", "sde", null, "END")) { + TableTest.assertColumnsAreEqual(expected, cv); + } + } + @Test void testRefCountLeak() throws InterruptedException { assumeTrue(Boolean.getBoolean("ai.rapids.cudf.flaky-tests-enabled")); From 01a1bfe7fabe50f14d6186d3b5ff185c93d14dcb Mon Sep 17 00:00:00 2001 From: Firestarman Date: Wed, 19 May 2021 12:21:33 +0800 Subject: [PATCH 2/3] Address comments Signed-off-by: Firestarman --- java/src/main/java/ai/rapids/cudf/ColumnVector.java | 4 ++-- java/src/main/java/ai/rapids/cudf/HostColumnVector.java | 4 ++-- java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 0b572d5dd3b..02675c9c9ca 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -1279,8 +1279,8 @@ public static ColumnVector fromStrings(String... values) { * Create a new string vector from the given values. This API * supports inline nulls. */ - public static ColumnVector fromUTF8StringsBytes(byte[]... values) { - try (HostColumnVector host = HostColumnVector.fromUTF8StringsBytes(values)) { + public static ColumnVector fromUTF8Strings(byte[]... values) { + try (HostColumnVector host = HostColumnVector.fromUTF8Strings(values)) { return host.copyToDevice(); } } diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 44370b658c6..ae9f009af66 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -582,7 +582,7 @@ public static HostColumnVector fromStrings(String... values) { * Create a new string vector from the given values. This API * supports inline nulls. */ - public static HostColumnVector fromUTF8StringsBytes(byte[]... values) { + public static HostColumnVector fromUTF8Strings(byte[]... values) { int rows = values.length; long nullCount = 0; long bufferSize = 0; @@ -1123,7 +1123,7 @@ private void appendChildOrNull(ColumnBuilder childBuilder, Object listElement) { childBuilder.append((List) listElement); } else if (listElement instanceof StructData) { childBuilder.append((StructData) listElement); - } else if (listElement instanceof byte[]) { + } else if (listElement instanceof byte[] && DType.STRING.equals(childBuilder.type)) { childBuilder.appendUTF8String((byte[]) listElement); } else { throw new IllegalStateException("Unexpected element type: " + listElement.getClass()); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 63d5e2dc703..f07d6a43883 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -179,7 +179,7 @@ void testStringCreation() { @Test void testUTF8StringCreation() { - try (ColumnVector cv = ColumnVector.fromUTF8StringsBytes( + try (ColumnVector cv = ColumnVector.fromUTF8Strings( "d".getBytes(StandardCharsets.UTF_8), "sd".getBytes(StandardCharsets.UTF_8), "sde".getBytes(StandardCharsets.UTF_8), From 54e500cd2f5deba2c3b7580cf38a76960bb5b939 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Thu, 20 May 2021 09:35:55 +0800 Subject: [PATCH 3/3] Address comments Signed-off-by: Firestarman --- java/src/main/java/ai/rapids/cudf/HostColumnVector.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index ae9f009af66..46255428c1c 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -1123,7 +1123,7 @@ private void appendChildOrNull(ColumnBuilder childBuilder, Object listElement) { childBuilder.append((List) listElement); } else if (listElement instanceof StructData) { childBuilder.append((StructData) listElement); - } else if (listElement instanceof byte[] && DType.STRING.equals(childBuilder.type)) { + } else if (listElement instanceof byte[]) { childBuilder.appendUTF8String((byte[]) listElement); } else { throw new IllegalStateException("Unexpected element type: " + listElement.getClass());