diff --git a/comps/embeddings/predictionguard/embedding_predictionguard.py b/comps/embeddings/predictionguard/embedding_predictionguard.py index e77749fbe..793f35d2d 100644 --- a/comps/embeddings/predictionguard/embedding_predictionguard.py +++ b/comps/embeddings/predictionguard/embedding_predictionguard.py @@ -5,6 +5,7 @@ import os import time from typing import List, Optional, Union + from predictionguard import PredictionGuard from comps import ( @@ -72,12 +73,14 @@ async def embedding( logger.info(res) return res + async def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]: texts = [text] if isinstance(text, str) else text response = client.embeddings.create(model=pg_embedding_model_name, input=texts)["data"] embed_vector = [response[i]["embedding"] for i in range(len(response))] return embed_vector + if __name__ == "__main__": pg_embedding_model_name = os.getenv("PG_EMBEDDING_MODEL_NAME", "bridgetower-large-itm-mlm-itc") print("Prediction Guard Embedding initialized.")