diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java index 13ffc367b28..795d93e9a4f 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java @@ -99,6 +99,7 @@ public class MemoryManager { private boolean pinnedMemoryCheckEnabled; private long pinnedMemoryCheckInterval; private long pinnedMemoryLastCheckTime = 0; + private boolean resumingByPinnedMemory = false; @VisibleForTesting public static MemoryManager initialize(CelebornConf conf) { @@ -336,23 +337,23 @@ public void switchServingState() { } switch (servingState) { case PUSH_PAUSED: - if (canResumeByPinnedMemory()) { - resumeByPinnedMemory(servingState); - } else { + if (!tryResumeByPinnedMemory(servingState, lastState)) { pausePushDataCounter.increment(); if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) { - logger.info("Trigger action: RESUME REPLICATE"); resumeReplicate(); } else { logger.info("Trigger action: PAUSE PUSH"); pausePushDataStartTime = System.currentTimeMillis(); + resumingByPinnedMemory = false; memoryPressureListeners.forEach( memoryPressureListener -> memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); + // trimCounter cannot be increased when channels resume by PinnedMemory, otherwise + // PauseSpentTime will be increased unexpectedly + trimCounter += 1; } } logger.debug("Trigger action: TRIM"); - trimCounter += 1; trimAllListeners(); if (trimCounter >= forceAppendPauseSpentTimeThreshold) { logger.debug( @@ -361,12 +362,11 @@ public void switchServingState() { } break; case PUSH_AND_REPLICATE_PAUSED: - if (canResumeByPinnedMemory()) { - resumeByPinnedMemory(servingState); - } else { + if (!tryResumeByPinnedMemory(servingState, lastState)) { pausePushDataAndReplicateCounter.increment(); logger.info("Trigger action: PAUSE PUSH"); pausePushDataAndReplicateStartTime = System.currentTimeMillis(); + resumingByPinnedMemory = false; memoryPressureListeners.forEach( memoryPressureListener -> memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); @@ -374,9 +374,9 @@ public void switchServingState() { memoryPressureListeners.forEach( memoryPressureListener -> memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE)); + trimCounter += 1; } logger.debug("Trigger action: TRIM"); - trimCounter += 1; trimAllListeners(); if (trimCounter >= forceAppendPauseSpentTimeThreshold) { logger.debug( @@ -386,6 +386,7 @@ public void switchServingState() { break; case NONE_PAUSED: // resume from paused mode, append pause spent time + resumingByPinnedMemory = false; if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) { resumeReplicate(); resumePush(); @@ -599,15 +600,32 @@ private void resumeByPinnedMemory(ServingState servingState) { } } - private boolean canResumeByPinnedMemory() { - if (pinnedMemoryCheckEnabled - && System.currentTimeMillis() - pinnedMemoryLastCheckTime >= pinnedMemoryCheckInterval - && getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) { - pinnedMemoryLastCheckTime = System.currentTimeMillis(); - return true; - } else { + private boolean tryResumeByPinnedMemory(ServingState currentState, ServingState lastState) { + boolean success = false; + if (!pinnedMemoryCheckEnabled) { return false; } + long currentTime = System.currentTimeMillis(); + if (currentTime - pinnedMemoryLastCheckTime >= pinnedMemoryCheckInterval) { + if (getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) { + pinnedMemoryLastCheckTime = currentTime; + resumingByPinnedMemory = true; + resumeByPinnedMemory(currentState); + success = true; + } + } else { + if (resumingByPinnedMemory + && lastState != ServingState.NONE_PAUSED + && getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) { + // do nothing, keep resume for a while + logger.info( + "currentState: {}, keep resume for {}ms after last resumeByPinnedMemory", + currentState, + currentTime - pinnedMemoryLastCheckTime); + success = true; + } + } + return success; } private void resumePush() { diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala index c0fe08e6173..d74f1889e3f 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala @@ -17,6 +17,8 @@ package org.apache.celeborn.service.deploy.memory +import java.util.concurrent.TimeUnit + import scala.concurrent.duration.DurationInt import org.mockito.{Mockito, MockitoSugar} @@ -216,6 +218,47 @@ class MemoryManagerSuite extends CelebornFunSuite { MemoryManager.reset() } + test("[CELEBORN-1792] Test MemoryManager keep resume a while by pinned memory") { + val conf = new CelebornConf() + conf.set(CelebornConf.WORKER_DIRECT_MEMORY_CHECK_INTERVAL.key, "300s") + conf.set(CelebornConf.WORKER_PINNED_MEMORY_CHECK_INTERVAL.key, "1s") + MemoryManager.reset() + val memoryManager = MockitoSugar.spy(MemoryManager.initialize(conf)) + val maxDirectorMemory = memoryManager.maxDirectMemory + val pushThreshold = + (conf.workerDirectMemoryRatioToPauseReceive * maxDirectorMemory).longValue() + val pinnedMemoryResumeThreshold = + (conf.workerPinnedMemoryRatioToResume * maxDirectorMemory).longValue() + val channelsLimiter = new MockChannelsLimiter() + memoryManager.registerMemoryListener(channelsLimiter) + + // NONE PAUSED -> PAUSE PUSH + Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L) + Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1) + memoryManager.switchServingState() + assert(channelsLimiter.isResume) + assert(memoryManager.servingState == ServingState.PUSH_PAUSED) + + // keep pause push, but channels keep resume when pinnedMemory still less than threshold + Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1) + memoryManager.switchServingState() + assert(channelsLimiter.isResume) + assert(memoryManager.servingState == ServingState.PUSH_PAUSED) + + // exit keepResumeByPinnedMemory because pinnedMemory is greater than threshold + Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn( + pinnedMemoryResumeThreshold + 1) + memoryManager.switchServingState() + assert(!channelsLimiter.isResume) + assert(memoryManager.servingState == ServingState.PUSH_PAUSED) + + Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L) + memoryManager.switchServingState() + assert(channelsLimiter.isResume) + assert(memoryManager.servingState == ServingState.NONE_PAUSED) + + } + class MockMemoryPressureListener( val belongModuleName: String, var isPause: Boolean = false) extends MemoryPressureListener { @@ -235,4 +278,19 @@ class MemoryManagerSuite extends CelebornFunSuite { // do nothing } } + + class MockChannelsLimiter(var isResume: Boolean = false) extends MemoryPressureListener { + override def onPause(moduleName: String): Unit = { + isResume = false + } + + override def onResume(moduleName: String): Unit = { + isResume = true + } + + override def onTrim(): Unit = { + // do nothing + } + } + }