diff --git a/src/main/java/org/opensearch/geospatial/ip2geo/common/DatasourceFacade.java b/src/main/java/org/opensearch/geospatial/ip2geo/common/DatasourceFacade.java index 71d8122b..20e19bba 100644 --- a/src/main/java/org/opensearch/geospatial/ip2geo/common/DatasourceFacade.java +++ b/src/main/java/org/opensearch/geospatial/ip2geo/common/DatasourceFacade.java @@ -52,6 +52,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource; import org.opensearch.geospatial.ip2geo.jobscheduler.DatasourceExtension; +import org.opensearch.geospatial.shared.StashedThreadContext; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; @@ -92,7 +93,7 @@ public void createIndexIfNotExists(final StepListener stepListener) { indexSettings.put(INDEX_SETTING_HIDDEN.v1(), INDEX_SETTING_HIDDEN.v2()); final CreateIndexRequest createIndexRequest = new CreateIndexRequest(DatasourceExtension.JOB_INDEX_NAME).mapping(getIndexMapping()) .settings(indexSettings); - client.admin().indices().create(createIndexRequest, new ActionListener<>() { + StashedThreadContext.run(client, () -> client.admin().indices().create(createIndexRequest, new ActionListener<>() { @Override public void onResponse(final CreateIndexResponse createIndexResponse) { stepListener.onResponse(null); @@ -107,7 +108,7 @@ public void onFailure(final Exception e) { } stepListener.onFailure(e); } - }); + })); } private String getIndexMapping() { @@ -126,17 +127,22 @@ private String getIndexMapping() { * Update datasource in an index {@code DatasourceExtension.JOB_INDEX_NAME} * @param datasource the datasource * @return index response - * @throws IOException exception */ - public IndexResponse updateDatasource(final Datasource datasource) throws IOException { + public IndexResponse updateDatasource(final Datasource datasource) { datasource.setLastUpdateTime(Instant.now()); - return client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME) - .setId(datasource.getName()) - .setOpType(DocWriteRequest.OpType.INDEX) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .execute() - .actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT)); + return StashedThreadContext.run(client, () -> { + try { + return client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME) + .setId(datasource.getName()) + .setOpType(DocWriteRequest.OpType.INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .execute() + .actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); } /** @@ -144,16 +150,21 @@ public IndexResponse updateDatasource(final Datasource datasource) throws IOExce * * @param datasource the datasource * @param listener the listener - * @throws IOException exception */ - public void putDatasource(final Datasource datasource, final ActionListener listener) throws IOException { + public void putDatasource(final Datasource datasource, final ActionListener listener) { datasource.setLastUpdateTime(Instant.now()); - client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME) - .setId(datasource.getName()) - .setOpType(DocWriteRequest.OpType.CREATE) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .execute(listener); + StashedThreadContext.run(client, () -> { + try { + client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME) + .setId(datasource.getName()) + .setOpType(DocWriteRequest.OpType.CREATE) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .execute(listener); + } catch (IOException e) { + new RuntimeException(e); + } + }); } /** @@ -166,7 +177,7 @@ public Datasource getDatasource(final String name) throws IOException { GetRequest request = new GetRequest(DatasourceExtension.JOB_INDEX_NAME, name); GetResponse response; try { - response = client.get(request).actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT)); + response = StashedThreadContext.run(client, () -> client.get(request).actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT))); if (response.isExists() == false) { log.error("Datasource[{}] does not exist in an index[{}]", name, DatasourceExtension.JOB_INDEX_NAME); return null; @@ -191,7 +202,7 @@ public Datasource getDatasource(final String name) throws IOException { */ public void getDatasource(final String name, final ActionListener actionListener) { GetRequest request = new GetRequest(DatasourceExtension.JOB_INDEX_NAME, name); - client.get(request, new ActionListener() { + StashedThreadContext.run(client, () -> client.get(request, new ActionListener<>() { @Override public void onResponse(final GetResponse response) { if (response.isExists() == false) { @@ -215,7 +226,7 @@ public void onResponse(final GetResponse response) { public void onFailure(final Exception e) { actionListener.onFailure(e); } - }); + })); } /** @@ -224,9 +235,12 @@ public void onFailure(final Exception e) { * @param actionListener the action listener */ public void getDatasources(final String[] names, final ActionListener> actionListener) { - client.prepareMultiGet() - .add(DatasourceExtension.JOB_INDEX_NAME, names) - .execute(createGetDataSourceQueryActionLister(MultiGetResponse.class, actionListener)); + StashedThreadContext.run( + client, + () -> client.prepareMultiGet() + .add(DatasourceExtension.JOB_INDEX_NAME, names) + .execute(createGetDataSourceQueryActionLister(MultiGetResponse.class, actionListener)) + ); } /** @@ -234,10 +248,13 @@ public void getDatasources(final String[] names, final ActionListener> actionListener) { - client.prepareSearch(DatasourceExtension.JOB_INDEX_NAME) - .setQuery(QueryBuilders.matchAllQuery()) - .setSize(MAX_SIZE) - .execute(createGetDataSourceQueryActionLister(SearchResponse.class, actionListener)); + StashedThreadContext.run( + client, + () -> client.prepareSearch(DatasourceExtension.JOB_INDEX_NAME) + .setQuery(QueryBuilders.matchAllQuery()) + .setSize(MAX_SIZE) + .execute(createGetDataSourceQueryActionLister(SearchResponse.class, actionListener)) + ); } private ActionListener createGetDataSourceQueryActionLister( diff --git a/src/main/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacade.java b/src/main/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacade.java index dd498d1b..3002cc89 100644 --- a/src/main/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacade.java +++ b/src/main/java/org/opensearch/geospatial/ip2geo/common/GeoIpDataFacade.java @@ -55,6 +55,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.geospatial.shared.StashedThreadContext; import org.opensearch.index.query.QueryBuilders; /** @@ -95,7 +96,10 @@ public void createIndexIfNotExists(final String indexName) { indexSettings.put(INDEX_SETTING_AUTO_EXPAND_REPLICAS.v1(), INDEX_SETTING_AUTO_EXPAND_REPLICAS.v2()); indexSettings.put(INDEX_SETTING_HIDDEN.v1(), INDEX_SETTING_HIDDEN.v2()); final CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).settings(indexSettings).mapping(getIndexMapping()); - client.admin().indices().create(createIndexRequest).actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT)); + StashedThreadContext.run( + client, + () -> client.admin().indices().create(createIndexRequest).actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT)) + ); } /** @@ -210,31 +214,34 @@ public String createDocument(final String[] fields, final String[] values) { * @param actionListener action listener */ public void getGeoIpData(final String indexName, final String ip, final ActionListener> actionListener) { - client.prepareSearch(indexName) - .setSize(1) - .setQuery(QueryBuilders.termQuery(IP_RANGE_FIELD_NAME, ip)) - .setPreference("_local") - .setRequestCache(true) - .execute(new ActionListener<>() { - @Override - public void onResponse(final SearchResponse searchResponse) { - if (searchResponse.getHits().getHits().length == 0) { - actionListener.onResponse(Collections.emptyMap()); - } else { - Map geoIpData = (Map) XContentHelper.convertToMap( - searchResponse.getHits().getAt(0).getSourceRef(), - false, - XContentType.JSON - ).v2().get(DATA_FIELD_NAME); - actionListener.onResponse(geoIpData); + StashedThreadContext.run( + client, + () -> client.prepareSearch(indexName) + .setSize(1) + .setQuery(QueryBuilders.termQuery(IP_RANGE_FIELD_NAME, ip)) + .setPreference("_local") + .setRequestCache(true) + .execute(new ActionListener<>() { + @Override + public void onResponse(final SearchResponse searchResponse) { + if (searchResponse.getHits().getHits().length == 0) { + actionListener.onResponse(Collections.emptyMap()); + } else { + Map geoIpData = (Map) XContentHelper.convertToMap( + searchResponse.getHits().getAt(0).getSourceRef(), + false, + XContentType.JSON + ).v2().get(DATA_FIELD_NAME); + actionListener.onResponse(geoIpData); + } } - } - @Override - public void onFailure(final Exception e) { - actionListener.onFailure(e); - } - }); + @Override + public void onFailure(final Exception e) { + actionListener.onFailure(e); + } + }) + ); } /** @@ -284,7 +291,7 @@ public void getGeoIpData( return; } - mRequestBuilder.execute(new ActionListener<>() { + StashedThreadContext.run(client, () -> mRequestBuilder.execute(new ActionListener<>() { @Override public void onResponse(final MultiSearchResponse items) { for (int i = 0; i < ipsToSearch.size(); i++) { @@ -318,7 +325,7 @@ public void onResponse(final MultiSearchResponse items) { public void onFailure(final Exception e) { actionListener.onFailure(e); } - }); + })); } /** @@ -331,14 +338,14 @@ public void onFailure(final Exception e) { */ public void putGeoIpData(final String indexName, final String[] fields, final Iterator iterator, final int bulkSize) { TimeValue timeout = clusterSettings.get(Ip2GeoSettings.TIMEOUT); - BulkRequest bulkRequest = new BulkRequest().setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + final BulkRequest bulkRequest = new BulkRequest().setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); while (iterator.hasNext()) { CSVRecord record = iterator.next(); String document = createDocument(fields, record.values()); IndexRequest request = Requests.indexRequest(indexName).source(document, XContentType.JSON); bulkRequest.add(request); if (iterator.hasNext() == false || bulkRequest.requests().size() == bulkSize) { - BulkResponse response = client.bulk(bulkRequest).actionGet(timeout); + BulkResponse response = StashedThreadContext.run(client, () -> client.bulk(bulkRequest).actionGet(timeout)); if (response.hasFailures()) { throw new OpenSearchException( "error occurred while ingesting GeoIP data in {} with an error {}", @@ -346,17 +353,19 @@ public void putGeoIpData(final String indexName, final String[] fields, final It response.buildFailureMessage() ); } - bulkRequest = new BulkRequest().setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + bulkRequest.requests().clear(); } } - client.admin().indices().prepareRefresh(indexName).execute().actionGet(timeout); - client.admin().indices().prepareForceMerge(indexName).setMaxNumSegments(1).execute().actionGet(timeout); - client.admin() - .indices() - .prepareUpdateSettings(indexName) - .setSettings(Map.of(INDEX_SETTING_READ_ONLY_ALLOW_DELETE.v1(), INDEX_SETTING_READ_ONLY_ALLOW_DELETE.v2())) - .execute() - .actionGet(timeout); + StashedThreadContext.run(client, () -> { + client.admin().indices().prepareRefresh(indexName).execute().actionGet(timeout); + client.admin().indices().prepareForceMerge(indexName).setMaxNumSegments(1).execute().actionGet(timeout); + client.admin() + .indices() + .prepareUpdateSettings(indexName) + .setSettings(Map.of(INDEX_SETTING_READ_ONLY_ALLOW_DELETE.v1(), INDEX_SETTING_READ_ONLY_ALLOW_DELETE.v2())) + .execute() + .actionGet(timeout); + }); } public AcknowledgedResponse deleteIp2GeoDataIndex(final String index) { @@ -367,11 +376,14 @@ public AcknowledgedResponse deleteIp2GeoDataIndex(final String index) { IP2GEO_DATA_INDEX_NAME_PREFIX ); } - return client.admin() - .indices() - .prepareDelete(index) - .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) - .execute() - .actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT)); + return StashedThreadContext.run( + client, + () -> client.admin() + .indices() + .prepareDelete(index) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .execute() + .actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT)) + ); } } diff --git a/src/main/java/org/opensearch/geospatial/shared/StashedThreadContext.java b/src/main/java/org/opensearch/geospatial/shared/StashedThreadContext.java new file mode 100644 index 00000000..1ee59297 --- /dev/null +++ b/src/main/java/org/opensearch/geospatial/shared/StashedThreadContext.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.geospatial.shared; + +import java.util.function.Supplier; + +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; + +/** + * Helper class to run code with stashed thread context + * + * Code need to be run with stashed thread context if it interacts with system index + * when security plugin is enabled. + */ +public class StashedThreadContext { + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing + */ + public static void run(final Client client, final Runnable function) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + function.run(); + } + } + + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function supplier function that needs to be executed after thread context has been stashed, return object + */ + public static T run(final Client client, final Supplier function) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + return function.get(); + } + } +}