Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change ApplicationContext lifecycle to per node level #822

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import org.opensearch.sql.expression.operator.predicate.UnaryPredicateOperator;
import org.opensearch.sql.expression.text.TextFunction;
import org.opensearch.sql.expression.window.WindowFunctions;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Scope;

/**
* Expression Config for Spring IoC.
Expand All @@ -32,6 +34,7 @@ public class ExpressionConfig {
* BuiltinFunctionRepository constructor.
*/
@Bean
@Scope(value = ConfigurableBeanFactory.SCOPE_PROTOTYPE)
public BuiltinFunctionRepository functionRepository() {
BuiltinFunctionRepository builtinFunctionRepository =
new BuiltinFunctionRepository(new HashMap<>());
Expand All @@ -50,6 +53,7 @@ public BuiltinFunctionRepository functionRepository() {
}

@Bean
@Scope(value = ConfigurableBeanFactory.SCOPE_PROTOTYPE)
public DSL dsl(BuiltinFunctionRepository repository) {
return new DSL(repository);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,17 @@
import static org.opensearch.sql.executor.ExecutionEngine.QueryResponse;
import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.PRETTY;

import java.io.IOException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import javax.xml.catalog.Catalog;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestStatus;
import org.opensearch.sql.catalog.CatalogService;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.common.response.ResponseListener;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.metrics.Metrics;
Expand All @@ -41,7 +35,6 @@
import org.opensearch.sql.protocol.response.format.RawResponseFormatter;
import org.opensearch.sql.protocol.response.format.ResponseFormatter;
import org.opensearch.sql.sql.SQLService;
import org.opensearch.sql.sql.config.SQLServiceConfig;
import org.opensearch.sql.sql.domain.SQLQueryRequest;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;

Expand All @@ -56,23 +49,14 @@ public class RestSQLQueryAction extends BaseRestHandler {

public static final RestChannelConsumer NOT_SUPPORTED_YET = null;

private final ClusterService clusterService;

/**
* Settings required by been initialization.
*/
private final Settings pluginSettings;

private final CatalogService catalogService;
private final AnnotationConfigApplicationContext applicationContext;

/**
* Constructor of RestSQLQueryAction.
*/
public RestSQLQueryAction(ClusterService clusterService, Settings pluginSettings, CatalogService catalogService) {
public RestSQLQueryAction(AnnotationConfigApplicationContext applicationContext) {
super();
this.clusterService = clusterService;
this.pluginSettings = pluginSettings;
this.catalogService = catalogService;
this.applicationContext = applicationContext;
}

@Override
Expand Down Expand Up @@ -101,7 +85,8 @@ public RestChannelConsumer prepareRequest(SQLQueryRequest request, NodeClient no
return NOT_SUPPORTED_YET;
}

SQLService sqlService = createSQLService(nodeClient);
SQLService sqlService =
SecurityAccess.doPrivileged(() -> applicationContext.getBean(SQLService.class));
PhysicalPlan plan;
try {
// For now analyzing and planning stage may throw syntax exception as well
Expand All @@ -123,20 +108,6 @@ public RestChannelConsumer prepareRequest(SQLQueryRequest request, NodeClient no
return channel -> sqlService.execute(plan, createQueryResponseListener(channel, request));
}

private SQLService createSQLService(NodeClient client) {
return doPrivileged(() -> {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.registerBean(ClusterService.class, () -> clusterService);
context.registerBean(NodeClient.class, () -> client);
context.registerBean(Settings.class, () -> pluginSettings);
context.registerBean(CatalogService.class, () -> catalogService);
context.register(OpenSearchSQLPluginConfig.class);
context.register(SQLServiceConfig.class);
context.refresh();
return context.getBean(SQLService.class);
});
}

private ResponseListener<ExplainResponse> createExplainResponseListener(RestChannel channel) {
return new ResponseListener<ExplainResponse>() {
@Override
Expand Down Expand Up @@ -185,14 +156,6 @@ public void onFailure(Exception e) {
};
}

private <T> T doPrivileged(PrivilegedExceptionAction<T> action) {
try {
return SecurityAccess.doPrivileged(action);
} catch (IOException e) {
throw new IllegalStateException("Failed to perform privileged action", e);
}
}

private void sendResponse(RestChannel channel, RestStatus status, String content) {
channel.sendResponse(new BytesRestResponse(
status, "application/json; charset=UTF-8", content));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestStatus;
import org.opensearch.sql.catalog.CatalogService;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.common.utils.QueryContext;
import org.opensearch.sql.exception.ExpressionEvaluationException;
Expand Down Expand Up @@ -65,6 +63,7 @@
import org.opensearch.sql.legacy.utils.JsonPrettyFormatter;
import org.opensearch.sql.legacy.utils.QueryDataAnonymizer;
import org.opensearch.sql.sql.domain.SQLQueryRequest;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;

public class RestSqlAction extends BaseRestHandler {

Expand All @@ -89,12 +88,16 @@ public class RestSqlAction extends BaseRestHandler {
*/
private final RestSQLQueryAction newSqlQueryHandler;

public RestSqlAction(Settings settings, ClusterService clusterService,
org.opensearch.sql.common.setting.Settings pluginSettings,
CatalogService catalogService) {
/**
* Application context used to create SQLService for each request.
*/
private final AnnotationConfigApplicationContext applicationContext;

public RestSqlAction(Settings settings, AnnotationConfigApplicationContext applicationContext) {
super();
this.allowExplicitIndex = MULTI_ALLOW_EXPLICIT_INDEX.get(settings);
this.newSqlQueryHandler = new RestSQLQueryAction(clusterService, pluginSettings, catalogService);
this.newSqlQueryHandler = new RestSQLQueryAction(applicationContext);
this.applicationContext = applicationContext;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.legacy.plugin.RestSQLQueryAction.NOT_SUPPORTED_YET;
import static org.opensearch.sql.legacy.plugin.RestSqlAction.EXPLAIN_API_ENDPOINT;
import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT;
Expand All @@ -18,36 +17,38 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.sql.catalog.CatalogService;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.executor.ExecutionEngine;
import org.opensearch.sql.sql.config.SQLServiceConfig;
import org.opensearch.sql.sql.domain.SQLQueryRequest;
import org.opensearch.sql.storage.StorageEngine;
import org.opensearch.threadpool.ThreadPool;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;

@RunWith(MockitoJUnitRunner.class)
public class RestSQLQueryActionTest {

@Mock
private ClusterService clusterService;

private NodeClient nodeClient;

@Mock
private ThreadPool threadPool;

@Mock
private Settings settings;

@Mock
private CatalogService catalogService;
private AnnotationConfigApplicationContext context;

@Before
public void setup() {
nodeClient = new NodeClient(org.opensearch.common.settings.Settings.EMPTY, threadPool);
when(threadPool.getThreadContext())
context = new AnnotationConfigApplicationContext();
context.registerBean(StorageEngine.class, () -> Mockito.mock(StorageEngine.class));
context.registerBean(ExecutionEngine.class, () -> Mockito.mock(ExecutionEngine.class));
context.registerBean(CatalogService.class, () -> Mockito.mock(CatalogService.class));
context.register(SQLServiceConfig.class);
context.refresh();
Mockito.lenient().when(threadPool.getThreadContext())
.thenReturn(new ThreadContext(org.opensearch.common.settings.Settings.EMPTY));
}

Expand All @@ -59,7 +60,7 @@ public void handleQueryThatCanSupport() {
QUERY_API_ENDPOINT,
"");

RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService);
RestSQLQueryAction queryAction = new RestSQLQueryAction(context);
assertNotSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient));
}

Expand All @@ -71,7 +72,7 @@ public void handleExplainThatCanSupport() {
EXPLAIN_API_ENDPOINT,
"");

RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService);
RestSQLQueryAction queryAction = new RestSQLQueryAction(context);
assertNotSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient));
}

Expand All @@ -84,7 +85,7 @@ public void skipQueryThatNotSupport() {
QUERY_API_ENDPOINT,
"");

RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService);
RestSQLQueryAction queryAction = new RestSQLQueryAction(context);
assertSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

package org.opensearch.sql.opensearch.security;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
Expand All @@ -21,13 +20,12 @@ public class SecurityAccess {
/**
* Execute the operation in privileged mode.
*/
public static <T> T doPrivileged(final PrivilegedExceptionAction<T> operation)
throws IOException {
public static <T> T doPrivileged(final PrivilegedExceptionAction<T> operation) {
SpecialPermission.check();
try {
return AccessController.doPrivileged(operation);
} catch (final PrivilegedActionException e) {
throw (IOException) e.getCause();
throw new IllegalStateException("Failed to perform privileged action", e);
}
}
}
Loading