From 54ba9dbeaad423a789dd5a4ee65d39ad8b089cae Mon Sep 17 00:00:00 2001 From: TimAdams84 Date: Thu, 13 Jun 2024 17:01:00 +0200 Subject: [PATCH] Fix: Add chunking for GPT4Adapter requests bigger than 2048 tokens --- datastew/embedding.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/datastew/embedding.py b/datastew/embedding.py index bf6250a..3d4875a 100644 --- a/datastew/embedding.py +++ b/datastew/embedding.py @@ -23,7 +23,7 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"): logging.info(f"Getting embedding for {text}") try: if text is None or text == "" or text is np.nan: - logging.warn(f"Empty text passed to get_embedding") + logging.warning(f"Empty text passed to get_embedding") return None if isinstance(text, str): text = text.replace("\n", " ") @@ -32,10 +32,18 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"): logging.error(f"Error getting embedding for {text}: {e}") return None - def get_embeddings(self, messages: [str], model="text-embedding-ada-002"): - # store index of nan entries - response = openai.Embedding.create(input=messages, model=model) - return [item["embedding"] for item in response["data"]] + def get_embeddings(self, messages: [str], model="text-embedding-ada-002", max_chunk_length=2048): + embeddings = [] + for message in messages: + if len(message) <= max_chunk_length: + embeddings.append(self.get_embedding(message, model)) + else: + # Split message into chunks + chunks = [message[i:i+max_chunk_length] for i in range(0, len(message), max_chunk_length)] + for idx, chunk in enumerate(chunks): + logging.info(f'Processing chunk {idx}/{len(chunks)}') + embeddings.append(self.get_embedding(chunk, model)) + return embeddings class MPNetAdapter(EmbeddingModel):