From d715ac8be048963158af97faa990a4bc496928b6 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Tue, 31 Jan 2023 10:20:46 -0600 Subject: [PATCH] Pay off some JNI RMM API tech debt (#12632) This makes the java APIs for RMM more closely match the C++ APIs. Authors: - Robert (Bobby) Evans (https://github.com/revans2) Approvers: - Jason Lowe (https://github.com/jlowe) URL: https://github.com/rapidsai/cudf/pull/12632 --- java/src/main/java/ai/rapids/cudf/Rmm.java | 339 +++++++++-- .../rapids/cudf/RmmArenaMemoryResource.java | 67 +++ .../cudf/RmmCudaAsyncMemoryResource.java | 59 ++ .../ai/rapids/cudf/RmmCudaMemoryResource.java | 44 ++ .../rapids/cudf/RmmDeviceMemoryResource.java | 31 + .../cudf/RmmEventHandlerResourceAdaptor.java | 76 +++ .../cudf/RmmLimitingResourceAdaptor.java | 59 ++ .../cudf/RmmLoggingResourceAdaptor.java | 58 ++ .../rapids/cudf/RmmManagedMemoryResource.java | 45 ++ .../ai/rapids/cudf/RmmPoolMemoryResource.java | 64 ++ .../cudf/RmmTrackingResourceAdaptor.java | 69 +++ .../cudf/RmmWrappingDeviceMemoryResource.java | 56 ++ java/src/main/native/src/RmmJni.cpp | 562 +++++++++++------- .../src/test/java/ai/rapids/cudf/RmmTest.java | 39 ++ 14 files changed, 1288 insertions(+), 280 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/RmmArenaMemoryResource.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmCudaAsyncMemoryResource.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmCudaMemoryResource.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmEventHandlerResourceAdaptor.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmLimitingResourceAdaptor.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmLoggingResourceAdaptor.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmManagedMemoryResource.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmPoolMemoryResource.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmTrackingResourceAdaptor.java create mode 100644 java/src/main/java/ai/rapids/cudf/RmmWrappingDeviceMemoryResource.java diff --git a/java/src/main/java/ai/rapids/cudf/Rmm.java b/java/src/main/java/ai/rapids/cudf/Rmm.java index d59311ebca0..66c053f15b2 100755 --- a/java/src/main/java/ai/rapids/cudf/Rmm.java +++ b/java/src/main/java/ai/rapids/cudf/Rmm.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,14 @@ package ai.rapids.cudf; import java.io.File; -import java.util.Arrays; import java.util.concurrent.TimeUnit; /** * This is the binding class for rmm lib. */ public class Rmm { + private static volatile RmmTrackingResourceAdaptor tracker = null; + private static volatile RmmDeviceMemoryResource deviceResource = null; private static volatile boolean initialized = false; private static volatile long poolSize = -1; private static volatile boolean poolingEnabled = false; @@ -30,7 +31,7 @@ public class Rmm { NativeDepsLoader.loadNativeDeps(); } - private enum LogLoc { + enum LogLoc { NONE(0), FILE(1), STDOUT(2), @@ -47,8 +48,8 @@ private enum LogLoc { * What to send RMM alloc and free logs to. */ public static class LogConf { - private final File file; - private final LogLoc loc; + final File file; + final LogLoc loc; private LogConf(File file, LogLoc loc) { this.file = file; @@ -77,6 +78,123 @@ public static LogConf logToStderr() { return new LogConf(null, LogLoc.STDERR); } + + /** + * Get the RmmDeviceMemoryResource that was last set through the java APIs. This will + * not return the correct value if the resource was not set using the java APIs. It will + * return a null if the resource was never set through the java APIs. + */ + public static synchronized RmmDeviceMemoryResource getCurrentDeviceResource() { + return deviceResource; + } + + /** + * Get the currently set RmmTrackingResourceAdaptor that is set. This might return null if + * RMM has nto been initialized. + */ + public static synchronized RmmTrackingResourceAdaptor getTracker() { + return tracker; + } + + /** + * Set the current device resource that RMM should use for all allocations and de-allocations. + * This should only be done if you feel comfortable that the current device resource has no + * pending allocations. Note that the caller of this is responsible for closing the current + * RmmDeviceMemoryResource that is returned by this. Assuming that it was not used to create + * the newResource. Please use the `shutdown` API to clear the resource as it does best + * effort clean up before shutting it down. If `newResource` is not null this will initialize + * the CUDA context for the calling thread if it is not already set. The caller is responsible + * for setting the desired CUDA device prior to this call if a specific device is already set. + *

NOTE: All cudf methods will set the chosen CUDA device in the CUDA context of the calling + * thread after this returns and `newResource` was not null. + *

If `newResource` is null this will unset the default CUDA device and mark RMM as not + * initialized. + *

Be aware that for many of these APIs to work the RmmDeviceMemoryResource will need an + * `RmmTrackingResourceAdaptor`. If one is not found and `newResource` is not null it will + * be added to `newResource`. + *

Also be very careful with how you set this up. It is possible to set up an + * RmmDeviceMemoryResource that is just bad, like multiple pools or pools on top of an + * RmmAsyncMemoryResource, that does pooling already. Unless you know what you are doing it is + * best to just use the `initialize` API instead. + * + * @param newResource the new resource to set. If it is null an RmmCudaMemoryResource will be + * used, and RMM will be set as not initialized. + * @param expectedResource the resource that we expect to be set. This is to let us avoid race + * conditions with multiple things trying to set this at once. It should + * never happen, but just to be careful. + * @param forceChange if true then the expectedResource check is not done. + */ + public static synchronized RmmDeviceMemoryResource setCurrentDeviceResource( + RmmDeviceMemoryResource newResource, + RmmDeviceMemoryResource expectedResource, + boolean forceChange) { + boolean shouldInit = false; + boolean shouldDeinit = false; + RmmDeviceMemoryResource newResourceToSet = newResource; + if (newResourceToSet == null) { + // We always want it to be set to something or else it can cause problems... + newResourceToSet = new RmmCudaMemoryResource(); + if (initialized) { + shouldDeinit = true; + } + } else if (!initialized) { + shouldInit = true; + } + + RmmDeviceMemoryResource oldResource = deviceResource; + if (!forceChange && expectedResource != null && deviceResource != null) { + long expectedOldHandle = expectedResource.getHandle(); + long oldHandle = deviceResource.getHandle(); + if (oldHandle != expectedOldHandle) { + throw new RmmException("The expected device resource is not correct " + + Long.toHexString(oldHandle) + " != " + Long.toHexString(expectedOldHandle)); + } + } + + poolSize = -1; + poolingEnabled = false; + setGlobalValsFromResource(newResourceToSet); + if (newResource != null && tracker == null) { + // No tracker was set, but we need one + tracker = new RmmTrackingResourceAdaptor<>(newResourceToSet, 256); + newResourceToSet = tracker; + } + long newHandle = newResourceToSet.getHandle(); + setCurrentDeviceResourceInternal(newHandle); + deviceResource = newResource; + if (shouldInit) { + initDefaultCudaDevice(); + MemoryCleaner.setDefaultGpu(Cuda.getDevice()); + initialized = true; + } + + if (shouldDeinit) { + cleanupDefaultCudaDevice(); + initialized = false; + } + return oldResource; + } + + private static void setGlobalValsFromResource(RmmDeviceMemoryResource resource) { + if (resource instanceof RmmTrackingResourceAdaptor) { + Rmm.tracker = (RmmTrackingResourceAdaptor) resource; + } else if (resource instanceof RmmPoolMemoryResource) { + Rmm.poolSize = Math.max(((RmmPoolMemoryResource)resource).getMaxSize(), Rmm.poolSize); + Rmm.poolingEnabled = true; + } else if (resource instanceof RmmArenaMemoryResource) { + Rmm.poolSize = Math.max(((RmmArenaMemoryResource)resource).getSize(), Rmm.poolSize); + Rmm.poolingEnabled = true; + } else if (resource instanceof RmmCudaAsyncMemoryResource) { + Rmm.poolSize = Math.max(((RmmCudaAsyncMemoryResource)resource).getSize(), Rmm.poolSize); + Rmm.poolingEnabled = true; + } + + // Recurse as needed + if (resource instanceof RmmWrappingDeviceMemoryResource) { + setGlobalValsFromResource(((RmmWrappingDeviceMemoryResource)resource).getWrapped()); + } + } + /** * Initialize memory manager state and storage. This will always initialize * the CUDA context for the calling thread if it is not already set. The @@ -109,20 +227,43 @@ public static synchronized void initialize(int allocationMode, LogConf logConf, throw new IllegalArgumentException( "CUDA Unified Memory is not supported in CUDA_ASYNC allocation mode"); } - LogLoc loc = LogLoc.NONE; - String path = null; - if (logConf != null) { - if (logConf.file != null) { - path = logConf.file.getAbsolutePath(); + + RmmDeviceMemoryResource resource = null; + boolean succeeded = false; + try { + if (isPool) { + if (isManaged) { + resource = new RmmPoolMemoryResource<>(new RmmManagedMemoryResource(), poolSize, poolSize); + } else { + resource = new RmmPoolMemoryResource<>(new RmmCudaMemoryResource(), poolSize, poolSize); + } + } else if (isArena) { + if (isManaged) { + resource = new RmmArenaMemoryResource<>(new RmmManagedMemoryResource(), poolSize, false); + } else { + resource = new RmmArenaMemoryResource<>(new RmmCudaMemoryResource(), poolSize, false); + } + } else if (isAsync) { + resource = new RmmLimitingResourceAdaptor<>( + new RmmCudaAsyncMemoryResource(poolSize, poolSize), poolSize, 512); + } else if (isManaged) { + resource = new RmmManagedMemoryResource(); + } else { + resource = new RmmCudaMemoryResource(); } - loc = logConf.loc; - } - initializeInternal(allocationMode, loc.internalId, path, poolSize); - MemoryCleaner.setDefaultGpu(Cuda.getDevice()); - initialized = true; - Rmm.poolingEnabled = isPool || isArena || isAsync; - Rmm.poolSize = Rmm.poolingEnabled ? poolSize : -1; + if (logConf != null && logConf.loc != LogLoc.NONE) { + resource = new RmmLoggingResourceAdaptor<>(resource, logConf, true); + } + + resource = new RmmTrackingResourceAdaptor<>(resource, 256); + setCurrentDeviceResource(resource, null, false); + succeeded = true; + } finally { + if (!succeeded && resource != null) { + resource.close(); + } + } } /** @@ -150,16 +291,28 @@ public static boolean isInitialized() throws RmmException { /** * Return the amount of RMM memory allocated in bytes. Note that the result * may be less than the actual amount of allocated memory if underlying RMM - * allocator decides to return more memory than what was requested. However + * allocator decides to return more memory than what was requested. However, * the result will always be a lower bound on the amount allocated. */ - public static native long getTotalBytesAllocated(); + public static synchronized long getTotalBytesAllocated() { + if (tracker == null) { + return 0; + } else { + return tracker.getTotalBytesAllocated(); + } + } /** * Returns the maximum amount of RMM memory (Bytes) outstanding during the * lifetime of the process. */ - public static native long getMaximumTotalBytesAllocated(); + public static synchronized long getMaximumTotalBytesAllocated() { + if (tracker == null) { + return 0; + } else { + return tracker.getMaxTotalBytesAllocated(); + } + } /** * Resets a scoped maximum counter of RMM memory used to keep track of usage between @@ -167,8 +320,10 @@ public static boolean isInitialized() throws RmmException { * * @param initialValue an initial value (in Bytes) to use for this scoped counter */ - public static void resetScopedMaximumBytesAllocated(long initialValue) { - resetScopedMaximumBytesAllocatedInternal(initialValue); + public static synchronized void resetScopedMaximumBytesAllocated(long initialValue) { + if (tracker != null) { + tracker.resetScopedMaxTotalBytesAllocated(initialValue); + } } /** @@ -177,26 +332,32 @@ public static void resetScopedMaximumBytesAllocated(long initialValue) { * * This resets the counter to 0 Bytes. */ - public static void resetScopedMaximumBytesAllocated() { - resetScopedMaximumBytesAllocatedInternal(0L); + public static synchronized void resetScopedMaximumBytesAllocated() { + if (tracker != null) { + tracker.resetScopedMaxTotalBytesAllocated(0L); + } } - private static native void resetScopedMaximumBytesAllocatedInternal(long initialValue); - /** * Returns the maximum amount of RMM memory (Bytes) outstanding since the last * `resetScopedMaximumOutstanding` call was issued (it is "scoped" because it's the * maximum amount seen since the last reset). - * + *

* If the memory used is net negative (for example if only frees happened since * reset, and we reset to 0), then result will be 0. - * + *

* If `resetScopedMaximumBytesAllocated` is never called, the scope is the whole * program and is equivalent to `getMaximumTotalBytesAllocated`. * * @return the scoped maximum bytes allocated */ - public static native long getScopedMaximumBytesAllocated(); + public static synchronized long getScopedMaximumBytesAllocated() { + if (tracker == null) { + return 0L; + } else { + return tracker.getScopedMaxTotalBytesAllocated(); + } + } /** * Sets the event handler to be called on RMM events (e.g.: allocation failure). @@ -210,7 +371,7 @@ public static void setEventHandler(RmmEventHandler handler) throws RmmException /** * Sets the event handler to be called on RMM events (e.g.: allocation failure) and * optionally enable debug mode (callbacks on every allocate and deallocate) - * + *

* NOTE: Only enable debug mode when necessary, as code will run much slower! * * @param handler event handler to invoke on RMM events or null to clear an existing handler @@ -218,29 +379,51 @@ public static void setEventHandler(RmmEventHandler handler) throws RmmException * (onAllocated, onDeallocated) * @throws RmmException if an active handler is already set */ - public static void setEventHandler(RmmEventHandler handler, + public static synchronized void setEventHandler(RmmEventHandler handler, boolean enableDebug) throws RmmException { - long[] allocThresholds = (handler != null) ? sortThresholds(handler.getAllocThresholds()) : null; - long[] deallocThresholds = (handler != null) ? sortThresholds(handler.getDeallocThresholds()) : null; - setEventHandlerInternal(handler, allocThresholds, deallocThresholds, enableDebug); + if (!initialized) { + throw new RmmException("RMM has not been initialized"); + } + if (deviceResource instanceof RmmEventHandlerResourceAdaptor) { + throw new RmmException("Another event handler is already set"); + } + if (tracker == null) { + // This is just to be safe it should always be true if this is initialized. + throw new RmmException("A tracker must be set for the event handler to work"); + } + RmmEventHandlerResourceAdaptor newResource = + new RmmEventHandlerResourceAdaptor<>(deviceResource, tracker, handler, enableDebug); + boolean success = false; + try { + setCurrentDeviceResource(newResource, deviceResource, false); + success = true; + } finally { + if (!success) { + newResource.releaseWrapped(); + } + } } /** Clears the active RMM event handler if one is set. */ - public static void clearEventHandler() throws RmmException { - setEventHandlerInternal(null, null, null, false); - } - - private static long[] sortThresholds(long[] thresholds) { - if (thresholds == null) { - return null; + public static synchronized void clearEventHandler() throws RmmException { + if (deviceResource != null && deviceResource instanceof RmmEventHandlerResourceAdaptor) { + RmmEventHandlerResourceAdaptor orig = + (RmmEventHandlerResourceAdaptor)deviceResource; + boolean success = false; + try { + setCurrentDeviceResource(orig.wrapped, orig, false); + success = true; + } finally { + if (success) { + orig.releaseWrapped(); + } + } } - long[] result = Arrays.copyOf(thresholds, thresholds.length); - Arrays.sort(result); - return result; } - private static native void initializeInternal(int allocationMode, int logTo, String path, - long poolSize) throws RmmException; + public static native void initDefaultCudaDevice(); + + public static native void cleanupDefaultCudaDevice(); /** * Shut down any initialized RMM instance. This should be used very rarely. It does not need to @@ -297,15 +480,12 @@ public static synchronized void shutdown(long forceGCInterval, long maxWaitTime, throw new RmmException("Could not shut down RMM there appear to be outstanding allocations"); } if (initialized) { - shutdownInternal(); - initialized = false; - poolSize = -1; - poolingEnabled = false; + if (deviceResource != null) { + setCurrentDeviceResource(null, deviceResource, true).close(); + } } } - private static native void shutdownInternal() throws RmmException; - /** * Allocate device memory and return a pointer to device memory, using stream 0. * @param size The size in bytes of the allocated memory region @@ -336,10 +516,6 @@ public static DeviceMemoryBuffer alloc(long size, Cuda.Stream stream) { */ static native void freeDeviceBuffer(long rmmBufferAddress) throws RmmException; - static native void setEventHandlerInternal(RmmEventHandler handler, - long[] allocThresholds, long[] deallocThresholds, - boolean enableDebug) throws RmmException; - /** * Allocate device memory using `cudaMalloc` and return a pointer to device memory. * @param size The size in bytes of the allocated memory region @@ -354,4 +530,55 @@ public static CudaMemoryBuffer allocCuda(long size, Cuda.Stream stream) { private static native long allocCudaInternal(long size, long stream) throws RmmException; static native void freeCuda(long ptr, long length, long stream) throws RmmException; + + static native long newCudaMemoryResource() throws RmmException; + + static native void releaseCudaMemoryResource(long handle); + + static native long newManagedMemoryResource() throws RmmException; + + static native void releaseManagedMemoryResource(long handle); + + static native long newPoolMemoryResource(long childHandle, + long initSize, long maxSize) throws RmmException; + + static native void releasePoolMemoryResource(long handle); + + static native long newArenaMemoryResource(long childHandle, + long size, boolean dumpOnOOM) throws RmmException; + + static native void releaseArenaMemoryResource(long handle); + + static native long newCudaAsyncMemoryResource(long size, long release) throws RmmException; + + static native void releaseCudaAsyncMemoryResource(long handle); + + static native long newLimitingResourceAdaptor(long handle, long limit, long align) throws RmmException; + + static native void releaseLimitingResourceAdaptor(long handle); + + static native long newLoggingResourceAdaptor(long handle, int type, String path, + boolean autoFlush) throws RmmException; + + static native void releaseLoggingResourceAdaptor(long handle); + + + static native long newTrackingResourceAdaptor(long handle, long alignment) throws RmmException; + + static native void releaseTrackingResourceAdaptor(long handle); + + static native long nativeGetTotalBytesAllocated(long handle); + + static native long nativeGetMaxTotalBytesAllocated(long handle); + + static native void nativeResetScopedMaxTotalBytesAllocated(long handle, long initValue); + + static native long nativeGetScopedMaxTotalBytesAllocated(long handle); + + static native long newEventHandlerResourceAdaptor(long handle, long trackerHandle, + RmmEventHandler handler, long[] allocThresholds, long[] deallocThresholds, boolean debug); + + static native long releaseEventHandlerResourceAdaptor(long handle, boolean debug); + + private static native void setCurrentDeviceResourceInternal(long newHandle); } diff --git a/java/src/main/java/ai/rapids/cudf/RmmArenaMemoryResource.java b/java/src/main/java/ai/rapids/cudf/RmmArenaMemoryResource.java new file mode 100644 index 00000000000..4638f52a1b1 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmArenaMemoryResource.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A device memory resource that will pre-allocate a pool of resources and sub-allocate from this + * pool to improve memory performance. This uses an algorithm to try and reduce fragmentation + * much more than the RmmPoolMemoryResource does. + */ +public class RmmArenaMemoryResource + extends RmmWrappingDeviceMemoryResource { + private final long size; + private final boolean dumpLogOnFailure; + private long handle = 0; + + + /** + * Create a new arena memory resource taking ownership of the RmmDeviceMemoryResource that it is + * wrapping. + * @param wrapped the memory resource to use for the pool. This should not be reused. + * @param size the size of the pool + * @param dumpLogOnFailure if true, dump memory log when running out of memory. + */ + public RmmArenaMemoryResource(C wrapped, long size, boolean dumpLogOnFailure) { + super(wrapped); + this.size = size; + this.dumpLogOnFailure = dumpLogOnFailure; + handle = Rmm.newArenaMemoryResource(wrapped.getHandle(), size, dumpLogOnFailure); + } + + @Override + public long getHandle() { + return handle; + } + + public long getSize() { + return size; + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releaseArenaMemoryResource(handle); + handle = 0; + } + super.close(); + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/ARENA(" + wrapped + + ", " + size + ", " + dumpLogOnFailure + ")"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmCudaAsyncMemoryResource.java b/java/src/main/java/ai/rapids/cudf/RmmCudaAsyncMemoryResource.java new file mode 100644 index 00000000000..fa1f13cb7ed --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmCudaAsyncMemoryResource.java @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A device memory resource that uses `cudaMallocAsync` and `cudaFreeAsync` for allocation and + * deallocation. + */ +public class RmmCudaAsyncMemoryResource implements RmmDeviceMemoryResource { + private final long releaseThreshold; + private final long size; + private long handle = 0; + + /** + * Create a new async memory resource + * @param size the initial size of the pool + * @param releaseThreshold size in bytes for when memory is released back to cuda + */ + public RmmCudaAsyncMemoryResource(long size, long releaseThreshold) { + this.size = size; + this.releaseThreshold = releaseThreshold; + handle = Rmm.newCudaAsyncMemoryResource(size, releaseThreshold); + } + + @Override + public long getHandle() { + return handle; + } + + public long getSize() { + return size; + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releaseCudaAsyncMemoryResource(handle); + handle = 0; + } + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/ASYNC(" + size + ", " + releaseThreshold + ")"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmCudaMemoryResource.java b/java/src/main/java/ai/rapids/cudf/RmmCudaMemoryResource.java new file mode 100644 index 00000000000..f31e9206b91 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmCudaMemoryResource.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A device memory resource that uses `cudaMalloc` and `cudaFree` for allocation and deallocation. + */ +public class RmmCudaMemoryResource implements RmmDeviceMemoryResource { + private long handle = 0; + + public RmmCudaMemoryResource() { + handle = Rmm.newCudaMemoryResource(); + } + @Override + public long getHandle() { + return handle; + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releaseCudaMemoryResource(handle); + handle = 0; + } + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/CUDA()"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java b/java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java new file mode 100644 index 00000000000..f44631396df --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A resource that allocates/deallocates device memory. This is not intended to be something that + * a user will just subclass. This is intended to be a wrapper around a C++ class that RMM will + * use directly. + */ +public interface RmmDeviceMemoryResource extends AutoCloseable { + /** + * Returns a pointer to the underlying C++ class that implements rmm::mr::device_memory_resource + */ + long getHandle(); + + // Remove the exception... + void close(); +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmEventHandlerResourceAdaptor.java b/java/src/main/java/ai/rapids/cudf/RmmEventHandlerResourceAdaptor.java new file mode 100644 index 00000000000..30d1d8a0f6b --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmEventHandlerResourceAdaptor.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +import java.util.Arrays; + +/** + * A device memory resource that will give callbacks in specific situations. + */ +public class RmmEventHandlerResourceAdaptor + extends RmmWrappingDeviceMemoryResource { + private long handle = 0; + private final long [] allocThresholds; + private final long [] deallocThresholds; + private final boolean debug; + + /** + * Create a new logging resource adaptor. + * @param wrapped the memory resource to get callbacks for. This should not be reused. + * @param handler the handler that will get the callbacks + * @param tracker the tracking event handler + * @param debug true if you want all the callbacks, else false + */ + public RmmEventHandlerResourceAdaptor(C wrapped, RmmTrackingResourceAdaptor tracker, + RmmEventHandler handler, boolean debug) { + super(wrapped); + this.debug = debug; + allocThresholds = sortThresholds(handler.getAllocThresholds()); + deallocThresholds = sortThresholds(handler.getDeallocThresholds()); + handle = Rmm.newEventHandlerResourceAdaptor(wrapped.getHandle(), tracker.getHandle(), handler, + allocThresholds, deallocThresholds, debug); + } + + private static long[] sortThresholds(long[] thresholds) { + if (thresholds == null) { + return null; + } + long[] result = Arrays.copyOf(thresholds, thresholds.length); + Arrays.sort(result); + return result; + } + + @Override + public long getHandle() { + return handle; + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releaseEventHandlerResourceAdaptor(handle, debug); + handle = 0; + } + super.close(); + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/EVENT(" + wrapped + + ", " + debug + ", " + Arrays.toString(allocThresholds) + ", " + + Arrays.toString(deallocThresholds) + ")"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmLimitingResourceAdaptor.java b/java/src/main/java/ai/rapids/cudf/RmmLimitingResourceAdaptor.java new file mode 100644 index 00000000000..0b0aa6d14a5 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmLimitingResourceAdaptor.java @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A device memory resource that will limit the maximum amount allocated. + */ +public class RmmLimitingResourceAdaptor + extends RmmWrappingDeviceMemoryResource { + private final long limit; + private final long alignment; + private long handle = 0; + + /** + * Create a new limiting resource adaptor. + * @param wrapped the memory resource to limit. This should not be reused. + * @param limit the allocation limit in bytes + * @param alignment the alignment + */ + public RmmLimitingResourceAdaptor(C wrapped, long limit, long alignment) { + super(wrapped); + this.limit = limit; + this.alignment = alignment; + handle = Rmm.newLimitingResourceAdaptor(wrapped.getHandle(), limit, alignment); + } + + @Override + public long getHandle() { + return handle; + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releaseLimitingResourceAdaptor(handle); + handle = 0; + } + super.close(); + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/LIMIT(" + wrapped + + ", " + limit + ", " + alignment + ")"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmLoggingResourceAdaptor.java b/java/src/main/java/ai/rapids/cudf/RmmLoggingResourceAdaptor.java new file mode 100644 index 00000000000..fe5d7e43b4f --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmLoggingResourceAdaptor.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A device memory resource that will log interactions. + */ +public class RmmLoggingResourceAdaptor + extends RmmWrappingDeviceMemoryResource { + private long handle = 0; + + /** + * Create a new logging resource adaptor. + * @param wrapped the memory resource to log interactions with. This should not be reused. + * @param conf the config of where this should be logged to + * @param autoFlush should the results be flushed after each entry or not. + */ + public RmmLoggingResourceAdaptor(C wrapped, Rmm.LogConf conf, boolean autoFlush) { + super(wrapped); + if (conf.loc == Rmm.LogLoc.NONE) { + throw new RmmException("Cannot initialize RmmLoggingResourceAdaptor with no logging"); + } + handle = Rmm.newLoggingResourceAdaptor(wrapped.getHandle(), conf.loc.internalId, + conf.file == null ? null : conf.file.getAbsolutePath(), autoFlush); + } + + @Override + public long getHandle() { + return handle; + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releaseLoggingResourceAdaptor(handle); + handle = 0; + } + super.close(); + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/LOG(" + wrapped + ")"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmManagedMemoryResource.java b/java/src/main/java/ai/rapids/cudf/RmmManagedMemoryResource.java new file mode 100644 index 00000000000..7a2f41c1d87 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmManagedMemoryResource.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A device memory resource that uses `cudaMallocManaged` and `cudaFreeManaged` for allocation and + * deallocation. + */ +public class RmmManagedMemoryResource implements RmmDeviceMemoryResource { + private long handle = 0; + + public RmmManagedMemoryResource() { + handle = Rmm.newManagedMemoryResource(); + } + @Override + public long getHandle() { + return handle; + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releaseManagedMemoryResource(handle); + handle = 0; + } + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/CUDA_MANAGED()"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmPoolMemoryResource.java b/java/src/main/java/ai/rapids/cudf/RmmPoolMemoryResource.java new file mode 100644 index 00000000000..7febd680bb3 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmPoolMemoryResource.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A device memory resource that will pre-allocate a pool of resources and sub-allocate from this + * pool to improve memory performance. + */ +public class RmmPoolMemoryResource + extends RmmWrappingDeviceMemoryResource { + private long handle = 0; + private final long initSize; + private final long maxSize; + + /** + * Create a new pooled memory resource taking ownership of the RmmDeviceMemoryResource that it is + * wrapping. + * @param wrapped the memory resource to use for the pool. This should not be reused. + * @param initSize the size of the initial pool + * @param maxSize the size of the maximum pool + */ + public RmmPoolMemoryResource(C wrapped, long initSize, long maxSize) { + super(wrapped); + this.initSize = initSize; + this.maxSize = maxSize; + handle = Rmm.newPoolMemoryResource(wrapped.getHandle(), initSize, maxSize); + } + + public long getMaxSize() { + return maxSize; + } + + @Override + public long getHandle() { + return handle; + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releasePoolMemoryResource(handle); + handle = 0; + } + super.close(); + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/POOL(" + wrapped + ")"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmTrackingResourceAdaptor.java b/java/src/main/java/ai/rapids/cudf/RmmTrackingResourceAdaptor.java new file mode 100644 index 00000000000..e9f1b08e9f4 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmTrackingResourceAdaptor.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A device memory resource that will track some basic statistics about the memory usage. + */ +public class RmmTrackingResourceAdaptor + extends RmmWrappingDeviceMemoryResource { + private long handle = 0; + + /** + * Create a new tracking resource adaptor. + * @param wrapped the memory resource to track allocations. This should not be reused. + * @param alignment the alignment to apply. + */ + public RmmTrackingResourceAdaptor(C wrapped, long alignment) { + super(wrapped); + handle = Rmm.newTrackingResourceAdaptor(wrapped.getHandle(), alignment); + } + + @Override + public long getHandle() { + return handle; + } + + public long getTotalBytesAllocated() { + return Rmm.nativeGetTotalBytesAllocated(getHandle()); + } + + public long getMaxTotalBytesAllocated() { + return Rmm.nativeGetMaxTotalBytesAllocated(getHandle()); + } + + public void resetScopedMaxTotalBytesAllocated(long initValue) { + Rmm.nativeResetScopedMaxTotalBytesAllocated(getHandle(), initValue); + } + + public long getScopedMaxTotalBytesAllocated() { + return Rmm.nativeGetScopedMaxTotalBytesAllocated(getHandle()); + } + + @Override + public void close() { + if (handle != 0) { + Rmm.releaseTrackingResourceAdaptor(handle); + handle = 0; + } + super.close(); + } + + @Override + public String toString() { + return Long.toHexString(getHandle()) + "/TRACK(" + wrapped + ")"; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RmmWrappingDeviceMemoryResource.java b/java/src/main/java/ai/rapids/cudf/RmmWrappingDeviceMemoryResource.java new file mode 100644 index 00000000000..b764798a3ae --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RmmWrappingDeviceMemoryResource.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +/** + * A resource that wraps another RmmDeviceMemoryResource + */ +public abstract class RmmWrappingDeviceMemoryResource + implements RmmDeviceMemoryResource { + protected C wrapped = null; + + public RmmWrappingDeviceMemoryResource(C wrapped) { + this.wrapped = wrapped; + } + + /** + * Get the resource that this is wrapping. Be very careful when using this as the returned value + * should not be added to another resource until it has been released. + * @return the resource that this is wrapping. + */ + public C getWrapped() { + return this.wrapped; + } + + /** + * Release the wrapped device memory resource and close this. + * @return the wrapped DeviceMemoryResource. + */ + public C releaseWrapped() { + C ret = this.wrapped; + this.wrapped = null; + close(); + return ret; + } + + @Override + public void close() { + if (wrapped != null) { + wrapped.close(); + wrapped = null; + } + } +} diff --git a/java/src/main/native/src/RmmJni.cpp b/java/src/main/native/src/RmmJni.cpp index b12f1ed0841..1ce69414c98 100644 --- a/java/src/main/native/src/RmmJni.cpp +++ b/java/src/main/native/src/RmmJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,9 +39,6 @@ using rmm::mr::logging_resource_adaptor; namespace { -// Alignment to which the RMM memory resource will round allocation sizes -constexpr std::size_t RMM_ALLOC_SIZE_ALIGNMENT = 256; - constexpr char const *RMM_EXCEPTION_CLASS = "ai/rapids/cudf/RmmException"; /** @@ -99,6 +96,10 @@ class tracking_resource_adaptor final : public base_tracking_resource_adaptor { return scoped_max_total_allocated; } + bool supports_get_mem_info() const noexcept override { return resource->supports_get_mem_info(); } + + bool supports_streams() const noexcept override { return resource->supports_streams(); } + private: Upstream *const resource; std::size_t const size_align; @@ -144,13 +145,9 @@ class tracking_resource_adaptor final : public base_tracking_resource_adaptor { } } - bool supports_get_mem_info() const noexcept override { return resource->supports_get_mem_info(); } - std::pair do_get_mem_info(rmm::cuda_stream_view stream) const override { return resource->get_mem_info(stream); } - - bool supports_streams() const noexcept override { return resource->supports_streams(); } }; template @@ -159,38 +156,6 @@ tracking_resource_adaptor *make_tracking_adaptor(Upstream *upstream, return new tracking_resource_adaptor{upstream, size_alignment}; } -std::unique_ptr Tracking_memory_resource{}; - -/** - * @brief Return the total amount of device memory allocated via RMM - */ -std::size_t get_total_bytes_allocated() { - if (Tracking_memory_resource) { - return Tracking_memory_resource->get_total_allocated(); - } - return 0; -} - -std::size_t get_max_total_allocated() { - if (Tracking_memory_resource) { - return Tracking_memory_resource->get_max_total_allocated(); - } - return 0; -} - -void reset_scoped_max_total_allocated(std::size_t initial_value) { - if (Tracking_memory_resource) { - return Tracking_memory_resource->reset_scoped_max_total_allocated(initial_value); - } -} - -std::size_t get_scoped_max_total_allocated() { - if (Tracking_memory_resource) { - return Tracking_memory_resource->get_scoped_max_total_allocated(); - } - return 0; -} - /** * @brief An RMM device memory resource adaptor that delegates to the wrapped resource * for most operations but will call Java to handle certain situations (e.g.: allocation failure). @@ -199,8 +164,9 @@ class java_event_handler_memory_resource : public device_memory_resource { public: java_event_handler_memory_resource(JNIEnv *env, jobject jhandler, jlongArray jalloc_thresholds, jlongArray jdealloc_thresholds, - device_memory_resource *resource_to_wrap) - : resource(resource_to_wrap) { + device_memory_resource *resource_to_wrap, + base_tracking_resource_adaptor *tracker) + : resource(resource_to_wrap), tracker(tracker) { if (env->GetJavaVM(&jvm) < 0) { throw std::runtime_error("GetJavaVM failed"); } @@ -250,8 +216,13 @@ class java_event_handler_memory_resource : public device_memory_resource { device_memory_resource *get_wrapped_resource() { return resource; } + bool supports_get_mem_info() const noexcept override { return resource->supports_get_mem_info(); } + + bool supports_streams() const noexcept override { return resource->supports_streams(); } + private: device_memory_resource *const resource; + base_tracking_resource_adaptor *const tracker; jmethodID on_alloc_fail_method; bool use_old_alloc_fail_interface; jmethodID on_alloc_threshold_method; @@ -309,14 +280,10 @@ class java_event_handler_memory_resource : public device_memory_resource { } } - bool supports_get_mem_info() const noexcept override { return resource->supports_get_mem_info(); } - std::pair do_get_mem_info(rmm::cuda_stream_view stream) const override { return resource->get_mem_info(stream); } - bool supports_streams() const noexcept override { return resource->supports_streams(); } - protected: JavaVM *jvm; jobject handler_obj; @@ -330,7 +297,7 @@ class java_event_handler_memory_resource : public device_memory_resource { int retry_count = 0; while (true) { try { - total_before = get_total_bytes_allocated(); + total_before = tracker->get_total_allocated(); result = resource->allocate(num_bytes, stream); break; } catch (rmm::out_of_memory const &e) { @@ -339,7 +306,7 @@ class java_event_handler_memory_resource : public device_memory_resource { } } } - auto total_after = get_total_bytes_allocated(); + auto total_after = tracker->get_total_allocated(); try { check_for_threshold_callback(total_before, total_after, alloc_thresholds, @@ -354,9 +321,9 @@ class java_event_handler_memory_resource : public device_memory_resource { } void do_deallocate(void *p, std::size_t size, rmm::cuda_stream_view stream) override { - auto total_before = get_total_bytes_allocated(); + auto total_before = tracker->get_total_allocated(); resource->deallocate(p, size, stream); - auto total_after = get_total_bytes_allocated(); + auto total_after = tracker->get_total_allocated(); check_for_threshold_callback(total_after, total_before, dealloc_thresholds, on_dealloc_threshold_method, "onDeallocThreshold", total_after); } @@ -367,9 +334,10 @@ class java_debug_event_handler_memory_resource final : public java_event_handler java_debug_event_handler_memory_resource(JNIEnv *env, jobject jhandler, jlongArray jalloc_thresholds, jlongArray jdealloc_thresholds, - device_memory_resource *resource_to_wrap) + device_memory_resource *resource_to_wrap, + base_tracking_resource_adaptor *tracker) : java_event_handler_memory_resource(env, jhandler, jalloc_thresholds, jdealloc_thresholds, - resource_to_wrap) { + resource_to_wrap, tracker) { jclass cls = env->GetObjectClass(jhandler); if (cls == nullptr) { throw cudf::jni::jni_exception("class not found"); @@ -415,241 +383,387 @@ class java_debug_event_handler_memory_resource final : public java_event_handler } }; -std::unique_ptr Java_memory_resource{}; - -void set_java_device_memory_resource(JNIEnv *env, jobject handler_obj, jlongArray jalloc_thresholds, - jlongArray jdealloc_thresholds, jboolean enable_debug) { - if (Java_memory_resource && handler_obj != nullptr) { - JNI_THROW_NEW(env, RMM_EXCEPTION_CLASS, "Another event handler is already set", ) - } - if (Java_memory_resource) { - auto java_resource = Java_memory_resource.get(); - auto old_resource = - rmm::mr::set_current_device_resource(Java_memory_resource->get_wrapped_resource()); - Java_memory_resource.reset(nullptr); - if (old_resource != java_resource) { - rmm::mr::set_current_device_resource(old_resource); - JNI_THROW_NEW(env, RMM_EXCEPTION_CLASS, - "Concurrent modification detected while removing memory resource", ); - } +} // anonymous namespace + +extern "C" { + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_initDefaultCudaDevice(JNIEnv *env, jclass clazz) { + // make sure the CUDA device is setup in the context + cudaError_t cuda_status = cudaFree(0); + cudf::jni::jni_cuda_check(env, cuda_status); + int device_id; + cuda_status = cudaGetDevice(&device_id); + cudf::jni::jni_cuda_check(env, cuda_status); + // Now that RMM has successfully initialized, setup all threads calling + // cudf to use the same device RMM is using. + cudf::jni::set_cudf_device(device_id); +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_cleanupDefaultCudaDevice(JNIEnv *env, jclass clazz) { + cudf::jni::set_cudf_device(cudaInvalidDeviceId); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_allocInternal(JNIEnv *env, jclass clazz, jlong size, + jlong stream) { + try { + cudf::jni::auto_set_device(env); + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource(); + auto c_stream = rmm::cuda_stream_view(reinterpret_cast(stream)); + void *ret = mr->allocate(size, c_stream); + return reinterpret_cast(ret); } - if (handler_obj != nullptr) { - auto resource = rmm::mr::get_current_device_resource(); - if (enable_debug) { - Java_memory_resource.reset(new java_debug_event_handler_memory_resource( - env, handler_obj, jalloc_thresholds, jdealloc_thresholds, resource)); - } else { - Java_memory_resource.reset(new java_event_handler_memory_resource( - env, handler_obj, jalloc_thresholds, jdealloc_thresholds, resource)); - } - auto replaced_resource = rmm::mr::set_current_device_resource(Java_memory_resource.get()); - if (resource != replaced_resource) { - rmm::mr::set_current_device_resource(replaced_resource); - Java_memory_resource.reset(nullptr); - JNI_THROW_NEW(env, RMM_EXCEPTION_CLASS, - "Concurrent modification detected while installing memory resource", ); - } + CATCH_STD(env, 0) +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_free(JNIEnv *env, jclass clazz, jlong ptr, + jlong size, jlong stream) { + try { + cudf::jni::auto_set_device(env); + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource(); + void *cptr = reinterpret_cast(ptr); + auto c_stream = rmm::cuda_stream_view(reinterpret_cast(stream)); + mr->deallocate(cptr, size, c_stream); } + CATCH_STD(env, ) } -// Need to keep both separate so we can shut them down appropriately -std::unique_ptr> Logging_memory_resource{}; -std::shared_ptr Initialized_resource{}; -std::unique_ptr Cuda_memory_resource{}; -} // anonymous namespace +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_freeDeviceBuffer(JNIEnv *env, jclass clazz, + jlong ptr) { + try { + cudf::jni::auto_set_device(env); + rmm::device_buffer *cptr = reinterpret_cast(ptr); + delete cptr; + } + CATCH_STD(env, ); +} -extern "C" { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_allocCudaInternal(JNIEnv *env, jclass clazz, + jlong size, jlong stream) { + try { + cudf::jni::auto_set_device(env); + void *ptr{nullptr}; + RMM_CUDA_TRY_ALLOC(cudaMalloc(&ptr, size)); + return reinterpret_cast(ptr); + } + CATCH_STD(env, 0) +} -JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_initializeInternal(JNIEnv *env, jclass clazz, - jint allocation_mode, jint log_to, - jstring jpath, jlong pool_size) { +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_freeCuda(JNIEnv *env, jclass clazz, jlong ptr, + jlong size, jlong stream) { try { - // make sure the CUDA device is setup in the context - cudaError_t cuda_status = cudaFree(0); - cudf::jni::jni_cuda_check(env, cuda_status); - int device_id; - cuda_status = cudaGetDevice(&device_id); - cudf::jni::jni_cuda_check(env, cuda_status); - - bool use_pool_alloc = allocation_mode & 1; - bool use_managed_mem = allocation_mode & 2; - bool use_arena_alloc = allocation_mode & 4; - bool use_cuda_async_alloc = allocation_mode & 8; - if (use_pool_alloc) { - if (use_managed_mem) { - Initialized_resource = rmm::mr::make_owning_wrapper( - std::make_shared(), pool_size, pool_size); - } else { - Initialized_resource = rmm::mr::make_owning_wrapper( - std::make_shared(), pool_size, pool_size); - } - } else if (use_arena_alloc) { - if (use_managed_mem) { - Initialized_resource = rmm::mr::make_owning_wrapper( - std::make_shared(), pool_size); - } else { - Initialized_resource = rmm::mr::make_owning_wrapper( - std::make_shared(), pool_size); - } - } else if (use_cuda_async_alloc) { - // Use `limiting_resource_adaptor` to set a hard limit on the max pool size since - // `cuda_async_memory_resource` only has a release threshold. - auto const alignment = 512; // Async allocator aligns to 512. - Initialized_resource = rmm::mr::make_owning_wrapper( - std::make_shared(pool_size, pool_size), pool_size, - alignment); - } else if (use_managed_mem) { - Initialized_resource = std::make_shared(); - } else { - Initialized_resource = std::make_shared(); - } + cudf::jni::auto_set_device(env); + void *cptr = reinterpret_cast(ptr); + RMM_ASSERT_CUDA_SUCCESS(cudaFree(cptr)); + } + CATCH_STD(env, ) +} - auto wrapped = make_tracking_adaptor(Initialized_resource.get(), RMM_ALLOC_SIZE_ALIGNMENT); - Tracking_memory_resource.reset(wrapped); +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newCudaMemoryResource(JNIEnv *env, jclass clazz) { + try { + cudf::jni::auto_set_device(env); + auto ret = new rmm::mr::cuda_memory_resource(); + return reinterpret_cast(ret); + } + CATCH_STD(env, 0) +} - auto resource = Tracking_memory_resource.get(); - rmm::mr::set_current_device_resource(resource); +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseCudaMemoryResource(JNIEnv *env, jclass clazz, + jlong ptr) { + try { + cudf::jni::auto_set_device(env); + auto mr = reinterpret_cast(ptr); + delete mr; + } + CATCH_STD(env, ) +} - std::unique_ptr> log_result; - switch (log_to) { - case 1: // File - { - cudf::jni::native_jstring path(env, jpath); - log_result.reset(new logging_resource_adaptor( - resource, path.get(), /*auto_flush=*/true)); - } break; - case 2: // stdout - log_result.reset(new logging_resource_adaptor( - resource, std::cout, /*auto_flush=*/true)); - break; - case 3: // stderr - log_result.reset(new logging_resource_adaptor( - resource, std::cerr, /*auto_flush=*/true)); - break; - } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newManagedMemoryResource(JNIEnv *env, + jclass clazz) { + try { + cudf::jni::auto_set_device(env); + auto ret = new rmm::mr::managed_memory_resource(); + return reinterpret_cast(ret); + } + CATCH_STD(env, 0) +} - if (log_result) { - if (Logging_memory_resource) { - JNI_THROW_NEW(env, RMM_EXCEPTION_CLASS, "Internal Error logging is double enabled", ) - } +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseManagedMemoryResource(JNIEnv *env, + jclass clazz, + jlong ptr) { + try { + cudf::jni::auto_set_device(env); + auto mr = reinterpret_cast(ptr); + delete mr; + } + CATCH_STD(env, ) +} - Logging_memory_resource = std::move(log_result); - auto replaced_resource = rmm::mr::set_current_device_resource(Logging_memory_resource.get()); - if (resource != replaced_resource) { - rmm::mr::set_current_device_resource(replaced_resource); - Logging_memory_resource.reset(nullptr); - JNI_THROW_NEW(env, RMM_EXCEPTION_CLASS, - "Concurrent modification detected while installing memory resource", ); - } - } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newPoolMemoryResource(JNIEnv *env, jclass clazz, + jlong child, jlong init, + jlong max) { + JNI_NULL_CHECK(env, child, "child is null", 0); + try { + cudf::jni::auto_set_device(env); + auto wrapped = reinterpret_cast(child); + auto ret = + new rmm::mr::pool_memory_resource(wrapped, init, max); + return reinterpret_cast(ret); + } + CATCH_STD(env, 0) +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releasePoolMemoryResource(JNIEnv *env, jclass clazz, + jlong ptr) { + try { + cudf::jni::auto_set_device(env); + auto mr = + reinterpret_cast *>(ptr); + delete mr; + } + CATCH_STD(env, ) +} - Cuda_memory_resource = std::make_unique(); +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newArenaMemoryResource(JNIEnv *env, jclass clazz, + jlong child, jlong init, + jboolean dump_on_oom) { + JNI_NULL_CHECK(env, child, "child is null", 0); + try { + cudf::jni::auto_set_device(env); + auto wrapped = reinterpret_cast(child); + auto ret = new rmm::mr::arena_memory_resource(wrapped, init, + dump_on_oom); + return reinterpret_cast(ret); + } + CATCH_STD(env, 0) +} - // Now that RMM has successfully initialized, setup all threads calling - // cudf to use the same device RMM is using. - cudf::jni::set_cudf_device(device_id); +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseArenaMemoryResource(JNIEnv *env, jclass clazz, + jlong ptr) { + try { + cudf::jni::auto_set_device(env); + auto mr = + reinterpret_cast *>(ptr); + delete mr; } CATCH_STD(env, ) } -JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_shutdownInternal(JNIEnv *env, jclass clazz) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newCudaAsyncMemoryResource(JNIEnv *env, + jclass clazz, + jlong child, jlong init, + jlong release) { + JNI_NULL_CHECK(env, child, "child is null", 0); try { cudf::jni::auto_set_device(env); - set_java_device_memory_resource(env, nullptr, nullptr, nullptr, false); - // Instead of trying to undo all of the adaptors that we added in reverse order - // we just reset the base adaptor so the others will not be called any more - // and then clean them up in really any order. There should be no interaction with - // RMM during this time anyways. - Initialized_resource = std::make_shared(); - rmm::mr::set_current_device_resource(Initialized_resource.get()); - Logging_memory_resource.reset(nullptr); - Tracking_memory_resource.reset(nullptr); - Cuda_memory_resource.reset(nullptr); - cudf::jni::set_cudf_device(cudaInvalidDeviceId); + auto ret = new rmm::mr::cuda_async_memory_resource(init, release); + return reinterpret_cast(ret); + } + CATCH_STD(env, 0) +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseCudaAsyncMemoryResource(JNIEnv *env, + jclass clazz, + jlong ptr) { + try { + cudf::jni::auto_set_device(env); + auto mr = reinterpret_cast(ptr); + delete mr; } CATCH_STD(env, ) } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_getTotalBytesAllocated(JNIEnv *env, jclass) { - return get_total_bytes_allocated(); +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newLimitingResourceAdaptor(JNIEnv *env, + jclass clazz, + jlong child, jlong limit, + jlong align) { + JNI_NULL_CHECK(env, child, "child is null", 0); + try { + cudf::jni::auto_set_device(env); + auto wrapped = reinterpret_cast(child); + auto ret = new rmm::mr::limiting_resource_adaptor( + wrapped, limit, align); + return reinterpret_cast(ret); + } + CATCH_STD(env, 0) } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_getMaximumTotalBytesAllocated(JNIEnv *env, jclass) { - return get_max_total_allocated(); +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseLimitingResourceAdaptor(JNIEnv *env, + jclass clazz, + jlong ptr) { + try { + cudf::jni::auto_set_device(env); + auto mr = + reinterpret_cast *>( + ptr); + delete mr; + } + CATCH_STD(env, ) } -JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_resetScopedMaximumBytesAllocatedInternal( - JNIEnv *env, jclass, jlong initialValue) { - reset_scoped_max_total_allocated(initialValue); +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newLoggingResourceAdaptor(JNIEnv *env, jclass clazz, + jlong child, jint type, + jstring jpath, + jboolean auto_flush) { + JNI_NULL_CHECK(env, child, "child is null", 0); + try { + cudf::jni::auto_set_device(env); + auto wrapped = reinterpret_cast(child); + switch (type) { + case 1: // File + { + cudf::jni::native_jstring path(env, jpath); + auto ret = new logging_resource_adaptor( + wrapped, path.get(), auto_flush); + return reinterpret_cast(ret); + } + case 2: // stdout + { + auto ret = new logging_resource_adaptor(wrapped, std::cout, + auto_flush); + return reinterpret_cast(ret); + } + case 3: // stderr + { + auto ret = new logging_resource_adaptor(wrapped, std::cerr, + auto_flush); + return reinterpret_cast(ret); + } + default: throw std::logic_error("unsupported logging location type"); + } + } + CATCH_STD(env, 0) } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_getScopedMaximumBytesAllocated(JNIEnv *env, - jclass) { - return get_scoped_max_total_allocated(); +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseLoggingResourceAdaptor(JNIEnv *env, + jclass clazz, + jlong ptr) { + try { + cudf::jni::auto_set_device(env); + auto mr = + reinterpret_cast *>(ptr); + delete mr; + } + CATCH_STD(env, ) } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_allocInternal(JNIEnv *env, jclass clazz, jlong size, - jlong stream) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newTrackingResourceAdaptor(JNIEnv *env, + jclass clazz, + jlong child, + jlong align) { + JNI_NULL_CHECK(env, child, "child is null", 0); try { cudf::jni::auto_set_device(env); - rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource(); - auto c_stream = rmm::cuda_stream_view(reinterpret_cast(stream)); - void *ret = mr->allocate(size, c_stream); + auto wrapped = reinterpret_cast(child); + auto ret = new tracking_resource_adaptor(wrapped, align); return reinterpret_cast(ret); } CATCH_STD(env, 0) } -JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_free(JNIEnv *env, jclass clazz, jlong ptr, - jlong size, jlong stream) { +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseTrackingResourceAdaptor(JNIEnv *env, + jclass clazz, + jlong ptr) { try { cudf::jni::auto_set_device(env); - rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource(); - void *cptr = reinterpret_cast(ptr); - auto c_stream = rmm::cuda_stream_view(reinterpret_cast(stream)); - mr->deallocate(cptr, size, c_stream); + auto mr = reinterpret_cast *>(ptr); + delete mr; } CATCH_STD(env, ) } -JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_freeDeviceBuffer(JNIEnv *env, jclass clazz, - jlong ptr) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_nativeGetTotalBytesAllocated(JNIEnv *env, + jclass clazz, + jlong ptr) { + JNI_NULL_CHECK(env, ptr, "adaptor is null", 0); try { cudf::jni::auto_set_device(env); - rmm::device_buffer *cptr = reinterpret_cast(ptr); - delete cptr; + auto mr = reinterpret_cast *>(ptr); + return mr->get_total_allocated(); } - CATCH_STD(env, ); + CATCH_STD(env, 0) } -JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_setEventHandlerInternal( - JNIEnv *env, jclass, jobject handler_obj, jlongArray jalloc_thresholds, - jlongArray jdealloc_thresholds, jboolean enable_debug) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_nativeGetMaxTotalBytesAllocated(JNIEnv *env, + jclass clazz, + jlong ptr) { + JNI_NULL_CHECK(env, ptr, "adaptor is null", 0); try { - set_java_device_memory_resource(env, handler_obj, jalloc_thresholds, jdealloc_thresholds, - enable_debug); + cudf::jni::auto_set_device(env); + auto mr = reinterpret_cast *>(ptr); + return mr->get_max_total_allocated(); + } + CATCH_STD(env, 0) +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_nativeResetScopedMaxTotalBytesAllocated(JNIEnv *env, + jclass clazz, + jlong ptr, + jlong init) { + JNI_NULL_CHECK(env, ptr, "adaptor is null", ); + try { + cudf::jni::auto_set_device(env); + auto mr = reinterpret_cast *>(ptr); + mr->reset_scoped_max_total_allocated(init); } CATCH_STD(env, ) } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_allocCudaInternal(JNIEnv *env, jclass clazz, - jlong size, jlong stream) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_nativeGetScopedMaxTotalBytesAllocated(JNIEnv *env, + jclass clazz, + jlong ptr) { + JNI_NULL_CHECK(env, ptr, "adaptor is null", 0); try { cudf::jni::auto_set_device(env); - auto c_stream = rmm::cuda_stream_view(reinterpret_cast(stream)); - void *ret = Cuda_memory_resource->allocate(size, c_stream); - return reinterpret_cast(ret); + auto mr = reinterpret_cast *>(ptr); + return mr->get_scoped_max_total_allocated(); } CATCH_STD(env, 0) } -JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_freeCuda(JNIEnv *env, jclass clazz, jlong ptr, - jlong size, jlong stream) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newEventHandlerResourceAdaptor( + JNIEnv *env, jclass, jlong child, jlong tracker, jobject handler_obj, + jlongArray jalloc_thresholds, jlongArray jdealloc_thresholds, jboolean enable_debug) { + JNI_NULL_CHECK(env, child, "child is null", 0); + JNI_NULL_CHECK(env, tracker, "tracker is null", 0); + try { + auto wrapped = reinterpret_cast(child); + auto t = + reinterpret_cast *>(tracker); + if (enable_debug) { + auto ret = new java_debug_event_handler_memory_resource(env, handler_obj, jalloc_thresholds, + jdealloc_thresholds, wrapped, t); + return reinterpret_cast(ret); + } else { + auto ret = new java_event_handler_memory_resource(env, handler_obj, jalloc_thresholds, + jdealloc_thresholds, wrapped, t); + return reinterpret_cast(ret); + } + } + CATCH_STD(env, 0) +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseEventHandlerResourceAdaptor( + JNIEnv *env, jclass clazz, jlong ptr, jboolean enable_debug) { try { cudf::jni::auto_set_device(env); - void *cptr = reinterpret_cast(ptr); - auto c_stream = rmm::cuda_stream_view(reinterpret_cast(stream)); - Cuda_memory_resource->deallocate(cptr, size, c_stream); + if (enable_debug) { + auto mr = reinterpret_cast(ptr); + delete mr; + } else { + auto mr = reinterpret_cast(ptr); + delete mr; + } + } + CATCH_STD(env, ) +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_setCurrentDeviceResourceInternal(JNIEnv *env, + jclass clazz, + jlong new_handle) { + try { + cudf::jni::auto_set_device(env); + auto mr = reinterpret_cast(new_handle); + rmm::mr::set_current_device_resource(mr); } CATCH_STD(env, ) } diff --git a/java/src/test/java/ai/rapids/cudf/RmmTest.java b/java/src/test/java/ai/rapids/cudf/RmmTest.java index c081f51c9f2..352f17e6174 100644 --- a/java/src/test/java/ai/rapids/cudf/RmmTest.java +++ b/java/src/test/java/ai/rapids/cudf/RmmTest.java @@ -51,6 +51,45 @@ public void teardown() { } } + @Test + public void testCreateAdaptors() { + final long poolSize = 32 * 1024 * 1024; // 32 MiB + try (RmmCudaMemoryResource r = new RmmCudaMemoryResource()) { + assert(r.getHandle() != 0); + } + try (RmmCudaAsyncMemoryResource r = new RmmCudaAsyncMemoryResource(poolSize, poolSize)) { + assert(r.getHandle() != 0); + } + try (RmmManagedMemoryResource r = new RmmManagedMemoryResource()) { + assert(r.getHandle() != 0); + } + try (RmmArenaMemoryResource r = + new RmmArenaMemoryResource<>(new RmmCudaMemoryResource(), poolSize, false)) { + assert(r.getHandle() != 0); + } + try (RmmPoolMemoryResource r = + new RmmPoolMemoryResource<>(new RmmCudaMemoryResource(), poolSize, poolSize)) { + assert(r.getHandle() != 0); + } + try (RmmLimitingResourceAdaptor r = + new RmmLimitingResourceAdaptor<>(new RmmCudaMemoryResource(), poolSize, 64)) { + assert(r.getHandle() != 0); + } + try (RmmLoggingResourceAdaptor r = + new RmmLoggingResourceAdaptor<>(new RmmCudaMemoryResource(), Rmm.logToStderr(), true)) { + assert(r.getHandle() != 0); + } + try (RmmTrackingResourceAdaptor r = + new RmmTrackingResourceAdaptor<>(new RmmCudaMemoryResource(), 64)) { + assert(r.getHandle() != 0); + assert(r.getTotalBytesAllocated() == 0); + assert(r.getMaxTotalBytesAllocated() == 0); + assert(r.getScopedMaxTotalBytesAllocated() == 0); + r.resetScopedMaxTotalBytesAllocated(1024); + assert(r.getScopedMaxTotalBytesAllocated() == 1024); + } + } + @ParameterizedTest @ValueSource(ints = { RmmAllocationMode.CUDA_DEFAULT,