Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JNI] rmm based pinned pool #15219

Merged
merged 8 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 29 additions & 252 deletions java/src/main/java/ai/rapids/cudf/PinnedMemoryPool.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,109 +22,46 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Comparator;
import java.util.Iterator;
import java.util.Objects;
import java.util.Optional;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

/**
* This provides a pool of pinned memory similar to what RMM does for device memory.
* This is the JNI interface to a rmm::pool_memory_resource<rmm::pinned_host_memory_resource>.
*/
public final class PinnedMemoryPool implements AutoCloseable {
private static final Logger log = LoggerFactory.getLogger(PinnedMemoryPool.class);
private static final long ALIGNMENT = ColumnView.hostPaddingSizeInBytes();

// These static fields should only ever be accessed when class-synchronized.
// Do NOT use singleton_ directly! Use the getSingleton accessor instead.
private static volatile PinnedMemoryPool singleton_ = null;
private static Future<PinnedMemoryPool> initFuture = null;

private final long totalPoolSize;
private final long pinnedPoolBase;
private final SortedSet<MemorySection> freeHeap = new TreeSet<>(new SortedByAddress());
private int numAllocatedSections = 0;
private long availableBytes;

private static class SortedBySize implements Comparator<MemorySection> {
@Override
public int compare(MemorySection s0, MemorySection s1) {
return Long.compare(s0.size, s1.size);
}
}

private static class SortedByAddress implements Comparator<MemorySection> {
@Override
public int compare(MemorySection s0, MemorySection s1) {
return Long.compare(s0.baseAddress, s1.baseAddress);
}
}

private static class MemorySection {
private long baseAddress;
private long size;

MemorySection(long baseAddress, long size) {
this.baseAddress = baseAddress;
this.size = size;
}

boolean canCombine(MemorySection other) {
boolean ret = (other.baseAddress + other.size) == baseAddress ||
(baseAddress + size) == other.baseAddress;
log.trace("CAN {} COMBINE WITH {} ? {}", this, other, ret);
return ret;
}

void combineWith(MemorySection other) {
assert canCombine(other);
log.trace("COMBINING {} AND {}", this, other);
this.baseAddress = Math.min(baseAddress, other.baseAddress);
this.size = other.size + this.size;
log.trace("COMBINED TO {}\n", this);
}

MemorySection splitOff(long newSize) {
assert this.size > newSize;
MemorySection ret = new MemorySection(baseAddress, newSize);
this.baseAddress += newSize;
this.size -= newSize;
return ret;
}

@Override
public String toString() {
return "PINNED: " + size + " bytes (0x" + Long.toHexString(baseAddress)
+ " to 0x" + Long.toHexString(baseAddress + size) + ")";
}
}
private long poolHandle;
private long poolSize;

private static final class PinnedHostBufferCleaner extends MemoryBuffer.MemoryBufferCleaner {
private MemorySection section;
private long address;
jbrennan333 marked this conversation as resolved.
Show resolved Hide resolved
private final long origLength;

PinnedHostBufferCleaner(MemorySection section, long length) {
this.section = section;
PinnedHostBufferCleaner(long address, long length) {
this.address = address;
origLength = length;
}

@Override
protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) {
boolean neededCleanup = false;
long origAddress = 0;
if (section != null) {
origAddress = section.baseAddress;
if (address != -1) {
origAddress = address;
try {
PinnedMemoryPool.freeInternal(section);
PinnedMemoryPool.freeInternal(address, origLength);
} finally {
// Always mark the resource as freed even if an exception is thrown.
// We cannot know how far it progressed before the exception, and
// therefore it is unsafe to retry.
section = null;
address = -1;
}
neededCleanup = true;
}
Expand All @@ -137,7 +74,7 @@ protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) {

@Override
public boolean isClean() {
return section == null;
return address == -1;
}
}

Expand All @@ -161,16 +98,8 @@ private static PinnedMemoryPool getSingleton() {
return singleton_;
}

private static void freeInternal(MemorySection section) {
Objects.requireNonNull(getSingleton()).free(section);
}

/**
* Used to indicate that memory was allocated from a reservation. This primarily is for
* keeping track of outstanding allocations.
*/
private static void reserveAllocInternal(MemorySection section) {
Objects.requireNonNull(getSingleton()).reserveAllocHappened(section);
private static void freeInternal(long address, long origLength) {
Objects.requireNonNull(getSingleton()).free(address, origLength);
}

/**
Expand Down Expand Up @@ -209,12 +138,14 @@ public static boolean isInitialized() {
}

/**
* Shut down the pool of memory. If there are outstanding allocations this may fail.
* Shut down the RMM pool_memory_resource, nulling out our reference. Any allocation
* or free that is in flight will fail after this.
*/
public static synchronized void shutdown() {
PinnedMemoryPool pool = getSingleton();
if (pool != null) {
pool.close();
pool = null;
}
initFuture = null;
singleton_ = null;
Expand All @@ -235,21 +166,6 @@ public static HostMemoryBuffer tryAllocate(long bytes) {
return result;
}

/**
* Factory method to create a pinned host memory reservation.
*
* @param bytes size in bytes to reserve
* @return newly created reservation or null if insufficient pinned memory to cover it.
*/
public static HostMemoryReservation tryReserve(long bytes) {
HostMemoryReservation result = null;
PinnedMemoryPool pool = getSingleton();
if (pool != null) {
result = pool.tryReserveInternal(bytes);
}
return result;
}

/**
* Factory method to create a host buffer but preferably pointing to pinned memory.
* It is not guaranteed that the returned buffer will be pointer to pinned memory.
Expand All @@ -276,26 +192,13 @@ public static HostMemoryBuffer allocate(long bytes) {
return allocate(bytes, DefaultHostMemoryAllocator.get());
}

/**
* Get the number of bytes free in the pinned memory pool.
*
* @return amount of free memory in bytes or 0 if the pool is not initialized
*/
public static long getAvailableBytes() {
PinnedMemoryPool pool = getSingleton();
if (pool != null) {
return pool.getAvailableBytesInternal();
}
return 0;
}

/**
* Get the number of bytes that the pinned memory pool was allocated with.
*/
public static long getTotalPoolSizeBytes() {
PinnedMemoryPool pool = getSingleton();
if (pool != null) {
return pool.getTotalPoolSizeInternal();
return pool.poolSize;
}
return 0;
}
Expand All @@ -306,157 +209,31 @@ private PinnedMemoryPool(long poolSize, int gpuId) {
Cuda.setDevice(gpuId);
Cuda.freeZero();
}
this.totalPoolSize = poolSize;
this.pinnedPoolBase = Cuda.hostAllocPinned(poolSize);
freeHeap.add(new MemorySection(pinnedPoolBase, poolSize));
this.availableBytes = poolSize;
this.poolHandle = Rmm.newPinnedPoolMemoryResource(poolSize, poolSize);
this.poolSize = poolSize;
}

@Override
public void close() {
assert numAllocatedSections == 0 : "Leaked " + numAllocatedSections + " pinned allocations";
Cuda.freePinned(pinnedPoolBase);
Rmm.releasePinnedPoolMemoryResource(this.poolHandle);
this.poolHandle = -1;
}

/**
* Pads a length of bytes to the alignment the CPU wants in the worst case. This helps to
* calculate the size needed for a reservation if there are multiple buffers.
* @param bytes the size in bytes
* @return the new padded size in bytes.
* This makes an attempt to allocate pinned memory, and if the pinned memory allocation fails
* it will return null, instead of throw.
*/
public static long padToCpuAlignment(long bytes) {
return ((bytes + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT;
}

private synchronized MemorySection tryGetInternal(long bytes, String what) {
if (freeHeap.isEmpty()) {
log.debug("No free pinned memory left");
return null;
}
// Align the allocation
long alignedBytes = padToCpuAlignment(bytes);
Optional<MemorySection> firstFit = freeHeap.stream()
.filter(section -> section.size >= alignedBytes)
.findFirst();
if (!firstFit.isPresent()) {
if (log.isDebugEnabled()) {
MemorySection largest = freeHeap.stream()
.max(new SortedBySize())
.orElse(new MemorySection(0, 0));
log.debug("Insufficient pinned memory. {} needed, {} found", alignedBytes, largest.size);
}
return null;
}
MemorySection first = firstFit.get();
log.debug("{} {}/{} bytes pinned from {} FREE COUNT {} OUTSTANDING COUNT {}",
what, bytes, alignedBytes, first, freeHeap.size(), numAllocatedSections);
freeHeap.remove(first);
MemorySection allocated;
if (first.size == alignedBytes) {
allocated = first;
} else {
allocated = first.splitOff(alignedBytes);
freeHeap.add(first);
}
numAllocatedSections++;
availableBytes -= allocated.size;
log.debug("{} {} free {} outstanding {}", what, allocated, freeHeap, numAllocatedSections);
return allocated;
}

private synchronized HostMemoryBuffer tryAllocateInternal(long bytes) {
MemorySection allocated = tryGetInternal(bytes, "allocate");
if (allocated == null) {
long allocated = Rmm.allocFromPinnedPool(this.poolHandle, bytes);
if (allocated == -1) {
return null;
} else {
return new HostMemoryBuffer(allocated.baseAddress, bytes,
return new HostMemoryBuffer(allocated, bytes,
new PinnedHostBufferCleaner(allocated, bytes));
}
}

private class PinnedReservation implements HostMemoryReservation {
private MemorySection section = null;

public PinnedReservation(MemorySection section) {
this.section = section;
}

@Override
public synchronized HostMemoryBuffer allocate(long bytes, boolean preferPinned) {
return this.allocate(bytes);
}

@Override
public synchronized HostMemoryBuffer allocate(long bytes) {
if (section == null || section.size < bytes) {
throw new OutOfMemoryError("Reservation didn't have enough space " + bytes + " / " +
(section == null ? 0 : section.size));
}
long alignedSize = padToCpuAlignment(bytes);
MemorySection allocated;
if (section.size >= bytes && section.size <= alignedSize) {
allocated = section;
section = null;
// No need for reserveAllocInternal because the original section is already tracked
} else {
allocated = section.splitOff(alignedSize);
PinnedMemoryPool.reserveAllocInternal(allocated);
}
return new HostMemoryBuffer(allocated.baseAddress, bytes,
new PinnedHostBufferCleaner(allocated, bytes));
}

@Override
public synchronized void close() throws Exception {
if (section != null) {
try {
PinnedMemoryPool.freeInternal(section);
} finally {
// Always mark the resource as freed even if an exception is thrown.
// We cannot know how far it progressed before the exception, and
// therefore it is unsafe to retry.
section = null;
}
}
}
}

private HostMemoryReservation tryReserveInternal(long bytes) {
MemorySection allocated = tryGetInternal(bytes, "allocate");
if (allocated == null) {
return null;
} else {
return new PinnedReservation(allocated);
}
}

private synchronized void free(MemorySection section) {
log.debug("Freeing {} with {} outstanding {}", section, freeHeap, numAllocatedSections);
availableBytes += section.size;
Iterator<MemorySection> it = freeHeap.iterator();
while(it.hasNext()) {
MemorySection current = it.next();
if (section.canCombine(current)) {
it.remove();
section.combineWith(current);
}
}
freeHeap.add(section);
numAllocatedSections--;
log.debug("After freeing {} outstanding {}", freeHeap, numAllocatedSections);
}

private synchronized void reserveAllocHappened(MemorySection section) {
if (section != null && section.size > 0) {
numAllocatedSections++;
}
}

private synchronized long getAvailableBytesInternal() {
return this.availableBytes;
}

private long getTotalPoolSizeInternal() {
return this.totalPoolSize;
private synchronized void free(long address, long size) {
Rmm.freeFromPinnedPool(this.poolHandle, address, size);
}
}
Loading
Loading