Skip to content

Commit

Permalink
Clean up REST API (Part 1) (opensearch-project#2900)
Browse files Browse the repository at this point in the history
The aim of this PR is to start cleaning code in REST API since with the
current implementation is difficult to understand and support.

Changes:

- Implemented new `RequestConetnValidator` class which uses the same
validation logic as `AbstractConfigurationValidator`
- Removed all redundant `AbstractConfigurationValidator` extensions

Signed-off-by: Andrey Pleskach <[email protected]>
(cherry picked from commit 6bac470)
Signed-off-by: Andrey Pleskach <[email protected]>

Signed-off-by: Andrey Pleskach <[email protected]>
  • Loading branch information
willyborankin committed Aug 21, 2023
1 parent 478b27d commit 36fc3c7
Show file tree
Hide file tree
Showing 46 changed files with 1,405 additions and 1,308 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
package org.opensearch.security.api;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import org.apache.hc.core5.http.HttpStatus;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.Objects;

Expand All @@ -30,7 +29,6 @@
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext.StoredContext;
import org.opensearch.common.xcontent.XContentHelper;
Expand All @@ -47,14 +45,12 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.security.DefaultObjectMapper;
import org.opensearch.security.action.configupdate.ConfigUpdateAction;
import org.opensearch.security.action.configupdate.ConfigUpdateNodeResponse;
import org.opensearch.security.action.configupdate.ConfigUpdateRequest;
import org.opensearch.security.action.configupdate.ConfigUpdateResponse;
import org.opensearch.security.auditlog.AuditLog;
import org.opensearch.security.configuration.AdminDNs;
import org.opensearch.security.configuration.ConfigurationRepository;
import org.opensearch.security.dlic.rest.validation.AbstractConfigurationValidator;
import org.opensearch.security.dlic.rest.validation.AbstractConfigurationValidator.ErrorType;
import org.opensearch.security.dlic.rest.validation.RequestContentValidator;
import org.opensearch.security.privileges.PrivilegesEvaluator;
import org.opensearch.security.securityconf.DynamicConfigFactory;
import org.opensearch.security.securityconf.impl.CType;
Expand All @@ -68,7 +64,7 @@

public abstract class AbstractApiAction extends BaseRestHandler {

protected final Logger log = LogManager.getLogger(this.getClass());
private final static Logger LOGGER = LogManager.getLogger(AbstractApiAction.class);

protected final ConfigurationRepository cl;
protected final ClusterService cs;
Expand Down Expand Up @@ -119,7 +115,7 @@ protected AbstractApiAction(
this.auditLog = auditLog;
}

protected abstract AbstractConfigurationValidator getValidator(RestRequest request, BytesReference ref, Object... params);
protected abstract RequestContentValidator createValidator(final Object... params);

protected abstract String getResourceName();

Expand All @@ -128,25 +124,22 @@ protected AbstractApiAction(
protected void handleApiRequest(final RestChannel channel, final RestRequest request, final Client client) throws IOException {

try {
// validate additional settings, if any
AbstractConfigurationValidator validator = getValidator(request, request.content());
if (!validator.validate()) {
request.params().clear();
badRequestResponse(channel, validator);
return;
}
switch (request.method()) {
case DELETE:
handleDelete(channel, request, client, validator.getContentAsNode());
handleDelete(channel, request, client, null);
break;
case POST:
handlePost(channel, request, client, validator.getContentAsNode());
createValidator().validate(request)
.valid(jsonContent -> handlePost(channel, request, client, jsonContent))
.error(toXContent -> requestContentInvalid(request, channel, toXContent));
break;
case PUT:
handlePut(channel, request, client, validator.getContentAsNode());
createValidator().validate(request)
.valid(jsonContent -> handlePut(channel, request, client, jsonContent))
.error(toXContent -> requestContentInvalid(request, channel, toXContent));
break;
case GET:
handleGet(channel, request, client, validator.getContentAsNode());
handleGet(channel, request, client, null);
break;
default:
throw new IllegalArgumentException(request.method() + " not supported");
Expand All @@ -160,6 +153,11 @@ protected void handleApiRequest(final RestChannel channel, final RestRequest req
}
}

protected void requestContentInvalid(final RestRequest request, final RestChannel channel, final ToXContent toXContent) {
request.params().clear();
badRequestResponse(channel, toXContent);
}

protected void handleDelete(final RestChannel channel, final RestRequest request, final Client client, final JsonNode content)
throws IOException {
final String name = request.param("name");
Expand Down Expand Up @@ -200,16 +198,12 @@ public void onResponse(IndexResponse response) {

protected void handlePut(final RestChannel channel, final RestRequest request, final Client client, final JsonNode content)
throws IOException {

final String name = request.param("name");

if (name == null || name.length() == 0) {
badRequestResponse(channel, "No " + getResourceName() + " specified.");
return;
}

final SecurityDynamicConfiguration<?> existingConfiguration = load(getConfigName(), false);

if (existingConfiguration.getSeqNo() < 0) {
forbidden(
channel,
Expand All @@ -227,8 +221,8 @@ protected void handlePut(final RestChannel channel, final RestRequest request, f
return;
}

if (log.isTraceEnabled() && content != null) {
log.trace(content.toString());
if (LOGGER.isTraceEnabled() && content != null) {
LOGGER.trace(content.toString());
}

boolean existed = existingConfiguration.exists(name);
Expand Down Expand Up @@ -274,28 +268,21 @@ protected boolean hasPermissionsToCreate(
}

protected void handleGet(final RestChannel channel, RestRequest request, Client client, final JsonNode content) throws IOException {

final String resourcename = request.param("name");

final SecurityDynamicConfiguration<?> configuration = load(getConfigName(), true);
filter(configuration);

// no specific resource requested, return complete config
if (resourcename == null || resourcename.length() == 0) {

successResponse(channel, configuration);
return;
}

if (!configuration.exists(resourcename)) {
notFound(channel, "Resource '" + resourcename + "' not found.");
return;
}

configuration.removeOthers(resourcename);
successResponse(channel, configuration);

return;
}

protected final SecurityDynamicConfiguration<?> load(final CType config, boolean logComplianceEvent) {
Expand All @@ -305,15 +292,6 @@ protected final SecurityDynamicConfiguration<?> load(final CType config, boolean
return DynamicConfigFactory.addStatics(loaded);
}

protected final SecurityDynamicConfiguration<?> load(final CType config, boolean logComplianceEvent, boolean acceptInvalid) {
SecurityDynamicConfiguration<?> loaded = cl.getConfigurationsFromIndex(
Collections.singleton(config),
logComplianceEvent,
acceptInvalid
).get(config).deepClone();
return DynamicConfigFactory.addStatics(loaded);
}

protected boolean ensureIndexExists() {
if (!cs.state().metadata().hasConcreteIndex(this.securityIndexName)) {
return false;
Expand Down Expand Up @@ -434,7 +412,7 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie

// check if .opendistro_security index has been initialized
if (!ensureIndexExists()) {
return channel -> internalErrorResponse(channel, ErrorType.SECURITY_NOT_INITIALIZED.getMessage());
return channel -> internalErrorResponse(channel, RequestContentValidator.ValidationError.SECURITY_NOT_INITIALIZED.message());
}

// check if request is authorized
Expand All @@ -443,7 +421,7 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie
final User user = (User) threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
final String userName = user == null ? null : user.getName();
if (authError != null) {
log.error("No permission to access REST API: " + authError);
LOGGER.error("No permission to access REST API: " + authError);
auditLog.logMissingPrivileges(authError, userName, request);
// for rest request
request.params().clear();
Expand All @@ -465,7 +443,7 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie

handleApiRequest(channel, request, client);
} catch (Exception e) {
log.error("Error processing request {}", request, e);
LOGGER.error("Error processing request {}", request, e);
try {
channel.sendResponse(new BytesRestResponse(channel, e));
} catch (IOException ioe) {
Expand All @@ -475,37 +453,6 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie
});
}

protected boolean checkConfigUpdateResponse(final ConfigUpdateResponse response) {

final int nodeCount = cs.state().getNodes().getNodes().size();
final int expectedConfigCount = 1;

boolean success = response.getNodes().size() == nodeCount;
if (!success) {
log.error("Expected " + nodeCount + " nodes to return response, but got only " + response.getNodes().size());
}

for (final String nodeId : response.getNodesMap().keySet()) {
final ConfigUpdateNodeResponse node = response.getNodesMap().get(nodeId);
final boolean successNode = node.getUpdatedConfigTypes() != null && node.getUpdatedConfigTypes().length == expectedConfigCount;

if (!successNode) {
log.error(
"Expected "
+ expectedConfigCount
+ " config types for node "
+ nodeId
+ " but got only "
+ Arrays.toString(node.getUpdatedConfigTypes())
);
}

success = success && successNode;
}

return success;
}

protected static XContentBuilder convertToJson(RestChannel channel, ToXContent toxContent) {
try {
XContentBuilder builder = channel.newBuilder();
Expand Down Expand Up @@ -541,12 +488,12 @@ protected void successResponse(RestChannel channel) {
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
} catch (IOException e) {
internalErrorResponse(channel, "Unable to fetch license: " + e.getMessage());
log.error("Cannot fetch convert license to XContent due to", e);
LOGGER.error("Cannot fetch convert license to XContent due to", e);
}
}

protected void badRequestResponse(RestChannel channel, AbstractConfigurationValidator validator) {
channel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, validator.errorsAsXContent(channel)));
protected void badRequestResponse(RestChannel channel, ToXContent validationResult) {
channel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, convertToJson(channel, validationResult)));
}

protected void successResponse(RestChannel channel, String message) {
Expand All @@ -573,10 +520,6 @@ protected void internalErrorResponse(RestChannel channel, String message) {
response(channel, RestStatus.INTERNAL_SERVER_ERROR, message);
}

protected void unprocessable(RestChannel channel, String message) {
response(channel, RestStatus.UNPROCESSABLE_ENTITY, message);
}

protected void conflict(RestChannel channel, String message) {
response(channel, RestStatus.CONFLICT, message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@
import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.bouncycastle.crypto.generators.OpenBSDBCrypt;

import org.opensearch.action.index.IndexResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.common.util.concurrent.ThreadContext;
Expand All @@ -38,8 +42,8 @@
import org.opensearch.security.auditlog.AuditLog;
import org.opensearch.security.configuration.AdminDNs;
import org.opensearch.security.configuration.ConfigurationRepository;
import org.opensearch.security.dlic.rest.validation.AbstractConfigurationValidator;
import org.opensearch.security.dlic.rest.validation.AccountValidator;
import org.opensearch.security.dlic.rest.validation.RequestContentValidator;
import org.opensearch.security.dlic.rest.validation.RequestContentValidator.DataType;
import org.opensearch.security.privileges.PrivilegesEvaluator;
import org.opensearch.security.securityconf.Hashed;
import org.opensearch.security.securityconf.impl.CType;
Expand All @@ -58,6 +62,9 @@
* Currently this action serves GET and PUT request for /_opendistro/_security/api/account endpoint
*/
public class AccountApiAction extends AbstractApiAction {

private final static Logger LOGGER = LogManager.getLogger(AccountApiAction.class);

private static final String RESOURCE_NAME = "account";
private static final List<Route> routes = addRoutesPrefix(
ImmutableList.of(new Route(Method.GET, "/account"), new Route(Method.PUT, "/account"))
Expand Down Expand Up @@ -154,7 +161,7 @@ protected void handleGet(RestChannel channel, RestRequest request, Client client

response = new BytesRestResponse(RestStatus.OK, builder);
} catch (final Exception exception) {
log.error(exception.toString());
LOGGER.error(exception.toString());

builder.startObject().field("error", exception.toString()).endObject();

Expand Down Expand Up @@ -241,9 +248,29 @@ public void onResponse(IndexResponse response) {
}

@Override
protected AbstractConfigurationValidator getValidator(RestRequest request, BytesReference ref, Object... params) {
protected RequestContentValidator createValidator(final Object... params) {
final User user = threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
return new AccountValidator(request, ref, this.settings, user.getName());
return RequestContentValidator.of(new RequestContentValidator.ValidationContext() {
@Override
public Object[] params() {
return new Object[] { user.getName() };
}

@Override
public Settings settings() {
return settings;
}

@Override
public Set<String> mandatoryKeys() {
return ImmutableSet.of("current_password");
}

@Override
public Map<String, RequestContentValidator.DataType> allowedKeys() {
return ImmutableMap.of("hash", DataType.STRING, "password", DataType.STRING, "current_password", DataType.STRING);
}
});
}

@Override
Expand Down
Loading

0 comments on commit 36fc3c7

Please sign in to comment.