From b17904dbaa4de1a162fcb4a0f64862f9f83b976f Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Tue, 10 Oct 2023 19:51:02 -0500 Subject: [PATCH] Add in java bindings for DataSource (#14254) This PR adds DataSource Java bindings. It also fixes a small bug in CUDF that made it so the bindings would not work for anything but CSV. Authors: - Robert (Bobby) Evans (https://github.com/revans2) Approvers: - Jason Lowe (https://github.com/jlowe) - Vukasin Milovanovic (https://github.com/vuule) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/14254 --- cpp/src/io/utilities/datasource.cpp | 8 + java/src/main/java/ai/rapids/cudf/Cuda.java | 24 +- .../main/java/ai/rapids/cudf/DataSource.java | 189 ++++++++++++++ .../java/ai/rapids/cudf/DataSourceHelper.java | 44 ++++ .../ai/rapids/cudf/DeviceMemoryBuffer.java | 6 +- .../ai/rapids/cudf/MultiBufferDataSource.java | 230 +++++++++++++++++ .../ai/rapids/cudf/ParquetChunkedReader.java | 45 +++- java/src/main/java/ai/rapids/cudf/Table.java | 99 +++++++- java/src/main/native/CMakeLists.txt | 1 + java/src/main/native/src/ChunkedReaderJni.cpp | 36 ++- java/src/main/native/src/CudfJni.cpp | 8 + .../main/native/src/DataSourceHelperJni.cpp | 237 ++++++++++++++++++ java/src/main/native/src/TableJni.cpp | 212 +++++++++++++++- java/src/main/native/src/cudf_jni_apis.hpp | 8 + .../test/java/ai/rapids/cudf/TableTest.java | 225 +++++++++++++++++ 15 files changed, 1358 insertions(+), 14 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/DataSource.java create mode 100644 java/src/main/java/ai/rapids/cudf/DataSourceHelper.java create mode 100644 java/src/main/java/ai/rapids/cudf/MultiBufferDataSource.java create mode 100644 java/src/main/native/src/DataSourceHelperJni.cpp diff --git a/cpp/src/io/utilities/datasource.cpp b/cpp/src/io/utilities/datasource.cpp index 7a7121aa91d..5cdd92ce3b7 100644 --- a/cpp/src/io/utilities/datasource.cpp +++ b/cpp/src/io/utilities/datasource.cpp @@ -375,6 +375,14 @@ class user_datasource_wrapper : public datasource { return source->device_read(offset, size, stream); } + std::future device_read_async(size_t offset, + size_t size, + uint8_t* dst, + rmm::cuda_stream_view stream) override + { + return source->device_read_async(offset, size, dst, stream); + } + [[nodiscard]] size_t size() const override { return source->size(); } private: diff --git a/java/src/main/java/ai/rapids/cudf/Cuda.java b/java/src/main/java/ai/rapids/cudf/Cuda.java index e1298e29925..7cc3d30a9cf 100755 --- a/java/src/main/java/ai/rapids/cudf/Cuda.java +++ b/java/src/main/java/ai/rapids/cudf/Cuda.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,6 @@ */ package ai.rapids.cudf; -import ai.rapids.cudf.NvtxColor; -import ai.rapids.cudf.NvtxRange; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -90,6 +87,21 @@ private Stream() { this.id = -1; } + private Stream(long id) { + this.cleaner = null; + this.id = id; + } + + /** + * Wrap a given stream ID to make it accessible. + */ + static Stream wrap(long id) { + if (id == -1) { + return DEFAULT_STREAM; + } + return new Stream(id); + } + /** * Have this stream not execute new work until the work recorded in event completes. * @param event the event to wait on. @@ -122,7 +134,9 @@ public synchronized void close() { cleaner.delRef(); } if (closed) { - cleaner.logRefCountDebug("double free " + this); + if (cleaner != null) { + cleaner.logRefCountDebug("double free " + this); + } throw new IllegalStateException("Close called too many times " + this); } if (cleaner != null) { diff --git a/java/src/main/java/ai/rapids/cudf/DataSource.java b/java/src/main/java/ai/rapids/cudf/DataSource.java new file mode 100644 index 00000000000..1e5893235df --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/DataSource.java @@ -0,0 +1,189 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; + +/** + * Base class that can be used to provide data dynamically to CUDF. This follows somewhat + * closely with cudf::io::datasource. There are a few main differences. + *
+ * First this does not expose async device reads. It will call the non-async device read API + * instead. This might be added in the future, but there was no direct use case for it in java + * right now to warrant the added complexity. + *
+ * Second there is no implementation of the device read API that returns a buffer instead of + * writing into one. This is not used by CUDF yet so testing an implementation that isn't used + * didn't feel ideal. If it is needed we will add one in the future. + */ +public abstract class DataSource implements AutoCloseable { + private static final Logger log = LoggerFactory.getLogger(DataSource.class); + + /** + * This is used to keep track of the HostMemoryBuffers in java land so the C++ layer + * does not have to do it. + */ + private final HashMap cachedBuffers = new HashMap<>(); + + @Override + public void close() { + if (!cachedBuffers.isEmpty()) { + throw new IllegalStateException("DataSource closed before all returned host buffers were closed"); + } + } + + /** + * Get the size of the source in bytes. + */ + public abstract long size(); + + /** + * Read data from the source at the given offset. Return a HostMemoryBuffer for the data + * that was read. + * @param offset where to start reading from. + * @param amount the maximum number of bytes to read. + * @return a buffer that points to the data. + * @throws IOException on any error. + */ + public abstract HostMemoryBuffer hostRead(long offset, long amount) throws IOException; + + + /** + * Called when the buffer returned from hostRead is done. The default is to close the buffer. + */ + protected void onHostBufferDone(HostMemoryBuffer buffer) { + if (buffer != null) { + buffer.close(); + } + } + + /** + * Read data from the source at the given offset into dest. Note that dest should not be closed, + * and no reference to it can outlive the call to hostRead. The target amount to read is + * dest's length. + * @param offset the offset to start reading from in the source. + * @param dest where to write the data. + * @return the actual number of bytes written to dest. + */ + public abstract long hostRead(long offset, HostMemoryBuffer dest) throws IOException; + + /** + * Return true if this supports reading directly to the device else false. The default is + * no device support. This cannot change dynamically. It is typically read just once. + */ + public boolean supportsDeviceRead() { + return false; + } + + /** + * Get the size cutoff between device reads and host reads when device reads are supported. + * Anything larger than the cutoff will be a device read and anything smaller will be a + * host read. By default, the cutoff is 0 so all reads will be device reads if device reads + * are supported. + */ + public long getDeviceReadCutoff() { + return 0; + } + + /** + * Read data from the source at the given offset into dest. Note that dest should not be closed, + * and no reference to it can outlive the call to hostRead. The target amount to read is + * dest's length. + * @param offset the offset to start reading from + * @param dest where to write the data. + * @param stream the stream to do the copy on. + * @return the actual number of bytes written to dest. + */ + public long deviceRead(long offset, DeviceMemoryBuffer dest, + Cuda.Stream stream) throws IOException { + throw new IllegalStateException("Device read is not implemented"); + } + + ///////////////////////////////////////////////// + // Internal methods called from JNI + ///////////////////////////////////////////////// + + private static class NoopCleaner extends MemoryBuffer.MemoryBufferCleaner { + @Override + protected boolean cleanImpl(boolean logErrorIfNotClean) { + return true; + } + + @Override + public boolean isClean() { + return true; + } + } + private static final NoopCleaner cleaner = new NoopCleaner(); + + // Called from JNI + private void onHostBufferDone(long bufferId) { + HostMemoryBuffer hmb = cachedBuffers.remove(bufferId); + if (hmb != null) { + onHostBufferDone(hmb); + } else { + // Called from C++ destructor so avoid throwing... + log.warn("Got a close callback for a buffer we could not find " + bufferId); + } + } + + // Called from JNI + private long hostRead(long offset, long amount, long dst) throws IOException { + if (amount < 0) { + throw new IllegalArgumentException("Cannot allocate more than " + Long.MAX_VALUE + " bytes"); + } + try (HostMemoryBuffer dstBuffer = new HostMemoryBuffer(dst, amount, cleaner)) { + return hostRead(offset, dstBuffer); + } + } + + // Called from JNI + private long[] hostReadBuff(long offset, long amount) throws IOException { + if (amount < 0) { + throw new IllegalArgumentException("Cannot read more than " + Long.MAX_VALUE + " bytes"); + } + HostMemoryBuffer buff = hostRead(offset, amount); + long[] ret = new long[3]; + if (buff != null) { + long id = buff.id; + if (cachedBuffers.put(id, buff) != null) { + throw new IllegalStateException("Already had a buffer cached for " + buff); + } + ret[0] = buff.address; + ret[1] = buff.length; + ret[2] = id; + } // else they are all 0 because java does that already + return ret; + } + + // Called from JNI + private long deviceRead(long offset, long amount, long dst, long stream) throws IOException { + if (amount < 0) { + throw new IllegalArgumentException("Cannot read more than " + Long.MAX_VALUE + " bytes"); + } + Cuda.Stream strm = Cuda.Stream.wrap(stream); + try (DeviceMemoryBuffer dstBuffer = new DeviceMemoryBuffer(dst, amount, cleaner)) { + return deviceRead(offset, dstBuffer, strm); + } + } +} diff --git a/java/src/main/java/ai/rapids/cudf/DataSourceHelper.java b/java/src/main/java/ai/rapids/cudf/DataSourceHelper.java new file mode 100644 index 00000000000..5d4dcb8e4e7 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/DataSourceHelper.java @@ -0,0 +1,44 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * This is here because we need some JNI methods to work with a DataSource, but + * we also want to cache callback methods at startup for performance reasons. If + * we put both in the same class we will get a deadlock because of how we load + * the JNI. We have a static block that blocks loading the class until the JNI + * library is loaded and the JNI library cannot load until the class is loaded + * and cached. This breaks the loop. + */ +class DataSourceHelper { + static { + NativeDepsLoader.loadNativeDeps(); + } + + static long createWrapperDataSource(DataSource ds) { + return createWrapperDataSource(ds, ds.size(), ds.supportsDeviceRead(), + ds.getDeviceReadCutoff()); + } + + private static native long createWrapperDataSource(DataSource ds, long size, + boolean deviceReadSupport, + long deviceReadCutoff); + + static native void destroyWrapperDataSource(long handle); +} diff --git a/java/src/main/java/ai/rapids/cudf/DeviceMemoryBuffer.java b/java/src/main/java/ai/rapids/cudf/DeviceMemoryBuffer.java index c4d9bdb8f91..9eab607ed0b 100644 --- a/java/src/main/java/ai/rapids/cudf/DeviceMemoryBuffer.java +++ b/java/src/main/java/ai/rapids/cudf/DeviceMemoryBuffer.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -112,6 +112,10 @@ public static DeviceMemoryBuffer fromRmm(long address, long lengthInBytes, long return new DeviceMemoryBuffer(address, lengthInBytes, rmmBufferAddress); } + DeviceMemoryBuffer(long address, long lengthInBytes, MemoryBufferCleaner cleaner) { + super(address, lengthInBytes, cleaner); + } + DeviceMemoryBuffer(long address, long lengthInBytes, long rmmBufferAddress) { super(address, lengthInBytes, new RmmDeviceBufferCleaner(rmmBufferAddress)); } diff --git a/java/src/main/java/ai/rapids/cudf/MultiBufferDataSource.java b/java/src/main/java/ai/rapids/cudf/MultiBufferDataSource.java new file mode 100644 index 00000000000..6986b6a7fec --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/MultiBufferDataSource.java @@ -0,0 +1,230 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * This is a DataSource that can take multiple HostMemoryBuffers. They + * are treated as if they are all part of a single file connected end to end. + */ +public class MultiBufferDataSource extends DataSource { + private final long sizeInBytes; + private final HostMemoryBuffer[] hostBuffers; + private final long[] startOffsets; + private final HostMemoryAllocator allocator; + + // Metrics + private long hostReads = 0; + private long hostReadBytes = 0; + private long devReads = 0; + private long devReadBytes = 0; + + /** + * Create a new data source backed by multiple buffers. + * @param buffers the buffers that will back the data source. + */ + public MultiBufferDataSource(HostMemoryBuffer ... buffers) { + this(DefaultHostMemoryAllocator.get(), buffers); + } + + /** + * Create a new data source backed by multiple buffers. + * @param allocator the allocator to use for host buffers, if needed. + * @param buffers the buffers that will back the data source. + */ + public MultiBufferDataSource(HostMemoryAllocator allocator, HostMemoryBuffer ... buffers) { + int numBuffers = buffers.length; + hostBuffers = new HostMemoryBuffer[numBuffers]; + startOffsets = new long[numBuffers]; + + long currentOffset = 0; + for (int i = 0; i < numBuffers; i++) { + HostMemoryBuffer hmb = buffers[i]; + hmb.incRefCount(); + hostBuffers[i] = hmb; + startOffsets[i] = currentOffset; + currentOffset += hmb.getLength(); + } + sizeInBytes = currentOffset; + this.allocator = allocator; + } + + @Override + public long size() { + return sizeInBytes; + } + + private int getStartBufferIndexForOffset(long offset) { + assert (offset >= 0); + + // It is super common to read from the start or end of a file (the header or footer) + // so special case them + if (offset == 0) { + return 0; + } + int startIndex = 0; + int endIndex = startOffsets.length - 1; + if (offset >= startOffsets[endIndex]) { + return endIndex; + } + while (startIndex != endIndex) { + int midIndex = (int)(((long)startIndex + endIndex) / 2); + long midStartOffset = startOffsets[midIndex]; + if (offset >= midStartOffset) { + // It is either in mid or after mid. + if (midIndex == endIndex || offset <= startOffsets[midIndex + 1]) { + // We found it in mid + return midIndex; + } else { + // It is after mid + startIndex = midIndex + 1; + } + } else { + // It is before mid + endIndex = midIndex - 1; + } + } + return startIndex; + } + + + interface DoCopy { + void copyFromHostBuffer(T dest, long destOffset, HostMemoryBuffer src, + long srcOffset, long srcAmount); + } + + private long read(long offset, T dest, DoCopy doCopy) { + assert (offset >= 0); + long realOffset = Math.min(offset, sizeInBytes); + long realAmount = Math.min(sizeInBytes - realOffset, dest.getLength()); + + int index = getStartBufferIndexForOffset(realOffset); + + HostMemoryBuffer buffer = hostBuffers[index]; + long bufferOffset = realOffset - startOffsets[index]; + long bufferAmount = Math.min(buffer.length - bufferOffset, realAmount); + long remainingAmount = realAmount; + long currentOffset = realOffset; + long outputOffset = 0; + + while (remainingAmount > 0) { + doCopy.copyFromHostBuffer(dest, outputOffset, buffer, + bufferOffset, bufferAmount); + remainingAmount -= bufferAmount; + outputOffset += bufferAmount; + currentOffset += bufferAmount; + index++; + if (index < hostBuffers.length) { + buffer = hostBuffers[index]; + bufferOffset = currentOffset - startOffsets[index]; + bufferAmount = Math.min(buffer.length - bufferOffset, remainingAmount); + } + } + + return realAmount; + } + + @Override + public HostMemoryBuffer hostRead(long offset, long amount) { + assert (offset >= 0); + assert (amount >= 0); + long realOffset = Math.min(offset, sizeInBytes); + long realAmount = Math.min(sizeInBytes - realOffset, amount); + + int index = getStartBufferIndexForOffset(realOffset); + + HostMemoryBuffer buffer = hostBuffers[index]; + long bufferOffset = realOffset - startOffsets[index]; + long bufferAmount = Math.min(buffer.length - bufferOffset, realAmount); + if (bufferAmount == realAmount) { + hostReads += 1; + hostReadBytes += realAmount; + // It all fits in a single buffer, so do a zero copy operation + return buffer.slice(bufferOffset, bufferAmount); + } else { + // We will have to allocate a new buffer and copy data into it. + boolean success = false; + HostMemoryBuffer ret = allocator.allocate(realAmount, true); + try { + long amountRead = read(offset, ret, HostMemoryBuffer::copyFromHostBuffer); + assert(amountRead == realAmount); + hostReads += 1; + hostReadBytes += amountRead; + success = true; + return ret; + } finally { + if (!success) { + ret.close(); + } + } + } + } + + @Override + public long hostRead(long offset, HostMemoryBuffer dest) { + long ret = read(offset, dest, HostMemoryBuffer::copyFromHostBuffer); + hostReads += 1; + hostReadBytes += ret; + return ret; + } + + @Override + public boolean supportsDeviceRead() { + return true; + } + + @Override + public long deviceRead(long offset, DeviceMemoryBuffer dest, + Cuda.Stream stream) { + long ret = read(offset, dest, (destParam, destOffset, src, srcOffset, srcAmount) -> + destParam.copyFromHostBufferAsync(destOffset, src, srcOffset, srcAmount, stream)); + devReads += 1; + devReadBytes += ret; + return ret; + } + + + @Override + public void close() { + try { + super.close(); + } finally { + for (HostMemoryBuffer hmb: hostBuffers) { + if (hmb != null) { + hmb.close(); + } + } + } + } + + public long getHostReads() { + return hostReads; + } + + public long getHostReadBytes() { + return hostReadBytes; + } + + public long getDevReads() { + return devReads; + } + + public long getDevReadBytes() { + return devReadBytes; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ParquetChunkedReader.java b/java/src/main/java/ai/rapids/cudf/ParquetChunkedReader.java index c34336ac73f..17d59b757c3 100644 --- a/java/src/main/java/ai/rapids/cudf/ParquetChunkedReader.java +++ b/java/src/main/java/ai/rapids/cudf/ParquetChunkedReader.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ public ParquetChunkedReader(long chunkSizeByteLimit, ParquetOptions opts, File f handle = create(chunkSizeByteLimit, opts.getIncludeColumnNames(), opts.getReadBinaryAsString(), filePath.getAbsolutePath(), 0, 0, opts.timeUnit().typeId.getNativeId()); - if(handle == 0) { + if (handle == 0) { throw new IllegalStateException("Cannot create native chunked Parquet reader object."); } } @@ -71,18 +71,45 @@ public ParquetChunkedReader(long chunkSizeByteLimit, ParquetOptions opts, HostMe handle = create(chunkSizeByteLimit, opts.getIncludeColumnNames(), opts.getReadBinaryAsString(), null, buffer.getAddress() + offset, len, opts.timeUnit().typeId.getNativeId()); - if(handle == 0) { + if (handle == 0) { throw new IllegalStateException("Cannot create native chunked Parquet reader object."); } } + /** + * Construct a reader instance from a DataSource + * @param chunkSizeByteLimit Limit on total number of bytes to be returned per read, + * or 0 if there is no limit. + * @param opts The options for Parquet reading. + * @param ds the data source to read from + */ + public ParquetChunkedReader(long chunkSizeByteLimit, ParquetOptions opts, DataSource ds) { + dataSourceHandle = DataSourceHelper.createWrapperDataSource(ds); + if (dataSourceHandle == 0) { + throw new IllegalStateException("Cannot create native datasource object"); + } + + boolean passed = false; + try { + handle = createWithDataSource(chunkSizeByteLimit, opts.getIncludeColumnNames(), + opts.getReadBinaryAsString(), opts.timeUnit().typeId.getNativeId(), + dataSourceHandle); + passed = true; + } finally { + if (!passed) { + DataSourceHelper.destroyWrapperDataSource(dataSourceHandle); + dataSourceHandle = 0; + } + } + } + /** * Check if the given file has anything left to read. * * @return A boolean value indicating if there is more data to read from file. */ public boolean hasNext() { - if(handle == 0) { + if (handle == 0) { throw new IllegalStateException("Native chunked Parquet reader object may have been closed."); } @@ -104,7 +131,7 @@ public boolean hasNext() { * @return A table of new rows reading from the given file. */ public Table readChunk() { - if(handle == 0) { + if (handle == 0) { throw new IllegalStateException("Native chunked Parquet reader object may have been closed."); } @@ -118,6 +145,10 @@ public void close() { close(handle); handle = 0; } + if (dataSourceHandle != 0) { + DataSourceHelper.destroyWrapperDataSource(dataSourceHandle); + dataSourceHandle = 0; + } } @@ -131,6 +162,7 @@ public void close() { */ private long handle; + private long dataSourceHandle = 0; /** * Create a native chunked Parquet reader object on heap and return its memory address. @@ -147,6 +179,9 @@ public void close() { private static native long create(long chunkSizeByteLimit, String[] filterColumnNames, boolean[] binaryToString, String filePath, long bufferAddrs, long length, int timeUnit); + private static native long createWithDataSource(long chunkedSizeByteLimit, + String[] filterColumnNames, boolean[] binaryToString, int timeUnit, long dataSourceHandle); + private static native boolean hasNext(long handle); private static native long[] readChunk(long handle); diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 51a33ebb72f..3bd1e3f25a7 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -235,6 +235,14 @@ private static native long[] readCSV(String[] columnNames, byte comment, String[] nullValues, String[] trueValues, String[] falseValues) throws CudfException; + private static native long[] readCSVFromDataSource(String[] columnNames, + int[] dTypeIds, int[] dTypeScales, + String[] filterColumnNames, + int headerRow, byte delim, int quoteStyle, byte quote, + byte comment, String[] nullValues, + String[] trueValues, String[] falseValues, + long dataSourceHandle) throws CudfException; + /** * read JSON data and return a pointer to a TableWithMeta object. */ @@ -244,6 +252,12 @@ private static native long readJSON(String[] columnNames, boolean dayFirst, boolean lines, boolean recoverWithNulls) throws CudfException; + private static native long readJSONFromDataSource(String[] columnNames, + int[] dTypeIds, int[] dTypeScales, + boolean dayFirst, boolean lines, + boolean recoverWithNulls, + long dsHandle) throws CudfException; + private static native long readAndInferJSON(long address, long length, boolean dayFirst, boolean lines, boolean recoverWithNulls) throws CudfException; @@ -260,6 +274,10 @@ private static native long readAndInferJSON(long address, long length, private static native long[] readParquet(String[] filterColumnNames, boolean[] binaryToString, String filePath, long address, long length, int timeUnit) throws CudfException; + private static native long[] readParquetFromDataSource(String[] filterColumnNames, + boolean[] binaryToString, int timeUnit, + long dataSourceHandle) throws CudfException; + /** * Read in Avro formatted data. * @param filterColumnNames name of the columns to read, or an empty array if we want to read @@ -271,6 +289,9 @@ private static native long[] readParquet(String[] filterColumnNames, boolean[] b private static native long[] readAvro(String[] filterColumnNames, String filePath, long address, long length) throws CudfException; + private static native long[] readAvroFromDataSource(String[] filterColumnNames, + long dataSourceHandle) throws CudfException; + /** * Setup everything to write parquet formatted data to a file. * @param columnNames names that correspond to the table columns @@ -372,6 +393,11 @@ private static native long[] readORC(String[] filterColumnNames, boolean usingNumPyTypes, int timeUnit, String[] decimal128Columns) throws CudfException; + private static native long[] readORCFromDataSource(String[] filterColumnNames, + boolean usingNumPyTypes, int timeUnit, + String[] decimal128Columns, + long dataSourceHandle) throws CudfException; + /** * Setup everything to write ORC formatted data to a file. * @param columnNames names that correspond to the table columns @@ -881,6 +907,27 @@ public static Table readCSV(Schema schema, CSVOptions opts, HostMemoryBuffer buf opts.getFalseValues())); } + public static Table readCSV(Schema schema, CSVOptions opts, DataSource ds) { + long dsHandle = DataSourceHelper.createWrapperDataSource(ds); + try { + return new Table(readCSVFromDataSource(schema.getColumnNames(), + schema.getTypeIds(), + schema.getTypeScales(), + opts.getIncludeColumnNames(), + opts.getHeaderRow(), + opts.getDelim(), + opts.getQuoteStyle().nativeId, + opts.getQuote(), + opts.getComment(), + opts.getNullValues(), + opts.getTrueValues(), + opts.getFalseValues(), + dsHandle)); + } finally { + DataSourceHelper.destroyWrapperDataSource(dsHandle); + } + } + private static native void writeCSVToFile(long table, String[] columnNames, boolean includeHeader, @@ -1128,6 +1175,24 @@ public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer b } } + /** + * Read JSON formatted data. + * @param schema the schema of the data. You may use Schema.INFERRED to infer the schema. + * @param opts various JSON parsing options. + * @param ds the DataSource to read from. + * @return the data parsed as a table on the GPU. + */ + public static Table readJSON(Schema schema, JSONOptions opts, DataSource ds) { + long dsHandle = DataSourceHelper.createWrapperDataSource(ds); + try (TableWithMeta twm = new TableWithMeta(readJSONFromDataSource(schema.getColumnNames(), + schema.getTypeIds(), schema.getTypeScales(), opts.isDayFirst(), opts.isLines(), + opts.isRecoverWithNull(), dsHandle))) { + return gatherJSONColumns(schema, twm); + } finally { + DataSourceHelper.destroyWrapperDataSource(dsHandle); + } + } + /** * Read a Parquet file using the default ParquetOptions. * @param path the local file to read. @@ -1214,6 +1279,17 @@ public static Table readParquet(ParquetOptions opts, HostMemoryBuffer buffer, null, buffer.getAddress() + offset, len, opts.timeUnit().typeId.getNativeId())); } + public static Table readParquet(ParquetOptions opts, DataSource ds) { + long dataSourceHandle = DataSourceHelper.createWrapperDataSource(ds); + try { + return new Table(readParquetFromDataSource(opts.getIncludeColumnNames(), + opts.getReadBinaryAsString(), opts.timeUnit().typeId.getNativeId(), + dataSourceHandle)); + } finally { + DataSourceHelper.destroyWrapperDataSource(dataSourceHandle); + } + } + /** * Read an Avro file using the default AvroOptions. * @param path the local file to read. @@ -1297,6 +1373,16 @@ public static Table readAvro(AvroOptions opts, HostMemoryBuffer buffer, null, buffer.getAddress() + offset, len)); } + public static Table readAvro(AvroOptions opts, DataSource ds) { + long dataSourceHandle = DataSourceHelper.createWrapperDataSource(ds); + try { + return new Table(readAvroFromDataSource(opts.getIncludeColumnNames(), + dataSourceHandle)); + } finally { + DataSourceHelper.destroyWrapperDataSource(dataSourceHandle); + } + } + /** * Read a ORC file using the default ORCOptions. * @param path the local file to read. @@ -1388,6 +1474,17 @@ public static Table readORC(ORCOptions opts, HostMemoryBuffer buffer, opts.getDecimal128Columns())); } + public static Table readORC(ORCOptions opts, DataSource ds) { + long dataSourceHandle = DataSourceHelper.createWrapperDataSource(ds); + try { + return new Table(readORCFromDataSource(opts.getIncludeColumnNames(), + opts.usingNumPyTypes(), opts.timeUnit().typeId.getNativeId(), + opts.getDecimal128Columns(), dataSourceHandle)); + } finally { + DataSourceHelper.destroyWrapperDataSource(dataSourceHandle); + } + } + private static class ParquetTableWriter extends TableWriter { HostBufferConsumer consumer; @@ -2262,7 +2359,7 @@ public Table dropDuplicates(int[] keyColumns, DuplicateKeepOption keep, boolean /** * Count how many rows in the table are distinct from one another. - * @param nullEqual if nulls should be considered equal to each other or not. + * @param nullsEqual if nulls should be considered equal to each other or not. */ public int distinctCount(NullEquality nullsEqual) { return distinctCount(nativeHandle, nullsEqual.nullsEqual); diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index 0dcfee2cffe..01161a03dd4 100644 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -135,6 +135,7 @@ add_library( src/ColumnViewJni.cu src/CompiledExpression.cpp src/ContiguousTableJni.cpp + src/DataSourceHelperJni.cpp src/HashJoinJni.cpp src/HostMemoryBufferNativeUtilsJni.cpp src/NvcompJni.cpp diff --git a/java/src/main/native/src/ChunkedReaderJni.cpp b/java/src/main/native/src/ChunkedReaderJni.cpp index 8d0a8bdbfe7..0044385f267 100644 --- a/java/src/main/native/src/ChunkedReaderJni.cpp +++ b/java/src/main/native/src/ChunkedReaderJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -85,6 +85,40 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ParquetChunkedReader_create( CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ParquetChunkedReader_createWithDataSource( + JNIEnv *env, jclass, jlong chunk_read_limit, jobjectArray filter_col_names, + jbooleanArray j_col_binary_read, jint unit, jlong ds_handle) { + JNI_NULL_CHECK(env, j_col_binary_read, "Null col_binary_read", 0); + JNI_NULL_CHECK(env, ds_handle, "Null DataSouurce", 0); + + try { + cudf::jni::auto_set_device(env); + + cudf::jni::native_jstringArray n_filter_col_names(env, filter_col_names); + + // TODO: This variable is unused now, but we still don't know what to do with it yet. + // As such, it needs to stay here for a little more time before we decide to use it again, + // or remove it completely. + cudf::jni::native_jbooleanArray n_col_binary_read(env, j_col_binary_read); + (void)n_col_binary_read; + + auto ds = reinterpret_cast(ds_handle); + cudf::io::source_info source{ds}; + + auto opts_builder = cudf::io::parquet_reader_options::builder(source); + if (n_filter_col_names.size() > 0) { + opts_builder = opts_builder.columns(n_filter_col_names.as_cpp_vector()); + } + auto const read_opts = opts_builder.convert_strings_to_categories(false) + .timestamp_type(cudf::data_type(static_cast(unit))) + .build(); + + return reinterpret_cast(new cudf::io::chunked_parquet_reader( + static_cast(chunk_read_limit), read_opts)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jboolean JNICALL Java_ai_rapids_cudf_ParquetChunkedReader_hasNext(JNIEnv *env, jclass, jlong handle) { JNI_NULL_CHECK(env, handle, "handle is null", false); diff --git a/java/src/main/native/src/CudfJni.cpp b/java/src/main/native/src/CudfJni.cpp index 0f143086451..d0a25d449a6 100644 --- a/java/src/main/native/src/CudfJni.cpp +++ b/java/src/main/native/src/CudfJni.cpp @@ -175,6 +175,14 @@ JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) { return JNI_ERR; } + if (!cudf::jni::cache_data_source_jni(env)) { + if (!env->ExceptionCheck()) { + env->ThrowNew(env->FindClass("java/lang/RuntimeException"), + "Unable to locate data source helper methods needed by JNI"); + } + return JNI_ERR; + } + return cudf::jni::MINIMUM_JNI_VERSION; } diff --git a/java/src/main/native/src/DataSourceHelperJni.cpp b/java/src/main/native/src/DataSourceHelperJni.cpp new file mode 100644 index 00000000000..8d0e4d36413 --- /dev/null +++ b/java/src/main/native/src/DataSourceHelperJni.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "cudf_jni_apis.hpp" +#include "jni_utils.hpp" + +namespace { + +#define DATA_SOURCE_CLASS "ai/rapids/cudf/DataSource" + +jclass DataSource_jclass; +jmethodID hostRead_method; +jmethodID hostReadBuff_method; +jmethodID onHostBufferDone_method; +jmethodID deviceRead_method; + +} // anonymous namespace + +namespace cudf { +namespace jni { +bool cache_data_source_jni(JNIEnv *env) { + jclass cls = env->FindClass(DATA_SOURCE_CLASS); + if (cls == nullptr) { + return false; + } + + hostRead_method = env->GetMethodID(cls, "hostRead", "(JJJ)J"); + if (hostRead_method == nullptr) { + return false; + } + + hostReadBuff_method = env->GetMethodID(cls, "hostReadBuff", "(JJ)[J"); + if (hostReadBuff_method == nullptr) { + return false; + } + + onHostBufferDone_method = env->GetMethodID(cls, "onHostBufferDone", "(J)V"); + if (onHostBufferDone_method == nullptr) { + return false; + } + + deviceRead_method = env->GetMethodID(cls, "deviceRead", "(JJJJ)J"); + if (deviceRead_method == nullptr) { + return false; + } + + // Convert local reference to global so it cannot be garbage collected. + DataSource_jclass = static_cast(env->NewGlobalRef(cls)); + if (DataSource_jclass == nullptr) { + return false; + } + return true; +} + +void release_data_source_jni(JNIEnv *env) { + DataSource_jclass = cudf::jni::del_global_ref(env, DataSource_jclass); +} + +class host_buffer_done_callback { +public: + explicit host_buffer_done_callback(JavaVM *jvm, jobject ds, long id) : jvm(jvm), ds(ds), id(id) {} + + host_buffer_done_callback(host_buffer_done_callback const &other) = delete; + host_buffer_done_callback(host_buffer_done_callback &&other) + : jvm(other.jvm), ds(other.ds), id(other.id) { + other.jvm = nullptr; + other.ds = nullptr; + other.id = -1; + } + + host_buffer_done_callback &operator=(host_buffer_done_callback &&other) = delete; + host_buffer_done_callback &operator=(host_buffer_done_callback const &other) = delete; + + ~host_buffer_done_callback() { + // because we are in a destructor we cannot throw an exception, so for now we are + // just going to keep the java exceptions around and have them be thrown when this + // thread returns to the JVM. It might be kind of confusing, but we will not lose + // them. + if (jvm != nullptr) { + // We cannot throw an exception in the destructor, so this is really best effort + JNIEnv *env = nullptr; + if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) { + env->CallVoidMethod(this->ds, onHostBufferDone_method, id); + } + } + } + +private: + JavaVM *jvm; + jobject ds; + long id; +}; + +class jni_datasource : public cudf::io::datasource { +public: + explicit jni_datasource(JNIEnv *env, jobject ds, size_t ds_size, bool device_read_supported, + size_t device_read_cutoff) + : ds_size(ds_size), device_read_supported(device_read_supported), + device_read_cutoff(device_read_cutoff) { + if (env->GetJavaVM(&jvm) < 0) { + throw std::runtime_error("GetJavaVM failed"); + } + this->ds = add_global_ref(env, ds); + } + + virtual ~jni_datasource() { + JNIEnv *env = nullptr; + if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) { + ds = del_global_ref(env, ds); + } + ds = nullptr; + } + + std::unique_ptr host_read(size_t offset, size_t size) override { + JNIEnv *env = nullptr; + if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) != JNI_OK) { + throw cudf::jni::jni_exception("Could not load JNIEnv"); + } + + jlongArray jbuffer_info = + static_cast(env->CallObjectMethod(this->ds, hostReadBuff_method, offset, size)); + if (env->ExceptionOccurred()) { + throw cudf::jni::jni_exception("Java exception in hostRead"); + } + + cudf::jni::native_jlongArray buffer_info(env, jbuffer_info); + auto ptr = reinterpret_cast(buffer_info[0]); + size_t length = buffer_info[1]; + long id = buffer_info[2]; + + cudf::jni::host_buffer_done_callback cb(this->jvm, this->ds, id); + return std::make_unique>(std::move(cb), ptr, + length); + } + + size_t host_read(size_t offset, size_t size, uint8_t *dst) override { + JNIEnv *env = nullptr; + if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) != JNI_OK) { + throw cudf::jni::jni_exception("Could not load JNIEnv"); + } + + jlong amount_read = + env->CallLongMethod(this->ds, hostRead_method, offset, size, reinterpret_cast(dst)); + if (env->ExceptionOccurred()) { + throw cudf::jni::jni_exception("Java exception in hostRead"); + } + return amount_read; + } + + size_t size() const override { return ds_size; } + + bool supports_device_read() const override { return device_read_supported; } + + bool is_device_read_preferred(size_t size) const override { + return device_read_supported && size >= device_read_cutoff; + } + + size_t device_read(size_t offset, size_t size, uint8_t *dst, + rmm::cuda_stream_view stream) override { + JNIEnv *env = nullptr; + if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) != JNI_OK) { + throw cudf::jni::jni_exception("Could not load JNIEnv"); + } + + jlong amount_read = + env->CallLongMethod(this->ds, deviceRead_method, offset, size, reinterpret_cast(dst), + reinterpret_cast(stream.value())); + if (env->ExceptionOccurred()) { + throw cudf::jni::jni_exception("Java exception in deviceRead"); + } + return amount_read; + } + + std::future device_read_async(size_t offset, size_t size, uint8_t *dst, + rmm::cuda_stream_view stream) override { + auto amount_read = device_read(offset, size, dst, stream); + // This is a bit ugly, but we don't have a good way or a need to return + // a future for the read + std::promise ret; + ret.set_value(amount_read); + return ret.get_future(); + } + +private: + size_t ds_size; + bool device_read_supported; + size_t device_read_cutoff; + JavaVM *jvm; + jobject ds; +}; +} // namespace jni +} // namespace cudf + +extern "C" { + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_DataSourceHelper_createWrapperDataSource( + JNIEnv *env, jclass, jobject ds, jlong ds_size, jboolean device_read_supported, + jlong device_read_cutoff) { + JNI_NULL_CHECK(env, ds, "Null data source", 0); + try { + cudf::jni::auto_set_device(env); + auto source = + new cudf::jni::jni_datasource(env, ds, ds_size, device_read_supported, device_read_cutoff); + return reinterpret_cast(source); + } + CATCH_STD(env, 0); +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_DataSourceHelper_destroyWrapperDataSource(JNIEnv *env, + jclass, + jlong handle) { + try { + cudf::jni::auto_set_device(env); + if (handle != 0) { + auto source = reinterpret_cast(handle); + delete (source); + } + } + CATCH_STD(env, ); +} + +} // extern "C" diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index b208ef8f381..fad19bdf895 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -1135,6 +1135,67 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_merge(JNIEnv *env, jclass CATCH_STD(env, NULL); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSVFromDataSource( + JNIEnv *env, jclass, jobjectArray col_names, jintArray j_types, jintArray j_scales, + jobjectArray filter_col_names, jint header_row, jbyte delim, jint j_quote_style, jbyte quote, + jbyte comment, jobjectArray null_values, jobjectArray true_values, jobjectArray false_values, + jlong ds_handle) { + JNI_NULL_CHECK(env, null_values, "null_values must be supplied, even if it is empty", NULL); + JNI_NULL_CHECK(env, ds_handle, "no data source handle given", NULL); + + try { + cudf::jni::auto_set_device(env); + cudf::jni::native_jstringArray n_col_names(env, col_names); + cudf::jni::native_jintArray n_types(env, j_types); + cudf::jni::native_jintArray n_scales(env, j_scales); + if (n_types.is_null() != n_scales.is_null()) { + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match null", + NULL); + } + std::vector data_types; + if (!n_types.is_null()) { + if (n_types.size() != n_scales.size()) { + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match size", + NULL); + } + data_types.reserve(n_types.size()); + std::transform(n_types.begin(), n_types.end(), n_scales.begin(), + std::back_inserter(data_types), [](auto type, auto scale) { + return cudf::data_type{static_cast(type), scale}; + }); + } + + cudf::jni::native_jstringArray n_null_values(env, null_values); + cudf::jni::native_jstringArray n_true_values(env, true_values); + cudf::jni::native_jstringArray n_false_values(env, false_values); + cudf::jni::native_jstringArray n_filter_col_names(env, filter_col_names); + + auto ds = reinterpret_cast(ds_handle); + cudf::io::source_info source{ds}; + + auto const quote_style = static_cast(j_quote_style); + + cudf::io::csv_reader_options opts = cudf::io::csv_reader_options::builder(source) + .delimiter(delim) + .header(header_row) + .names(n_col_names.as_cpp_vector()) + .dtypes(data_types) + .use_cols_names(n_filter_col_names.as_cpp_vector()) + .true_values(n_true_values.as_cpp_vector()) + .false_values(n_false_values.as_cpp_vector()) + .na_values(n_null_values.as_cpp_vector()) + .keep_default_na(false) + .na_filter(n_null_values.size() > 0) + .quoting(quote_style) + .quotechar(quote) + .comment(comment) + .build(); + + return convert_table_for_return(env, cudf::io::read_csv(opts).tbl); + } + CATCH_STD(env, NULL); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( JNIEnv *env, jclass, jobjectArray col_names, jintArray j_types, jintArray j_scales, jobjectArray filter_col_names, jstring inputfilepath, jlong buffer, jlong buffer_length, @@ -1407,6 +1468,72 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_TableWithMeta_releaseTable(JNIE CATCH_STD(env, nullptr); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSONFromDataSource( + JNIEnv *env, jclass, jobjectArray col_names, jintArray j_types, jintArray j_scales, + jboolean day_first, jboolean lines, jboolean recover_with_null, jlong ds_handle) { + + JNI_NULL_CHECK(env, ds_handle, "no data source handle given", 0); + + try { + cudf::jni::auto_set_device(env); + cudf::jni::native_jstringArray n_col_names(env, col_names); + cudf::jni::native_jintArray n_types(env, j_types); + cudf::jni::native_jintArray n_scales(env, j_scales); + if (n_types.is_null() != n_scales.is_null()) { + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match null", + 0); + } + std::vector data_types; + if (!n_types.is_null()) { + if (n_types.size() != n_scales.size()) { + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match size", + 0); + } + data_types.reserve(n_types.size()); + std::transform(n_types.begin(), n_types.end(), n_scales.begin(), + std::back_inserter(data_types), [](auto const &type, auto const &scale) { + return cudf::data_type{static_cast(type), scale}; + }); + } + + auto ds = reinterpret_cast(ds_handle); + cudf::io::source_info source{ds}; + + cudf::io::json_recovery_mode_t recovery_mode = + recover_with_null ? cudf::io::json_recovery_mode_t::RECOVER_WITH_NULL : + cudf::io::json_recovery_mode_t::FAIL; + cudf::io::json_reader_options_builder opts = cudf::io::json_reader_options::builder(source) + .dayfirst(static_cast(day_first)) + .lines(static_cast(lines)) + .recovery_mode(recovery_mode); + + if (!n_col_names.is_null() && data_types.size() > 0) { + if (n_col_names.size() != n_types.size()) { + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + "types and column names must match size", 0); + } + + std::map map; + + auto col_names_vec = n_col_names.as_cpp_vector(); + std::transform(col_names_vec.begin(), col_names_vec.end(), data_types.begin(), + std::inserter(map, map.end()), + [](std::string a, cudf::data_type b) { return std::make_pair(a, b); }); + opts.dtypes(map); + } else if (data_types.size() > 0) { + opts.dtypes(data_types); + } else { + // should infer the types + } + + auto result = + std::make_unique(cudf::io::read_json(opts.build())); + + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSON( JNIEnv *env, jclass, jobjectArray col_names, jintArray j_types, jintArray j_scales, jstring inputfilepath, jlong buffer, jlong buffer_length, jboolean day_first, jboolean lines, @@ -1489,6 +1616,36 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSON( CATCH_STD(env, 0); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquetFromDataSource( + JNIEnv *env, jclass, jobjectArray filter_col_names, jbooleanArray j_col_binary_read, jint unit, + jlong ds_handle) { + + JNI_NULL_CHECK(env, ds_handle, "no data source handle given", 0); + JNI_NULL_CHECK(env, j_col_binary_read, "null col_binary_read", 0); + + try { + cudf::jni::auto_set_device(env); + + cudf::jni::native_jstringArray n_filter_col_names(env, filter_col_names); + cudf::jni::native_jbooleanArray n_col_binary_read(env, j_col_binary_read); + + auto ds = reinterpret_cast(ds_handle); + cudf::io::source_info source{ds}; + + auto builder = cudf::io::parquet_reader_options::builder(source); + if (n_filter_col_names.size() > 0) { + builder = builder.columns(n_filter_col_names.as_cpp_vector()); + } + + cudf::io::parquet_reader_options opts = + builder.convert_strings_to_categories(false) + .timestamp_type(cudf::data_type(static_cast(unit))) + .build(); + return convert_table_for_return(env, cudf::io::read_parquet(opts).tbl); + } + CATCH_STD(env, NULL); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet( JNIEnv *env, jclass, jobjectArray filter_col_names, jbooleanArray j_col_binary_read, jstring inputfilepath, jlong buffer, jlong buffer_length, jint unit) { @@ -1535,10 +1692,31 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet( CATCH_STD(env, NULL); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readAvroFromDataSource( + JNIEnv *env, jclass, jobjectArray filter_col_names, jlong ds_handle) { + + JNI_NULL_CHECK(env, ds_handle, "no data source handle given", 0); + + try { + cudf::jni::auto_set_device(env); + + cudf::jni::native_jstringArray n_filter_col_names(env, filter_col_names); + + auto ds = reinterpret_cast(ds_handle); + cudf::io::source_info source{ds}; + + cudf::io::avro_reader_options opts = cudf::io::avro_reader_options::builder(source) + .columns(n_filter_col_names.as_cpp_vector()) + .build(); + return convert_table_for_return(env, cudf::io::read_avro(opts).tbl); + } + CATCH_STD(env, NULL); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readAvro(JNIEnv *env, jclass, jobjectArray filter_col_names, jstring inputfilepath, jlong buffer, - jlong buffer_length, jint unit) { + jlong buffer_length) { const bool read_buffer = (buffer != 0); if (!read_buffer) { @@ -1715,6 +1893,38 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeParquetEnd(JNIEnv *env, jc CATCH_STD(env, ) } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readORCFromDataSource( + JNIEnv *env, jclass, jobjectArray filter_col_names, jboolean usingNumPyTypes, jint unit, + jobjectArray dec128_col_names, jlong ds_handle) { + + JNI_NULL_CHECK(env, ds_handle, "no data source handle given", 0); + + try { + cudf::jni::auto_set_device(env); + + cudf::jni::native_jstringArray n_filter_col_names(env, filter_col_names); + + cudf::jni::native_jstringArray n_dec128_col_names(env, dec128_col_names); + + auto ds = reinterpret_cast(ds_handle); + cudf::io::source_info source{ds}; + + auto builder = cudf::io::orc_reader_options::builder(source); + if (n_filter_col_names.size() > 0) { + builder = builder.columns(n_filter_col_names.as_cpp_vector()); + } + + cudf::io::orc_reader_options opts = + builder.use_index(false) + .use_np_dtypes(static_cast(usingNumPyTypes)) + .timestamp_type(cudf::data_type(static_cast(unit))) + .decimal128_columns(n_dec128_col_names.as_cpp_vector()) + .build(); + return convert_table_for_return(env, cudf::io::read_orc(opts).tbl); + } + CATCH_STD(env, NULL); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readORC( JNIEnv *env, jclass, jobjectArray filter_col_names, jstring inputfilepath, jlong buffer, jlong buffer_length, jboolean usingNumPyTypes, jint unit, jobjectArray dec128_col_names) { diff --git a/java/src/main/native/src/cudf_jni_apis.hpp b/java/src/main/native/src/cudf_jni_apis.hpp index 867df80b722..bd82bbd2899 100644 --- a/java/src/main/native/src/cudf_jni_apis.hpp +++ b/java/src/main/native/src/cudf_jni_apis.hpp @@ -134,5 +134,13 @@ void auto_set_device(JNIEnv *env); */ void device_memset_async(JNIEnv *env, rmm::device_buffer &buf, char value); +// +// DataSource APIs +// + +bool cache_data_source_jni(JNIEnv *env); + +void release_data_source_jni(JNIEnv *env); + } // namespace jni } // namespace cudf diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index faa73ac4322..b0dd4122b0e 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -327,6 +327,25 @@ void testReadJSONFile() { } } + @Test + void testReadJSONFromDataSource() throws IOException { + Schema schema = Schema.builder() + .column(DType.STRING, "name") + .column(DType.INT32, "age") + .build(); + JSONOptions opts = JSONOptions.builder() + .withLines(true) + .build(); + try (Table expected = new Table.TestBuilder() + .column("Michael", "Andy", "Justin") + .column(null, 30, 19) + .build(); + MultiBufferDataSource source = sourceFrom(TEST_SIMPLE_JSON_FILE); + Table table = Table.readJSON(schema, opts, source)) { + assertTablesAreEqual(expected, table); + } + } + @Test void testReadJSONFileWithInvalidLines() { Schema schema = Schema.builder() @@ -560,6 +579,126 @@ void testReadCSVBuffer() { } } + byte[][] sliceBytes(byte[] data, int slices) { + slices = Math.min(data.length, slices); + // We are not going to worry about making it super even here. + // The last one gets the extras. + int bytesPerSlice = data.length / slices; + byte[][] ret = new byte[slices][]; + int startingAt = 0; + for (int i = 0; i < (slices - 1); i++) { + ret[i] = new byte[bytesPerSlice]; + System.arraycopy(data, startingAt, ret[i], 0, bytesPerSlice); + startingAt += bytesPerSlice; + } + // Now for the last one + ret[slices - 1] = new byte[data.length - startingAt]; + System.arraycopy(data, startingAt, ret[slices - 1], 0, data.length - startingAt); + return ret; + } + + @Test + void testReadCSVBufferMultiBuffer() { + CSVOptions opts = CSVOptions.builder() + .includeColumn("A") + .includeColumn("B") + .hasHeader() + .withDelim('|') + .withQuote('\'') + .withNullValue("NULL") + .build(); + byte[][] data = sliceBytes(CSV_DATA_BUFFER, 10); + try (Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, null, 118.2, 119.8) + .build(); + MultiBufferDataSource source = sourceFrom(data); + Table table = Table.readCSV(TableTest.CSV_DATA_BUFFER_SCHEMA, opts, source)) { + assertTablesAreEqual(expected, table); + } + } + + public static byte[] arrayFrom(File f) throws IOException { + long len = f.length(); + if (len > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Sorry cannot read " + f + + " into an array it does not fit"); + } + int remaining = (int)len; + byte[] ret = new byte[remaining]; + try (java.io.FileInputStream fin = new java.io.FileInputStream(f)) { + int at = 0; + while (remaining > 0) { + int amount = fin.read(ret, at, remaining); + at += amount; + remaining -= amount; + } + } + return ret; + } + + public static MultiBufferDataSource sourceFrom(File f) throws IOException { + long len = f.length(); + byte[] tmp = new byte[(int)Math.min(32 * 1024, len)]; + try (HostMemoryBuffer buffer = HostMemoryBuffer.allocate(len)) { + try (java.io.FileInputStream fin = new java.io.FileInputStream(f)) { + long at = 0; + while (at < len) { + int amount = fin.read(tmp); + buffer.setBytes(at, tmp, 0, amount); + at += amount; + } + } + return new MultiBufferDataSource(buffer); + } + } + + public static MultiBufferDataSource sourceFrom(byte[] data) { + long len = data.length; + try (HostMemoryBuffer buffer = HostMemoryBuffer.allocate(len)) { + buffer.setBytes(0, data, 0, len); + return new MultiBufferDataSource(buffer); + } + } + + public static MultiBufferDataSource sourceFrom(byte[][] data) { + HostMemoryBuffer[] buffers = new HostMemoryBuffer[data.length]; + try { + for (int i = 0; i < data.length; i++) { + byte[] subData = data[i]; + buffers[i] = HostMemoryBuffer.allocate(subData.length); + buffers[i].setBytes(0, subData, 0, subData.length); + } + return new MultiBufferDataSource(buffers); + } finally { + for (HostMemoryBuffer buffer: buffers) { + if (buffer != null) { + buffer.close(); + } + } + } + } + + @Test + void testReadCSVDataSource() { + CSVOptions opts = CSVOptions.builder() + .includeColumn("A") + .includeColumn("B") + .hasHeader() + .withDelim('|') + .withQuote('\'') + .withNullValue("NULL") + .build(); + try (Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, null, 118.2, 119.8) + .build(); + MultiBufferDataSource source = sourceFrom(TableTest.CSV_DATA_BUFFER); + Table table = Table.readCSV(TableTest.CSV_DATA_BUFFER_SCHEMA, opts, source)) { + assertTablesAreEqual(expected, table); + } + } + @Test void testReadCSVWithOffset() { CSVOptions opts = CSVOptions.builder() @@ -864,6 +1003,37 @@ void testReadParquet() { } } + @Test + void testReadParquetFromDataSource() throws IOException { + ParquetOptions opts = ParquetOptions.builder() + .includeColumn("loan_id") + .includeColumn("zip") + .includeColumn("num_units") + .build(); + try (MultiBufferDataSource source = sourceFrom(TEST_PARQUET_FILE); + Table table = Table.readParquet(opts, source)) { + long rows = table.getRowCount(); + assertEquals(1000, rows); + assertTableTypes(new DType[]{DType.INT64, DType.INT32, DType.INT32}, table); + } + } + + @Test + void testReadParquetMultiBuffer() throws IOException { + ParquetOptions opts = ParquetOptions.builder() + .includeColumn("loan_id") + .includeColumn("zip") + .includeColumn("num_units") + .build(); + byte [][] data = sliceBytes(arrayFrom(TEST_PARQUET_FILE), 10); + try (MultiBufferDataSource source = sourceFrom(data); + Table table = Table.readParquet(opts, source)) { + long rows = table.getRowCount(); + assertEquals(1000, rows); + assertTableTypes(new DType[]{DType.INT64, DType.INT32, DType.INT32}, table); + } + } + @Test void testReadParquetBinary() { ParquetOptions opts = ParquetOptions.builder() @@ -1018,6 +1188,23 @@ void testChunkedReadParquet() { } } + @Test + void testChunkedReadParquetFromDataSource() throws IOException { + try (MultiBufferDataSource source = sourceFrom(TEST_PARQUET_FILE_CHUNKED_READ); + ParquetChunkedReader reader = new ParquetChunkedReader(240000, ParquetOptions.DEFAULT, source)) { + int numChunks = 0; + long totalRows = 0; + while(reader.hasNext()) { + ++numChunks; + try(Table chunk = reader.readChunk()) { + totalRows += chunk.getRowCount(); + } + } + assertEquals(2, numChunks); + assertEquals(40000, totalRows); + } + } + @Test void testReadAvro() { AvroOptions opts = AvroOptions.builder() @@ -1037,6 +1224,26 @@ void testReadAvro() { } } + @Test + void testReadAvroFromDataSource() throws IOException { + AvroOptions opts = AvroOptions.builder() + .includeColumn("bool_col") + .includeColumn("int_col") + .includeColumn("timestamp_col") + .build(); + + try (Table expected = new Table.TestBuilder() + .column(true, false, true, false, true, false, true, false) + .column(0, 1, 0, 1, 0, 1, 0, 1) + .column(1235865600000000L, 1235865660000000L, 1238544000000000L, 1238544060000000L, + 1233446400000000L, 1233446460000000L, 1230768000000000L, 1230768060000000L) + .build(); + MultiBufferDataSource source = sourceFrom(TEST_ALL_TYPES_PLAIN_AVRO_FILE); + Table table = Table.readAvro(opts, source)) { + assertTablesAreEqual(expected, table); + } + } + @Test void testReadAvroBuffer() throws IOException{ AvroOptions opts = AvroOptions.builder() @@ -1094,6 +1301,24 @@ void testReadORC() { } } + @Test + void testReadORCFromDataSource() throws IOException { + ORCOptions opts = ORCOptions.builder() + .includeColumn("string1") + .includeColumn("float1") + .includeColumn("int1") + .build(); + try (Table expected = new Table.TestBuilder() + .column("hi","bye") + .column(1.0f,2.0f) + .column(65536,65536) + .build(); + MultiBufferDataSource source = sourceFrom(TEST_ORC_FILE); + Table table = Table.readORC(opts, source)) { + assertTablesAreEqual(expected, table); + } + } + @Test void testReadORCBuffer() throws IOException { ORCOptions opts = ORCOptions.builder()