Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1792][FOLLOWUP] Keep resume for a while after resumeByPinnedMemory #3099

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public class MemoryManager {
private boolean pinnedMemoryCheckEnabled;
private long pinnedMemoryCheckInterval;
private long pinnedMemoryLastCheckTime = 0;
private boolean resumingByPinnedMemory = false;

@VisibleForTesting
public static MemoryManager initialize(CelebornConf conf) {
Expand Down Expand Up @@ -336,23 +337,23 @@ public void switchServingState() {
}
switch (servingState) {
case PUSH_PAUSED:
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else {
if (!tryResumeByPinnedMemory(servingState, lastState)) {
pausePushDataCounter.increment();
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
resumeReplicate();
} else {
logger.info("Trigger action: PAUSE PUSH");
pausePushDataStartTime = System.currentTimeMillis();
resumingByPinnedMemory = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resumingByPinnedMemory need also change to false when servingState changes to NONE_PAUSED state

memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
// trimCounter cannot be increased when channels resume by PinnedMemory, otherwise
// PauseSpentTime will be increased unexpectedly
trimCounter += 1;
}
}
logger.debug("Trigger action: TRIM");
trimCounter += 1;
trimAllListeners();
if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
logger.debug(
Expand All @@ -361,22 +362,21 @@ public void switchServingState() {
}
break;
case PUSH_AND_REPLICATE_PAUSED:
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else {
if (!tryResumeByPinnedMemory(servingState, lastState)) {
pausePushDataAndReplicateCounter.increment();
logger.info("Trigger action: PAUSE PUSH");
pausePushDataAndReplicateStartTime = System.currentTimeMillis();
resumingByPinnedMemory = false;
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
logger.info("Trigger action: PAUSE REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
trimCounter += 1;
}
logger.debug("Trigger action: TRIM");
trimCounter += 1;
trimAllListeners();
if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
logger.debug(
Expand All @@ -386,6 +386,7 @@ public void switchServingState() {
break;
case NONE_PAUSED:
// resume from paused mode, append pause spent time
resumingByPinnedMemory = false;
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
resumeReplicate();
resumePush();
Expand Down Expand Up @@ -599,15 +600,32 @@ private void resumeByPinnedMemory(ServingState servingState) {
}
}

private boolean canResumeByPinnedMemory() {
if (pinnedMemoryCheckEnabled
&& System.currentTimeMillis() - pinnedMemoryLastCheckTime >= pinnedMemoryCheckInterval
&& getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) {
pinnedMemoryLastCheckTime = System.currentTimeMillis();
return true;
} else {
private boolean tryResumeByPinnedMemory(ServingState currentState, ServingState lastState) {
boolean success = false;
if (!pinnedMemoryCheckEnabled) {
return false;
}
long currentTime = System.currentTimeMillis();
if (currentTime - pinnedMemoryLastCheckTime >= pinnedMemoryCheckInterval) {
if (getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) {
pinnedMemoryLastCheckTime = currentTime;
resumingByPinnedMemory = true;
resumeByPinnedMemory(currentState);
success = true;
}
} else {
if (resumingByPinnedMemory
&& lastState != ServingState.NONE_PAUSED
&& getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) {
// do nothing, keep resume for a while
logger.info(
"currentState: {}, keep resume for {}ms after last resumeByPinnedMemory",
currentState,
currentTime - pinnedMemoryLastCheckTime);
success = true;
}
}
return success;
}

private void resumePush() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.celeborn.service.deploy.memory

import java.util.concurrent.TimeUnit

import scala.concurrent.duration.DurationInt

import org.mockito.{Mockito, MockitoSugar}
Expand Down Expand Up @@ -216,6 +218,47 @@ class MemoryManagerSuite extends CelebornFunSuite {
MemoryManager.reset()
}

test("[CELEBORN-1792] Test MemoryManager keep resume a while by pinned memory") {
val conf = new CelebornConf()
conf.set(CelebornConf.WORKER_DIRECT_MEMORY_CHECK_INTERVAL.key, "300s")
conf.set(CelebornConf.WORKER_PINNED_MEMORY_CHECK_INTERVAL.key, "1s")
MemoryManager.reset()
val memoryManager = MockitoSugar.spy(MemoryManager.initialize(conf))
val maxDirectorMemory = memoryManager.maxDirectMemory
val pushThreshold =
(conf.workerDirectMemoryRatioToPauseReceive * maxDirectorMemory).longValue()
val pinnedMemoryResumeThreshold =
(conf.workerPinnedMemoryRatioToResume * maxDirectorMemory).longValue()
val channelsLimiter = new MockChannelsLimiter()
memoryManager.registerMemoryListener(channelsLimiter)

// NONE PAUSED -> PAUSE PUSH
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L)
Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1)
memoryManager.switchServingState()
assert(channelsLimiter.isResume)
assert(memoryManager.servingState == ServingState.PUSH_PAUSED)

// keep pause push, but channels keep resume when pinnedMemory still less than threshold
Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1)
memoryManager.switchServingState()
assert(channelsLimiter.isResume)
assert(memoryManager.servingState == ServingState.PUSH_PAUSED)

// exit keepResumeByPinnedMemory because pinnedMemory is greater than threshold
Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(
pinnedMemoryResumeThreshold + 1)
memoryManager.switchServingState()
assert(!channelsLimiter.isResume)
assert(memoryManager.servingState == ServingState.PUSH_PAUSED)

Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L)
memoryManager.switchServingState()
assert(channelsLimiter.isResume)
assert(memoryManager.servingState == ServingState.NONE_PAUSED)

}

class MockMemoryPressureListener(
val belongModuleName: String,
var isPause: Boolean = false) extends MemoryPressureListener {
Expand All @@ -235,4 +278,19 @@ class MemoryManagerSuite extends CelebornFunSuite {
// do nothing
}
}

class MockChannelsLimiter(var isResume: Boolean = false) extends MemoryPressureListener {
override def onPause(moduleName: String): Unit = {
isResume = false
}

override def onResume(moduleName: String): Unit = {
isResume = true
}

override def onTrim(): Unit = {
// do nothing
}
}

}
Loading