From f164f0d7768c7f2463e11679785b9c7d7e93a19c Mon Sep 17 00:00:00 2001 From: Jincheng Miao Date: Mon, 15 Jul 2024 12:02:06 +0800 Subject: [PATCH] embeddings: adaptive detect embedding model arguments in mosec (#296) * embeddings: adaptive detect embedding model arguments in mosec Signed-off-by: Jincheng Miao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jincheng Miao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../langchain-mosec/mosec-docker/server-ipex.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comps/embeddings/langchain-mosec/mosec-docker/server-ipex.py b/comps/embeddings/langchain-mosec/mosec-docker/server-ipex.py index 03d926f6b..05c1c63f3 100644 --- a/comps/embeddings/langchain-mosec/mosec-docker/server-ipex.py +++ b/comps/embeddings/langchain-mosec/mosec-docker/server-ipex.py @@ -34,9 +34,14 @@ def __init__(self): d = torch.randint(vocab_size, size=[batch_size, seq_length]) t = torch.randint(0, 1, size=[batch_size, seq_length]) m = torch.randint(1, 2, size=[batch_size, seq_length]) - self.model = torch.jit.trace(self.model, [d, t, m], check_trace=False, strict=False) + model_inputs = [d] + if "token_type_ids" in self.tokenizer.model_input_names: + model_inputs.append(t) + if "attention_mask" in self.tokenizer.model_input_names: + model_inputs.append(m) + self.model = torch.jit.trace(self.model, model_inputs, check_trace=False, strict=False) self.model = torch.jit.freeze(self.model) - self.model(d, t, m) + self.model(*model_inputs) def get_embedding_with_token_count(self, sentences: Union[str, List[Union[str, List[int]]]]): # Mean Pooling - Take attention mask into account for correct averaging