diff --git a/src/main/java/org/opensearch/geospatial/ip2geo/dao/Ip2GeoCachedDao.java b/src/main/java/org/opensearch/geospatial/ip2geo/dao/Ip2GeoCachedDao.java index cd645cd1..23e98279 100644 --- a/src/main/java/org/opensearch/geospatial/ip2geo/dao/Ip2GeoCachedDao.java +++ b/src/main/java/org/opensearch/geospatial/ip2geo/dao/Ip2GeoCachedDao.java @@ -7,16 +7,24 @@ import java.io.IOException; import java.time.Instant; +import java.util.Iterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.extern.log4j.Log4j2; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.CacheBuilder; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.geospatial.annotation.VisibleForTesting; import org.opensearch.geospatial.ip2geo.common.DatasourceState; import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource; import org.opensearch.index.engine.Engine; @@ -119,4 +127,65 @@ public DatasourceMetadata(final Datasource datasource) { this.state = datasource.getState(); } } + + /** + * Cache to hold geo data + * + * GeoData in an index in immutable. Therefore, invalidation is not needed. + */ + @VisibleForTesting + protected static class GeoDataCache { + private Cache> cache; + + public GeoDataCache(final long maxSize) { + if (maxSize < 0) { + throw new IllegalArgumentException("ip2geo max cache size must be 0 or greater"); + } + this.cache = CacheBuilder.>builder().setMaximumWeight(maxSize).build(); + } + + public Map putIfAbsent( + final String indexName, + final String ip, + final Function> retrieveFunction + ) throws ExecutionException { + CacheKey cacheKey = new CacheKey(indexName, ip); + return cache.computeIfAbsent(cacheKey, key -> retrieveFunction.apply(key.ip)); + } + + public Map get(final String indexName, final String ip) { + return cache.get(new CacheKey(indexName, ip)); + } + + /** + * Create a new cache with give size and replace existing cache + * + * Try to populate the existing value from previous cache to the new cache in best effort + * + * @param maxSize + */ + public void updateMaxSize(final long maxSize) { + if (maxSize < 0) { + throw new IllegalArgumentException("ip2geo max cache size must be 0 or greater"); + } + Cache> temp = CacheBuilder.>builder() + .setMaximumWeight(maxSize) + .build(); + int count = 0; + Iterator it = cache.keys().iterator(); + while (it.hasNext() && count < maxSize) { + CacheKey key = it.next(); + temp.put(key, cache.get(key)); + count++; + } + cache = temp; + } + + @AllArgsConstructor + @EqualsAndHashCode + private static class CacheKey { + private final String indexName; + private final String ip; + } + } } diff --git a/src/test/java/org/opensearch/geospatial/ip2geo/dao/Ip2GeoCachedDaoTests.java b/src/test/java/org/opensearch/geospatial/ip2geo/dao/Ip2GeoCachedDaoTests.java index 5aba0f62..4906fba9 100644 --- a/src/test/java/org/opensearch/geospatial/ip2geo/dao/Ip2GeoCachedDaoTests.java +++ b/src/test/java/org/opensearch/geospatial/ip2geo/dao/Ip2GeoCachedDaoTests.java @@ -9,12 +9,16 @@ import static org.mockito.Mockito.when; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import lombok.SneakyThrows; import org.junit.Before; import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.network.NetworkAddress; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.geospatial.GeospatialTestHelper; import org.opensearch.geospatial.ip2geo.Ip2GeoTestCase; @@ -178,4 +182,59 @@ public void testPostDelete_whenSucceed_thenUpdate() { // Verify assertFalse(ip2GeoCachedDao.has(datasource.getName())); } + + @SneakyThrows + public void testUpdateMaxSize_whenBiggerSize_thenContainsAllData() { + int cacheSize = 10; + String datasource = GeospatialTestHelper.randomLowerCaseString(); + Ip2GeoCachedDao.GeoDataCache geoDataCache = new Ip2GeoCachedDao.GeoDataCache(cacheSize); + List ips = new ArrayList<>(cacheSize); + for (int i = 0; i < cacheSize; i++) { + String ip = NetworkAddress.format(randomIp(false)); + ips.add(ip); + geoDataCache.putIfAbsent(datasource, ip, addr -> Collections.emptyMap()); + } + + // Verify all data exist in the cache + assertTrue(ips.stream().allMatch(ip -> geoDataCache.get(datasource, ip) != null)); + + // Update cache size + int newCacheSize = 15; + geoDataCache.updateMaxSize(newCacheSize); + + // Verify all data exist in the cache + assertTrue(ips.stream().allMatch(ip -> geoDataCache.get(datasource, ip) != null)); + + // Add (newCacheSize - cacheSize + 1) data and the first data should not be available in the cache + for (int i = 0; i < newCacheSize - cacheSize + 1; i++) { + geoDataCache.putIfAbsent(datasource, NetworkAddress.format(randomIp(false)), addr -> Collections.emptyMap()); + } + assertNull(geoDataCache.get(datasource, ips.get(0))); + } + + @SneakyThrows + public void testUpdateMaxSize_whenSmallerSize_thenContainsPartialData() { + int cacheSize = 10; + String datasource = GeospatialTestHelper.randomLowerCaseString(); + Ip2GeoCachedDao.GeoDataCache geoDataCache = new Ip2GeoCachedDao.GeoDataCache(cacheSize); + List ips = new ArrayList<>(cacheSize); + for (int i = 0; i < cacheSize; i++) { + String ip = NetworkAddress.format(randomIp(false)); + ips.add(ip); + geoDataCache.putIfAbsent(datasource, ip, addr -> Collections.emptyMap()); + } + + // Verify all data exist in the cache + assertTrue(ips.stream().allMatch(ip -> geoDataCache.get(datasource, ip) != null)); + + // Update cache size + int newCacheSize = 5; + geoDataCache.updateMaxSize(newCacheSize); + + // Verify the last (cacheSize - newCacheSize) data is available in the cache + List deleted = ips.subList(0, ips.size() - newCacheSize); + List retained = ips.subList(ips.size() - newCacheSize, ips.size()); + assertTrue(deleted.stream().allMatch(ip -> geoDataCache.get(datasource, ip) == null)); + assertTrue(retained.stream().allMatch(ip -> geoDataCache.get(datasource, ip) != null)); + } }