Skip to content

Commit

Permalink
rename poolingMethod to poolingMode
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jan 9, 2023
1 parent 7f23bbe commit 25a0d07
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ public class TextEmbeddingModelConfig extends MLModelConfig {

private final Integer embeddingDimension;
private final FrameworkType frameworkType;
private final PoolingMethod poolingMode;
private final PoolingMode poolingMode;
private final boolean normalizeResult;
private final Integer modelMaxLength;

@Builder(toBuilder = true)
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMethod poolingMode, boolean normalizeResult, Integer modelMaxLength) {
PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
super(modelType, allConfig);
if (embeddingDimension == null) {
throw new IllegalArgumentException("embedding dimension is null");
Expand All @@ -58,7 +58,7 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
if (poolingMode != null) {
this.poolingMode = poolingMode;
} else {
this.poolingMode = PoolingMethod.MEAN;
this.poolingMode = PoolingMode.MEAN;
}
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
Expand All @@ -69,7 +69,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
Integer embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMethod poolingMethod = PoolingMethod.MEAN;
PoolingMode poolingMode = PoolingMode.MEAN;
boolean normalizeResult = false;
Integer modelMaxLength = null;

Expand All @@ -92,7 +92,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
allConfig = parser.text();
break;
case POOLING_MODE_FIELD:
poolingMethod = PoolingMethod.from(parser.text().toUpperCase(Locale.ROOT));
poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT));
break;
case NORMALIZE_RESULT_FIELD:
normalizeResult = parser.booleanValue();
Expand All @@ -105,7 +105,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
break;
}
}
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMethod, normalizeResult, modelMaxLength);
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength);
}

@Override
Expand All @@ -117,7 +117,7 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
super(in);
embeddingDimension = in.readInt();
frameworkType = in.readEnum(FrameworkType.class);
poolingMode = in.readEnum(PoolingMethod.class);
poolingMode = in.readEnum(PoolingMode.class);
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
}
Expand Down Expand Up @@ -156,7 +156,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public enum PoolingMethod {
public enum PoolingMode {
MEAN("mean"),
MEAN_SQRT_LEN("mean_sqrt_len"),
MAX("max"),
Expand All @@ -169,13 +169,13 @@ public enum PoolingMethod {
public String getName() {
return name;
}
PoolingMethod(String name) {
PoolingMode(String name) {
this.name = name;
}

public static PoolingMethod from(String value) {
public static PoolingMode from(String value) {
try {
return PoolingMethod.valueOf(value);
return PoolingMode.valueOf(value);
} catch (Exception e) {
throw new IllegalArgumentException("Wrong pooling method");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class MLCreateModelMetaInputTest {
@Before
public void setup() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMethod.MEAN, true, 512);
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mLCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.LOADING, 200L, "123", config, 2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class MLCreateModelMetaRequestTest {
@Before
public void setUp() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMethod.MEAN, true, 512);
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mlCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.LOADING, 200L, "123", config, 2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput
configBuilder.frameworkType(TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString()));
break;
case TextEmbeddingModelConfig.POOLING_MODE_FIELD:
configBuilder.poolingMode(TextEmbeddingModelConfig.PoolingMethod.from(configEntry.getValue().toString()));
configBuilder.poolingMode(TextEmbeddingModelConfig.PoolingMode.from(configEntry.getValue().toString()));
break;
case TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD:
configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ public class HuggingfaceTextEmbeddingTranslatorFactory implements TranslatorFact
SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
}

private final TextEmbeddingModelConfig.PoolingMethod poolingMode;
private final TextEmbeddingModelConfig.PoolingMode poolingMode;
private boolean normalizeResult;
private final String modelType;
private final boolean neuron;

public HuggingfaceTextEmbeddingTranslatorFactory(TextEmbeddingModelConfig.PoolingMethod poolingMode, boolean normalizeResult, String modelType, boolean neuron) {
public HuggingfaceTextEmbeddingTranslatorFactory(TextEmbeddingModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType, boolean neuron) {
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelType = modelType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
public class ONNXSentenceTransformerTextEmbeddingTranslator implements ServingTranslator {
private static final int[] AXIS = {0};
private HuggingFaceTokenizer tokenizer;
private TextEmbeddingModelConfig.PoolingMethod poolingMethod;
private TextEmbeddingModelConfig.PoolingMode poolingMode;
private boolean normalizeResult;
private String modelType;

public ONNXSentenceTransformerTextEmbeddingTranslator(TextEmbeddingModelConfig.PoolingMethod poolingMethod, boolean normalizeResult, String modelType) {
this.poolingMethod = poolingMethod;
public ONNXSentenceTransformerTextEmbeddingTranslator(TextEmbeddingModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType) {
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelType = modelType;
}
Expand Down Expand Up @@ -90,7 +90,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
long[] attentionMask = encoding.getAttentionMask();
NDManager manager = ctx.getNDManager();
NDArray inputAttentionMask = manager.create(attentionMask);
switch (this.poolingMethod) {
switch (this.poolingMode) {
case MEAN:
embeddings = meanPool(embeddings, inputAttentionMask, false);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,15 @@ protected void loadTextEmbeddingModel(File modelZipFile, String modelId, String
TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig;
TextEmbeddingModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType();
String modelType = textEmbeddingModelConfig.getModelType();
TextEmbeddingModelConfig.PoolingMethod poolingMethod = textEmbeddingModelConfig.getPoolingMode();
TextEmbeddingModelConfig.PoolingMode poolingMode = textEmbeddingModelConfig.getPoolingMode();
boolean normalizeResult = textEmbeddingModelConfig.isNormalizeResult();
Integer modelMaxLength = textEmbeddingModelConfig.getModelMaxLength();
if (modelMaxLength != null) {
criteriaBuilder.optArgument("modelMaxLength", modelMaxLength);
}
//TODO: refactor this when we support more engine type
if (ONNX_ENGINE.equals(engine)) { //ONNX
criteriaBuilder.optTranslator(new ONNXSentenceTransformerTextEmbeddingTranslator(poolingMethod, normalizeResult, modelType));
criteriaBuilder.optTranslator(new ONNXSentenceTransformerTextEmbeddingTranslator(poolingMode, normalizeResult, modelType));
} else { // pytorch
if (transformersType == SENTENCE_TRANSFORMERS) {
criteriaBuilder.optTranslator(new SentenceTransformerTextEmbeddingTranslator());
Expand All @@ -205,7 +205,7 @@ protected void loadTextEmbeddingModel(File modelZipFile, String modelId, String
if (transformersType.name().endsWith("_NEURON")) {
neuron = true;
}
criteriaBuilder.optTranslatorFactory(new HuggingfaceTextEmbeddingTranslatorFactory(poolingMethod, normalizeResult, modelType, neuron));
criteriaBuilder.optTranslatorFactory(new HuggingfaceTextEmbeddingTranslatorFactory(poolingMode, normalizeResult, modelType, neuron));
}
}
Criteria<Input, Output> criteria = criteriaBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,36 +162,36 @@ public void initModel_predict_TorchScript_SentenceTransformer_ResultFilter() {
public void initModel_predict_TorchScript_Huggingface() throws URISyntaxException {
String modelFile = "all-MiniLM-L6-v2_torchscript_huggingface.zip";
String modelType = "bert";
TextEmbeddingModelConfig.PoolingMethod poolingMethod = TextEmbeddingModelConfig.PoolingMethod.MEAN;
TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.MEAN;
boolean normalize = true;
int modelMaxLength = 512;
MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT;
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMethod, normalize, modelMaxLength, modelFormat, dimension);
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
}

@Test
public void initModel_predict_ONNX_bert() throws URISyntaxException {
String modelFile = "all-MiniLM-L6-v2_onnx.zip";
String modelType = "bert";
TextEmbeddingModelConfig.PoolingMethod poolingMethod = TextEmbeddingModelConfig.PoolingMethod.MEAN;
TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.MEAN;
boolean normalize = true;
int modelMaxLength = 512;
MLModelFormat modelFormat = MLModelFormat.ONNX;
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMethod, normalize, modelMaxLength, modelFormat, dimension);
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
}

@Test
public void initModel_predict_ONNX_albert() throws URISyntaxException {
String modelFile = "paraphrase-albert-small-v2_onnx.zip";
String modelType = "albert";
TextEmbeddingModelConfig.PoolingMethod poolingMethod = TextEmbeddingModelConfig.PoolingMethod.MEAN;
TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.MEAN;
boolean normalize = false;
int modelMaxLength = 512;
MLModelFormat modelFormat = MLModelFormat.ONNX;
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMethod, normalize, modelMaxLength, modelFormat, 768);
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, 768);
}

private void initModel_predict_HuggingfaceModel(String modelFile, String modelType, TextEmbeddingModelConfig.PoolingMethod poolingMethod,
private void initModel_predict_HuggingfaceModel(String modelFile, String modelType, TextEmbeddingModelConfig.PoolingMode poolingMode,
boolean normalizeResult, Integer modelMaxLength,
MLModelFormat modelFormat, int dimension) throws URISyntaxException {
Map<String, Object> params = new HashMap<>();
Expand All @@ -201,7 +201,7 @@ private void initModel_predict_HuggingfaceModel(String modelFile, String modelTy
TextEmbeddingModelConfig onnxModelConfig = modelConfig.toBuilder()
.frameworkType(HUGGINGFACE_TRANSFORMERS)
.modelType(modelType)
.poolingMode(poolingMethod)
.poolingMode(poolingMode)
.normalizeResult(normalizeResult)
.modelMaxLength(modelMaxLength)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ private MLUploadInput prepareInput() {
123,
TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS,
"all config",
TextEmbeddingModelConfig.PoolingMethod.MEAN,
TextEmbeddingModelConfig.PoolingMode.MEAN,
true,
512
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ private MLUploadModelRequest prepareRequest(String url) {
123,
TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS,
"all config",
TextEmbeddingModelConfig.PoolingMethod.MEAN,
TextEmbeddingModelConfig.PoolingMode.MEAN,
true,
512
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private MLCreateModelMetaInput prepareRequest() {
123,
FrameworkType.SENTENCE_TRANSFORMERS,
"all config",
TextEmbeddingModelConfig.PoolingMethod.MEAN,
TextEmbeddingModelConfig.PoolingMode.MEAN,
true,
512
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private MLCreateModelMetaRequest prepareRequest() {
123,
FrameworkType.SENTENCE_TRANSFORMERS,
"all config",
TextEmbeddingModelConfig.PoolingMethod.MEAN,
TextEmbeddingModelConfig.PoolingMode.MEAN,
true,
512
)
Expand Down

0 comments on commit 25a0d07

Please sign in to comment.