diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
index 73261f7cb..4860cd764 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
@@ -176,7 +176,7 @@ public KudoSerializer(Schema schema) {
* @param numRows number of rows to write
* @return number of bytes written
*/
- WriteMetrics writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
+ WriteMetrics writeToStreamWithMetrics(Table table, OutputStream out, int rowOffset, int numRows) {
HostColumnVector[] columns = null;
try {
columns = IntStream.range(0, table.getNumberOfColumns())
@@ -185,7 +185,7 @@ WriteMetrics writeToStream(Table table, OutputStream out, int rowOffset, int num
.toArray(HostColumnVector[]::new);
Cuda.DEFAULT_STREAM.sync();
- return writeToStream(columns, out, rowOffset, numRows);
+ return writeToStreamWithMetrics(columns, out, rowOffset, numRows);
} finally {
if (columns != null) {
for (HostColumnVector column : columns) {
@@ -195,6 +195,16 @@ WriteMetrics writeToStream(Table table, OutputStream out, int rowOffset, int num
}
}
+ /**
+ * Write partition of an array of {@link HostColumnVector} to an output stream.
+ * See {@link #writeToStreamWithMetrics(HostColumnVector[], OutputStream, int, int)} for more
+ * details.
+ * @return number of bytes written
+ */
+ public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
+ return writeToStreamWithMetrics(columns, out, rowOffset, numRows).getWrittenBytes();
+ }
+
/**
* Write partition of an array of {@link HostColumnVector} to an output stream.
*
@@ -208,7 +218,7 @@ WriteMetrics writeToStream(Table table, OutputStream out, int rowOffset, int num
* @param numRows number of rows to write
* @return number of bytes written
*/
- public WriteMetrics writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
+ public WriteMetrics writeToStreamWithMetrics(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
ensure(numRows > 0, () -> "numRows must be > 0, but was " + numRows);
ensure(columns.length > 0, () -> "columns must not be empty, for row count only records " +
"please call writeRowCountToStream");
@@ -291,13 +301,9 @@ private WriteMetrics writeSliced(HostColumnVector[] columns, DataWriter out, int
KudoTableHeaderCalc headerCalc = new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount);
withTime(() -> Visitors.visitColumns(columns, headerCalc), metrics::addCalcHeaderTime);
KudoTableHeader header = headerCalc.getHeader();
- withTime(() -> {
- try {
- header.writeTo(out);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }, metrics::addCopyHeaderTime);
+ long currentTime = System.nanoTime();
+ header.writeTo(out);
+ metrics.addCopyHeaderTime(System.nanoTime() - currentTime);
metrics.addWrittenBytes(header.getSerializedSize());
long bytesWritten = 0;
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
index 48a6eaf73..86d51116b 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
@@ -208,14 +208,12 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th
}
}
- private long copyBufferAndPadForHost(HostMemoryBuffer buffer, long offset, long length) {
- return withTime(() -> {
- try {
- writer.copyDataFrom(buffer, offset, length);
- return padForHostAlignment(writer, length);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }, metrics::addCopyBufferTime);
+ private long copyBufferAndPadForHost(HostMemoryBuffer buffer, long offset, long length)
+ throws IOException {
+ long now = System.nanoTime();
+ writer.copyDataFrom(buffer, offset, length);
+ long ret = padForHostAlignment(writer, length);
+ metrics.addCopyBufferTime(System.nanoTime() - now);
+ return ret;
}
}
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java
index b3b341814..3ffcb5e61 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java
@@ -75,7 +75,7 @@ public void testWriteSimple() throws Exception {
try (Table t = buildSimpleTable()) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
- long bytesWritten = serializer.writeToStream(t, out, 0, 4).getWrittenBytes();
+ long bytesWritten = serializer.writeToStreamWithMetrics(t, out, 0, 4).getWrittenBytes();
assertEquals(189, bytesWritten);
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
@@ -365,7 +365,7 @@ private static void checkMergeTable(Table expected, List tableSlices
ByteArrayOutputStream bout = new ByteArrayOutputStream();
for (TableSlice slice : tableSlices) {
- serializer.writeToStream(slice.getBaseTable(), bout, slice.getStartRow(), slice.getNumRows());
+ serializer.writeToStreamWithMetrics(slice.getBaseTable(), bout, slice.getStartRow(), slice.getNumRows());
}
bout.flush();