diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index adf0f317340..02675c9c9ca 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -1275,6 +1275,16 @@ public static ColumnVector fromStrings(String... values) { } } + /** + * Create a new string vector from the given values. This API + * supports inline nulls. + */ + public static ColumnVector fromUTF8Strings(byte[]... values) { + try (HostColumnVector host = HostColumnVector.fromUTF8Strings(values)) { + return host.copyToDevice(); + } + } + /** * Create a new vector from the given values. This API supports inline nulls, * but is much slower than building from primitive array of unscaledValues. diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 846bcb3b635..46255428c1c 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -29,6 +29,7 @@ import java.util.Objects; import java.util.Optional; import java.util.StringJoiner; +import java.util.function.BiConsumer; import java.util.function.Consumer; /** @@ -577,6 +578,40 @@ public static HostColumnVector fromStrings(String... values) { }); } + /** + * Create a new string vector from the given values. This API + * supports inline nulls. + */ + public static HostColumnVector fromUTF8Strings(byte[]... values) { + int rows = values.length; + long nullCount = 0; + long bufferSize = 0; + // How many bytes do we need to hold the data. + for (byte[] s: values) { + if (s == null) { + nullCount++; + } else { + bufferSize += s.length; + } + } + + BiConsumer appendUTF8 = nullCount == 0 ? + (b, s) -> b.appendUTF8String(s) : + (b, s) -> { + if (s == null) { + b.appendNull(); + } else { + b.appendUTF8String(s); + } + }; + + return build(rows, bufferSize, (b) -> { + for (byte[] s: values) { + appendUTF8.accept(b, s); + } + }); + } + /** * Create a new vector from the given values. This API supports inline nulls, * but is much slower than building from primitive array of unscaledValues. @@ -1085,9 +1120,11 @@ private void appendChildOrNull(ColumnBuilder childBuilder, Object listElement) { } else if (listElement instanceof BigDecimal) { childBuilder.append((BigDecimal) listElement); } else if (listElement instanceof List) { - childBuilder.append((List) listElement); + childBuilder.append((List) listElement); } else if (listElement instanceof StructData) { childBuilder.append((StructData) listElement); + } else if (listElement instanceof byte[]) { + childBuilder.appendUTF8String((byte[]) listElement); } else { throw new IllegalStateException("Unexpected element type: " + listElement.getClass()); } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index a7733897d10..f07d6a43883 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -29,6 +29,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -176,6 +177,19 @@ void testStringCreation() { } } + @Test + void testUTF8StringCreation() { + try (ColumnVector cv = ColumnVector.fromUTF8Strings( + "d".getBytes(StandardCharsets.UTF_8), + "sd".getBytes(StandardCharsets.UTF_8), + "sde".getBytes(StandardCharsets.UTF_8), + null, + "END".getBytes(StandardCharsets.UTF_8)); + ColumnVector expected = ColumnVector.fromStrings("d", "sd", "sde", null, "END")) { + TableTest.assertColumnsAreEqual(expected, cv); + } + } + @Test void testRefCountLeak() throws InterruptedException { assumeTrue(Boolean.getBoolean("ai.rapids.cudf.flaky-tests-enabled"));