Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Nov 5, 2023
1 parent fa8339b commit 69cd31f
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 79 deletions.
2 changes: 1 addition & 1 deletion client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ plugins {

dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
compileOnly project(':opensearch-ml-common')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
Expand Down
3 changes: 1 addition & 2 deletions memory/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ plugins {
}

dependencies {
// implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation group: 'org.opensearch', name:'opensearch-ml-common', version: "${opensearch_build}"
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1'
implementation "org.opensearch:common-utils:${common_utils_version}"
Expand Down
2 changes: 1 addition & 1 deletion ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ repositories {

dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
// implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation group: 'org.opensearch', name:'opensearch-ml-common', version: "${opensearch_build}"
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
testImplementation "org.opensearch.test:framework:${opensearch_version}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
Expand All @@ -18,7 +17,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.search.SearchHit;
Expand All @@ -29,10 +27,7 @@
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.ml.common.utils.StringUtils.gson;

/**
Expand All @@ -52,80 +47,23 @@ public class VectorDBTool implements Tool {
private NamedXContentRegistry xContentRegistry;
private String index;
private String embeddingField;
private String sourceField;
private String[] sourceFields;
private String modelId;
private Integer docSize ;
private Integer k;

@Builder
public VectorDBTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String sourceField, Integer k, Integer docSize, String modelId) {
public VectorDBTool(Client client, NamedXContentRegistry xContentRegistry, String index, String embeddingField, String[] sourceFields, Integer k, Integer docSize, String modelId) {
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
this.embeddingField = embeddingField;
this.sourceField = sourceField;
this.sourceFields = sourceFields;
this.modelId = modelId;
this.docSize = docSize == null? 2 : docSize;
this.k = k == null? 10 : k;
}

@Override
public <T> T run(Map<String, String> parameters) {
try {
String question = parameters.get("input");
try {
question = gson.fromJson(question, String.class);
} catch (Exception e) {
throw new IllegalArgumentException("wrong input");
}
String query = "{\"query\":{\"neural\":{\""+ embeddingField +"\":{\"query_text\":\"" + question + "\",\"model_id\":\""
+ modelId + "\",\"k\":" + k + "}}},\"size\":\"" + docSize
+ "\",\"_source\":[\"" + sourceField + "\"]}";
AtomicReference<String> contextRef = new AtomicReference<>("");
AtomicReference<Exception> exceptionRef = new AtomicReference<>(null);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener listener = new LatchedActionListener<SearchResponse>(ActionListener.wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_id", hit.getId());
docContent.put("_source", hit.getSourceAsMap());
return gson.toJson(docContent);
});
contextBuilder.append(doc).append("\n");
}
contextRef.set(gson.toJson(contextBuilder.toString()));
}
}, e -> {
log.error("Failed to search index", e);
exceptionRef.set(e);
}), latch);
client.search(searchRequest, listener);

try {
latch.await(50, SECONDS);
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
if (exceptionRef.get() != null) {
throw new MLException(exceptionRef.get());
}
return (T)contextRef.get();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
Expand All @@ -135,13 +73,14 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
} catch (Exception e) {
//throw new IllegalArgumentException("wrong input");
}
String query = "{\"query\":{\"neural\":{\""+ embeddingField +"\":{\"query_text\":\"" + question + "\",\"model_id\":\""
+ modelId + "\",\"k\":" + k + "}}},\"size\":\"" + docSize
+ "\",\"_source\":[\"" + sourceField + "\"]}";
String query = "{\"query\":{\"neural\":{\"" + embeddingField + "\":{\"query_text\":\"" + question + "\",\"model_id\":\""
+ modelId + "\",\"k\":" + k + "}}}" + " }";

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.fetchSource(sourceFields, null);
searchSourceBuilder.size(docSize);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();
Expand Down Expand Up @@ -217,15 +156,15 @@ public void init(Client client, NamedXContentRegistry xContentRegistry) {
public VectorDBTool create(Map<String, Object> params) {
String index = (String)params.get("index");
String embeddingField = (String)params.get("embedding_field");
String sourceField = (String)params.get("source_field");
String[] sourceFields = gson.fromJson((String)params.get("source_field"), String[].class);
String modelId = (String)params.get("model_id");
Integer docSize = params.containsKey("doc_size")? Integer.parseInt((String)params.get("doc_size")) : 2;
return VectorDBTool.builder()
.client(client)
.xContentRegistry(xContentRegistry)
.index(index)
.embeddingField(embeddingField)
.sourceField(sourceField)
.sourceFields(sourceFields)
.modelId(modelId)
.docSize(docSize)
.build();
Expand All @@ -236,5 +175,4 @@ public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}

}
}
2 changes: 1 addition & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ opensearchplugin {

dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
// implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation group: 'org.opensearch', name:'opensearch-ml-common', version: "${opensearch_build}"
implementation project(':opensearch-ml-algorithms')
implementation project(':opensearch-ml-search-processors')
Expand Down
3 changes: 1 addition & 2 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ repositories {
}

dependencies {
// implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation group: 'org.opensearch', name:'opensearch-ml-common', version: "${opensearch_build}"
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
implementation 'org.apache.commons:commons-lang3:3.12.0'
Expand Down

0 comments on commit 69cd31f

Please sign in to comment.