Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

init master key automatically #1075

Merged
merged 10 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ public class CommonValue {
public static final String UNDEPLOYED = "undeployed";
public static final String NOT_FOUND = "not_found";

public static final String MASTER_KEY = "master_key";
public static final String CREATE_TIME_FIELD = "create_time";

public static final String BOX_TYPE_KEY = "box_type";
//hot node
public static String HOT_BOX_TYPE = "hot";
Expand All @@ -37,6 +40,8 @@ public class CommonValue {
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down Expand Up @@ -301,4 +306,19 @@ public class CommonValue {
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";


public static final String ML_CONFIG_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_CONFIG_INDEX_SCHEMA_VERSION
+ "},\n"
+ " \"properties\": {\n"
+ " \""
+ MASTER_KEY
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ CREATE_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
Expand All @@ -18,29 +19,35 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.encryptor.Encryptor;

import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

/**
* This is the interface to all ml algorithms.
*/
@Log4j2
public class MLEngine {

public static final String REGISTER_MODEL_FOLDER = "register";
public static final String DEPLOY_MODEL_FOLDER = "deploy";
private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";

private final Path mlUserConfigPath;
@Getter
private final Path mlConfigPath;

@Getter
private final Path mlCachePath;
private final Path mlModelsCachePath;

private final Encryptor encryptor;
private Encryptor encryptor;

public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
mlCachePath = opensearchDataFolder.resolve("ml_cache");
mlModelsCachePath = mlCachePath.resolve("models_cache");
public MLEngine(Path opensearchDataFolder, Path opensearchConfigFolder, Encryptor encryptor) {
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
this.mlUserConfigPath = opensearchConfigFolder.resolve("opensearch-ml");
this.mlConfigPath = mlCachePath.resolve("config");
this.encryptor = encryptor;
}

Expand Down Expand Up @@ -195,7 +202,4 @@ public String encrypt(String credential) {
return encryptor.encrypt(credential);
}

public void setMasterKey(String masterKey) {
encryptor.setMasterKey(masterKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
} else {
throw new IllegalArgumentException("Wrong input type");
}
Map<String, String> escapedParameters = new HashMap<>();
inputData.getParameters().entrySet().forEach(entry -> {
escapedParameters.put(entry.getKey(), escapeJava(entry.getValue()));
});
inputData.setParameters(escapedParameters);
return inputData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.engine.encryptor;

import java.security.SecureRandom;
import java.util.Base64;

public interface Encryptor {

/**
Expand All @@ -29,4 +32,7 @@ public interface Encryptor {
* @param masterKey masterKey to be set.
*/
void setMasterKey(String masterKey);

String generateMasterKey();

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Base64;

public class EncryptorImpl implements Encryptor {

private volatile String masterKey;

public EncryptorImpl(String masterKey) {
this.masterKey = masterKey;
public EncryptorImpl() {
this.masterKey = null;
}

@Override
Expand Down Expand Up @@ -60,14 +61,17 @@ public String decrypt(String encryptedText) {
return new String(decryptedResult.getResult());
}

@Override
public String generateMasterKey() {
byte[] keyBytes = new byte[16];
new SecureRandom().nextBytes(keyBytes);
String base64Key = Base64.getEncoder().encodeToString(keyBytes);
return base64Key;
}

private void checkMasterKey() {
if (masterKey == "0000000000000000" || masterKey == null) {
throw new MetaDataException("Please provide a masterKey for credential encryption! Example: PUT /_cluster/settings\n" +
"{\n" +
" \"persistent\" : {\n" +
" \"plugins.ml_commons.encryption.master_key\" : \"1234567x\" \n" +
" }\n" +
"}");
if (masterKey == null) {
throw new MetaDataException("Encryption key not created yet.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ public class MLEngineTest {

@Before
public void setUp() {
Encryptor encryptor = new EncryptorImpl("0000000000000000");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
Encryptor encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000000");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -128,7 +129,8 @@ public class MetricsCorrelationTest {
ActionListener<MLDeployModelResponse> mlDeployModelResponseActionListener;
private MetricsCorrelation metricsCorrelation;
private MetricsCorrelationInput input, extendedInput;
private Path djlCachePath;
private Path mlCachePath;
private Path mlConfigPath;
private MLModel model;

private MetricsCorrelationModelConfig modelConfig;
Expand All @@ -144,7 +146,6 @@ public class MetricsCorrelationTest {

Map<String, Object> params = new HashMap<>();

@Mock
private Encryptor encryptor;

public MetricsCorrelationTest() {
Expand All @@ -155,8 +156,11 @@ public void setUp() throws IOException, URISyntaxException {

System.setProperty("testMode", "true");

djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlEngine = new MLEngine(djlCachePath, encryptor);
mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlConfigPath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor);
modelConfig = MetricsCorrelationModelConfig.builder()
.modelType(MetricsCorrelation.MODEL_TYPE)
.allConfig(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public class AwsConnectorExecutorTest {
@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc()
processInput_TextDocsInputDataSet_PreprocessFunction(
"{\"input\": ${parameters.input}}",
"{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }",
"[\\\"test_value1\\\",\\\"test_value2\\\"]");
"[\"test_value1\",\"test_value2\"]");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public class RemoteModelTest {
public void setUp() {
MockitoAnnotations.openMocks(this);
remoteModel = new RemoteModel();
encryptor = spy(new EncryptorImpl("0000000000000001"));
encryptor = spy(new EncryptorImpl());
encryptor.setMasterKey("0000000000000001");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;

import java.io.IOException;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -50,12 +52,16 @@ public class ModelHelperTest {
@Mock
ActionListener<MLRegisterModelInput> registerModelListener;

Encryptor encryptor;

@Before
public void setup() throws URISyntaxException {
MockitoAnnotations.openMocks(this);
modelFormat = MLModelFormat.TORCH_SCRIPT;
modelId = "model_id";
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), null);
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), Path.of("/tmp/test_config"), encryptor);
modelHelper = new ModelHelper(mlEngine);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.utils.FileUtils;

import java.io.File;
Expand Down Expand Up @@ -62,16 +63,20 @@ public class TextEmbeddingModelTest {
private ModelHelper modelHelper;
private Map<String, Object> params;
private TextEmbeddingModel textEmbeddingModel;
private Path djlCachePath;
private Path mlCachePath;
private Path mlConfigPath;
private TextDocsInputDataSet inputDataSet;
private int dimension = 384;
private MLEngine mlEngine;
private Encryptor encryptor;

@Before
public void setUp() throws URISyntaxException {
djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlEngine = new MLEngine(djlCachePath, encryptor);
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
mlConfigPath = Path.of("/tmp/ml_config" + UUID.randomUUID());
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor);
modelId = "test_model_id";
modelName = "test_model_name";
functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -329,7 +334,7 @@ public void predict_BeforeInitingModel() {

@After
public void tearDown() {
FileUtils.deleteFileQuietly(djlCachePath);
FileUtils.deleteFileQuietly(mlCachePath);
}

private int findSentenceEmbeddingPosition(ModelTensors modelTensors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.common.component.LifecycleListener;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.threadpool.Scheduler;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -30,6 +31,7 @@ public class MLCommonsClusterManagerEventListener implements LocalNodeClusterMan
private Scheduler.Cancellable syncModelRoutingCron;
private DiscoveryNodeHelper nodeHelper;
private final MLIndicesHandler mlIndicesHandler;
private final Encryptor encryptor;

private volatile Integer jobInterval;

Expand All @@ -39,14 +41,16 @@ public MLCommonsClusterManagerEventListener(
Settings settings,
ThreadPool threadPool,
DiscoveryNodeHelper nodeHelper,
MLIndicesHandler mlIndicesHandler
MLIndicesHandler mlIndicesHandler,
Encryptor encryptor
) {
this.clusterService = clusterService;
this.client = client;
this.threadPool = threadPool;
this.clusterService.addListener(this);
this.nodeHelper = nodeHelper;
this.mlIndicesHandler = mlIndicesHandler;
this.encryptor = encryptor;

this.jobInterval = ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS, it -> {
Expand All @@ -67,7 +71,7 @@ private void startSyncModelRoutingCron() {
if (jobInterval > 0) {
syncModelRoutingCron = threadPool
.scheduleWithFixedDelay(
new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler),
new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor),
TimeValue.timeValueSeconds(jobInterval),
GENERAL_THREAD_POOL
);
Expand Down
Loading