Skip to content

Commit

Permalink
[fix][ml] Fix memory leak due to duplicated RangeCache value retain o…
Browse files Browse the repository at this point in the history
…perations (apache#23955)

Co-authored-by: Lari Hotari <[email protected]>
  • Loading branch information
BewareMyPower and lhotari authored Feb 10, 2025
1 parent 215b36d commit 20b3b22
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,19 @@ private Value getValueFromWrapper(Key key, EntryWrapper<Key, Value> valueWrapper
}
}

/**
* @apiNote the returned value must be released if it's not null
*/
private Value getValueMatchingEntry(Map.Entry<Key, EntryWrapper<Key, Value>> entry) {
Value valueMatchingEntry = EntryWrapper.getValueMatchingMapEntry(entry);
return getRetainedValueMatchingKey(entry.getKey(), valueMatchingEntry);
}

// validates that the value matches the key and that the value has not been recycled
// which are possible due to the lack of exclusive locks in the cache and the use of reference counted objects
/**
* @apiNote the returned value must be released if it's not null
*/
private Value getRetainedValueMatchingKey(Key key, Value value) {
if (value == null) {
// the wrapper has been recycled and contains another key
Expand Down Expand Up @@ -350,7 +356,7 @@ public Pair<Integer, Long> removeRange(Key first, Key last, boolean lastInclusiv
RemovalCounters counters = RemovalCounters.create();
Map<Key, EntryWrapper<Key, Value>> subMap = entries.subMap(first, true, last, lastInclusive);
for (Map.Entry<Key, EntryWrapper<Key, Value>> entry : subMap.entrySet()) {
removeEntry(entry, counters, true);
removeEntry(entry, counters);
}
return handleRemovalResult(counters);
}
Expand All @@ -361,84 +367,48 @@ enum RemoveEntryResult {
BREAK_LOOP;
}

private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, Value>> entry, RemovalCounters counters,
boolean skipInvalid) {
return removeEntry(entry, counters, skipInvalid, x -> true);
private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, Value>> entry, RemovalCounters counters) {
return removeEntry(entry, counters, x -> true);
}

private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, Value>> entry, RemovalCounters counters,
boolean skipInvalid, Predicate<Value> removeCondition) {
Predicate<Value> removeCondition) {
Key key = entry.getKey();
EntryWrapper<Key, Value> entryWrapper = entry.getValue();
Value value = getValueMatchingEntry(entry);
if (value == null) {
// the wrapper has already been recycled and contains another key
if (!skipInvalid) {
EntryWrapper<Key, Value> removed = entries.remove(key);
if (removed != null) {
// log and remove the entry without releasing the value
log.info("Key {} does not match the entry's value wrapper's key {}, removed entry by key without "
+ "releasing the value", key, entryWrapper.getKey());
counters.entryRemoved(removed.getSize());
return RemoveEntryResult.ENTRY_REMOVED;
}
}
return RemoveEntryResult.CONTINUE_LOOP;
}
try {
// add extra retain to avoid value being released while we are removing it
value.retain();
} catch (IllegalReferenceCountException e) {
// Value was already released
if (!skipInvalid) {
// remove the specific entry without releasing the value
if (entries.remove(key, entryWrapper)) {
log.info("Value was already released for key {}, removed entry without releasing the value", key);
counters.entryRemoved(entryWrapper.getSize());
return RemoveEntryResult.ENTRY_REMOVED;
}
}
// the wrapper has already been recycled or contains another key
entries.remove(key, entryWrapper);
return RemoveEntryResult.CONTINUE_LOOP;
}
if (!value.matchesKey(key)) {
// this is unexpected since the IdentityWrapper.getValue(key) already checked that the value matches the key
log.warn("Unexpected race condition. Value {} does not match the key {}. Removing entry.", value, key);
}
try {
if (!removeCondition.test(value)) {
return RemoveEntryResult.BREAK_LOOP;
}
if (!skipInvalid) {
// remove the specific entry
boolean entryRemoved = entries.remove(key, entryWrapper);
if (entryRemoved) {
counters.entryRemoved(entryWrapper.getSize());
// check that the value hasn't been recycled in between
// there should be at least 2 references since this method adds one and the cache should have
// one reference. it is valid that the value contains references even after the key has been
// removed from the cache
if (value.refCnt() > 1) {
entryWrapper.recycle();
// remove the cache reference
value.release();
} else {
log.info("Unexpected refCnt {} for key {}, removed entry without releasing the value",
value.refCnt(), key);
}
}
} else if (skipInvalid && value.refCnt() > 1 && entries.remove(key, entryWrapper)) {
// when skipInvalid is true, we don't remove the entry if it doesn't match matches the key
// or the refCnt is invalid
// remove the specific entry
boolean entryRemoved = entries.remove(key, entryWrapper);
if (entryRemoved) {
counters.entryRemoved(entryWrapper.getSize());
entryWrapper.recycle();
// remove the cache reference
value.release();
// check that the value hasn't been recycled in between
// there should be at least 2 references since this method adds one and the cache should have
// one reference. it is valid that the value contains references even after the key has been
// removed from the cache
if (value.refCnt() > 1) {
entryWrapper.recycle();
// remove the cache reference
value.release();
} else {
log.info("Unexpected refCnt {} for key {}, removed entry without releasing the value",
value.refCnt(), key);
}
return RemoveEntryResult.ENTRY_REMOVED;
} else {
return RemoveEntryResult.CONTINUE_LOOP;
}
} finally {
// remove the extra retain
value.release();
}
return RemoveEntryResult.ENTRY_REMOVED;
}

private Pair<Integer, Long> handleRemovalResult(RemovalCounters counters) {
Expand All @@ -464,7 +434,7 @@ public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
if (entry == null) {
break;
}
removeEntry(entry, counters, false);
removeEntry(entry, counters);
}
return handleRemovalResult(counters);
}
Expand All @@ -484,7 +454,7 @@ public Pair<Integer, Long> evictLEntriesBeforeTimestamp(long maxTimestamp) {
if (entry == null) {
break;
}
if (removeEntry(entry, counters, false, value -> timestampExtractor.getTimestamp(value) <= maxTimestamp)
if (removeEntry(entry, counters, value -> timestampExtractor.getTimestamp(value) <= maxTimestamp)
== RemoveEntryResult.BREAK_LOOP) {
break;
}
Expand Down Expand Up @@ -518,7 +488,7 @@ public Pair<Integer, Long> clear() {
if (entry == null) {
break;
}
removeEntry(entry, counters, false);
removeEntry(entry, counters);
}
return handleRemovalResult(counters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
import com.google.common.collect.Lists;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCounted;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import lombok.Cleanup;
import lombok.Data;
import org.apache.commons.lang3.tuple.Pair;
import org.awaitility.Awaitility;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

public class RangeCacheTest {
Expand Down Expand Up @@ -140,9 +143,14 @@ public void customWeighter() {
assertEquals(cache.getNumberOfEntries(), 2);
}

@DataProvider
public static Object[][] retainBeforeEviction() {
return new Object[][]{ { true }, { false } };
}

@Test
public void customTimeExtraction() {

@Test(dataProvider = "retainBeforeEviction")
public void customTimeExtraction(boolean retain) {
RangeCache<Integer, RefString> cache = new RangeCache<>(value -> value.s.length(), x -> x.s.length());

cache.put(1, new RefString("1"));
Expand All @@ -152,13 +160,30 @@ public void customTimeExtraction() {

assertEquals(cache.getSize(), 10);
assertEquals(cache.getNumberOfEntries(), 4);
final var retainedEntries = cache.getRange(1, 4444);
for (final var entry : retainedEntries) {
assertEquals(entry.refCnt(), 2);
if (!retain) {
entry.release();
}
}

Pair<Integer, Long> evictedSize = cache.evictLEntriesBeforeTimestamp(3);
assertEquals(evictedSize.getRight().longValue(), 6);
assertEquals(evictedSize.getLeft().longValue(), 3);

assertEquals(cache.getSize(), 4);
assertEquals(cache.getNumberOfEntries(), 1);

if (retain) {
final var valueToRefCnt = retainedEntries.stream().collect(Collectors.toMap(RefString::getS,
AbstractReferenceCounted::refCnt));
assertEquals(valueToRefCnt, Map.of("1", 1, "22", 1, "333", 1, "4444", 2));
retainedEntries.forEach(AbstractReferenceCounted::release);
} else {
final var valueToRefCnt = retainedEntries.stream().filter(v -> v.refCnt() > 0).collect(Collectors.toMap(
RefString::getS, AbstractReferenceCounted::refCnt));
assertEquals(valueToRefCnt, Map.of("4444", 1));
}
}

@Test
Expand Down Expand Up @@ -355,4 +380,4 @@ public void testGetKeyWithDifferentInstance() {
// the value should be found
assertEquals(s.s, "129");
}
}
}

0 comments on commit 20b3b22

Please sign in to comment.