Skip to content

Commit

Permalink
Address some comments
Browse files Browse the repository at this point in the history
Signed-off-by: Liyun Xiu <[email protected]>
  • Loading branch information
chishui committed Aug 6, 2024
1 parent 3ed1de9 commit a90203f
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ protected String registerModelGroupAndGetModelId(final String requestBody) throw

protected void createPipelineProcessor(final String modelId, final String pipelineName) throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/PipelineConfiguration.json").toURI()));
createPipelineProcessor(requestBody, pipelineName, modelId);
createPipelineProcessor(requestBody, pipelineName, modelId, null);
}

protected String uploadSparseEncodingModel() throws Exception {
Expand All @@ -90,22 +90,15 @@ protected void createPipelineForTextImageProcessor(final String modelId, final S
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/PipelineForTextImageProcessorConfiguration.json").toURI())
);
createPipelineProcessor(requestBody, pipelineName, modelId);
createPipelineProcessor(requestBody, pipelineName, modelId, null);
}

protected void createPipelineForSparseEncodingProcessor(String modelId, String pipelineName, Integer batchSize) throws Exception {
protected void createPipelineForSparseEncodingProcessor(final String modelId, final String pipelineName, final Integer batchSize)
throws Exception {
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/PipelineForSparseEncodingProcessorConfiguration.json").toURI())
);
final String batchSizeTag = "{{batch_size}}";
if (requestBody.contains(batchSizeTag)) {
if (batchSize != null) {
requestBody = requestBody.replace(batchSizeTag, String.format(LOCALE, "\n\"batch_size\": %d,\n", batchSize));
} else {
requestBody = requestBody.replace(batchSizeTag, "");
}
}
createPipelineProcessor(requestBody, pipelineName, modelId);
createPipelineProcessor(requestBody, pipelineName, modelId, batchSize);
}

protected void createPipelineForSparseEncodingProcessor(final String modelId, final String pipelineName) throws Exception {
Expand All @@ -116,6 +109,6 @@ protected void createPipelineForTextChunkingProcessor(String pipelineName) throw
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/PipelineForTextChunkingProcessorConfiguration.json").toURI())
);
createPipelineProcessor(requestBody, pipelineName, "");
createPipelineProcessor(requestBody, pipelineName, "", null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ protected String registerModelGroupAndGetModelId(String requestBody) throws Exce

protected void createPipelineProcessor(String modelId, String pipelineName) throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/PipelineConfiguration.json").toURI()));
createPipelineProcessor(requestBody, pipelineName, modelId);
createPipelineProcessor(requestBody, pipelineName, modelId, null);
}

protected String uploadTextImageEmbeddingModel() throws Exception {
Expand All @@ -114,7 +114,7 @@ protected void createPipelineForTextImageProcessor(String modelId, String pipeli
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/PipelineForTextImageProcessorConfiguration.json").toURI())
);
createPipelineProcessor(requestBody, pipelineName, modelId);
createPipelineProcessor(requestBody, pipelineName, modelId, null);
}

protected String uploadSparseEncodingModel() throws Exception {
Expand All @@ -136,7 +136,7 @@ protected void createPipelineForSparseEncodingProcessor(String modelId, String p
requestBody = requestBody.replace(batchSizeTag, "");
}
}
createPipelineProcessor(requestBody, pipelineName, modelId);
createPipelineProcessor(requestBody, pipelineName, modelId, null);
}

protected void createPipelineForSparseEncodingProcessor(String modelId, String pipelineName) throws Exception {
Expand All @@ -147,6 +147,6 @@ protected void createPipelineForTextChunkingProcessor(String pipelineName) throw
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/PipelineForTextChunkingProcessorConfiguration.json").toURI())
);
createPipelineProcessor(requestBody, pipelineName, "");
createPipelineProcessor(requestBody, pipelineName, "", null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"processors": [
{
"sparse_encoding": {
"model_id": "%s",{{batch_size}}
"model_id": "%s",
"batch_size": "%d",
"field_map": {
"passage_text": "passage_embedding"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ private void createPipelineProcessor(String pipelineName) throws Exception {
URL pipelineURLPath = classLoader.getResource(PIPELINE_CONFIGS_BY_NAME.get(pipelineName));
Objects.requireNonNull(pipelineURLPath);
String requestBody = Files.readString(Path.of(pipelineURLPath.toURI()));
createPipelineProcessor(requestBody, pipelineName, "");
createPipelineProcessor(requestBody, pipelineName, "", null);
}

private void createTextChunkingIndex(String indexName, String pipelineName) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Excepti
URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json");
Objects.requireNonNull(pipelineURLPath);
String requestBody = Files.readString(Path.of(pipelineURLPath.toURI()));
createPipelineProcessor(requestBody, PIPELINE_NAME, modelId);
createPipelineProcessor(requestBody, PIPELINE_NAME, modelId, null);
createTextEmbeddingIndex();
int docCount = 5;
ingestBatchDocumentWithBulk("batch_", docCount, Collections.emptySet(), Collections.emptySet());
Expand Down Expand Up @@ -214,7 +214,7 @@ public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception {
URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json");
Objects.requireNonNull(pipelineURLPath);
String requestBody = Files.readString(Path.of(pipelineURLPath.toURI()));
createPipelineProcessor(requestBody, PIPELINE_NAME, modelId);
createPipelineProcessor(requestBody, PIPELINE_NAME, modelId, null);
createTextEmbeddingIndex();
int docCount = 5;
ingestBatchDocumentWithBulk("batch_", docCount, Set.of(0), Set.of(1));
Expand Down
3 changes: 2 additions & 1 deletion src/test/resources/processor/PipelineConfiguration.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"processors": [
{
"text_embedding": {
"model_id": "%s",{{batch_size}}
"model_id": "%s",
"batch_size": "%d",
"field_map": {
"title": "title_knn",
"favor_list": "favor_list_knn",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"processors" : [
{
"sparse_encoding": {
"model_id": "%s",{{batch_size}}
"model_id": "%s",
"batch_size": "%d",
"field_map": {
"title": "title_sparse",
"favor_list": "favor_list_sparse",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,24 +304,21 @@ protected void createPipelineProcessor(
final Integer batchSize
) throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGS_BY_TYPE.get(processorType)).toURI()));
final String batchSizeTag = "{{batch_size}}";
if (requestBody.contains(batchSizeTag)) {
if (batchSize != null) {
requestBody = requestBody.replace(batchSizeTag, String.format(LOCALE, "\n\"batch_size\": %d,\n", batchSize));
} else {
requestBody = requestBody.replace(batchSizeTag, "");
}
}
createPipelineProcessor(requestBody, pipelineName, modelId);
createPipelineProcessor(requestBody, pipelineName, modelId, batchSize);
}

protected void createPipelineProcessor(final String requestBody, final String pipelineName, final String modelId) throws Exception {
protected void createPipelineProcessor(
final String requestBody,
final String pipelineName,
final String modelId,
final Integer batchSize
) throws Exception {
Response pipelineCreateResponse = makeRequest(
client(),
"PUT",
"/_ingest/pipeline/" + pipelineName,
null,
toHttpEntity(String.format(LOCALE, requestBody, modelId)),
toHttpEntity(String.format(LOCALE, requestBody, modelId, batchSize == null ? 1 : batchSize)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
Expand Down

0 comments on commit a90203f

Please sign in to comment.