Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in java bindings for DataSource #14254

Merged
merged 12 commits into from
Oct 11, 2023
8 changes: 8 additions & 0 deletions cpp/src/io/utilities/datasource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,14 @@ class user_datasource_wrapper : public datasource {
return source->device_read(offset, size, stream);
}

std::future<size_t> device_read_async(size_t offset,
size_t size,
uint8_t* dst,
rmm::cuda_stream_view stream) override
{
return source->device_read_async(offset, size, dst, stream);
}

[[nodiscard]] size_t size() const override { return source->size(); }

private:
Expand Down
24 changes: 19 additions & 5 deletions java/src/main/java/ai/rapids/cudf/Cuda.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,9 +15,6 @@
*/
package ai.rapids.cudf;

import ai.rapids.cudf.NvtxColor;
import ai.rapids.cudf.NvtxRange;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -90,6 +87,21 @@ private Stream() {
this.id = -1;
}

private Stream(long id) {
this.cleaner = null;
this.id = id;
}

/**
* Wrap a given stream ID to make it accessible.
*/
static Stream wrap(long id) {
if (id == -1) {
return DEFAULT_STREAM;
}
return new Stream(id);
}

/**
* Have this stream not execute new work until the work recorded in event completes.
* @param event the event to wait on.
Expand Down Expand Up @@ -122,7 +134,9 @@ public synchronized void close() {
cleaner.delRef();
}
if (closed) {
cleaner.logRefCountDebug("double free " + this);
if (cleaner != null) {
cleaner.logRefCountDebug("double free " + this);
}
throw new IllegalStateException("Close called too many times " + this);
}
if (cleaner != null) {
Expand Down
189 changes: 189 additions & 0 deletions java/src/main/java/ai/rapids/cudf/DataSource.java
jlowe marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.HashMap;

/**
* Base class that can be used to provide data dynamically to CUDF. This follows somewhat
* closely with cudf::io::datasource. There are a few main differences.
* <br/>
* First this does not expose async device reads. It will call the non-async device read API
* instead. This might be added in the future, but there was no direct use case for it in java
* right now to warrant the added complexity.
* <br/>
* Second there is no implementation of the device read API that returns a buffer instead of
* writing into one. This is not used by CUDF yet so testing an implementation that isn't used
* didn't feel ideal. If it is needed we will add one in the future.
*/
public abstract class DataSource implements AutoCloseable {
private static final Logger log = LoggerFactory.getLogger(DataSource.class);

/**
* This is used to keep track of the HostMemoryBuffers in java land so the C++ layer
* does not have to do it.
*/
private final HashMap<Long, HostMemoryBuffer> cachedBuffers = new HashMap<>();

@Override
public void close() {
if (!cachedBuffers.isEmpty()) {
throw new IllegalStateException("DataSource closed before all returned host buffers were closed");
}
}

/**
* Get the size of the source in bytes.
*/
public abstract long size();

/**
* Read data from the source at the given offset. Return a HostMemoryBuffer for the data
* that was read.
* @param offset where to start reading from.
* @param amount the maximum number of bytes to read.
* @return a buffer that points to the data.
* @throws IOException on any error.
*/
public abstract HostMemoryBuffer hostRead(long offset, long amount) throws IOException;


/**
* Called when the buffer returned from hostRead is done. The default is to close the buffer.
*/
protected void onHostBufferDone(HostMemoryBuffer buffer) {
if (buffer != null) {
buffer.close();
}
}

/**
* Read data from the source at the given offset into dest. Note that dest should not be closed,
* and no reference to it can outlive the call to hostRead. The target amount to read is
* dest's length.
* @param offset the offset to start reading from in the source.
* @param dest where to write the data.
* @return the actual number of bytes written to dest.
*/
public abstract long hostRead(long offset, HostMemoryBuffer dest) throws IOException;

/**
* Return true if this supports reading directly to the device else false. The default is
* no device support. This cannot change dynamically. It is typically read just once.
*/
public boolean supportsDeviceRead() {
return false;
}

/**
* Get the size cutoff between device reads and host reads when device reads are supported.
* Anything larger than the cutoff will be a device read and anything smaller will be a
* host read. By default, the cutoff is 0 so all reads will be device reads if device reads
* are supported.
*/
public long getDeviceReadCutoff() {
return 0;
}

/**
* Read data from the source at the given offset into dest. Note that dest should not be closed,
* and no reference to it can outlive the call to hostRead. The target amount to read is
* dest's length.
* @param offset the offset to start reading from
* @param dest where to write the data.
* @param stream the stream to do the copy on.
* @return the actual number of bytes written to dest.
*/
public long deviceRead(long offset, DeviceMemoryBuffer dest,
Cuda.Stream stream) throws IOException {
throw new IllegalStateException("Device read is not implemented");
}

/////////////////////////////////////////////////
// Internal methods called from JNI
/////////////////////////////////////////////////

private static class NoopCleaner extends MemoryBuffer.MemoryBufferCleaner {
@Override
protected boolean cleanImpl(boolean logErrorIfNotClean) {
return true;
}

@Override
public boolean isClean() {
return true;
}
}
private static final NoopCleaner cleaner = new NoopCleaner();

// Called from JNI
private void onHostBufferDone(long bufferId) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
HostMemoryBuffer hmb = cachedBuffers.remove(bufferId);
if (hmb != null) {
onHostBufferDone(hmb);
} else {
// Called from C++ destructor so avoid throwing...
log.warn("Got a close callback for a buffer we could not find " + bufferId);
}
}

// Called from JNI
private long hostRead(long offset, long amount, long dst) throws IOException {
if (amount < 0) {
throw new IllegalArgumentException("Cannot allocate more than " + Long.MAX_VALUE + " bytes");
}
try (HostMemoryBuffer dstBuffer = new HostMemoryBuffer(dst, amount, cleaner)) {
return hostRead(offset, dstBuffer);
}
}

// Called from JNI
private long[] hostReadBuff(long offset, long amount) throws IOException {
if (amount < 0) {
throw new IllegalArgumentException("Cannot read more than " + Long.MAX_VALUE + " bytes");
}
HostMemoryBuffer buff = hostRead(offset, amount);
long[] ret = new long[3];
if (buff != null) {
long id = buff.id;
if (cachedBuffers.put(id, buff) != null) {
throw new IllegalStateException("Already had a buffer cached for " + buff);
}
ret[0] = buff.address;
ret[1] = buff.length;
ret[2] = id;
} // else they are all 0 because java does that already
return ret;
}

// Called from JNI
private long deviceRead(long offset, long amount, long dst, long stream) throws IOException {
if (amount < 0) {
throw new IllegalArgumentException("Cannot read more than " + Long.MAX_VALUE + " bytes");
}
Cuda.Stream strm = Cuda.Stream.wrap(stream);
try (DeviceMemoryBuffer dstBuffer = new DeviceMemoryBuffer(dst, amount, cleaner)) {
return deviceRead(offset, dstBuffer, strm);
}
}
}
44 changes: 44 additions & 0 deletions java/src/main/java/ai/rapids/cudf/DataSourceHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* This is here because we need some JNI methods to work with a DataSource, but
* we also want to cache callback methods at startup for performance reasons. If
* we put both in the same class we will get a deadlock because of how we load
* the JNI. We have a static block that blocks loading the class until the JNI
* library is loaded and the JNI library cannot load until the class is loaded
* and cached. This breaks the loop.
*/
class DataSourceHelper {
static {
NativeDepsLoader.loadNativeDeps();
}

static long createWrapperDataSource(DataSource ds) {
return createWrapperDataSource(ds, ds.size(), ds.supportsDeviceRead(),
ds.getDeviceReadCutoff());
}

private static native long createWrapperDataSource(DataSource ds, long size,
boolean deviceReadSupport,
long deviceReadCutoff);

static native void destroyWrapperDataSource(long handle);
}
6 changes: 5 additions & 1 deletion java/src/main/java/ai/rapids/cudf/DeviceMemoryBuffer.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -112,6 +112,10 @@ public static DeviceMemoryBuffer fromRmm(long address, long lengthInBytes, long
return new DeviceMemoryBuffer(address, lengthInBytes, rmmBufferAddress);
}

DeviceMemoryBuffer(long address, long lengthInBytes, MemoryBufferCleaner cleaner) {
super(address, lengthInBytes, cleaner);
}

DeviceMemoryBuffer(long address, long lengthInBytes, long rmmBufferAddress) {
super(address, lengthInBytes, new RmmDeviceBufferCleaner(rmmBufferAddress));
}
Expand Down
Loading