Skip to content

Commit

Permalink
added missing result filter to inference (#2367)
Browse files Browse the repository at this point in the history
Signed-off-by: br3no <[email protected]>
  • Loading branch information
br3no authored Apr 28, 2024
1 parent 85d0c9e commit caf1d65
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private TextDocsInputDataSet addPrefixesToData(AsymmetricTextEmbeddingParameters
: modelConfig.getQueryPrefix();
if (prefix != null) {
List<String> prefixedDocs = inputDataSet.getDocs().stream().map(s -> prefix + s).collect(Collectors.toList());
return TextDocsInputDataSet.builder().docs(prefixedDocs).build();
return TextDocsInputDataSet.builder().docs(prefixedDocs).resultFilter(inputDataSet.getResultFilter()).build();
}
return inputDataSet;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,27 @@ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_As
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(
TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
TextDocsInputDataSet
.builder()
.docs(Arrays.asList("what is the meaning of life?", "who won this year's us open"))
.resultFilter(
ModelResultFilter.builder().targetResponse(List.of(SENTENCE_EMBEDDING)).returnBytes(true).returnNumber(true).build()
)
.build()
)
.parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY))
.build();
MLInput asymmetricMlInputPassages = MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(
TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build()
TextDocsInputDataSet
.builder()
.docs(Arrays.asList("The meaning of life is 42", "I won this year's us open"))
.resultFilter(
ModelResultFilter.builder().targetResponse(List.of(SENTENCE_EMBEDDING)).returnBytes(true).returnNumber(true).build()
)
.build()
)
.parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.PASSAGE))
.build();
Expand All @@ -285,20 +297,38 @@ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_As
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(
TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
TextDocsInputDataSet
.builder()
.docs(Arrays.asList("what is the meaning of life?", "who won this year's us open"))
.resultFilter(
ModelResultFilter.builder().targetResponse(List.of(SENTENCE_EMBEDDING)).returnBytes(true).returnNumber(true).build()
)
.build()
)
.build();
MLInput symmetricMlInputPassages = MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(
TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build()
TextDocsInputDataSet
.builder()
.docs(Arrays.asList("The meaning of life is 42", "I won this year's us open"))
.resultFilter(
ModelResultFilter.builder().targetResponse(List.of(SENTENCE_EMBEDDING)).returnBytes(true).returnNumber(true).build()
)
.build()
)
.build();

ModelTensorOutput symmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputQueries);
ModelTensorOutput symmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputPassages);

assertTrue(
"asymmetric and symmetric embeddings should have the same number of tensors",
asymmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().size() == 1
&& symmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().size() == 1
);

assertTrue(
"asymmetric and symmetric query embeddings should be different",
areTensorsDifferent(
Expand Down

0 comments on commit caf1d65

Please sign in to comment.