Skip to content

Commit

Permalink
Add tokenizer and sparse encoding (opensearch-project#1301) (opensear…
Browse files Browse the repository at this point in the history
…ch-project#1393)

* add tokenizer and sparse encoding

Signed-off-by: xinyual <[email protected]>

* add tokenizer and sparse encoding

Signed-off-by: xinyual <[email protected]>

* add tokenizer and sparse encoding

Signed-off-by: xinyual <[email protected]>

* add tokenizer and sparse encoding

Signed-off-by: xinyual <[email protected]>

* add tokenizer and sparse encoding

Signed-off-by: xinyual <[email protected]>

* remove special token

Signed-off-by: xinyual <[email protected]>

* add filter

Signed-off-by: xinyual <[email protected]>

* try empty model

Signed-off-by: xinyual <[email protected]>

* remove warm up

Signed-off-by: xinyual <[email protected]>

* try empty model

Signed-off-by: xinyual <[email protected]>

* add block

Signed-off-by: xinyual <[email protected]>

* add log

Signed-off-by: xinyual <[email protected]>

* add log

Signed-off-by: xinyual <[email protected]>

* add log

Signed-off-by: xinyual <[email protected]>

* remove log

Signed-off-by: xinyual <[email protected]>

* remove pt file detect

Signed-off-by: xinyual <[email protected]>

* add log

Signed-off-by: xinyual <[email protected]>

* add functionName pipeline

Signed-off-by: xinyual <[email protected]>

* remove verify log

Signed-off-by: xinyual <[email protected]>

* skip special token in sparse encoding

Signed-off-by: xinyual <[email protected]>

* skip omit tokenize config

Signed-off-by: xinyual <[email protected]>

* skip omit tokenize config-change warm up logic

Signed-off-by: xinyual <[email protected]>

* reArch

Signed-off-by: xinyual <[email protected]>

* deduplicate

Signed-off-by: xinyual <[email protected]>

* omit ml config in sparse encoding

Signed-off-by: xinyual <[email protected]>

* add null config in warm up

Signed-off-by: xinyual <[email protected]>

* fix original test

Signed-off-by: xinyual <[email protected]>

* add tokenize ut half

Signed-off-by: xinyual <[email protected]>

* fix sparse encoding bug

Signed-off-by: xinyual <[email protected]>

* add UT for sparse encoding and tokenize

Signed-off-by: xinyual <[email protected]>

* remove useless framwork type

Signed-off-by: xinyual <[email protected]>

* common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java

Signed-off-by: xinyual <[email protected]>

* change key for tokenize

Signed-off-by: xinyual <[email protected]>

* reArch DLModel

Signed-off-by: xinyual <[email protected]>

* reArch DLModel again

Signed-off-by: xinyual <[email protected]>

* response format

Signed-off-by: xinyual <[email protected]>

* tokenize only one output

Signed-off-by: xinyual <[email protected]>

* clean sparse output

Signed-off-by: xinyual <[email protected]>

* clean sparse output

Signed-off-by: xinyual <[email protected]>

* change UT number

Signed-off-by: xinyual <[email protected]>

* remove useless predict code

Signed-off-by: xinyual <[email protected]>

* remove useless part

Signed-off-by: xinyual <[email protected]>

* change tokenize way

Signed-off-by: xinyual <[email protected]>

* reArch add textEmbedding model

Signed-off-by: xinyual <[email protected]>

* add tokenize logic

Signed-off-by: xinyual <[email protected]>

* add abstract

Signed-off-by: xinyual <[email protected]>

* clear code

Signed-off-by: xinyual <[email protected]>

* fix it class

Signed-off-by: xinyual <[email protected]>

* fix it class

Signed-off-by: xinyual <[email protected]>

* add IT file

Signed-off-by: xinyual <[email protected]>

* reformulate

Signed-off-by: xinyual <[email protected]>

* reformulate remote inference

Signed-off-by: xinyual <[email protected]>

* reformulate remote inference

Signed-off-by: xinyual <[email protected]>

* reformulate remote inference json and array

Signed-off-by: xinyual <[email protected]>

* verify

Signed-off-by: xinyual <[email protected]>

* undo string utils

Signed-off-by: xinyual <[email protected]>

* skip dummy model

Signed-off-by: xinyual <[email protected]>

* skip dummy model

Signed-off-by: xinyual <[email protected]>

* skip dummy model

Signed-off-by: xinyual <[email protected]>

* skip dummy model

Signed-off-by: xinyual <[email protected]>

* skip dummy model

Signed-off-by: xinyual <[email protected]>

* skip dummy model

Signed-off-by: xinyual <[email protected]>

* add inner load Model

Signed-off-by: xinyual <[email protected]>

* rename variable

Signed-off-by: xinyual <[email protected]>

* add default for idf

Signed-off-by: xinyual <[email protected]>

* add ut for sparse encoding and tokenizer

Signed-off-by: xinyual <[email protected]>

* add close model

Signed-off-by: xinyual <[email protected]>

* change mock class

Signed-off-by: xinyual <[email protected]>

* remove buffer for sparse encoding output

Signed-off-by: xinyual <[email protected]>

* change tokenize model ready logic

Signed-off-by: xinyual <[email protected]>

* rewrite input functionName

Signed-off-by: xinyual <[email protected]>

* deduplicate

Signed-off-by: xinyual <[email protected]>

* change UT usage

Signed-off-by: xinyual <[email protected]>

* fix downloadAndSplit test

Signed-off-by: xinyual <[email protected]>

* fix Helper  test

Signed-off-by: xinyual <[email protected]>

* remove meaningless change

Signed-off-by: xinyual <[email protected]>

* remove complie change

Signed-off-by: xinyual <[email protected]>

* rename

Signed-off-by: xinyual <[email protected]>

* fix typo error and simplify wrap code

Signed-off-by: xinyual <[email protected]>

* add comment

Signed-off-by: xinyual <[email protected]>

* using gson and remove useless close logic

Signed-off-by: xinyual <[email protected]>

* update comment and import problem

Signed-off-by: xinyual <[email protected]>

* add static idf name

Signed-off-by: xinyual <[email protected]>

* fix format problem

Signed-off-by: xinyual <[email protected]>

* extract an abstract model for sparse and dense sentence transformer translator

Signed-off-by: xinyual <[email protected]>

* fix typo error

Signed-off-by: xinyual <[email protected]>

* remove duplicate tokenizer file, fix import problem and add comment for tokenizer model

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
(cherry picked from commit 31a4e25)

Co-authored-by: xinyual <[email protected]>
(cherry picked from commit 44946da)
  • Loading branch information
opensearch-trigger-bot[bot] authored and xinyual committed Sep 27, 2023
1 parent 547ef21 commit 29e27ed
Show file tree
Hide file tree
Showing 34 changed files with 1,121 additions and 245 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class CommonValue {
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
REMOTE;

Expand All @@ -33,7 +35,7 @@ public static FunctionName from(String value) {
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Optional;

import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

@Getter
Expand Down Expand Up @@ -101,7 +102,7 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
return;
}
if (response instanceof String && isJson((String)response)) {
Map<String, Object> data = StringUtils.fromJson((String) response, "response");
Map<String, Object> data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
} else {
Map<String, Object> map = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
}
}
MLInputDataset inputDataSet = null;
if (algorithm == FunctionName.TEXT_EMBEDDING) {
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) {
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public MLRegisterModelInput(FunctionName functionName,
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (url != null && modelConfig == null) {
if (url != null && modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration.
throw new IllegalArgumentException("model config is null");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
if (modelContentHashValue == null) {
throw new IllegalArgumentException("model content hash value is null");
}
if (modelConfig == null) {
if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration.
throw new IllegalArgumentException("model config is null");
}
if (totalChunks == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,27 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException {
assertArrayEquals(new long[]{1, 2}, metrics);
}

@Test
public void testClassLoader_MLInput() throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(FunctionName.TEXT_EMBEDDING));
private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(functionName));

String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(FunctionName.TEXT_EMBEDDING, new Object[]{parser, FunctionName.TEXT_EMBEDDING}, XContentParser.class, FunctionName.class);
TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class);
assertNotNull(mlInput);
assertEquals(FunctionName.TEXT_EMBEDDING, mlInput.getFunctionName());
assertEquals(functionName, mlInput.getFunctionName());
assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size());
}

@Test
public void testClassLoader_MLInput() throws IOException {
testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_TOKENIZE);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
}

public enum TestEnum {
TEST
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
Expand Down Expand Up @@ -68,11 +68,11 @@ public class MLInputTest {

@Before
public void setUp() throws Exception {
final ColumnMeta[] columnMetas = new ColumnMeta[]{new ColumnMeta("test", ColumnType.DOUBLE)};
final ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.DOUBLE) };
List<Row> rows = new ArrayList<>();
rows.add(new Row(new ColumnValue[]{new DoubleValue(1.0)}));
rows.add(new Row(new ColumnValue[]{new DoubleValue(2.0)}));
rows.add(new Row(new ColumnValue[]{new DoubleValue(3.0)}));
rows.add(new Row(new ColumnValue[] { new DoubleValue(1.0) }));
rows.add(new Row(new ColumnValue[] { new DoubleValue(2.0) }));
rows.add(new Row(new ColumnValue[] { new DoubleValue(3.0) }));
DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
input = MLInput.builder()
.algorithm(algorithm)
Expand All @@ -96,35 +96,39 @@ public void parse_LinearRegression() throws IOException {
.searchSourceBuilder(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1))
.build();
String expectedInputStr = "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_index\":[\"index1\"],\"input_query\":{\"size\":1,\"query\":{\"match_all\":{\"boost\":1.0}}}}";
testParse(FunctionName.LINEAR_REGRESSION, inputDataset, expectedInputStr, parsedInput -> {
testParse(FunctionName.LINEAR_REGRESSION, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
assertEquals(1, ((SearchQueryInputDataset)parsedInput.getInputDataset()).getIndices().size());
assertEquals(indexName, ((SearchQueryInputDataset)parsedInput.getInputDataset()).getIndices().get(0));
assertEquals(1, ((SearchQueryInputDataset) parsedInput.getInputDataset()).getIndices().size());
assertEquals(indexName, ((SearchQueryInputDataset) parsedInput.getInputDataset()).getIndices().get(0));
});

@NonNull DataFrame dataFrame = new DefaultDataFrame(new ColumnMeta[]{ColumnMeta.builder().name("value").columnType(ColumnType.FLOAT).build()});
dataFrame.appendRow(new Float[]{1.0f});
@NonNull
DataFrame dataFrame = new DefaultDataFrame(
new ColumnMeta[] { ColumnMeta.builder().name("value").columnType(ColumnType.FLOAT).build() });
dataFrame.appendRow(new Float[] { 1.0f });
DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder().dataFrame(dataFrame).build();
expectedInputStr = "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_data\":{\"column_metas\":[{\"name\":\"value\",\"column_type\":\"FLOAT\"}],\"rows\":[{\"values\":[{\"column_type\":\"FLOAT\",\"value\":1.0}]}]}}";
testParse(FunctionName.LINEAR_REGRESSION, dataFrameInputDataset, expectedInputStr, parsedInput -> {
testParse(FunctionName.LINEAR_REGRESSION, dataFrameInputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
assertEquals(1, ((DataFrameInputDataset)parsedInput.getInputDataset()).getDataFrame().size());
assertEquals(1.0f, ((DataFrameInputDataset)parsedInput.getInputDataset()).getDataFrame().getRow(0).getValue(0).floatValue(), 1e-5);
assertEquals(1, ((DataFrameInputDataset) parsedInput.getInputDataset()).getDataFrame().size());
assertEquals(1.0f, ((DataFrameInputDataset) parsedInput.getInputDataset()).getDataFrame().getRow(0)
.getValue(0).floatValue(), 1e-5);
});
}

@Test
public void parse_TextEmbedding() throws IOException {
private void parse_NLPModel(FunctionName functionName) throws IOException {
String sentence = "test sentence";
String column = "column1";
Integer position = 1;
ModelResultFilter resultFilter = ModelResultFilter.builder()
.targetResponse(Arrays.asList(column))
.targetResponsePositions(Arrays.asList(position))
.build();

TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
TextDocsInputDataSet parsedInputDataSet = (TextDocsInputDataSet) parsedInput.getInputDataset();
assertEquals(1, parsedInputDataSet.getDocs().size());
Expand All @@ -136,35 +140,50 @@ public void parse_TextEmbedding() throws IOException {
}

@Test
public void parse_TextEmbedding_NullResultFilter() throws IOException {
public void parse_NLP_Related() throws IOException {
parse_NLPModel(FunctionName.TEXT_EMBEDDING);
parse_NLPModel(FunctionName.SPARSE_TOKENIZE);
parse_NLPModel(FunctionName.SPARSE_ENCODING);
}

private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws IOException {
String sentence = "test sentence";
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
assertEquals(1, ((TextDocsInputDataSet)parsedInput.getInputDataset()).getDocs().size());
assertEquals(sentence, ((TextDocsInputDataSet)parsedInput.getInputDataset()).getDocs().get(0));
assertEquals(1, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().size());
assertEquals(sentence, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().get(0));
});
}


@Test
public void parse_NLPRelated_NullResultFilter() throws IOException {
parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_TOKENIZE);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = XContentFactory.jsonBuilder();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(expectedInputStr, jsonStr);

XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
XContentParser parser = XContentType.JSON.xContent()
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
MLInput parsedInput = MLInput.parse(parser, algorithm.name());
assertEquals(input.getFunctionName(), parsedInput.getFunctionName());
assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType());
verify.accept(parsedInput);
}


@Test
public void readInputStream_Success() throws IOException {
readInputStream(input, parsedInput -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ GET /_plugins/_ml/profile/models/zwla5YUB1qmVrJFlwzXJ
"models": {
"zwla5YUB1qmVrJFlwzXJ": { # model id
"model_state": "LOADED",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel@1a0b0793",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel@1a0b0793",
"target_worker_nodes": [ # plan to deploy model to these nodes
"0TLL4hHxRv6_G3n6y1l0BQ"
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re
* @param modelContentHash model content hash value
* @param listener action listener
*/
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
Expand All @@ -199,7 +199,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
File modelZipFile = new File(modelPath);
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
verifyModelZipFile(modelFormat, modelPath, modelName, functionName);
String hash = calculateFileHash(modelZipFile);
if (hash.equals(modelContentHash)) {
List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Expand All @@ -221,7 +221,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
}
}

public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName) throws IOException {
public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
boolean hasPtFile = false;
boolean hasOnnxFile = false;
boolean hasTokenizerFile = false;
Expand All @@ -236,7 +236,7 @@ public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePat
}
}
}
if (!hasPtFile && !hasOnnxFile) {
if (!hasPtFile && !hasOnnxFile && functionName != FunctionName.SPARSE_TOKENIZE) { // sparse tokenizer model doesn't need model file.
throw new IllegalArgumentException("Can't find model file");
}
if (!hasTokenizerFile) {
Expand Down
Loading

0 comments on commit 29e27ed

Please sign in to comment.