diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 0778af444a..83502fb85c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -17,6 +17,7 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.ml.common.exception.MLException; import javax.crypto.spec.SecretKeySpec; @@ -62,9 +63,9 @@ public String encrypt(String plainText) { final AwsCrypto crypto = AwsCrypto.builder() .withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) .build(); - + byte[] bytes = Base64.getDecoder().decode(masterKey); JceMasterKey jceMasterKey - = JceMasterKey.getInstance(new SecretKeySpec(masterKey.getBytes(), "AES"), "Custom", "", + = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding"); final CryptoResult encryptResult = crypto.encryptData(jceMasterKey, @@ -79,8 +80,9 @@ public String decrypt(String encryptedText) { .withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) .build(); + byte[] bytes = Base64.getDecoder().decode(masterKey); JceMasterKey jceMasterKey - = JceMasterKey.getInstance(new SecretKeySpec(masterKey.getBytes(), "AES"), "Custom", "", + = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding"); final CryptoResult decryptedResult @@ -90,7 +92,7 @@ public String decrypt(String encryptedText) { @Override public String generateMasterKey() { - byte[] keyBytes = new byte[16]; + byte[] keyBytes = new byte[32]; new SecureRandom().nextBytes(keyBytes); String base64Key = Base64.getEncoder().encodeToString(keyBytes); return base64Key; @@ -104,18 +106,20 @@ private void initMasterKey() { CountDownLatch latch = new CountDownLatch(1); if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) { - GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); - client.get(getRequest, new LatchedActionListener(ActionListener.wrap(r -> { - if (r.isExists()) { - String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY); - setMasterKey(masterKey); - } else { - exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); - } - }, e -> { - log.error("Failed to get ML encryption master key", e); - exceptionRef.set(e); - }), latch)); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, new LatchedActionListener(ActionListener.wrap(r -> { + if (r.isExists()) { + String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY); + setMasterKey(masterKey); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + }, e -> { + log.error("Failed to get ML encryption master key", e); + exceptionRef.set(e); + }), latch)); + } } else { exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index b761cddd90..789781308b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -53,7 +53,7 @@ public class MLEngineTest { @Before public void setUp() { - Encryptor encryptor = new EncryptorImpl("0000000000000000"); + Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index 35d782bc67..72c0d687b0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -157,7 +157,7 @@ public void setUp() throws IOException, URISyntaxException { System.setProperty("testMode", "true"); mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); modelConfig = MetricsCorrelationModelConfig.builder() .modelType(MetricsCorrelation.MODEL_TYPE) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 8d6130566a..ecc143ea6f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -67,7 +67,7 @@ public class AwsConnectorExecutorTest { @Before public void setUp() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index bef3e1da71..6016748a1e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -48,7 +48,7 @@ public class RemoteModelTest { public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); - encryptor = spy(new EncryptorImpl("0000000000000001")); + encryptor = spy(new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java index 245afdfe7f..a03fbfd46b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java @@ -59,7 +59,7 @@ public void setup() throws URISyntaxException { MockitoAnnotations.openMocks(this); modelFormat = MLModelFormat.TORCH_SCRIPT; modelId = "model_id"; - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor); modelHelper = new ModelHelper(mlEngine); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java index 70d7fd2e35..b49938b50e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java @@ -73,7 +73,7 @@ public class TextEmbeddingModelTest { @Before public void setUp() throws URISyntaxException { mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); modelId = "test_model_id"; modelName = "test_model_name"; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java index 2e0980bc0f..aa5d7bd0b0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java @@ -18,6 +18,9 @@ import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.threadpool.ThreadPool; import java.time.Instant; @@ -43,10 +46,15 @@ public class EncryptorImplTest { String masterKey; + @Mock + ThreadPool threadPool; + ThreadContext threadContext; + final String USER_STRING = "myuser|role1,role2|myTenant"; + @Before public void setUp() { MockitoAnnotations.openMocks(this); - masterKey = "0000000000000001"; + masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -72,6 +80,12 @@ public void setUp() { .build()) .build()).build(); when(clusterState.metadata()).thenReturn(metadata); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); } @Test @@ -83,6 +97,17 @@ public void encrypt() { Assert.assertEquals(masterKey, encryptor.getMasterKey()); } + @Test + public void encrypt_DifferentMasterKey() { + Encryptor encryptor = new EncryptorImpl(masterKey); + Assert.assertNotNull(encryptor.getMasterKey()); + String encrypted1 = encryptor.encrypt("test"); + + encryptor.setMasterKey(encryptor.generateMasterKey()); + String encrypted2 = encryptor.encrypt("test"); + Assert.assertNotEquals(encrypted1, encrypted2); + } + @Test public void decrypt() { Encryptor encryptor = new EncryptorImpl(clusterService, client); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 95c5ec037c..bd0d813a0a 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -31,6 +31,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.ml.common.MLModel; @@ -168,24 +169,26 @@ void initMLConfig() { } mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> { GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); - client.get(getRequest, ActionListener.wrap(getResponse -> { - if (!getResponse.isExists()) { - IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY); - final String masterKey = encryptor.generateMasterKey(); - indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexRequest, ActionListener.wrap(indexResponse -> { - log.info("ML configuration initialized successfully"); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (!getResponse.isExists()) { + IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + final String masterKey = encryptor.generateMasterKey(); + indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + log.info("ML configuration initialized successfully"); + encryptor.setMasterKey(masterKey); + mlConfigInited = true; + }, e -> { log.debug("Failed to save ML encryption master key", e); })); + } else { + final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); encryptor.setMasterKey(masterKey); mlConfigInited = true; - }, e -> { log.debug("Failed to save ML encryption master key", e); })); - } else { - final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); - encryptor.setMasterKey(masterKey); - mlConfigInited = true; - log.info("ML configuration already initialized, no action needed"); - } - }, e -> { log.debug("Failed to get ML encryption master key", e); })); + log.info("ML configuration already initialized, no action needed"); + } + }, e -> { log.debug("Failed to get ML encryption master key", e); })); + } }, e -> { log.debug("Failed to init ML config index", e); })); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index 1cb0670e14..9ec83e3a6a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -142,7 +142,7 @@ public void setup() { clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN))); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); modelHelper = new ModelHelper(mlEngine); when(mlDeployModelRequest.getModelId()).thenReturn("mockModelId"); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 5a6913f576..c860412ad7 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -57,7 +57,10 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; @@ -75,12 +78,12 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; public class MLSyncUpCronTests extends OpenSearchTestCase { - @Mock private Client client; @Mock @@ -100,6 +103,11 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { private ClusterState testState; private Encryptor encryptor; + @Mock + ThreadPool threadPool; + ThreadContext threadContext; + final String USER_STRING = "myuser|role1,role2|myTenant"; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -116,6 +124,12 @@ public void setup() throws IOException { actionListener.onResponse(true); return null; }).when(mlIndicesHandler).initMLConfigIndex(any()); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); } public void testInitMlConfig_MasterKeyNotExist() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 24fd02bf71..9f5cb8c441 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -170,7 +170,7 @@ public class MLModelManagerTests extends OpenSearchTestCase { @Before public void setup() throws URISyntaxException { - String masterKey = "0000000000000001"; + String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; MockitoAnnotations.openMocks(this); encryptor = new EncryptorImpl(masterKey); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 56f6d09fd1..37679e7820 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -93,7 +93,7 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000000"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), encryptor); when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 530033197c..f91aee8d9b 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -127,7 +127,7 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 0714bf0234..3241e522bc 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -101,7 +101,7 @@ public class MLTrainAndPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().build(); MockitoAnnotations.openMocks(this); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index 6565a41b95..9144a118bf 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -111,7 +111,7 @@ public class MLTrainingTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000001"); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT);