Skip to content

Commit

Permalink
tune model config: change pooling mode to optional (opensearch-projec…
Browse files Browse the repository at this point in the history
…t#724)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Mar 3, 2023
1 parent bcaf3c0 commit 74170fc
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
}
this.embeddingDimension = embeddingDimension;
this.frameworkType = frameworkType;
if (poolingMode != null) {
this.poolingMode = poolingMode;
} else {
this.poolingMode = PoolingMode.MEAN;
}
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
}
Expand All @@ -69,7 +65,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
Integer embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMode poolingMode = PoolingMode.MEAN;
PoolingMode poolingMode = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;

Expand Down Expand Up @@ -117,7 +113,11 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
super(in);
embeddingDimension = in.readInt();
frameworkType = in.readEnum(FrameworkType.class);
poolingMode = in.readEnum(PoolingMode.class);
if (in.readBoolean()) {
poolingMode = in.readEnum(PoolingMode.class);
} else {
poolingMode = null;
}
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
}
Expand All @@ -127,7 +127,12 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeInt(embeddingDimension);
out.writeEnum(frameworkType);
out.writeEnum(poolingMode);
if (poolingMode != null) {
out.writeBoolean(true);
out.writeEnum(poolingMode);
} else {
out.writeBoolean(false);
}
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
}
Expand All @@ -150,8 +155,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (modelMaxLength != null) {
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);
}
builder.field(POOLING_MODE_FIELD, poolingMode);
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
if (poolingMode != null) {
builder.field(POOLING_MODE_FIELD, poolingMode);
}
if (normalizeResult) {
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void toXContent() throws IOException {
config.toXContent(builder, EMPTY_PARAMS);
String configContent = TestHelper.xContentBuilderToString(builder);
System.out.println(configContent);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"pooling_mode\":\"MEAN\",\"normalize_result\":false}", configContent);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class MLUploadInputTest {
private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," +
"\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," +
"\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" +
",\"pooling_mode\":\"MEAN\",\"normalize_result\":false},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
"},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
private final FunctionName functionName = FunctionName.LINEAR_REGRESSION;
private final String modelName = "modelName";
private final String version = "version";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class HuggingfaceTextEmbeddingTranslatorFactory implements TranslatorFact
private final boolean neuron;

public HuggingfaceTextEmbeddingTranslatorFactory(TextEmbeddingModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType, boolean neuron) {
this.poolingMode = poolingMode;
this.poolingMode = poolingMode == null ? TextEmbeddingModelConfig.PoolingMode.MEAN : poolingMode;
this.normalizeResult = normalizeResult;
this.modelType = modelType;
this.neuron = neuron;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class ONNXSentenceTransformerTextEmbeddingTranslator implements ServingTr
private String modelType;

public ONNXSentenceTransformerTextEmbeddingTranslator(TextEmbeddingModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType) {
this.poolingMode = poolingMode;
this.poolingMode = poolingMode == null ? TextEmbeddingModelConfig.PoolingMode.MEAN : poolingMode;
this.normalizeResult = normalizeResult;
this.modelType = modelType;
}
Expand Down

0 comments on commit 74170fc

Please sign in to comment.