diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index 8eeb047ddc..e09ef0dfdb 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -1780,6 +1780,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const thread = threads.find(tid); if (thread != threads.end()) { log_status("DEALLOC", tid, thread->second.task_id, thread->second.state); + if (!is_for_cpu) { thread->second.gpu_memory_allocated_bytes -= num_bytes; } } else { log_status("DEALLOC", tid, -2, thread_state::UNKNOWN); } @@ -1802,7 +1803,6 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (is_for_cpu == t_state.is_cpu_alloc) { transition(t_state, thread_state::THREAD_ALLOC_FREE); } - if (!is_for_cpu) { t_state.gpu_memory_allocated_bytes -= num_bytes; } break; default: break; } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java b/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java index 987dd58534..270a4266cd 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java @@ -360,7 +360,7 @@ public void testInsertOOMsGpu() { assertThrows(GpuSplitAndRetryOOM.class, () -> Rmm.alloc(100).close()); assertEquals(0, RmmSpark.getAndResetNumRetryThrow(taskid)); assertEquals(1, RmmSpark.getAndResetNumSplitRetryThrow(taskid)); - assertEquals(ALIGNMENT * 2, RmmSpark.getAndResetGpuMaxMemoryAllocated(taskid)); + assertEquals(ALIGNMENT, RmmSpark.getAndResetGpuMaxMemoryAllocated(taskid)); // Verify that injecting OOM does not cause the block to actually happen assertEquals(RmmSparkThreadState.THREAD_RUNNING, RmmSpark.getStateOf(threadId)); @@ -818,6 +818,11 @@ public void testBasicMixedBlocking() throws ExecutionException, InterruptedExcep secondGpuAlloc.waitForAlloc(); secondGpuAlloc.freeAndWait(); } + // Do one more alloc after freeing on same task to show the max allocation metric is unimpacted + try (AllocOnAnotherThread secondGpuAlloc = new GpuAllocOnAnotherThread(taskThree, FIVE_MB)) { + secondGpuAlloc.waitForAlloc(); + secondGpuAlloc.freeAndWait(); + } } } } finally {