From a33eb897c3930cad64e909d012a1129a746479ef Mon Sep 17 00:00:00 2001 From: liurenjie1024 Date: Thu, 14 Nov 2024 13:42:18 +0800 Subject: [PATCH] Fix comments --- .../rapids/jni/kudo/KudoHostMergeResult.java | 13 +++++--- .../spark/rapids/jni/kudo/KudoSerializer.java | 5 ++- .../rapids/jni/kudo/KudoTableMerger.java | 2 +- .../jni/kudo/MultiKudoTableVisitor.java | 31 ++++++++++++++++--- 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java index 1b5710544a..efcc216b53 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java @@ -31,11 +31,11 @@ public class KudoHostMergeResult implements AutoCloseable { private final Schema schema; private final List columnInfoList; - private final HostMemoryBuffer hostBuf; + private HostMemoryBuffer hostBuf; KudoHostMergeResult(Schema schema, HostMemoryBuffer hostBuf, List columnInfoList) { requireNonNull(schema, "schema is null"); - requireNonNull(columnInfoList, "columnOffsets is null"); + requireNonNull(columnInfoList, "columnInfoList is null"); ensure(schema.getFlattenedColumnNames().length == columnInfoList.size(), () -> "Column offsets size does not match flattened schema size, column offsets size: " + columnInfoList.size() + ", flattened schema size: " + schema.getFlattenedColumnNames().length); @@ -46,11 +46,14 @@ public class KudoHostMergeResult implements AutoCloseable { @Override public void close() throws Exception { - if (hostBuf != null) { - hostBuf.close(); - } + hostBuf.close(); + hostBuf = null; } + /** + * Convert the host buffer into a cudf table. + * @return the cudf table + */ public Table toTable() { try (DeviceMemoryBuffer deviceMemBuf = DeviceMemoryBuffer.allocate(hostBuf.getLength())) { if (hostBuf.getLength() > 0) { diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java index f4b0eb4f62..570719a820 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java @@ -277,8 +277,9 @@ public Pair mergeOnHost(List kudoT * @param kudoTables list of kudo tables. This method doesn't take ownership of the input tables, and caller should * take care of closing them after calling this method. * @return the merged table, and metrics during merge. + * @throws Exception if any error occurs during merge. */ - public Pair mergeToTable(List kudoTables) { + public Pair mergeToTable(List kudoTables) throws Exception { Pair result = mergeOnHost(kudoTables); MergeMetrics.Builder builder = MergeMetrics.builder(result.getRight()); try (KudoHostMergeResult children = result.getLeft()) { @@ -286,8 +287,6 @@ public Pair mergeToTable(List kudoTables) { builder::convertToTableTime); return Pair.of(table, builder.build()); - } catch (Exception e) { - throw new RuntimeException(e); } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java index 07681df8ee..af80391f3d 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java @@ -139,7 +139,7 @@ private int deserializeValidityBuffer(ColumnOffsetInfo curColOffset) { startRow += sliceInfo.getRowCount(); } - return toIntExact(nullCountTotal); + return nullCountTotal; } } else { return 0; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java index 98f9ed5c75..afa7ba6ea0 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java @@ -56,6 +56,7 @@ protected MultiKudoTableVisitor(List inputTables) { this.currentOffsetOffsets = new long[tables.size()]; this.currentDataOffset = new long[tables.size()]; this.sliceInfoStack = new Deque[tables.size()]; + long totalRowCount = 0L; for (int i = 0; i < tables.size(); i++) { this.currentValidityOffsets[i] = 0; KudoTableHeader header = tables.get(i).getHeader(); @@ -63,8 +64,8 @@ protected MultiKudoTableVisitor(List inputTables) { this.currentDataOffset[i] = header.getValidityBufferLen() + header.getOffsetBufferLen(); this.sliceInfoStack[i] = new ArrayDeque<>(16); this.sliceInfoStack[i].add(new SliceInfo(header.getOffset(), header.getNumRows())); + totalRowCount += header.getNumRows(); } - long totalRowCount = tables.stream().mapToLong(t -> t.getHeader().getNumRows()).sum(); this.totalRowCountStack = new ArrayDeque<>(16); totalRowCountStack.addLast(toIntExact(totalRowCount)); this.hasNull = true; @@ -88,7 +89,12 @@ public R visitTopSchema(Schema schema, List children) { public T visitStruct(Schema structType, List children) { updateHasNull(); T t = doVisitStruct(structType, children); - updateOffsets(false, false, false, -1); + updateOffsets( + false, // Update offset buffer offset + false, // Update data buffer offset + false, // Update slice info + -1 // element size in bytes, not used for struct + ); currentIdx += 1; return t; } @@ -99,7 +105,12 @@ public T visitStruct(Schema structType, List children) { public P preVisitList(Schema listType) { updateHasNull(); P t = doPreVisitList(listType); - updateOffsets(true, false, true, Integer.BYTES); + updateOffsets( + true, // update offset buffer offset + false, // update data buffer offset + true, // update slice info + Integer.BYTES // element size in bytes + ); currentIdx += 1; return t; } @@ -128,9 +139,19 @@ public T visit(Schema primitiveType) { T t = doVisit(primitiveType); if (primitiveType.getType().hasOffsets()) { - updateOffsets(true, true, false, -1); + updateOffsets( + true, // update offset buffer offset + true, // update data buffer offset + false, // update slice info + -1 // element size in bytes, not used for string + ); } else { - updateOffsets(false, true, false, primitiveType.getType().getSizeInBytes()); + updateOffsets( + false, //update offset buffer offset + true, // update data buffer offset + false, // update slice info + primitiveType.getType().getSizeInBytes() // element size in bytes + ); } currentIdx += 1; return t;