Skip to content

Commit

Permalink
Fix bug in get datasource API and improve memory usage
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed May 19, 2023
1 parent 0d65260 commit 3a80586
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@

package org.opensearch.geospatial.ip2geo.action;

import java.util.Collections;
import java.util.List;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.geospatial.annotation.VisibleForTesting;
import org.opensearch.geospatial.ip2geo.common.DatasourceFacade;
import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -57,7 +60,8 @@ private boolean shouldGetAllDatasource(final GetDatasourceRequest request) {
return request.getNames().length == 0 || (request.getNames().length == 1 && "_all".equals(request.getNames()[0]));
}

private ActionListener<List<Datasource>> newActionListener(final ActionListener<GetDatasourceResponse> listener) {
@VisibleForTesting
protected ActionListener<List<Datasource>> newActionListener(final ActionListener<GetDatasourceResponse> listener) {
return new ActionListener<>() {
@Override
public void onResponse(final List<Datasource> datasources) {
Expand All @@ -66,6 +70,10 @@ public void onResponse(final List<Datasource> datasources) {

@Override
public void onFailure(final Exception e) {
if (e instanceof IndexNotFoundException) {
listener.onResponse(new GetDatasourceResponse(Collections.emptyList()));
return;
}
listener.onFailure(e);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
Expand All @@ -35,6 +37,7 @@
import org.opensearch.OpenSearchException;
import org.opensearch.SpecialPermission;
import org.opensearch.action.ActionListener;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
Expand All @@ -51,8 +54,10 @@
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.geospatial.annotation.VisibleForTesting;
import org.opensearch.geospatial.shared.Constants;
import org.opensearch.geospatial.shared.StashedThreadContext;
Expand Down Expand Up @@ -187,7 +192,7 @@ protected CSVParser internalGetDatabaseReader(final DatasourceManifest manifest,
}

/**
* Create a document in json string format to ingest in datasource database index
* Create a document to ingest in datasource database index
*
* It assumes the first field as ip_range. The rest is added under data field.
*
Expand All @@ -204,31 +209,23 @@ protected CSVParser internalGetDatabaseReader(final DatasourceManifest manifest,
* @param fields a list of field name
* @param values a list of values
* @return Document in json string format
* @throws IOException the exception
*/
public String createDocument(final String[] fields, final String[] values) {
public XContentBuilder createDocument(final String[] fields, final String[] values) throws IOException {
if (fields.length != values.length) {
throw new OpenSearchException("header[{}] and record[{}] length does not match", fields, values);
}
StringBuilder sb = new StringBuilder();
sb.append("{\"");
sb.append(IP_RANGE_FIELD_NAME);
sb.append("\":\"");
sb.append(values[0]);
sb.append("\",\"");
sb.append(DATA_FIELD_NAME);
sb.append("\":{");
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.field(IP_RANGE_FIELD_NAME, values[0]);
builder.startObject(DATA_FIELD_NAME);
for (int i = 1; i < fields.length; i++) {
if (i != 1) {
sb.append(",");
}
sb.append("\"");
sb.append(fields[i]);
sb.append("\":\"");
sb.append(values[i]);
sb.append("\"");
builder.field(fields[i], values[i]);
}
sb.append("}}");
return sb.toString();
builder.endObject();
builder.endObject();
builder.close();
return builder;
}

/**
Expand Down Expand Up @@ -368,14 +365,20 @@ public void putGeoIpData(
@NonNull final Iterator<CSVRecord> iterator,
final int bulkSize,
@NonNull final Runnable renewLock
) {
) throws IOException {
TimeValue timeout = clusterSettings.get(Ip2GeoSettings.TIMEOUT);
final BulkRequest bulkRequest = new BulkRequest();
Queue<DocWriteRequest> requests = new LinkedList<>();
for (int i = 0; i < bulkSize; i++) {
requests.add(Requests.indexRequest(indexName));
}
while (iterator.hasNext()) {
CSVRecord record = iterator.next();
String document = createDocument(fields, record.values());
IndexRequest request = Requests.indexRequest(indexName).source(document, XContentType.JSON);
bulkRequest.add(request);
XContentBuilder document = createDocument(fields, record.values());
IndexRequest indexRequest = (IndexRequest) requests.poll();
indexRequest.source(document);
indexRequest.id(record.get(0));
bulkRequest.add(indexRequest);
if (iterator.hasNext() == false || bulkRequest.requests().size() == bulkSize) {
BulkResponse response = StashedThreadContext.run(client, () -> client.bulk(bulkRequest).actionGet(timeout));
if (response.hasFailures()) {
Expand All @@ -385,6 +388,7 @@ public void putGeoIpData(
response.buildFailureMessage()
);
}
requests.addAll(bulkRequest.requests());
bulkRequest.requests().clear();
}
renewLock.run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.geospatial.ip2geo;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -38,6 +39,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Randomness;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.breaker.CircuitBreaker;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
Expand All @@ -52,6 +54,7 @@
import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource;
import org.opensearch.geospatial.ip2geo.jobscheduler.DatasourceUpdateService;
import org.opensearch.geospatial.ip2geo.processor.Ip2GeoProcessor;
import org.opensearch.geospatial.plugin.GeospatialPlugin;
import org.opensearch.ingest.IngestMetadata;
import org.opensearch.ingest.IngestService;
import org.opensearch.jobscheduler.spi.LockModel;
Expand Down Expand Up @@ -101,6 +104,7 @@ public abstract class Ip2GeoTestCase extends RestActionTestCase {

@Before
public void prepareIp2GeoTestCase() {
GeospatialPlugin.circuitBreaker = mock(CircuitBreaker.class);
openMocks = MockitoAnnotations.openMocks(this);
settings = Settings.EMPTY;
client = new NoOpNodeClient(this.getTestName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@

package org.opensearch.geospatial.ip2geo.action;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import javax.swing.*;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.opensearch.action.ActionListener;
import org.opensearch.geospatial.ip2geo.Ip2GeoTestCase;
import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.tasks.Task;

public class GetDatasourceTransportActionTests extends Ip2GeoTestCase {
Expand All @@ -27,7 +31,7 @@ public void init() {
action = new GetDatasourceTransportAction(transportService, actionFilters, datasourceFacade);
}

public void testDoExecute_whenAll_thenSucceed() throws Exception {
public void testDoExecute_whenAll_thenSucceed() {
Task task = mock(Task.class);
GetDatasourceRequest request = new GetDatasourceRequest(new String[] { "_all" });
ActionListener<GetDatasourceResponse> listener = mock(ActionListener.class);
Expand All @@ -36,22 +40,7 @@ public void testDoExecute_whenAll_thenSucceed() throws Exception {
action.doExecute(task, request, listener);

// Verify
ArgumentCaptor<ActionListener<List<Datasource>>> captor = ArgumentCaptor.forClass(ActionListener.class);
verify(datasourceFacade).getAllDatasources(captor.capture());

// Run
List<Datasource> datasources = Arrays.asList(randomDatasource(), randomDatasource());
captor.getValue().onResponse(datasources);

// Verify
verify(listener).onResponse(new GetDatasourceResponse(datasources));

// Run
RuntimeException exception = new RuntimeException();
captor.getValue().onFailure(exception);

// Verify
verify(listener).onFailure(exception);
verify(datasourceFacade).getAllDatasources(any(ActionListener.class));
}

public void testDoExecute_whenNames_thenSucceed() {
Expand All @@ -66,20 +55,37 @@ public void testDoExecute_whenNames_thenSucceed() {
action.doExecute(task, request, listener);

// Verify
ArgumentCaptor<ActionListener<List<Datasource>>> captor = ArgumentCaptor.forClass(ActionListener.class);
verify(datasourceFacade).getDatasources(eq(datasourceNames), captor.capture());
verify(datasourceFacade).getDatasources(eq(datasourceNames), any(ActionListener.class));
}

public void testNewActionListener_whenOnResponse_thenSucceed() {
List<Datasource> datasources = Arrays.asList(randomDatasource(), randomDatasource());
ActionListener<GetDatasourceResponse> actionListener = mock(ActionListener.class);

// Run
action.newActionListener(actionListener).onResponse(datasources);

// Verify
verify(actionListener).onResponse(new GetDatasourceResponse(datasources));
}

public void testNewActionListener_whenOnFailureWithNoSuchIndexException_thenEmptyDatasource() {
ActionListener<GetDatasourceResponse> actionListener = mock(ActionListener.class);

// Run
captor.getValue().onResponse(datasources);
action.newActionListener(actionListener).onFailure(new IndexNotFoundException("no index"));

// Verify
verify(listener).onResponse(new GetDatasourceResponse(datasources));
verify(actionListener).onResponse(new GetDatasourceResponse(Collections.emptyList()));
}

public void testNewActionListener_whenOnFailure_thenFails() {
ActionListener<GetDatasourceResponse> actionListener = mock(ActionListener.class);

// Run
RuntimeException exception = new RuntimeException();
captor.getValue().onFailure(exception);
action.newActionListener(actionListener).onFailure(new RuntimeException());

// Verify
verify(listener).onFailure(exception);
verify(actionListener).onFailure(any(RuntimeException.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.common.Randomness;
import org.opensearch.common.Strings;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.geospatial.GeospatialTestHelper;
Expand Down Expand Up @@ -103,12 +104,13 @@ public void testCreateIndexIfNotExistsWithoutExistingIndex() {
verifyingGeoIpDataFacade.createIndexIfNotExists(index);
}

@SneakyThrows
public void testCreateDocument() {
String[] names = { "ip", "country", "city" };
String[] values = { "1.0.0.0/25", "USA", "Seattle" };
assertEquals(
"{\"_cidr\":\"1.0.0.0/25\",\"_data\":{\"country\":\"USA\",\"city\":\"Seattle\"}}",
noOpsGeoIpDataFacade.createDocument(names, values)
Strings.toString(noOpsGeoIpDataFacade.createDocument(names, values))
);
}

Expand Down

0 comments on commit 3a80586

Please sign in to comment.