Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/conversation memory feature flag #1271

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.ml.common.conversation;

import org.opensearch.common.settings.Setting;

/**
* Class containing a bunch of constant defining how the conversational indices are formatted
*/
Expand Down Expand Up @@ -97,4 +99,7 @@ public class ConversationalIndexConstants {
+ " }\n"
+ "}";

/** Feature Flag setting for conversational memory */
public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting
.boolSetting("plugins.ml_commons.memory_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -39,27 +42,45 @@ public class CreateConversationTransportAction extends HandledTransportAction<Cr
private ConversationalMemoryHandler cmHandler;
private Client client;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public CreateConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(CreateConversationAction.NAME, transportService, actionFilters, CreateConversationRequest::new);
this.cmHandler = cmHandler;
this.client = client;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, CreateConversationRequest request, ActionListener<CreateConversationResponse> actionListener) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
}
String name = request.getName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateConversationResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -39,27 +42,45 @@ public class CreateInteractionTransportAction extends HandledTransportAction<Cre
private ConversationalMemoryHandler cmHandler;
private Client client;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for doing intra-cluster communication
* @param actionFilters for filtering actions
* @param cmHandler handler for conversational memory
* @param client client for general opensearch ops
* @param clusterService for some cluster ops
*/
@Inject
public CreateInteractionTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(CreateInteractionAction.NAME, transportService, actionFilters, CreateInteractionRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, CreateInteractionRequest request, ActionListener<CreateInteractionResponse> actionListener) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
}
String cid = request.getConversationId();
String inp = request.getInput();
String rsp = request.getResponse();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -39,27 +42,45 @@ public class DeleteConversationTransportAction extends HandledTransportAction<De
private Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public DeleteConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(DeleteConversationAction.NAME, transportService, actionFilters, DeleteConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
public void doExecute(Task task, DeleteConversationRequest request, ActionListener<DeleteConversationResponse> listener) {
if (!featureIsEnabled) {
listener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
}
String conversationId = request.getId();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<DeleteConversationResponse> internalListener = ActionListener.runBefore(listener, () -> context.restore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@

import java.util.List;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -42,27 +45,45 @@ public class GetConversationsTransportAction extends HandledTransportAction<GetC
private Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetConversationsTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(GetConversationsAction.NAME, transportService, actionFilters, GetConversationsRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
public void doExecute(Task task, GetConversationsRequest request, ActionListener<GetConversationsResponse> actionListener) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
}
int maxResults = request.getMaxResults();
int from = request.getFrom();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

import java.util.List;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
Expand All @@ -42,27 +45,45 @@ public class GetInteractionsTransportAction extends HandledTransportAction<GetIn
private Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetInteractionsTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(GetInteractionsAction.NAME, transportService, actionFilters, GetInteractionsRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
public void doExecute(Task task, GetInteractionsRequest request, ActionListener<GetInteractionsResponse> actionListener) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
}
int maxResults = request.getMaxResults();
int from = request.getFrom();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Set;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -33,10 +34,12 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -89,12 +92,16 @@ public void setup() throws IOException {
this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class);

this.request = new CreateConversationRequest("test");
this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client));

Settings settings = Settings.builder().build();
Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
this.threadContext = new ThreadContext(settings);
when(this.client.threadPool()).thenReturn(this.threadPool);
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
when(this.clusterService.getSettings()).thenReturn(settings);
when(this.clusterService.getClusterSettings())
.thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED)));

this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService));
}

public void testCreateConversation() {
Expand Down Expand Up @@ -144,4 +151,15 @@ public void testDoExecuteFails_thenFail() {
assert (argCaptor.getValue().getMessage().equals("Test doExecute Error"));
}

public void testFeatureDisabled_ThenFail() {
when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY);
when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED)));
this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

action.doExecute(null, request, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled."));
}

}
Loading