Skip to content

Commit

Permalink
refactor the config
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenpyzhang committed Dec 13, 2019
1 parent 919924e commit 286a5f4
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.confluent.ksql.parser.KsqlParser.ParsedStatement;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.query.id.SpecificQueryIdGenerator;
import io.confluent.ksql.rest.ErrorMessages;
import io.confluent.ksql.rest.Errors;
import io.confluent.ksql.rest.client.RestResponse;
import io.confluent.ksql.rest.entity.KsqlEntityList;
Expand Down Expand Up @@ -382,10 +383,10 @@ protected void registerWebSocketEndpoints(final ServerContainer container) {
final StatementParser statementParser = new StatementParser(ksqlEngine);
final Optional<KsqlAuthorizationValidator> authorizationValidator =
KsqlAuthorizationValidatorFactory.create(ksqlConfigNoPort, serviceContext);
final Errors errorHandler = restConfig.getConfiguredInstance(
final Errors errorHandler = new Errors(restConfig.getConfiguredInstance(
KsqlRestConfig.KSQL_SERVER_ERRORS,
Errors.class
);
ErrorMessages.class
));

container.addEndpoint(
ServerEndpointConfig.Builder
Expand Down Expand Up @@ -504,10 +505,10 @@ static KsqlRestApplication buildApplication(
final Optional<KsqlAuthorizationValidator> authorizationValidator =
KsqlAuthorizationValidatorFactory.create(ksqlConfig, serviceContext);

final Errors errorHandler = restConfig.getConfiguredInstance(
final Errors errorHandler = new Errors(restConfig.getConfiguredInstance(
KsqlRestConfig.KSQL_SERVER_ERRORS,
Errors.class
);
ErrorMessages.class
));

final StreamedQueryResource streamedQueryResource = new StreamedQueryResource(
ksqlEngine,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

package io.confluent.ksql.rest.server;

import io.confluent.ksql.rest.DefaultErrorsImpl;
import io.confluent.ksql.rest.DefaultErrorMessages;
import io.confluent.ksql.util.KsqlException;
import io.confluent.rest.RestConfig;
import java.util.Map;
Expand Down Expand Up @@ -146,7 +146,7 @@ public class KsqlRestConfig extends RestConfig {
).define(
KSQL_SERVER_ERRORS,
Type.CLASS,
DefaultErrorsImpl.class,
DefaultErrorMessages.class,
Importance.LOW,
KSQL_SERVER_ERRORS_DOC
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ private Response handleStatement(
"Statement type `%s' not supported for this resource",
statement.getClass().getName()));
} catch (final TopicAuthorizationException e) {
return errorHandler.accessDeniedFromKafkaResponse(e);
return errorHandler.accessDeniedFromKafka(e);
} catch (final KsqlStatementException e) {
return Errors.badStatement(e.getRawMessage(), e.getSqlStatement());
} catch (final KsqlException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ public void onOpen(final Session session, final EndpointConfig unused) {
} catch (final TopicAuthorizationException e) {
log.debug("Error processing request", e);
SessionUtil.closeSilently(
session, CloseCodes.CANNOT_ACCEPT, errorHandler.webSocketAuthorizationErrorMessage(e));
session,
CloseCodes.CANNOT_ACCEPT,
errorHandler.webSocketKafkaAuthorizationErrorMessage(e));
} catch (final Exception e) {
log.debug("Error processing request", e);
SessionUtil.closeSilently(session, CloseCodes.CANNOT_ACCEPT, e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public static Response generateResponse(
final Errors errorHandler
) {
if (ExceptionUtils.indexOfType(e, TopicAuthorizationException.class) >= 0) {
return errorHandler.accessDeniedFromKafkaResponse(e);
return errorHandler.accessDeniedFromKafka(e);
} else {
return defaultResponse;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.query.QueryId;
import io.confluent.ksql.query.id.SpecificQueryIdGenerator;
import io.confluent.ksql.rest.DefaultErrorsImpl;
import io.confluent.ksql.rest.Errors;
import io.confluent.ksql.rest.entity.CommandId;
import io.confluent.ksql.rest.entity.CommandId.Action;
Expand Down Expand Up @@ -223,7 +222,7 @@ private class KsqlServer {
Duration.ofMillis(0),
()->{},
Optional.of((sc, metastore, statement) -> { }),
new DefaultErrorsImpl()
mock(Errors.class)
);

this.statementExecutor.configure(ksqlConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static io.confluent.ksql.parser.ParserMatchers.configured;
import static io.confluent.ksql.parser.ParserMatchers.preparedStatement;
import static io.confluent.ksql.parser.ParserMatchers.preparedStatementText;
import static io.confluent.ksql.rest.Errors.ERROR_CODE_FORBIDDEN_KAFKA_ACCESS;
import static io.confluent.ksql.rest.entity.CommandId.Action.CREATE;
import static io.confluent.ksql.rest.entity.CommandId.Action.DROP;
import static io.confluent.ksql.rest.entity.CommandId.Action.EXECUTE;
Expand All @@ -31,6 +32,7 @@
import static io.confluent.ksql.rest.server.resources.KsqlRestExceptionMatchers.exceptionStatementErrorMessage;
import static io.confluent.ksql.rest.server.resources.KsqlRestExceptionMatchers.exceptionStatusCode;
import static java.util.Collections.emptyMap;
import static javax.ws.rs.core.Response.Status.FORBIDDEN;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.hasItem;
Expand Down Expand Up @@ -87,7 +89,6 @@
import io.confluent.ksql.parser.tree.TableElement.Namespace;
import io.confluent.ksql.parser.tree.TableElements;
import io.confluent.ksql.parser.tree.TerminateQuery;
import io.confluent.ksql.rest.DefaultErrorsImpl;
import io.confluent.ksql.rest.Errors;
import io.confluent.ksql.rest.entity.ClusterTerminateRequest;
import io.confluent.ksql.rest.entity.CommandId;
Expand Down Expand Up @@ -238,6 +239,11 @@ public class KsqlResourceTest {
.valueColumn(ColumnName.of("f1"), SqlTypes.STRING)
.build();

private static Response AUTHORIZATION_ERROR_RESPONSE = Response
.status(FORBIDDEN)
.entity(new KsqlErrorMessage(ERROR_CODE_FORBIDDEN_KAFKA_ACCESS, "some error"))
.build();

@Rule
public final ExpectedException expectedException = ExpectedException.none();

Expand Down Expand Up @@ -334,6 +340,8 @@ public void setUp() throws IOException, RestClientException {
when(topicInjector.inject(any()))
.thenAnswer(inv -> inv.getArgument(0));

when(errorsHandler.accessDeniedFromKafka(any(Exception.class))).thenReturn(AUTHORIZATION_ERROR_RESPONSE);

setUpKsqlResource();
}

Expand Down Expand Up @@ -782,9 +790,6 @@ public void shouldReturnForbiddenKafkaAccessIfRootCauseKsqlTopicAuthorizationExc
// Then:
assertThat(result, is(instanceOf(KsqlErrorMessage.class)));
assertThat(result.getErrorCode(), is(Errors.ERROR_CODE_FORBIDDEN_KAFKA_ACCESS));
assertThat(result.getMessage(), is(
"Could not delete the corresponding kafka topic: topic\n" +
"Caused by: Authorization denied to Delete on topic(s): [topic]"));
}

@Test
Expand Down Expand Up @@ -2084,7 +2089,7 @@ private void setUpKsqlResource() {
topicInjectorFactory.apply(ec),
new TopicDeleteInjector(ec, sc)),
Optional.of(authorizationValidator),
new DefaultErrorsImpl()
errorsHandler
);

ksqlResource.configure(ksqlConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import static io.confluent.ksql.rest.entity.KsqlErrorMessageMatchers.errorCode;
import static io.confluent.ksql.rest.entity.KsqlErrorMessageMatchers.errorMessage;
import static io.confluent.ksql.rest.Errors.ERROR_CODE_FORBIDDEN_KAFKA_ACCESS;
import static io.confluent.ksql.rest.server.resources.KsqlRestExceptionMatchers.exceptionErrorMessage;
import static io.confluent.ksql.rest.server.resources.KsqlRestExceptionMatchers.exceptionStatusCode;
import static javax.ws.rs.core.Response.Status.FORBIDDEN;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -46,7 +48,6 @@
import io.confluent.ksql.parser.tree.PrintTopic;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.rest.DefaultErrorsImpl;
import io.confluent.ksql.rest.Errors;
import io.confluent.ksql.rest.entity.KsqlErrorMessage;
import io.confluent.ksql.rest.entity.KsqlRequest;
Expand Down Expand Up @@ -112,6 +113,11 @@ public class StreamedQueryResourceTest {
StreamsConfig.APPLICATION_SERVER_CONFIG, "something:1"
));
private static final Long closeTimeout = KsqlConfig.KSQL_SHUTDOWN_TIMEOUT_MS_DEFAULT;

private static Response AUTHORIZATION_ERROR_RESPONSE = Response
.status(FORBIDDEN)
.entity(new KsqlErrorMessage(ERROR_CODE_FORBIDDEN_KAFKA_ACCESS, "some error"))
.build();

private static final String TOPIC_NAME = "test_stream";
private static final String PUSH_QUERY_STRING = "SELECT * FROM " + TOPIC_NAME + " EMIT CHANGES;";
Expand All @@ -137,6 +143,8 @@ public class StreamedQueryResourceTest {
private Consumer<QueryMetadata> queryCloseCallback;
@Mock
private KsqlAuthorizationValidator authorizationValidator;
@Mock
private Errors errorsHandler;
private StreamedQueryResource testResource;
private PreparedStatement<Statement> invalid;
private PreparedStatement<Query> query;
Expand All @@ -153,6 +161,7 @@ public void setup() {
when(pullQuery.isPullQuery()).thenReturn(true);
final PreparedStatement<Statement> pullQueryStatement = PreparedStatement.of(PULL_QUERY_STRING, pullQuery);
when(mockStatementParser.parseSingleStatement(PULL_QUERY_STRING)).thenReturn(pullQueryStatement);
when(errorsHandler.accessDeniedFromKafka(any(Exception.class))).thenReturn(AUTHORIZATION_ERROR_RESPONSE);

testResource = new StreamedQueryResource(
mockKsqlEngine,
Expand All @@ -162,7 +171,7 @@ public void setup() {
COMMAND_QUEUE_CATCHUP_TIMOEUT,
activenessRegistrar,
Optional.of(authorizationValidator),
new DefaultErrorsImpl()
errorsHandler
);

testResource.configure(VALID_CONFIG);
Expand All @@ -188,7 +197,7 @@ public void shouldThrowOnHandleStatementIfNotConfigured() {
COMMAND_QUEUE_CATCHUP_TIMOEUT,
activenessRegistrar,
Optional.of(authorizationValidator),
new DefaultErrorsImpl()
errorsHandler
);

// Then:
Expand Down Expand Up @@ -553,12 +562,9 @@ public void shouldReturnForbiddenKafkaAccessIfKsqlTopicAuthorizationException()
new KsqlRequest(PUSH_QUERY_STRING, Collections.emptyMap(), null)
);

final Response expected = Errors.accessDeniedFromKafka(
new KsqlTopicAuthorizationException(AclOperation.READ, Collections.singleton(TOPIC_NAME)));

final KsqlErrorMessage responseEntity = (KsqlErrorMessage) response.getEntity();
final KsqlErrorMessage expectedEntity = (KsqlErrorMessage) expected.getEntity();
assertEquals(response.getStatus(), expected.getStatus());
final KsqlErrorMessage expectedEntity = (KsqlErrorMessage) AUTHORIZATION_ERROR_RESPONSE.getEntity();
assertEquals(response.getStatus(), AUTHORIZATION_ERROR_RESPONSE.getStatus());
assertEquals(responseEntity.getMessage(), expectedEntity.getMessage());
}

Expand All @@ -578,14 +584,9 @@ public void shouldReturnForbiddenKafkaAccessIfRootCauseKsqlTopicAuthorizationExc
new KsqlRequest(PUSH_QUERY_STRING, Collections.emptyMap(), null)
);

final Response expected = Errors.accessDeniedFromKafka(
new KsqlException(
"",
new KsqlTopicAuthorizationException(AclOperation.READ, Collections.singleton(TOPIC_NAME))));

final KsqlErrorMessage responseEntity = (KsqlErrorMessage) response.getEntity();
final KsqlErrorMessage expectedEntity = (KsqlErrorMessage) expected.getEntity();
assertEquals(response.getStatus(), expected.getStatus());
final KsqlErrorMessage expectedEntity = (KsqlErrorMessage) AUTHORIZATION_ERROR_RESPONSE.getEntity();
assertEquals(response.getStatus(), AUTHORIZATION_ERROR_RESPONSE.getStatus());
assertEquals(responseEntity.getMessage(), expectedEntity.getMessage());
}

Expand All @@ -606,11 +607,8 @@ public void shouldReturnForbiddenKafkaAccessIfPrintTopicKsqlTopicAuthorizationEx
new KsqlRequest(PRINT_TOPIC, Collections.emptyMap(), null)
);

final Response expected = Errors.accessDeniedFromKafka(
new KsqlTopicAuthorizationException(AclOperation.READ, Collections.singleton(TOPIC_NAME)));

assertEquals(response.getStatus(), expected.getStatus());
assertEquals(response.getEntity(), expected.getEntity());
assertEquals(response.getStatus(), AUTHORIZATION_ERROR_RESPONSE.getStatus());
assertEquals(response.getEntity(), AUTHORIZATION_ERROR_RESPONSE.getEntity());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient;
import io.confluent.ksql.engine.KsqlEngine;
import io.confluent.ksql.exception.KsqlTopicAuthorizationException;
import io.confluent.ksql.json.JsonMapper;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.Relation;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.rest.DefaultErrorsImpl;
import io.confluent.ksql.rest.Errors;
import io.confluent.ksql.rest.entity.KsqlErrorMessage;
import io.confluent.ksql.rest.entity.KsqlRequest;
import io.confluent.ksql.rest.entity.Versions;
Expand Down Expand Up @@ -81,6 +83,8 @@
import javax.websocket.Session;
import javax.ws.rs.core.Response;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.common.acl.AclOperation;
import org.apache.kafka.common.errors.TopicAuthorizationException;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -148,12 +152,16 @@ public class WSQueryEndpointTest {
@Mock
private ServiceContext serviceContext;
@Mock
private MetaStore metaStore;
@Mock
private UserServiceContextFactory serviceContextFactory;
@Mock
private ServerState serverState;
@Mock
private KsqlUserContextProvider userContextProvider;
@Mock
private Errors errorsHandler;
@Mock
private DefaultServiceContextFactory defaultServiceContextProvider;
@Captor
private ArgumentCaptor<CloseReason> closeReasonCaptor;
Expand All @@ -178,6 +186,7 @@ public void setUp() {
when(defaultServiceContextProvider.create(any(), any())).thenReturn(serviceContext);
when(serviceContext.getTopicClient()).thenReturn(topicClient);
when(serverState.checkReady()).thenReturn(Optional.empty());
when(ksqlEngine.getMetaStore()).thenReturn(metaStore);
givenRequest(VALID_REQUEST);

wsQueryEndpoint = new WSQueryEndpoint(
Expand All @@ -193,7 +202,7 @@ public void setUp() {
activenessRegistrar,
COMMAND_QUEUE_CATCHUP_TIMEOUT,
Optional.of(authorizationValidator),
new DefaultErrorsImpl(),
errorsHandler,
securityExtension,
serviceContextFactory,
defaultServiceContextProvider,
Expand Down Expand Up @@ -385,6 +394,26 @@ public void shouldHandlePushQuery() {
any());
}

@Test
public void shouldReturnErrorMessageWhenTopicAuthorizationException() throws Exception {
// Given:
final String errorMessage = "authorization error";
givenRequestIs(query);
when(errorsHandler.webSocketKafkaAuthorizationErrorMessage(any(TopicAuthorizationException.class)))
.thenReturn(errorMessage);
doThrow(new KsqlTopicAuthorizationException(AclOperation.CREATE, Collections.singleton("topic")))
.when(authorizationValidator).checkAuthorization(serviceContext, metaStore, query);

// When:
wsQueryEndpoint.onOpen(session, null);

// Then:
verifyClosedContainingReason(
errorMessage,
CloseCodes.CANNOT_ACCEPT
);
}

@Test
public void shouldHandlePullQuery() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,10 @@

package io.confluent.ksql.rest;

import javax.ws.rs.core.Response;

public class DefaultErrorsImpl implements Errors {

@Override
public Response accessDeniedFromKafkaResponse(final Throwable t) {
return Errors.accessDeniedFromKafka(t);
}
public class DefaultErrorMessages implements ErrorMessages {

@Override
public String webSocketAuthorizationErrorMessage(final Throwable t) {
return t.getMessage();
public String kafkaAuthorizationErrorMessage(final Exception e) {
return e.getMessage();
}
}
Loading

0 comments on commit 286a5f4

Please sign in to comment.