Skip to content

Commit

Permalink
Fix wrong 503 error response code
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <[email protected]>
  • Loading branch information
vmmusings committed Jan 31, 2024
1 parent 6fcf31b commit 1213b54
Show file tree
Hide file tree
Showing 14 changed files with 137 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
package org.opensearch.sql.datasources.rest;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.NOT_FOUND;
import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.rest.RestRequest.Method.*;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -293,7 +293,7 @@ private void handleException(Exception e, RestChannel restChannel) {
reportError(restChannel, e, BAD_REQUEST);
} else {
MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_SYS);
reportError(restChannel, e, SERVICE_UNAVAILABLE);
reportError(restChannel, e, INTERNAL_SERVER_ERROR);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public void queryExceedResourceLimitShouldFail() throws IOException {
String query = String.format("search source=%s age=20", TEST_INDEX_DOG);

ResponseException exception = expectThrows(ResponseException.class, () -> executeQuery(query));
assertEquals(503, exception.getResponse().getStatusLine().getStatusCode());
assertEquals(500, exception.getResponse().getStatusLine().getStatusCode());
assertThat(
exception.getMessage(),
Matchers.containsString("resource is not enough to run the" + " query, quit."));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
package org.opensearch.sql.legacy.plugin;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.OK;
import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;

import com.alibaba.druid.sql.parser.ParserException;
import com.google.common.collect.ImmutableList;
Expand All @@ -23,6 +23,7 @@
import java.util.regex.Pattern;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.inject.Injector;
Expand Down Expand Up @@ -171,21 +172,23 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
QueryAction queryAction = explainRequest(client, sqlRequest, format);
executeSqlRequest(request, queryAction, client, restChannel);
} catch (Exception e) {
logAndPublishMetrics(e);
reportError(restChannel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
handleException(restChannel, e);
}
},
(restChannel, exception) -> {
logAndPublishMetrics(exception);
reportError(
restChannel,
exception,
isClientError(exception) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
});
this::handleException);
} catch (Exception e) {
logAndPublishMetrics(e);
return channel ->
reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
return channel -> handleException(channel, e);
}
}

private void handleException(RestChannel restChannel, Exception exception) {
logAndPublishMetrics(exception);
if (exception instanceof OpenSearchSecurityException) {
OpenSearchSecurityException securityException = (OpenSearchSecurityException) exception;
reportError(restChannel, exception, securityException.status());
} else {
reportError(
restChannel, exception, isClientError(exception) ? BAD_REQUEST : INTERNAL_SERVER_ERROR);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.legacy.plugin;

import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
Expand Down Expand Up @@ -77,8 +77,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
return channel ->
channel.sendResponse(
new BytesRestResponse(
SERVICE_UNAVAILABLE,
ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus())
INTERNAL_SERVER_ERROR,
ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus())
.toString()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,15 @@ public void collect(
indexToType.put(tableName, null);
} else if (sqlExprTableSource.getExpr() instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExprTableSource.getExpr();
tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName();
SQLExpr leftSideOfExpression = sqlBinaryOpExpr.getLeft();
if (leftSideOfExpression instanceof SQLIdentifierExpr) {
tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName();
} else {
throw new ParserException(
"Left side of the expression ["
+ leftSideOfExpression.toString()
+ "] is expected to be an identifier");
}
SQLExpr rightSideOfExpression = sqlBinaryOpExpr.getRight();

// This assumes that right side of the expression is different name in query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.parser.ParserException;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -100,6 +101,15 @@ public void testSelectTheFieldWithConflictMappingShouldThrowException() {
rewriteTerm(sql);
}

@Test
public void testIssue2391_WithWrongBinaryOperation() {
String sql = "SELECT * from I_THINK/IM/A_URL";
exception.expect(ParserException.class);
exception.expectMessage(
"Left side of the expression [I_THINK / IM] is expected to be an identifier");
rewriteTerm(sql);
}

private String rewriteTerm(String sql) {
SQLQueryExpr sqlQueryExpr = SqlParserUtils.parse(sql);
sqlQueryExpr.accept(new TermFieldRewriter());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService(
StateStore stateStore = new StateStore(client, clusterService);
registerStateStoreMetrics(stateStore);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
new OpensearchAsyncQueryJobMetadataStorageService(stateStore, this.dataSourceService);
EMRServerlessClient emrServerlessClient =
createEMRServerlessClient(sparkExecutionEngineConfig.getRegion());
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.OK;
import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
Expand Down Expand Up @@ -130,7 +129,7 @@ public void onFailure(Exception e) {
Metrics.getInstance()
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS)
.increment();
reportError(channel, e, SERVICE_UNAVAILABLE);
reportError(channel, e, INTERNAL_SERVER_ERROR);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.plugin.rest;

import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
Expand Down Expand Up @@ -75,8 +75,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
return channel ->
channel.sendResponse(
new BytesRestResponse(
SERVICE_UNAVAILABLE,
ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus())
INTERNAL_SERVER_ERROR,
ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus())
.toString()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.execution.statestore.StateStore;
Expand All @@ -22,6 +25,8 @@ public class OpensearchAsyncQueryJobMetadataStorageService

private final StateStore stateStore;

private final DataSourceService dataSourceService;

@Override
public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) {
AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId();
Expand All @@ -31,6 +36,10 @@ public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) {
@Override
public Optional<AsyncQueryJobMetadata> getJobMetadata(String qid) {
AsyncQueryId queryId = new AsyncQueryId(qid);
if (!dataSourceService.dataSourceExists(queryId.getDataSourceName())
|| StringUtils.isEmpty(queryId.getId())) {
throw new AsyncQueryNotFoundException(String.format("Invalid queryId: %s", qid));
}
return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName())
.apply(queryId.docId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.sql.spark.rest;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.TOO_MANY_REQUESTS;
import static org.opensearch.rest.RestRequest.Method.DELETE;
import static org.opensearch.rest.RestRequest.Method.GET;
Expand All @@ -26,10 +26,12 @@
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException;
import org.opensearch.sql.datasources.exceptions.ErrorMessage;
import org.opensearch.sql.datasources.utils.Scheduler;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
Expand Down Expand Up @@ -112,12 +114,12 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
}
}

private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient)
throws IOException {
MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT);
CreateAsyncQueryRequest submitJobRequest =
CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser());
return restChannel ->
private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) {
return restChannel -> {
try {
MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT);
CreateAsyncQueryRequest submitJobRequest =
CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser());
Scheduler.schedule(
nodeClient,
() ->
Expand All @@ -140,6 +142,10 @@ public void onFailure(Exception e) {
handleException(e, restChannel, restRequest.method());
}
}));
} catch (Exception e) {
handleException(e, restChannel, restRequest.method());
}
};
}

private RestChannelConsumer executeGetAsyncQueryResultRequest(
Expand Down Expand Up @@ -187,7 +193,7 @@ private void handleException(
reportError(restChannel, e, BAD_REQUEST);
addCustomerErrorMetric(requestMethod);
} else {
reportError(restChannel, e, SERVICE_UNAVAILABLE);
reportError(restChannel, e, INTERNAL_SERVER_ERROR);
addSystemErrorMetric(requestMethod);
}
}
Expand Down Expand Up @@ -227,7 +233,10 @@ private void reportError(final RestChannel channel, final Exception e, final Res
}

private static boolean isClientError(Exception e) {
return e instanceof IllegalArgumentException || e instanceof IllegalStateException;
return e instanceof IllegalArgumentException
|| e instanceof IllegalStateException
|| e instanceof DataSourceNotFoundException
|| e instanceof AsyncQueryNotFoundException;
}

private void addSystemErrorMetric(RestRequest.Method requestMethod) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,28 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser)
LangType lang = null;
String datasource = null;
String sessionId = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
if (fieldName.equals("query")) {
query = parser.textOrNull();
} else if (fieldName.equals("lang")) {
String langString = parser.textOrNull();
lang = LangType.fromString(langString);
} else if (fieldName.equals("datasource")) {
datasource = parser.textOrNull();
} else if (fieldName.equals(SESSION_ID)) {
sessionId = parser.textOrNull();
} else {
throw new IllegalArgumentException("Unknown field: " + fieldName);
try {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
if (fieldName.equals("query")) {
query = parser.textOrNull();
} else if (fieldName.equals("lang")) {
String langString = parser.textOrNull();
lang = LangType.fromString(langString);
} else if (fieldName.equals("datasource")) {
datasource = parser.textOrNull();
} else if (fieldName.equals(SESSION_ID)) {
sessionId = parser.textOrNull();
} else {
throw new IllegalArgumentException("Unknown field: " + fieldName);
}
}
return new CreateAsyncQueryRequest(query, datasource, lang, sessionId);
} catch (Exception e) {
throw new IllegalArgumentException(
String.format("Error while parsing the request body: %s", e.getMessage()));
}
return new CreateAsyncQueryRequest(query, datasource, lang, sessionId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
JobExecutionResponseReader jobExecutionResponseReader) {
StateStore stateStore = new StateStore(client, clusterService);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
new OpensearchAsyncQueryJobMetadataStorageService(stateStore, dataSourceService);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClient,
Expand Down
Loading

0 comments on commit 1213b54

Please sign in to comment.