From 84c2ce8b1f529068772fca3435e8dfd1d53f41b8 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 13:26:23 -0700 Subject: [PATCH 01/10] init master key automatically Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/common/CommonValue.java | 20 ++++++ ml-algorithms/build.gradle | 1 + .../org/opensearch/ml/engine/MLEngine.java | 66 ++++++++++++++++--- .../ml/engine/encryptor/Encryptor.java | 6 ++ .../ml/engine/encryptor/EncryptorImpl.java | 22 ++++--- .../opensearch/ml/engine/MLEngineTest.java | 5 +- .../MetricsCorrelationTest.java | 12 ++-- .../remote/AwsConnectorExecutorTest.java | 3 +- .../algorithms/remote/RemoteModelTest.java | 3 +- .../text_embedding/ModelHelperTest.java | 8 ++- .../TextEmbeddingModelTest.java | 13 ++-- .../MLCommonsClusterManagerEventListener.java | 8 ++- .../opensearch/ml/cluster/MLSyncUpCron.java | 49 +++++++++++++- .../org/opensearch/ml/indices/MLIndex.java | 6 +- .../ml/indices/MLIndicesHandler.java | 4 ++ .../opensearch/ml/model/MLModelManager.java | 8 --- .../ml/plugin/MachineLearningPlugin.java | 20 ++++-- .../ml/settings/MLCommonsSettings.java | 3 - .../TransportDeployModelActionTests.java | 5 +- .../ml/cluster/MLSyncUpCronTests.java | 63 +++++++++++++++++- .../ml/model/MLModelManagerTests.java | 10 ++- .../ml/task/MLExecuteTaskRunnerTests.java | 9 ++- .../ml/task/MLPredictTaskRunnerTests.java | 5 +- .../MLTrainAndPredictTaskRunnerTests.java | 5 +- .../ml/task/MLTrainingTaskRunnerTests.java | 9 ++- 25 files changed, 294 insertions(+), 69 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 9fc2294d3b..16554933b5 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -24,6 +24,9 @@ public class CommonValue { public static final String UNDEPLOYED = "undeployed"; public static final String NOT_FOUND = "not_found"; + public static final String MASTER_KEY = "master_key"; + public static final String CREATE_TIME_FIELD = "create_time"; + public static final String BOX_TYPE_KEY = "box_type"; //hot node public static String HOT_BOX_TYPE = "hot"; @@ -37,6 +40,8 @@ public class CommonValue { public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1; + public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; + public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER + "\": {\n" @@ -301,4 +306,19 @@ public class CommonValue { + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + " }\n" + "}"; + + + public static final String ML_CONFIG_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONFIG_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MASTER_KEY + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; } diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 8c6a9ca0db..7b66b4e00e 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -64,6 +64,7 @@ dependencies { implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.0' implementation 'com.jayway.jsonpath:json-path:2.8.0' implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'org.yaml', name: 'snakeyaml', version: '2.0' } configurations.all { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index b0ed953bd1..41f6d794c9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -5,12 +5,17 @@ package org.opensearch.ml.engine; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.stream.JsonReader; import lombok.Getter; +import lombok.extern.log4j.Log4j2; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.MLInput; @@ -18,30 +23,78 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.engine.encryptor.Encryptor; - +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.yaml.snakeyaml.DumperOptions; +import org.yaml.snakeyaml.Yaml; + +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.FileWriter; +import java.nio.file.Files; import java.nio.file.Path; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.security.SecureRandom; +import java.util.Base64; +import java.util.HashMap; import java.util.Locale; import java.util.Map; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; + /** * This is the interface to all ml algorithms. */ +@Log4j2 public class MLEngine { public static final String REGISTER_MODEL_FOLDER = "register"; public static final String DEPLOY_MODEL_FOLDER = "deploy"; private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models"; + private final Path mlUserConfigPath; + @Getter + private final Path mlConfigPath; + @Getter private final Path mlCachePath; private final Path mlModelsCachePath; - private final Encryptor encryptor; + private Encryptor encryptor; - public MLEngine(Path opensearchDataFolder, Encryptor encryptor) { - mlCachePath = opensearchDataFolder.resolve("ml_cache"); - mlModelsCachePath = mlCachePath.resolve("models_cache"); + public MLEngine(Path opensearchDataFolder, Path opensearchConfigFolder, Encryptor encryptor) { + this.mlCachePath = opensearchDataFolder.resolve("ml_cache"); + this.mlModelsCachePath = mlCachePath.resolve("models_cache"); + this.mlUserConfigPath = opensearchConfigFolder.resolve("opensearch-ml"); + this.mlConfigPath = mlCachePath.resolve("config"); this.encryptor = encryptor; + initMasterKey(); + } + + private synchronized void initMasterKey() { + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Path userConfigFilePath = mlUserConfigPath.resolve("security_config.json"); + Map config = null; + if (Files.exists(userConfigFilePath)) { + try (FileInputStream fis = new FileInputStream(userConfigFilePath.toFile());) { + Yaml yaml = new Yaml(); + config = yaml.load(fis); + } + } + if (config == null) { + config = new HashMap<>(); + } + + if (config.containsKey(MASTER_KEY)) { + encryptor.setMasterKey(config.get(MASTER_KEY)); + } + return null; + }); + } catch (Exception e) { + log.error("Failed to save master key", e); + throw new MLException(e); + } } public String getPrebuiltModelMetaListPath() { @@ -195,7 +248,4 @@ public String encrypt(String credential) { return encryptor.encrypt(credential); } - public void setMasterKey(String masterKey) { - encryptor.setMasterKey(masterKey); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java index df8e43d887..7bbe58cec5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java @@ -5,6 +5,9 @@ package org.opensearch.ml.engine.encryptor; +import java.security.SecureRandom; +import java.util.Base64; + public interface Encryptor { /** @@ -29,4 +32,7 @@ public interface Encryptor { * @param masterKey masterKey to be set. */ void setMasterKey(String masterKey); + + String generateMasterKey(); + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 3e9d9175b4..16abde0a24 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -13,14 +13,15 @@ import javax.crypto.spec.SecretKeySpec; import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; import java.util.Base64; public class EncryptorImpl implements Encryptor { private volatile String masterKey; - public EncryptorImpl(String masterKey) { - this.masterKey = masterKey; + public EncryptorImpl() { + this.masterKey = null; } @Override @@ -60,14 +61,17 @@ public String decrypt(String encryptedText) { return new String(decryptedResult.getResult()); } + @Override + public String generateMasterKey() { + byte[] keyBytes = new byte[16]; + new SecureRandom().nextBytes(keyBytes); + String base64Key = Base64.getEncoder().encodeToString(keyBytes); + return base64Key; + } + private void checkMasterKey() { - if (masterKey == "0000000000000000" || masterKey == null) { - throw new MetaDataException("Please provide a masterKey for credential encryption! Example: PUT /_cluster/settings\n" + - "{\n" + - " \"persistent\" : {\n" + - " \"plugins.ml_commons.encryption.master_key\" : \"1234567x\" \n" + - " }\n" + - "}"); + if (masterKey == null) { + throw new MetaDataException("Encryption key not created yet."); } } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index b761cddd90..24115d4f46 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -53,8 +53,9 @@ public class MLEngineTest { @Before public void setUp() { - Encryptor encryptor = new EncryptorImpl("0000000000000000"); - mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); + Encryptor encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000000"); + mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), Path.of("/tmp/test" + UUID.randomUUID()), encryptor); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index 32d1df3a01..e860a811e2 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -58,6 +58,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -128,7 +129,8 @@ public class MetricsCorrelationTest { ActionListener mlDeployModelResponseActionListener; private MetricsCorrelation metricsCorrelation; private MetricsCorrelationInput input, extendedInput; - private Path djlCachePath; + private Path mlCachePath; + private Path mlConfigPath; private MLModel model; private MetricsCorrelationModelConfig modelConfig; @@ -144,7 +146,6 @@ public class MetricsCorrelationTest { Map params = new HashMap<>(); - @Mock private Encryptor encryptor; public MetricsCorrelationTest() { @@ -155,8 +156,11 @@ public void setUp() throws IOException, URISyntaxException { System.setProperty("testMode", "true"); - djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); - mlEngine = new MLEngine(djlCachePath, encryptor); + mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); + mlConfigPath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000001"); + mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor); modelConfig = MetricsCorrelationModelConfig.builder() .modelType(MetricsCorrelation.MODEL_TYPE) .allConfig(null) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 8d6130566a..49c83d40c5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -67,7 +67,8 @@ public class AwsConnectorExecutorTest { @Before public void setUp() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000001"); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index bef3e1da71..b39ae6bc9e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -48,7 +48,8 @@ public class RemoteModelTest { public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); - encryptor = spy(new EncryptorImpl("0000000000000001")); + encryptor = spy(new EncryptorImpl()); + encryptor.setMasterKey("0000000000000001"); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java index 3d2043fc24..8f4dab8963 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java @@ -19,6 +19,8 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import java.io.IOException; import java.net.URISyntaxException; @@ -50,12 +52,16 @@ public class ModelHelperTest { @Mock ActionListener registerModelListener; + Encryptor encryptor; + @Before public void setup() throws URISyntaxException { MockitoAnnotations.openMocks(this); modelFormat = MLModelFormat.TORCH_SCRIPT; modelId = "model_id"; - mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), null); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), Path.of("/tmp/test_config"), encryptor); modelHelper = new ModelHelper(mlEngine); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java index e8c46d37a9..15d75f4bca 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java @@ -27,6 +27,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.utils.FileUtils; import java.io.File; @@ -62,7 +63,8 @@ public class TextEmbeddingModelTest { private ModelHelper modelHelper; private Map params; private TextEmbeddingModel textEmbeddingModel; - private Path djlCachePath; + private Path mlCachePath; + private Path mlConfigPath; private TextDocsInputDataSet inputDataSet; private int dimension = 384; private MLEngine mlEngine; @@ -70,8 +72,11 @@ public class TextEmbeddingModelTest { @Before public void setUp() throws URISyntaxException { - djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); - mlEngine = new MLEngine(djlCachePath, encryptor); + mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); + mlConfigPath = Path.of("/tmp/ml_config" + UUID.randomUUID()); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000001"); + mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor); modelId = "test_model_id"; modelName = "test_model_name"; functionName = FunctionName.TEXT_EMBEDDING; @@ -329,7 +334,7 @@ public void predict_BeforeInitingModel() { @After public void tearDown() { - FileUtils.deleteFileQuietly(djlCachePath); + FileUtils.deleteFileQuietly(mlCachePath); } private int findSentenceEmbeddingPosition(ModelTensors modelTensors) { diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index 050d475dd2..f4abfea4df 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -14,6 +14,7 @@ import org.opensearch.common.component.LifecycleListener; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; @@ -30,6 +31,7 @@ public class MLCommonsClusterManagerEventListener implements LocalNodeClusterMan private Scheduler.Cancellable syncModelRoutingCron; private DiscoveryNodeHelper nodeHelper; private final MLIndicesHandler mlIndicesHandler; + private final Encryptor encryptor; private volatile Integer jobInterval; @@ -39,7 +41,8 @@ public MLCommonsClusterManagerEventListener( Settings settings, ThreadPool threadPool, DiscoveryNodeHelper nodeHelper, - MLIndicesHandler mlIndicesHandler + MLIndicesHandler mlIndicesHandler, + Encryptor encryptor ) { this.clusterService = clusterService; this.client = client; @@ -47,6 +50,7 @@ public MLCommonsClusterManagerEventListener( this.clusterService.addListener(this); this.nodeHelper = nodeHelper; this.mlIndicesHandler = mlIndicesHandler; + this.encryptor = encryptor; this.jobInterval = ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS, it -> { @@ -67,7 +71,7 @@ private void startSyncModelRoutingCron() { if (jobInterval > 0) { syncModelRoutingCron = threadPool .scheduleWithFixedDelay( - new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler), + new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor), TimeValue.timeValueSeconds(jobInterval), GENERAL_THREAD_POOL ); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 0ba118bd29..95c5ec037c 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -5,6 +5,9 @@ package org.opensearch.ml.cluster; +import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import java.time.Instant; @@ -20,7 +23,10 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -33,6 +39,7 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -50,19 +57,30 @@ public class MLSyncUpCron implements Runnable { private ClusterService clusterService; private DiscoveryNodeHelper nodeHelper; private MLIndicesHandler mlIndicesHandler; + private Encryptor encryptor; + private volatile Boolean mlConfigInited; @VisibleForTesting Semaphore updateModelStateSemaphore; - public MLSyncUpCron(Client client, ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler) { + public MLSyncUpCron( + Client client, + ClusterService clusterService, + DiscoveryNodeHelper nodeHelper, + MLIndicesHandler mlIndicesHandler, + Encryptor encryptor + ) { this.client = client; this.clusterService = clusterService; this.nodeHelper = nodeHelper; this.mlIndicesHandler = mlIndicesHandler; this.updateModelStateSemaphore = new Semaphore(1); + this.mlConfigInited = false; + this.encryptor = encryptor; } @Override public void run() { + initMLConfig(); if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_INDEX)) { // no need to run sync up job if no model index return; @@ -71,6 +89,7 @@ public void run() { DiscoveryNode[] allNodes = nodeHelper.getAllNodes(); MLSyncUpInput gatherInfoInput = MLSyncUpInput.builder().getDeployedModels(true).build(); MLSyncUpNodesRequest gatherInfoRequest = new MLSyncUpNodesRequest(allNodes, gatherInfoInput); + // gather running model/tasks on nodes client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> { List responses = r.getNodes(); @@ -142,6 +161,34 @@ public void run() { }, e -> { log.error("Failed to sync model routing", e); })); } + @VisibleForTesting + void initMLConfig() { + if (mlConfigInited) { + return; + } + mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> { + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (!getResponse.isExists()) { + IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + final String masterKey = encryptor.generateMasterKey(); + indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + log.info("ML configuration initialized successfully"); + encryptor.setMasterKey(masterKey); + mlConfigInited = true; + }, e -> { log.debug("Failed to save ML encryption master key", e); })); + } else { + final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); + encryptor.setMasterKey(masterKey); + mlConfigInited = true; + log.info("ML configuration already initialized, no action needed"); + } + }, e -> { log.debug("Failed to get ML encryption master key", e); })); + }, e -> { log.debug("Failed to init ML config index", e); })); + } + @VisibleForTesting void refreshModelState(Map> modelWorkerNodes, Map> deployingModels) { if (!updateModelStateSemaphore.tryAcquire()) { diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java index 668306b763..b81682f07e 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java @@ -5,6 +5,9 @@ package org.opensearch.ml.indices; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_SCHEMA_VERSION; @@ -22,7 +25,8 @@ public enum MLIndex { MODEL_GROUP(ML_MODEL_GROUP_INDEX, false, ML_MODEL_GROUP_INDEX_MAPPING, ML_MODEL_GROUP_INDEX_SCHEMA_VERSION), MODEL(ML_MODEL_INDEX, false, ML_MODEL_INDEX_MAPPING, ML_MODEL_INDEX_SCHEMA_VERSION), TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), - CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION); + CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION), + CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION); private final String indexName; // whether we use an alias for the index diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java index 3235d27f29..12954a62a2 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java @@ -61,6 +61,10 @@ public void initMLConnectorIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONNECTOR, listener); } + public void initMLConfigIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.CONFIG, listener); + } + public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) { String indexName = index.getIndexName(); String mapping = index.getMapping(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 02d3889464..07720d45f8 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -34,7 +34,6 @@ import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; @@ -154,7 +153,6 @@ public class MLModelManager { private volatile Integer maxModelPerNode; private volatile Integer maxRegisterTasksPerNode; private volatile Integer maxDeployTasksPerNode; - private volatile String masterKey; public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet .of( @@ -208,12 +206,6 @@ public MLModelManager( clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it); - - this.masterKey = ML_COMMONS_MASTER_SECRET_KEY.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> { - masterKey = it; - mlEngine.setMasterKey(masterKey); - }); } public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 6b525d8695..84fecebd31 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -7,8 +7,8 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; +import java.nio.file.Path; import java.util.Collection; import java.util.List; import java.util.Map; @@ -276,17 +276,22 @@ public Collection createComponents( this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; Settings settings = environment.settings(); - String masterKey = ML_COMMONS_MASTER_SECRET_KEY.get(clusterService.getSettings()); - Encryptor encryptor = new EncryptorImpl(masterKey); + Path dataPath = environment.dataFiles()[0]; + Path configFile = environment.configFile(); + System.out.println("----------------ylwwwdebug"); + System.out.println(configFile); + System.out.println("----------------ylwwwdebugend"); - mlEngine = new MLEngine(environment.dataFiles()[0], encryptor); + Encryptor encryptor = new EncryptorImpl(); + + mlEngine = new MLEngine(dataPath, configFile, encryptor); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); modelCacheHelper = new MLModelCacheHelper(clusterService, settings); JvmService jvmService = new JvmService(environment.settings()); OsService osService = new OsService(environment.settings()); MLCircuitBreakerService mlCircuitBreakerService = new MLCircuitBreakerService(jvmService, osService, settings, clusterService) - .init(environment.dataFiles()[0]); + .init(dataPath); Map> stats = new ConcurrentHashMap<>(); // cluster level stats @@ -408,11 +413,13 @@ public Collection createComponents( settings, threadPool, nodeHelper, - mlIndicesHandler + mlIndicesHandler, + encryptor ); return ImmutableList .of( + encryptor, mlEngine, nodeHelper, modelCacheHelper, @@ -601,7 +608,6 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL, MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED, - MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY, MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX ); diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 856b819380..f4bc1da757 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -111,9 +111,6 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_MASTER_SECRET_KEY = Setting - .simpleString("plugins.ml_commons.encryption.master_key", "0000000000000000", Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index 38dc5a605d..193435270b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -142,8 +142,9 @@ public void setup() { clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN))); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - encryptor = new EncryptorImpl("0000000000000000"); - mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); modelHelper = new ModelHelper(mlEngine); when(mlDeployModelRequest.getModelId()).thenReturn("mockModelId"); when(mlDeployModelRequest.getModelNodeIds()).thenReturn(new String[] { "node1" }); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 417e9d2e77..a3497af6e2 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -10,10 +10,14 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; @@ -30,6 +34,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.search.TotalHits; +import org.junit.Assert; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -37,6 +42,8 @@ import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; @@ -57,6 +64,9 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -66,6 +76,7 @@ import org.opensearch.search.suggest.Suggest; import org.opensearch.test.OpenSearchTestCase; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; public class MLSyncUpCronTests extends OpenSearchTestCase { @@ -76,6 +87,8 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { private ClusterService clusterService; @Mock private DiscoveryNodeHelper nodeHelper; + @Mock + private MLIndicesHandler mlIndicesHandler; private DiscoveryNode mlNode1; private DiscoveryNode mlNode2; @@ -85,16 +98,64 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { private final String mlNode2Id = "mlNode2"; private ClusterState testState; + private Encryptor encryptor; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); - syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, null); + encryptor = spy(new EncryptorImpl()); + syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor); testState = setupTestClusterState(); when(clusterService.state()).thenReturn(testState); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + } + + public void testInitMlConfig_MasterKeyNotExist() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + IndexResponse indexResponse = mock(IndexResponse.class); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + + syncUpCron.initMLConfig(); + Assert.assertNotNull(encryptor.encrypt("test")); + syncUpCron.initMLConfig(); + verify(encryptor, times(1)).setMasterKey(any()); + } + + public void testInitMlConfig_MasterKeyExists() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + String masterKey = encryptor.generateMasterKey(); + when(response.getSourceAsMap()) + .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + syncUpCron.initMLConfig(); + Assert.assertNotNull(encryptor.encrypt("test")); + syncUpCron.initMLConfig(); + verify(encryptor, times(1)).setMasterKey(any()); } public void testRun_NoMLModelIndex() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 07c87c65fd..456068391f 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -26,7 +26,6 @@ import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; @@ -173,20 +172,19 @@ public class MLModelManagerTests extends OpenSearchTestCase { public void setup() throws URISyntaxException { String masterKey = "0000000000000001"; MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl(masterKey); - mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey(masterKey); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().put(ML_COMMONS_MAX_MODELS_PER_NODE.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE.getKey(), 10).build(); - settings = Settings.builder().put(ML_COMMONS_MASTER_SECRET_KEY.getKey(), masterKey).build(); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_MAX_MODELS_PER_NODE, ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE, ML_COMMONS_MONITORING_REQUEST_COUNT, - ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, - ML_COMMONS_MASTER_SECRET_KEY + ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE ); clusterService = spy(new ClusterService(settings, clusterSettings, null)); xContentRegistry = NamedXContentRegistry.EMPTY; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 56f6d09fd1..6813000638 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -93,8 +93,13 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000000"); - mlEngine = new MLEngine(Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000000"); + mlEngine = new MLEngine( + Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), + Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), + encryptor + ); when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 1002ee7d8f..689b15626b 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -127,8 +127,9 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000000"); - mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); when(clusterService.localNode()).thenReturn(localNode); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 73df81252c..39add53b8e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -101,8 +101,9 @@ public class MLTrainAndPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { - encryptor = new EncryptorImpl("0000000000000000"); - mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().build(); MockitoAnnotations.openMocks(this); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index f64faf59cc..12a63d4f6f 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -111,8 +111,13 @@ public class MLTrainingTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000000"); - mlEngine = new MLEngine(Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl(); + encryptor.setMasterKey("0000000000000001"); + mlEngine = new MLEngine( + Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), + Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), + encryptor + ); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); when(clusterService.localNode()).thenReturn(localNode); From d002a59f7bc66a1db746213a0a82d727ce53e1d1 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 15:32:01 -0700 Subject: [PATCH 02/10] remove unnecessary escape Signed-off-by: Yaliang Wu --- .../ml/engine/algorithms/remote/ConnectorUtils.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index f3bfed3c3e..7eccd6155d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -95,11 +95,6 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } else { throw new IllegalArgumentException("Wrong input type"); } - Map escapedParameters = new HashMap<>(); - inputData.getParameters().entrySet().forEach(entry -> { - escapedParameters.put(entry.getKey(), escapeJava(entry.getValue())); - }); - inputData.setParameters(escapedParameters); return inputData; } From 3af33ba31d54591196123f2ed4046f5d9af339b8 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 16:28:24 -0700 Subject: [PATCH 03/10] fix failed ut Signed-off-by: Yaliang Wu --- .../ml/engine/algorithms/remote/ConnectorUtilsTest.java | 2 +- .../java/org/opensearch/ml/plugin/MachineLearningPlugin.java | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 9c3057b3a5..857cbe997f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -84,7 +84,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() processInput_TextDocsInputDataSet_PreprocessFunction( "{\"input\": ${parameters.input}}", "{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }", - "[\\\"test_value1\\\",\\\"test_value2\\\"]"); + "[\"test_value1\",\"test_value2\"]"); } @Test diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 84fecebd31..74c2227b37 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -278,9 +278,6 @@ public Collection createComponents( Settings settings = environment.settings(); Path dataPath = environment.dataFiles()[0]; Path configFile = environment.configFile(); - System.out.println("----------------ylwwwdebug"); - System.out.println(configFile); - System.out.println("----------------ylwwwdebugend"); Encryptor encryptor = new EncryptorImpl(); From c2e5e7468a1da35d6f0801b430bad6a31bace350 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 17:12:13 -0700 Subject: [PATCH 04/10] tune syncup jot interval Signed-off-by: Yaliang Wu --- .../main/java/org/opensearch/ml/settings/MLCommonsSettings.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index f4bc1da757..9f1a5308a2 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -42,7 +42,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS = Setting .intSetting( "plugins.ml_commons.sync_up_job_interval_in_seconds", - 3, + 10, 0, 86400, Setting.Property.NodeScope, From f854e23cd33725d8fef7272ff1601c74485061e8 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 17:21:30 -0700 Subject: [PATCH 05/10] tune syncup jot interval Signed-off-by: Yaliang Wu --- .../main/java/org/opensearch/ml/settings/MLCommonsSettings.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 9f1a5308a2..bb7deca49f 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -42,7 +42,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS = Setting .intSetting( "plugins.ml_commons.sync_up_job_interval_in_seconds", - 10, + 15, 0, 86400, Setting.Property.NodeScope, From 15d55533d4948c1d2ecc2bb1a636869e7790c211 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 17:51:14 -0700 Subject: [PATCH 06/10] remove local config file code Signed-off-by: Yaliang Wu --- ml-algorithms/build.gradle | 1 - .../org/opensearch/ml/engine/MLEngine.java | 46 ------------------- 2 files changed, 47 deletions(-) diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 7b66b4e00e..8c6a9ca0db 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -64,7 +64,6 @@ dependencies { implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.0' implementation 'com.jayway.jsonpath:json-path:2.8.0' implementation group: 'org.json', name: 'json', version: '20230227' - implementation group: 'org.yaml', name: 'snakeyaml', version: '2.0' } configurations.all { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index 41f6d794c9..a273a6bead 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -5,9 +5,6 @@ package org.opensearch.ml.engine; -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; -import com.google.gson.stream.JsonReader; import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.opensearch.ml.common.FunctionName; @@ -15,7 +12,6 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; -import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.MLInput; @@ -23,25 +19,10 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import org.yaml.snakeyaml.DumperOptions; -import org.yaml.snakeyaml.Yaml; - -import java.io.FileInputStream; -import java.io.FileReader; -import java.io.FileWriter; -import java.nio.file.Files; import java.nio.file.Path; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; -import java.security.SecureRandom; -import java.util.Base64; -import java.util.HashMap; import java.util.Locale; import java.util.Map; -import static org.opensearch.ml.common.CommonValue.MASTER_KEY; - /** * This is the interface to all ml algorithms. */ @@ -68,33 +49,6 @@ public MLEngine(Path opensearchDataFolder, Path opensearchConfigFolder, Encrypto this.mlUserConfigPath = opensearchConfigFolder.resolve("opensearch-ml"); this.mlConfigPath = mlCachePath.resolve("config"); this.encryptor = encryptor; - initMasterKey(); - } - - private synchronized void initMasterKey() { - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - Path userConfigFilePath = mlUserConfigPath.resolve("security_config.json"); - Map config = null; - if (Files.exists(userConfigFilePath)) { - try (FileInputStream fis = new FileInputStream(userConfigFilePath.toFile());) { - Yaml yaml = new Yaml(); - config = yaml.load(fis); - } - } - if (config == null) { - config = new HashMap<>(); - } - - if (config.containsKey(MASTER_KEY)) { - encryptor.setMasterKey(config.get(MASTER_KEY)); - } - return null; - }); - } catch (Exception e) { - log.error("Failed to save master key", e); - throw new MLException(e); - } } public String getPrebuiltModelMetaListPath() { From 2ab81f36929f45002c119836e3d5ec2c75ac8348 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 19:22:58 -0700 Subject: [PATCH 07/10] set master key when init remote model Signed-off-by: Yaliang Wu --- .../engine/algorithms/remote/RemoteModel.java | 47 +++++- .../ml/engine/encryptor/Encryptor.java | 1 + .../ml/engine/encryptor/EncryptorImpl.java | 5 + .../algorithms/remote/RemoteModelTest.java | 139 +++++++++++++++++- 4 files changed, 188 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 4449ee6996..191d5ad9c2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -7,6 +7,11 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -23,6 +28,11 @@ import org.opensearch.script.ScriptService; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; @Log4j2 @Function(FunctionName.REMOTE) @@ -77,11 +87,42 @@ public boolean isModelReady() { public void initModel(MLModel model, Map params, Encryptor encryptor) { try { Connector connector = model.getConnector().cloneConnector(); - connector.decrypt((credential) -> encryptor.decrypt(credential)); + + ClusterService clusterService = (ClusterService) params.get(CLUSTER_SERVICE); + Client client = (Client) params.get(CLIENT); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference exceptionRef = new AtomicReference<>(); + if (encryptor.getMasterKey() == null) { + if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) { + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, new LatchedActionListener(ActionListener.< GetResponse >wrap(r-> { + if (r.isExists()) { + String masterKey = (String)r.getSourceAsMap().get(MASTER_KEY); + encryptor.setMasterKey(masterKey); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + }, e-> { + log.error("Failed to get ML encryption master key", e); + exceptionRef.set(e); + }), latch)); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + } + + if (exceptionRef.get() != null) { + throw exceptionRef.get(); + } + if (encryptor.getMasterKey() != null) { + connector.decrypt((credential) -> encryptor.decrypt(credential)); + } else { + throw new MLException("ML encryptor not initialized"); + } this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); - this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); - this.connectorExecutor.setClient((Client) params.get(CLIENT)); + this.connectorExecutor.setClusterService(clusterService); + this.connectorExecutor.setClient(client); this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY)); } catch (RuntimeException e) { log.error("Failed to init remote model", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java index 7bbe58cec5..2316869ffd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java @@ -32,6 +32,7 @@ public interface Encryptor { * @param masterKey masterKey to be set. */ void setMasterKey(String masterKey); + String getMasterKey(); String generateMasterKey(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 16abde0a24..6179e0e297 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -29,6 +29,11 @@ public void setMasterKey(String masterKey) { this.masterKey = masterKey; } + @Override + public String getMasterKey() { + return masterKey; + } + @Override public String encrypt(String plainText) { checkMasterKey(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index b39ae6bc9e..b1fabaa0c8 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -13,6 +13,16 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -22,19 +32,38 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import java.time.Instant; import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; public class RemoteModelTest { @Mock MLInput mlInput; + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + @Mock MLModel mlModel; @@ -44,12 +73,46 @@ public class RemoteModelTest { RemoteModel remoteModel; Encryptor encryptor; + String masterKey; + + Map params; + private static final AtomicInteger portGenerator = new AtomicInteger(); + @Before public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); encryptor = spy(new EncryptorImpl()); - encryptor.setMasterKey("0000000000000001"); + masterKey = "0000000000000001"; + encryptor.setMasterKey(masterKey); + params = new HashMap<>(); + params.put(CLIENT, client); + params.put(CLUSTER_SERVICE, clusterService); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + when(response.getSourceAsMap()) + .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + + when(clusterService.state()).thenReturn(clusterState); + + Metadata metadata = new Metadata.Builder() + .indices(ImmutableMap + .builder() + .put(ML_CONFIG_INDEX, IndexMetadata.builder(ML_CONFIG_INDEX) + .settings(Settings.builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id)) + .build()) + .build()).build(); + when(clusterState.metadata()).thenReturn(metadata); } @Test @@ -112,6 +175,80 @@ public void initModel_WithHeader() { Assert.assertNull(remoteModel.getConnectorExecutor()); } + @Test + public void initModel_WithHeader_NullMasterKey_MasterKeyExistInIndex() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + Encryptor encryptor = new EncryptorImpl(); + remoteModel.initModel(mlModel, params, encryptor); + Map decryptedHeaders = connector.getDecryptedHeaders(); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + Assert.assertNotNull(executor); + Assert.assertNull(decryptedHeaders); + Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); + Assert.assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); + Assert.assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); + + remoteModel.close(); + Assert.assertNull(remoteModel.getConnectorExecutor()); + } + + @Test + public void initModel_WithHeader_NullMasterKey_MasterKeyNotExistInIndex() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage("ML encryption master key not initialized yet"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + Encryptor encryptor = new EncryptorImpl(); + remoteModel.initModel(mlModel, params, encryptor); + } + + @Test + public void initModel_WithHeader_GetMasterKey_Exception() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("test error"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test error")); + return null; + }).when(client).get(any(), any()); + + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + Encryptor encryptor = new EncryptorImpl(); + remoteModel.initModel(mlModel, params, encryptor); + } + + @Test + public void initModel_WithHeader_IndexNotFound() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage("ML encryption master key not initialized yet"); + + Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build(); + when(clusterState.metadata()).thenReturn(metadata); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test error")); + return null; + }).when(client).get(any(), any()); + + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + Encryptor encryptor = new EncryptorImpl(); + remoteModel.initModel(mlModel, params, encryptor); + } + private Connector createConnector(Map headers) { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) From e22f5145655b6283418f81da2cafab3325c0dfac Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 20:35:14 -0700 Subject: [PATCH 08/10] move init master key to encryptor Signed-off-by: Yaliang Wu --- .../engine/algorithms/remote/RemoteModel.java | 47 +----- .../ml/engine/encryptor/EncryptorImpl.java | 65 +++++++- .../algorithms/remote/RemoteModelTest.java | 139 +--------------- .../engine/encryptor/EncryptorImplTest.java | 148 ++++++++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 2 +- .../ml/settings/MLCommonsSettings.java | 2 +- 6 files changed, 212 insertions(+), 191 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 191d5ad9c2..4449ee6996 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -7,11 +7,6 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; -import org.opensearch.ResourceNotFoundException; -import org.opensearch.action.ActionListener; -import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -28,11 +23,6 @@ import org.opensearch.script.ScriptService; import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; - -import static org.opensearch.ml.common.CommonValue.MASTER_KEY; -import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; @Log4j2 @Function(FunctionName.REMOTE) @@ -87,42 +77,11 @@ public boolean isModelReady() { public void initModel(MLModel model, Map params, Encryptor encryptor) { try { Connector connector = model.getConnector().cloneConnector(); - - ClusterService clusterService = (ClusterService) params.get(CLUSTER_SERVICE); - Client client = (Client) params.get(CLIENT); - CountDownLatch latch = new CountDownLatch(1); - AtomicReference exceptionRef = new AtomicReference<>(); - if (encryptor.getMasterKey() == null) { - if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) { - GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); - client.get(getRequest, new LatchedActionListener(ActionListener.< GetResponse >wrap(r-> { - if (r.isExists()) { - String masterKey = (String)r.getSourceAsMap().get(MASTER_KEY); - encryptor.setMasterKey(masterKey); - } else { - exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); - } - }, e-> { - log.error("Failed to get ML encryption master key", e); - exceptionRef.set(e); - }), latch)); - } else { - exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); - } - } - - if (exceptionRef.get() != null) { - throw exceptionRef.get(); - } - if (encryptor.getMasterKey() != null) { - connector.decrypt((credential) -> encryptor.decrypt(credential)); - } else { - throw new MLException("ML encryptor not initialized"); - } + connector.decrypt((credential) -> encryptor.decrypt(credential)); this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); - this.connectorExecutor.setClusterService(clusterService); - this.connectorExecutor.setClient(client); + this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); + this.connectorExecutor.setClient((Client) params.get(CLIENT)); this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY)); } catch (RuntimeException e) { log.error("Failed to init remote model", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 6179e0e297..c690304161 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -9,19 +9,42 @@ import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.CryptoResult; import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import org.opensearch.ml.engine.exceptions.MetaDataException; +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.common.exception.MLException; import javax.crypto.spec.SecretKeySpec; import java.nio.charset.StandardCharsets; import java.security.SecureRandom; import java.util.Base64; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; + +@Log4j2 public class EncryptorImpl implements Encryptor { + private ClusterService clusterService; + private Client client; private volatile String masterKey; - public EncryptorImpl() { + public EncryptorImpl(ClusterService clusterService, Client client) { this.masterKey = null; + this.clusterService = clusterService; + this.client = client; + } + + @VisibleForTesting + public EncryptorImpl() { } @Override @@ -36,7 +59,7 @@ public String getMasterKey() { @Override public String encrypt(String plainText) { - checkMasterKey(); + initMasterKey(); final AwsCrypto crypto = AwsCrypto.builder() .withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) .build(); @@ -52,7 +75,7 @@ public String encrypt(String plainText) { @Override public String decrypt(String encryptedText) { - checkMasterKey(); + initMasterKey(); final AwsCrypto crypto = AwsCrypto.builder() .withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) .build(); @@ -74,9 +97,37 @@ public String generateMasterKey() { return base64Key; } - private void checkMasterKey() { - if (masterKey == null) { - throw new MetaDataException("Encryption key not created yet."); + private void initMasterKey() { + if (masterKey != null) { + return; + } + AtomicReference exceptionRef = new AtomicReference<>(); + + CountDownLatch latch = new CountDownLatch(1); + if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) { + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, new LatchedActionListener(ActionListener.wrap(r -> { + if (r.isExists()) { + String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY); + setMasterKey(masterKey); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + }, e -> { + log.error("Failed to get ML encryption master key", e); + exceptionRef.set(e); + }), latch)); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + + if (exceptionRef.get() != null) { + log.debug("Failed to init master key", exceptionRef.get()); + if (exceptionRef.get() instanceof RuntimeException) { + throw (RuntimeException) exceptionRef.get(); + } else { + throw new MLException(exceptionRef.get()); + } } } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index b1fabaa0c8..b39ae6bc9e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -13,16 +13,6 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ResourceNotFoundException; -import org.opensearch.Version; -import org.opensearch.action.ActionListener; -import org.opensearch.action.get.GetResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -32,38 +22,19 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; -import java.time.Instant; import java.util.Arrays; -import java.util.HashMap; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; -import static org.opensearch.ml.common.CommonValue.MASTER_KEY; -import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; -import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; -import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; public class RemoteModelTest { @Mock MLInput mlInput; - @Mock - Client client; - - @Mock - ClusterService clusterService; - - @Mock - ClusterState clusterState; - @Mock MLModel mlModel; @@ -73,46 +44,12 @@ public class RemoteModelTest { RemoteModel remoteModel; Encryptor encryptor; - String masterKey; - - Map params; - private static final AtomicInteger portGenerator = new AtomicInteger(); - @Before public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); encryptor = spy(new EncryptorImpl()); - masterKey = "0000000000000001"; - encryptor.setMasterKey(masterKey); - params = new HashMap<>(); - params.put(CLIENT, client); - params.put(CLUSTER_SERVICE, clusterService); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(true); - when(response.getSourceAsMap()) - .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); - listener.onResponse(response); - return null; - }).when(client).get(any(), any()); - - - when(clusterService.state()).thenReturn(clusterState); - - Metadata metadata = new Metadata.Builder() - .indices(ImmutableMap - .builder() - .put(ML_CONFIG_INDEX, IndexMetadata.builder(ML_CONFIG_INDEX) - .settings(Settings.builder() - .put("index.number_of_shards", 1) - .put("index.number_of_replicas", 1) - .put("index.version.created", Version.CURRENT.id)) - .build()) - .build()).build(); - when(clusterState.metadata()).thenReturn(metadata); + encryptor.setMasterKey("0000000000000001"); } @Test @@ -175,80 +112,6 @@ public void initModel_WithHeader() { Assert.assertNull(remoteModel.getConnectorExecutor()); } - @Test - public void initModel_WithHeader_NullMasterKey_MasterKeyExistInIndex() { - Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); - when(mlModel.getConnector()).thenReturn(connector); - Encryptor encryptor = new EncryptorImpl(); - remoteModel.initModel(mlModel, params, encryptor); - Map decryptedHeaders = connector.getDecryptedHeaders(); - RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); - Assert.assertNotNull(executor); - Assert.assertNull(decryptedHeaders); - Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); - Assert.assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); - Assert.assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); - - remoteModel.close(); - Assert.assertNull(remoteModel.getConnectorExecutor()); - } - - @Test - public void initModel_WithHeader_NullMasterKey_MasterKeyNotExistInIndex() { - exceptionRule.expect(ResourceNotFoundException.class); - exceptionRule.expectMessage("ML encryption master key not initialized yet"); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - listener.onResponse(response); - return null; - }).when(client).get(any(), any()); - - Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); - when(mlModel.getConnector()).thenReturn(connector); - Encryptor encryptor = new EncryptorImpl(); - remoteModel.initModel(mlModel, params, encryptor); - } - - @Test - public void initModel_WithHeader_GetMasterKey_Exception() { - exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("test error"); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("test error")); - return null; - }).when(client).get(any(), any()); - - Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); - when(mlModel.getConnector()).thenReturn(connector); - Encryptor encryptor = new EncryptorImpl(); - remoteModel.initModel(mlModel, params, encryptor); - } - - @Test - public void initModel_WithHeader_IndexNotFound() { - exceptionRule.expect(ResourceNotFoundException.class); - exceptionRule.expectMessage("ML encryption master key not initialized yet"); - - Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build(); - when(clusterState.metadata()).thenReturn(metadata); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("test error")); - return null; - }).when(client).get(any(), any()); - - Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); - when(mlModel.getConnector()).thenReturn(connector); - Encryptor encryptor = new EncryptorImpl(); - remoteModel.initModel(mlModel, params, encryptor); - } - private Connector createConnector(Map headers) { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java new file mode 100644 index 0000000000..2e0980bc0f --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java @@ -0,0 +1,148 @@ +package org.opensearch.ml.engine.encryptor; + +import com.google.common.collect.ImmutableMap; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; + +import java.time.Instant; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; + +public class EncryptorImplTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + String masterKey; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + masterKey = "0000000000000001"; + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + when(response.getSourceAsMap()) + .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + + when(clusterService.state()).thenReturn(clusterState); + + Metadata metadata = new Metadata.Builder() + .indices(ImmutableMap + .builder() + .put(ML_CONFIG_INDEX, IndexMetadata.builder(ML_CONFIG_INDEX) + .settings(Settings.builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id)) + .build()) + .build()).build(); + when(clusterState.metadata()).thenReturn(metadata); + } + + @Test + public void encrypt() { + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + String encrypted = encryptor.encrypt("test"); + Assert.assertNotNull(encrypted); + Assert.assertEquals(masterKey, encryptor.getMasterKey()); + } + + @Test + public void decrypt() { + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + String encrypted = encryptor.encrypt("test"); + String decrypted = encryptor.decrypt(encrypted); + Assert.assertEquals("test", decrypted); + Assert.assertEquals(masterKey, encryptor.getMasterKey()); + } + + @Test + public void encrypt_NullMasterKey_NullMasterKey_MasterKeyNotExistInIndex() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage("ML encryption master key not initialized yet"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.encrypt("test"); + } + + @Test + public void decrypt_NullMasterKey_GetMasterKey_Exception() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("test error"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test error")); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.decrypt("test"); + } + + @Test + public void decrypt_MLConfigIndexNotFound() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage("ML encryption master key not initialized yet"); + + Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build(); + when(clusterState.metadata()).thenReturn(metadata); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test error")); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.decrypt("test"); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 74c2227b37..61265d5a35 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -279,7 +279,7 @@ public Collection createComponents( Path dataPath = environment.dataFiles()[0]; Path configFile = environment.configFile(); - Encryptor encryptor = new EncryptorImpl(); + Encryptor encryptor = new EncryptorImpl(clusterService, client); mlEngine = new MLEngine(dataPath, configFile, encryptor); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index bb7deca49f..9f1a5308a2 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -42,7 +42,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS = Setting .intSetting( "plugins.ml_commons.sync_up_job_interval_in_seconds", - 15, + 10, 0, 86400, Setting.Property.NodeScope, From 5d7a8987995dbb0a614ec0dc096614381e38d4bd Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 21:02:11 -0700 Subject: [PATCH 09/10] fine tune code Signed-off-by: Yaliang Wu --- .../src/main/java/org/opensearch/ml/engine/MLEngine.java | 4 +--- .../opensearch/ml/engine/encryptor/EncryptorImpl.java | 5 ++--- .../test/java/org/opensearch/ml/engine/MLEngineTest.java | 5 ++--- .../metrics_correlation/MetricsCorrelationTest.java | 6 ++---- .../algorithms/remote/AwsConnectorExecutorTest.java | 3 +-- .../ml/engine/algorithms/remote/RemoteModelTest.java | 3 +-- .../algorithms/text_embedding/ModelHelperTest.java | 5 ++--- .../text_embedding/TextEmbeddingModelTest.java | 6 ++---- .../org/opensearch/ml/plugin/MachineLearningPlugin.java | 2 +- .../action/deploy/TransportDeployModelActionTests.java | 5 ++--- .../org/opensearch/ml/cluster/MLSyncUpCronTests.java | 2 +- .../org/opensearch/ml/model/MLModelManagerTests.java | 5 ++--- .../org/opensearch/ml/task/MLExecuteTaskRunnerTests.java | 9 ++------- .../org/opensearch/ml/task/MLPredictTaskRunnerTests.java | 5 ++--- .../ml/task/MLTrainAndPredictTaskRunnerTests.java | 5 ++--- .../opensearch/ml/task/MLTrainingTaskRunnerTests.java | 3 +-- 16 files changed, 26 insertions(+), 47 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index a273a6bead..0c49e83bac 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -33,7 +33,6 @@ public class MLEngine { public static final String DEPLOY_MODEL_FOLDER = "deploy"; private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models"; - private final Path mlUserConfigPath; @Getter private final Path mlConfigPath; @@ -43,10 +42,9 @@ public class MLEngine { private Encryptor encryptor; - public MLEngine(Path opensearchDataFolder, Path opensearchConfigFolder, Encryptor encryptor) { + public MLEngine(Path opensearchDataFolder, Encryptor encryptor) { this.mlCachePath = opensearchDataFolder.resolve("ml_cache"); this.mlModelsCachePath = mlCachePath.resolve("models_cache"); - this.mlUserConfigPath = opensearchConfigFolder.resolve("opensearch-ml"); this.mlConfigPath = mlCachePath.resolve("config"); this.encryptor = encryptor; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index c690304161..0778af444a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -9,7 +9,6 @@ import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.CryptoResult; import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionListener; @@ -43,8 +42,8 @@ public EncryptorImpl(ClusterService clusterService, Client client) { this.client = client; } - @VisibleForTesting - public EncryptorImpl() { + public EncryptorImpl(String masterKey) { + this.masterKey = masterKey; } @Override diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 24115d4f46..b761cddd90 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -53,9 +53,8 @@ public class MLEngineTest { @Before public void setUp() { - Encryptor encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000000"); - mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), Path.of("/tmp/test" + UUID.randomUUID()), encryptor); + Encryptor encryptor = new EncryptorImpl("0000000000000000"); + mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index e860a811e2..35d782bc67 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -157,10 +157,8 @@ public void setUp() throws IOException, URISyntaxException { System.setProperty("testMode", "true"); mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); - mlConfigPath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000001"); - mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(mlCachePath, encryptor); modelConfig = MetricsCorrelationModelConfig.builder() .modelType(MetricsCorrelation.MODEL_TYPE) .allConfig(null) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 49c83d40c5..8d6130566a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -67,8 +67,7 @@ public class AwsConnectorExecutorTest { @Before public void setUp() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000001"); + encryptor = new EncryptorImpl("0000000000000001"); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index b39ae6bc9e..bef3e1da71 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -48,8 +48,7 @@ public class RemoteModelTest { public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); - encryptor = spy(new EncryptorImpl()); - encryptor.setMasterKey("0000000000000001"); + encryptor = spy(new EncryptorImpl("0000000000000001")); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java index 8f4dab8963..245afdfe7f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java @@ -59,9 +59,8 @@ public void setup() throws URISyntaxException { MockitoAnnotations.openMocks(this); modelFormat = MLModelFormat.TORCH_SCRIPT; modelId = "model_id"; - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000001"); - mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), Path.of("/tmp/test_config"), encryptor); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor); modelHelper = new ModelHelper(mlEngine); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java index 15d75f4bca..70d7fd2e35 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java @@ -73,10 +73,8 @@ public class TextEmbeddingModelTest { @Before public void setUp() throws URISyntaxException { mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); - mlConfigPath = Path.of("/tmp/ml_config" + UUID.randomUUID()); - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000001"); - mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(mlCachePath, encryptor); modelId = "test_model_id"; modelName = "test_model_name"; functionName = FunctionName.TEXT_EMBEDDING; diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 61265d5a35..5870cbd3f1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -281,7 +281,7 @@ public Collection createComponents( Encryptor encryptor = new EncryptorImpl(clusterService, client); - mlEngine = new MLEngine(dataPath, configFile, encryptor); + mlEngine = new MLEngine(dataPath, encryptor); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); modelCacheHelper = new MLModelCacheHelper(clusterService, settings); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index 193435270b..1cb0670e14 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -142,9 +142,8 @@ public void setup() { clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN))); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000001"); - mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); modelHelper = new ModelHelper(mlEngine); when(mlDeployModelRequest.getModelId()).thenReturn("mockModelId"); when(mlDeployModelRequest.getModelNodeIds()).thenReturn(new String[] { "node1" }); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index a3497af6e2..5a6913f576 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -105,7 +105,7 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); - encryptor = spy(new EncryptorImpl()); + encryptor = spy(new EncryptorImpl(null)); syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor); testState = setupTestClusterState(); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 456068391f..24fd02bf71 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -172,9 +172,8 @@ public class MLModelManagerTests extends OpenSearchTestCase { public void setup() throws URISyntaxException { String masterKey = "0000000000000001"; MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl(); - encryptor.setMasterKey(masterKey); - mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl(masterKey); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().put(ML_COMMONS_MAX_MODELS_PER_NODE.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), 10).build(); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 6813000638..56f6d09fd1 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -93,13 +93,8 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000000"); - mlEngine = new MLEngine( - Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), - Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), - encryptor - ); + encryptor = new EncryptorImpl("0000000000000000"); + mlEngine = new MLEngine(Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), encryptor); when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 689b15626b..530033197c 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -127,9 +127,8 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000001"); - mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); when(clusterService.localNode()).thenReturn(localNode); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 39add53b8e..0714bf0234 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -101,9 +101,8 @@ public class MLTrainAndPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000001"); - mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().build(); MockitoAnnotations.openMocks(this); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index 12a63d4f6f..c6a6b9a886 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -111,8 +111,7 @@ public class MLTrainingTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl(); - encryptor.setMasterKey("0000000000000001"); + encryptor = new EncryptorImpl("0000000000000001"); mlEngine = new MLEngine( Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), From d6fb41c8297fc400225bb0b910d21488072a41de Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 21:05:43 -0700 Subject: [PATCH 10/10] fine tune code Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/task/MLTrainingTaskRunnerTests.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index c6a6b9a886..6565a41b95 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -112,11 +112,7 @@ public class MLTrainingTaskRunnerTests extends OpenSearchTestCase { public void setup() { MockitoAnnotations.openMocks(this); encryptor = new EncryptorImpl("0000000000000001"); - mlEngine = new MLEngine( - Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), - Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), - encryptor - ); + mlEngine = new MLEngine(Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); when(clusterService.localNode()).thenReturn(localNode);