diff --git a/src/main/java/org/opensearch/geospatial/ip2geo/action/GetDatasourceTransportAction.java b/src/main/java/org/opensearch/geospatial/ip2geo/action/GetDatasourceTransportAction.java index 905d8165..6bac685e 100644 --- a/src/main/java/org/opensearch/geospatial/ip2geo/action/GetDatasourceTransportAction.java +++ b/src/main/java/org/opensearch/geospatial/ip2geo/action/GetDatasourceTransportAction.java @@ -5,6 +5,7 @@ package org.opensearch.geospatial.ip2geo.action; +import java.util.Collections; import java.util.List; import org.opensearch.OpenSearchException; @@ -12,8 +13,10 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; +import org.opensearch.geospatial.annotation.VisibleForTesting; import org.opensearch.geospatial.ip2geo.common.DatasourceFacade; import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -57,7 +60,8 @@ private boolean shouldGetAllDatasource(final GetDatasourceRequest request) { return request.getNames().length == 0 || (request.getNames().length == 1 && "_all".equals(request.getNames()[0])); } - private ActionListener> newActionListener(final ActionListener listener) { + @VisibleForTesting + protected ActionListener> newActionListener(final ActionListener listener) { return new ActionListener<>() { @Override public void onResponse(final List datasources) { @@ -66,6 +70,10 @@ public void onResponse(final List datasources) { @Override public void onFailure(final Exception e) { + if (e instanceof IndexNotFoundException) { + listener.onResponse(new GetDatasourceResponse(Collections.emptyList())); + return; + } listener.onFailure(e); } }; diff --git a/src/main/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacade.java b/src/main/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacade.java index a9ffa43d..4b6d9fe9 100644 --- a/src/main/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacade.java +++ b/src/main/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacade.java @@ -20,8 +20,10 @@ import java.util.Collections; import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.stream.Collectors; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; @@ -35,6 +37,7 @@ import org.opensearch.OpenSearchException; import org.opensearch.SpecialPermission; import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteRequest; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; @@ -51,8 +54,10 @@ import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.geospatial.annotation.VisibleForTesting; import org.opensearch.geospatial.shared.Constants; import org.opensearch.geospatial.shared.StashedThreadContext; @@ -187,7 +192,7 @@ protected CSVParser internalGetDatabaseReader(final DatasourceManifest manifest, } /** - * Create a document in json string format to ingest in datasource database index + * Create a document to ingest in datasource database index * * It assumes the first field as ip_range. The rest is added under data field. * @@ -204,31 +209,23 @@ protected CSVParser internalGetDatabaseReader(final DatasourceManifest manifest, * @param fields a list of field name * @param values a list of values * @return Document in json string format + * @throws IOException the exception */ - public String createDocument(final String[] fields, final String[] values) { + public XContentBuilder createDocument(final String[] fields, final String[] values) throws IOException { if (fields.length != values.length) { throw new OpenSearchException("header[{}] and record[{}] length does not match", fields, values); } - StringBuilder sb = new StringBuilder(); - sb.append("{\""); - sb.append(IP_RANGE_FIELD_NAME); - sb.append("\":\""); - sb.append(values[0]); - sb.append("\",\""); - sb.append(DATA_FIELD_NAME); - sb.append("\":{"); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field(IP_RANGE_FIELD_NAME, values[0]); + builder.startObject(DATA_FIELD_NAME); for (int i = 1; i < fields.length; i++) { - if (i != 1) { - sb.append(","); - } - sb.append("\""); - sb.append(fields[i]); - sb.append("\":\""); - sb.append(values[i]); - sb.append("\""); + builder.field(fields[i], values[i]); } - sb.append("}}"); - return sb.toString(); + builder.endObject(); + builder.endObject(); + builder.close(); + return builder; } /** @@ -368,14 +365,20 @@ public void putGeoIpData( @NonNull final Iterator iterator, final int bulkSize, @NonNull final Runnable renewLock - ) { + ) throws IOException { TimeValue timeout = clusterSettings.get(Ip2GeoSettings.TIMEOUT); final BulkRequest bulkRequest = new BulkRequest(); + Queue requests = new LinkedList<>(); + for (int i = 0; i < bulkSize; i++) { + requests.add(Requests.indexRequest(indexName)); + } while (iterator.hasNext()) { CSVRecord record = iterator.next(); - String document = createDocument(fields, record.values()); - IndexRequest request = Requests.indexRequest(indexName).source(document, XContentType.JSON); - bulkRequest.add(request); + XContentBuilder document = createDocument(fields, record.values()); + IndexRequest indexRequest = (IndexRequest) requests.poll(); + indexRequest.source(document); + indexRequest.id(record.get(0)); + bulkRequest.add(indexRequest); if (iterator.hasNext() == false || bulkRequest.requests().size() == bulkSize) { BulkResponse response = StashedThreadContext.run(client, () -> client.bulk(bulkRequest).actionGet(timeout)); if (response.hasFailures()) { @@ -385,6 +388,7 @@ public void putGeoIpData( response.buildFailureMessage() ); } + requests.addAll(bulkRequest.requests()); bulkRequest.requests().clear(); } renewLock.run(); diff --git a/src/test/java/org/opensearch/geospatial/ip2geo/Ip2GeoTestCase.java b/src/test/java/org/opensearch/geospatial/ip2geo/Ip2GeoTestCase.java index 931dfdc1..c57e36f9 100644 --- a/src/test/java/org/opensearch/geospatial/ip2geo/Ip2GeoTestCase.java +++ b/src/test/java/org/opensearch/geospatial/ip2geo/Ip2GeoTestCase.java @@ -5,6 +5,7 @@ package org.opensearch.geospatial.ip2geo; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -38,6 +39,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Randomness; import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.breaker.CircuitBreaker; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.OpenSearchExecutors; @@ -52,6 +54,7 @@ import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource; import org.opensearch.geospatial.ip2geo.jobscheduler.DatasourceUpdateService; import org.opensearch.geospatial.ip2geo.processor.Ip2GeoProcessor; +import org.opensearch.geospatial.plugin.GeospatialPlugin; import org.opensearch.ingest.IngestMetadata; import org.opensearch.ingest.IngestService; import org.opensearch.jobscheduler.spi.LockModel; @@ -101,6 +104,7 @@ public abstract class Ip2GeoTestCase extends RestActionTestCase { @Before public void prepareIp2GeoTestCase() { + GeospatialPlugin.circuitBreaker = mock(CircuitBreaker.class); openMocks = MockitoAnnotations.openMocks(this); settings = Settings.EMPTY; client = new NoOpNodeClient(this.getTestName()); diff --git a/src/test/java/org/opensearch/geospatial/ip2geo/action/GetDatasourceTransportActionTests.java b/src/test/java/org/opensearch/geospatial/ip2geo/action/GetDatasourceTransportActionTests.java index 3dd4ddf5..f1dbd3fc 100644 --- a/src/test/java/org/opensearch/geospatial/ip2geo/action/GetDatasourceTransportActionTests.java +++ b/src/test/java/org/opensearch/geospatial/ip2geo/action/GetDatasourceTransportActionTests.java @@ -5,18 +5,22 @@ package org.opensearch.geospatial.ip2geo.action; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import javax.swing.*; + import org.junit.Before; -import org.mockito.ArgumentCaptor; import org.opensearch.action.ActionListener; import org.opensearch.geospatial.ip2geo.Ip2GeoTestCase; import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.tasks.Task; public class GetDatasourceTransportActionTests extends Ip2GeoTestCase { @@ -27,7 +31,7 @@ public void init() { action = new GetDatasourceTransportAction(transportService, actionFilters, datasourceFacade); } - public void testDoExecute_whenAll_thenSucceed() throws Exception { + public void testDoExecute_whenAll_thenSucceed() { Task task = mock(Task.class); GetDatasourceRequest request = new GetDatasourceRequest(new String[] { "_all" }); ActionListener listener = mock(ActionListener.class); @@ -36,22 +40,7 @@ public void testDoExecute_whenAll_thenSucceed() throws Exception { action.doExecute(task, request, listener); // Verify - ArgumentCaptor>> captor = ArgumentCaptor.forClass(ActionListener.class); - verify(datasourceFacade).getAllDatasources(captor.capture()); - - // Run - List datasources = Arrays.asList(randomDatasource(), randomDatasource()); - captor.getValue().onResponse(datasources); - - // Verify - verify(listener).onResponse(new GetDatasourceResponse(datasources)); - - // Run - RuntimeException exception = new RuntimeException(); - captor.getValue().onFailure(exception); - - // Verify - verify(listener).onFailure(exception); + verify(datasourceFacade).getAllDatasources(any(ActionListener.class)); } public void testDoExecute_whenNames_thenSucceed() { @@ -66,20 +55,37 @@ public void testDoExecute_whenNames_thenSucceed() { action.doExecute(task, request, listener); // Verify - ArgumentCaptor>> captor = ArgumentCaptor.forClass(ActionListener.class); - verify(datasourceFacade).getDatasources(eq(datasourceNames), captor.capture()); + verify(datasourceFacade).getDatasources(eq(datasourceNames), any(ActionListener.class)); + } + + public void testNewActionListener_whenOnResponse_thenSucceed() { + List datasources = Arrays.asList(randomDatasource(), randomDatasource()); + ActionListener actionListener = mock(ActionListener.class); + + // Run + action.newActionListener(actionListener).onResponse(datasources); + + // Verify + verify(actionListener).onResponse(new GetDatasourceResponse(datasources)); + } + + public void testNewActionListener_whenOnFailureWithNoSuchIndexException_thenEmptyDatasource() { + ActionListener actionListener = mock(ActionListener.class); // Run - captor.getValue().onResponse(datasources); + action.newActionListener(actionListener).onFailure(new IndexNotFoundException("no index")); // Verify - verify(listener).onResponse(new GetDatasourceResponse(datasources)); + verify(actionListener).onResponse(new GetDatasourceResponse(Collections.emptyList())); + } + + public void testNewActionListener_whenOnFailure_thenFails() { + ActionListener actionListener = mock(ActionListener.class); // Run - RuntimeException exception = new RuntimeException(); - captor.getValue().onFailure(exception); + action.newActionListener(actionListener).onFailure(new RuntimeException()); // Verify - verify(listener).onFailure(exception); + verify(actionListener).onFailure(any(RuntimeException.class)); } } diff --git a/src/test/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacadeTests.java b/src/test/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacadeTests.java index b11851a9..cd4c9bad 100644 --- a/src/test/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacadeTests.java +++ b/src/test/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacadeTests.java @@ -52,6 +52,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.IndicesOptions; import org.opensearch.common.Randomness; +import org.opensearch.common.Strings; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.bytes.BytesReference; import org.opensearch.geospatial.GeospatialTestHelper; @@ -103,12 +104,13 @@ public void testCreateIndexIfNotExistsWithoutExistingIndex() { verifyingGeoIpDataFacade.createIndexIfNotExists(index); } + @SneakyThrows public void testCreateDocument() { String[] names = { "ip", "country", "city" }; String[] values = { "1.0.0.0/25", "USA", "Seattle" }; assertEquals( "{\"_cidr\":\"1.0.0.0/25\",\"_data\":{\"country\":\"USA\",\"city\":\"Seattle\"}}", - noOpsGeoIpDataFacade.createDocument(names, values) + Strings.toString(noOpsGeoIpDataFacade.createDocument(names, values)) ); }