Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpu memory watermark apis to JNI #11950

Merged
merged 15 commits into from
Oct 24, 2022
Merged
86 changes: 86 additions & 0 deletions java/src/main/java/ai/rapids/cudf/GpuMemoryTracker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.rapids.cudf;

import java.util.Optional;

/**
* This is a helper class to track the maximum amount of GPU memory outstanding
* for the current thread (stream in PTDS). If free ocurrs while tracking, and the
* free is for memory that wasn't created in the scope, or it was created in a different
* thread, it will be ignored.
*
* The constructor enables a new memory tracking scope and .close stops tracking, and collects
* the result.
*
* If `ai.rapids.cudf.gpuMemoryTracking.enabled` is false (default), the result of
* `getMaxOutstanding` is an empty java Optional<long>.
*
* Usage:
*
* <pre>
* try (GpuMemoryTracker a = new GpuMemoryTracker()) {
* ...
* try (GpuMemoryTracker b = new GpuMemoryTracker()) {
* ...
* // bMaxMemory is the maximum memory used while b is not closed
* Optional<long> bMaxMemory = b.getMaxOutsanding();
* }
* ...
*
* // aMaxMemory is the maximum memory used while a is not closed
* // which includes bMaxMemory.
* Optional<long> aMaxMemory = a.getMaxOutsanding();
* }
* </pre>
*
* Instances should be associated with a single thread and should be at a fine
* granularity. Tracking memory when there could be free of buffers created in different
* streams will have undeserired results.
*/
public class GpuMemoryTracker implements AutoCloseable {
private static final boolean isEnabled =
Boolean.getBoolean("ai.rapids.cudf.gpuMemoryTracking.enabled");

private long maxOutstanding;

static {
if (isEnabled) {
NativeDepsLoader.loadNativeDeps();
}
}

public GpuMemoryTracker() {
if (isEnabled) {
Rmm.pushThreadMemoryTracker();
}
}

@Override
public void close() {
if (isEnabled) {
maxOutstanding = Rmm.popThreadMemoryTracker();
}
}

public Optional<Long> getMaxOutstanding() {
if (isEnabled) {
return Optional.of(maxOutstanding);
} else {
return Optional.empty();
}
}
}
2 changes: 2 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Rmm.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ public static boolean isInitialized() throws RmmException {
* the result will always be a lower bound on the amount allocated.
*/
public static native long getTotalBytesAllocated();
public static native void pushThreadMemoryTracker();
public static native long popThreadMemoryTracker();

/**
* Sets the event handler to be called on RMM events (e.g.: allocation failure).
Expand Down
69 changes: 69 additions & 0 deletions java/src/main/native/src/RmmJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <fstream>
#include <iostream>
#include <limits>
#include <stack>

#include <rmm/mr/device/aligned_resource_adaptor.hpp>
#include <rmm/mr/device/arena_memory_resource.hpp>
Expand Down Expand Up @@ -50,8 +51,18 @@ constexpr char const *RMM_EXCEPTION_CLASS = "ai/rapids/cudf/RmmException";
class base_tracking_resource_adaptor : public device_memory_resource {
public:
virtual std::size_t get_total_allocated() = 0;
virtual void push_thread_memory_tracker() = 0;
virtual long pop_thread_memory_tracker() = 0;
};

struct memory_tracker {
long current_outstanding;
long max_outstanding;
};

thread_local std::stack<memory_tracker> memory_tracker_stack = std::stack<memory_tracker>();
thread_local std::unordered_map<long, std::size_t> alloc_map;

/**
* @brief An RMM device memory resource that delegates to another resource
* while tracking the amount of memory allocated.
Expand Down Expand Up @@ -79,18 +90,54 @@ class tracking_resource_adaptor final : public base_tracking_resource_adaptor {

std::size_t get_total_allocated() override { return total_allocated.load(); }

void push_thread_memory_tracker() override { memory_tracker_stack.emplace(); }

long pop_thread_memory_tracker() override {
auto top_tracker = memory_tracker_stack.top();
auto ret = top_tracker.max_outstanding;
memory_tracker_stack.pop();
if (memory_tracker_stack.empty()) {
alloc_map.clear();
} else {
// carry the max to the next level
memory_tracker_stack.top().max_outstanding += ret;
}
return ret;
}

private:
Upstream *const resource;
std::size_t const size_align;
std::atomic_size_t total_allocated{0};

void thread_allocated(long addr, std::size_t num_bytes) {
if (!memory_tracker_stack.empty()) {
alloc_map[addr] = num_bytes;
memory_tracker &tracker = memory_tracker_stack.top();
tracker.current_outstanding += num_bytes;
tracker.max_outstanding = std::max(tracker.current_outstanding, tracker.max_outstanding);
}
}

void thread_freed(long addr, std::size_t num_bytes) {
if (!memory_tracker_stack.empty()) {
auto it = alloc_map.find(addr);
if (it != alloc_map.end()) {
auto tracker = memory_tracker_stack.top();
tracker.current_outstanding -= it->second;
alloc_map.erase(it);
}
}
}
abellina marked this conversation as resolved.
Show resolved Hide resolved

void *do_allocate(std::size_t num_bytes, rmm::cuda_stream_view stream) override {
// adjust size of allocation based on specified size alignment
num_bytes = (num_bytes + size_align - 1) / size_align * size_align;

auto result = resource->allocate(num_bytes, stream);
if (result) {
total_allocated += num_bytes;
thread_allocated(reinterpret_cast<long>(result), num_bytes);
abellina marked this conversation as resolved.
Show resolved Hide resolved
}
return result;
}
Expand All @@ -102,6 +149,7 @@ class tracking_resource_adaptor final : public base_tracking_resource_adaptor {

if (p) {
total_allocated -= size;
thread_freed(reinterpret_cast<long>(p), size);
}
}

Expand Down Expand Up @@ -132,6 +180,19 @@ std::size_t get_total_bytes_allocated() {
return 0;
}

void push_thread_memory_tracker() {
if (Tracking_memory_resource) {
Tracking_memory_resource->push_thread_memory_tracker();
}
}

long pop_thread_memory_tracker() {
if (Tracking_memory_resource) {
return Tracking_memory_resource->pop_thread_memory_tracker();
}
return 0;
}

/**
* @brief An RMM device memory resource adaptor that delegates to the wrapped resource
* for most operations but will call Java to handle certain situations (e.g.: allocation failure).
Expand Down Expand Up @@ -455,6 +516,14 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_getTotalBytesAllocated(JNIEnv *e
return get_total_bytes_allocated();
}

JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_pushThreadMemoryTracker(JNIEnv *env, jclass) {
push_thread_memory_tracker();
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_popThreadMemoryTracker(JNIEnv *env, jclass) {
return pop_thread_memory_tracker();
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_allocInternal(JNIEnv *env, jclass clazz, jlong size,
jlong stream) {
try {
Expand Down