Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1792][FOLLOWUP] Keep resume for a while after resumeByPinnedMemory #3099

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ public void switchServingState() {
case PUSH_PAUSED:
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else if (keepResumeByPinnedMemory(lastState)) {
// do nothing, keep resume for a while
logger.info(
"keep resume for {}ms after last resumeByPinnedMemory",
System.currentTimeMillis() - pinnedMemoryLastCheckTime);
} else {
pausePushDataCounter.increment();
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
Expand All @@ -363,6 +368,11 @@ public void switchServingState() {
case PUSH_AND_REPLICATE_PAUSED:
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else if (keepResumeByPinnedMemory(lastState)) {
// do nothing, keep resume for a while
logger.info(
"keep resume for {}ms after last resumeByPinnedMemory",
System.currentTimeMillis() - pinnedMemoryLastCheckTime);
} else {
pausePushDataAndReplicateCounter.increment();
logger.info("Trigger action: PAUSE PUSH");
Expand Down Expand Up @@ -610,6 +620,14 @@ && getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) {
}
}

private boolean keepResumeByPinnedMemory(ServingState lastState) {
return pinnedMemoryCheckEnabled
&& (lastState == ServingState.PUSH_PAUSED
|| lastState == ServingState.PUSH_AND_REPLICATE_PAUSED)
&& getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio
&& System.currentTimeMillis() - pinnedMemoryLastCheckTime < pinnedMemoryCheckInterval;
}

private void resumePush() {
logger.info("Trigger action: RESUME PUSH");
memoryPressureListeners.forEach(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.celeborn.service.deploy.memory

import java.util.concurrent.TimeUnit

import scala.concurrent.duration.DurationInt

import org.mockito.{Mockito, MockitoSugar}
Expand Down Expand Up @@ -216,6 +218,47 @@ class MemoryManagerSuite extends CelebornFunSuite {
MemoryManager.reset()
}

test("[CELEBORN-1792] Test MemoryManager keep resume a while by pinned memory") {
val conf = new CelebornConf()
conf.set(CelebornConf.WORKER_DIRECT_MEMORY_CHECK_INTERVAL.key, "300s")
conf.set(CelebornConf.WORKER_PINNED_MEMORY_CHECK_INTERVAL.key, "1s")
MemoryManager.reset()
val memoryManager = MockitoSugar.spy(MemoryManager.initialize(conf))
val maxDirectorMemory = memoryManager.maxDirectMemory
val pushThreshold =
(conf.workerDirectMemoryRatioToPauseReceive * maxDirectorMemory).longValue()
val pinnedMemoryResumeThreshold =
(conf.workerPinnedMemoryRatioToResume * maxDirectorMemory).longValue()
val channelsLimiter = new MockChannelsLimiter()
memoryManager.registerMemoryListener(channelsLimiter)

// NONE PAUSED -> PAUSE PUSH
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L)
Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1)
memoryManager.switchServingState()
assert(channelsLimiter.isResume)
assert(memoryManager.servingState == ServingState.PUSH_PAUSED)

// keep pause push, but channels keep resume when pinnedMemory still less than threshold
Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1)
memoryManager.switchServingState()
assert(channelsLimiter.isResume)
assert(memoryManager.servingState == ServingState.PUSH_PAUSED)

// exit keepResumeByPinnedMemory because pinnedMemory is greater than threshold
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(
pinnedMemoryResumeThreshold + 1)
memoryManager.switchServingState()
assert(!channelsLimiter.isResume)
assert(memoryManager.servingState == ServingState.PUSH_PAUSED)

Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L)
memoryManager.switchServingState()
assert(channelsLimiter.isResume)
assert(memoryManager.servingState == ServingState.NONE_PAUSED)

}

class MockMemoryPressureListener(
val belongModuleName: String,
var isPause: Boolean = false) extends MemoryPressureListener {
Expand All @@ -235,4 +278,19 @@ class MemoryManagerSuite extends CelebornFunSuite {
// do nothing
}
}

class MockChannelsLimiter(var isResume: Boolean = false) extends MemoryPressureListener {
override def onPause(moduleName: String): Unit = {
isResume = false
}

override def onResume(moduleName: String): Unit = {
isResume = true
}

override def onTrim(): Unit = {
// do nothing
}
}

}
Loading