Skip to content

Commit

Permalink
Search: increase embedding context to 5 messages
Browse files Browse the repository at this point in the history
  • Loading branch information
YuraLukashik committed Nov 21, 2023
1 parent 5fceabe commit 1961f5e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 25 deletions.
2 changes: 2 additions & 0 deletions src/semantic_search/semantic_search/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

load_dotenv()

CONTEXT_LENGTH = 5


def get_openai_key() -> str:
return os.environ.get('OPENAI_API_KEY')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from slack_sdk import WebClient

from .internal_api import get_team_data
from ..config import CONTEXT_LENGTH

slack_names = {}

Expand Down Expand Up @@ -113,12 +114,12 @@ def slack_names_map(team_id):


def load_previous_messages(team_id: str, channel_id: str, last_message_id: str, number: int):
(messages, _) = load_previous_messages_with_pointer(team_id, channel_id, last_message_id, number)
(messages, ) = load_previous_messages_with_pointer(team_id, channel_id, last_message_id, number)
return messages[-number:]


def load_subsequent_messages(team_id: str, channel_id: str, first_message_id: str, number: int):
(messages, _) = load_subsequent_messages_with_pointer(team_id, channel_id, first_message_id, number)
(messages, ) = load_subsequent_messages_with_pointer(team_id, channel_id, first_message_id, number)
return messages[:number]


Expand All @@ -129,11 +130,11 @@ def load_previous_messages_with_pointer(team_id: str, channel_id: str, last_mess
messages = fetch_several_messages_before(team_id, channel_id, last_message_id, bulk_size)
actual_messages = filter_messages(messages)
if len(messages) < bulk_size:
return [actual_messages, None]
return [actual_messages, None, 0]
if len(actual_messages) < minimum_number:
return load_previous_messages_with_pointer(team_id, channel_id, last_message_id, minimum_number, bulk_size * 2)
pointer = actual_messages[0]['ts'] if len(actual_messages) == 1 else actual_messages[1]['ts']
return [actual_messages, pointer]
context_tail = actual_messages[:CONTEXT_LENGTH - 1]
return [actual_messages, context_tail[-1]['ts'], len(context_tail)]


def load_subsequent_messages_with_pointer(team_id: str, channel_id: str, first_message_id: str, minimum_number: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ def handle_task():
namespace = payload['namespace']
channel_id = payload['channel_id']
last_message_id = payload['last_message_id']
[messages, next_last_message] = load_previous_messages_with_pointer(namespace, channel_id, last_message_id, BULK_SIZE)
[messages, next_last_message, start_from] = load_previous_messages_with_pointer(namespace, channel_id, last_message_id, BULK_SIZE)
logging.info(f"Task: {task_id}, Iteration Number: {iteration_number}")
logging.info(f"Task: {task_id}, Number of Actual Messages: {len(messages)}")
start_from = 0 if next_last_message is None else 2
index_messages(channel_id, messages, start_from, get_pinecone_index(), namespace)
if next_last_message is not None:
queue_task({
Expand Down
31 changes: 13 additions & 18 deletions src/semantic_search/semantic_search/load_messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import copy
from typing import List, Dict

from .config import CONTEXT_LENGTH
from .external_services.pinecone import get_pinecone_index
from .external_services.openai import create_embeddings, gpt_summarize_thread
import datetime
Expand Down Expand Up @@ -101,12 +103,8 @@ def replace_ids_with_names(embeddings: List[Embedding], team_id: str) -> List[Em
def enrich_with_adjacent_messages(embeddings: List[Embedding]) -> List[Embedding]:
updated_embeddings = []
for i in range(len(embeddings)):
first_context_embedding = embeddings[i - 2] if i - 2 >= 0 else None
second_context_embedding = embeddings[i - 1] if i - 1 >= 0 else None
updated_embeddings.append(embeddings[i].add_adjacent_messages_context([
first_context_embedding,
second_context_embedding
]))
context_embeddings = embeddings[max(0, i - CONTEXT_LENGTH + 1): i]
updated_embeddings.append(embeddings[i].add_adjacent_messages_context(context_embeddings))
return updated_embeddings


Expand All @@ -115,8 +113,8 @@ def enrich_with_datetime(embeddings: List[Embedding]) -> List[Embedding]:


def attach_header(embeddings: List[Embedding], header: Embedding) -> List[Embedding]:
part_with_header = embeddings[:2]
part_without_header = embeddings[2:]
part_with_header = embeddings[:CONTEXT_LENGTH - 1]
part_without_header = embeddings[CONTEXT_LENGTH - 1:]
return part_with_header + [embedding.add_header(header) for embedding in part_without_header]


Expand Down Expand Up @@ -144,10 +142,7 @@ def index_messages(channel_id, messages, start_from, pinecone_index, pinecone_na
thread_header = thread_embeddings[0]
raw_messages_for_summary = list(map(lambda e: e.text, thread_embeddings))

additional_context_embeddings = [
embeddings_without_context[counter - 2] if counter >= 2 else None,
embeddings_without_context[counter - 1] if counter >= 1 else None,
]
additional_context_embeddings = embeddings_without_context[max(0, counter - CONTEXT_LENGTH + 1):counter]
additional_context_embeddings = list(filter(None, additional_context_embeddings))
thread_embeddings = additional_context_embeddings + thread_embeddings
thread_embeddings = enrich_with_adjacent_messages(thread_embeddings)[len(additional_context_embeddings):]
Expand Down Expand Up @@ -235,9 +230,9 @@ def handle_message_update_and_reindex(body):
index_messages(channel_id, load_previous_messages(team_id, channel_id, message.get('thread_ts'), 1), 0, get_pinecone_index(), team_id)
return
message_ts = message['ts']
messages_for_reindex = load_previous_messages(team_id, channel_id, message_ts, 2) + load_subsequent_messages(team_id, channel_id, message_ts, 2)
messages_for_reindex = load_previous_messages(team_id, channel_id, message_ts, CONTEXT_LENGTH - 1) + load_subsequent_messages(team_id, channel_id, message_ts, CONTEXT_LENGTH - 1)
# reindex surrounding messages
index_messages(channel_id, messages_for_reindex, 2, get_pinecone_index(), team_id)
index_messages(channel_id, messages_for_reindex, CONTEXT_LENGTH - 1, get_pinecone_index(), team_id)
return
if 'subtype' in event and event['subtype'] == 'message_changed':
# processing a message update
Expand All @@ -250,9 +245,9 @@ def handle_message_update_and_reindex(body):
index_messages(channel_id, load_previous_messages(team_id, channel_id, message.get('thread_ts'), 1), 0, get_pinecone_index(), team_id)
return
message_ts = message['ts']
messages_for_reindex = load_previous_messages(team_id, channel_id, message_ts, 3) + load_subsequent_messages(team_id, channel_id, message_ts, 3)[1:]
messages_for_reindex = load_previous_messages(team_id, channel_id, message_ts, CONTEXT_LENGTH) + load_subsequent_messages(team_id, channel_id, message_ts, CONTEXT_LENGTH)[1:]
# reindex surrounding messages
index_messages(channel_id, messages_for_reindex, 2, get_pinecone_index(), team_id)
index_messages(channel_id, messages_for_reindex, CONTEXT_LENGTH - 1, get_pinecone_index(), team_id)
return
if 'subtype' not in event:
message = event
Expand Down Expand Up @@ -289,15 +284,15 @@ def generate_embedding_for_message(team_id, channel_id, message_id, thread_ts) -
if message_index is None:
return []
messages = thread_messages[:message_index + 1]
additional_messages = [] if len(messages) > 2 else load_previous_messages(team_id, channel_id, thread_head['ts'], 4 - len(messages))[:-1]
additional_messages = [] if len(messages) >= CONTEXT_LENGTH else load_previous_messages(team_id, channel_id, thread_head['ts'], CONTEXT_LENGTH + 1 - len(messages))[:-1]
embeddings = generate_embeddings(channel_id, additional_messages) + generate_embeddings(channel_id, messages)
embeddings = replace_ids_with_names(embeddings, team_id)
embeddings = enrich_with_datetime(embeddings)
embeddings = enrich_with_adjacent_messages(embeddings)[len(additional_messages):]
embeddings = attach_header(embeddings, head_embedding)
return embeddings[-1:]

embeddings = generate_embeddings(channel_id, load_previous_messages(team_id, channel_id, message_id, 3))
embeddings = generate_embeddings(channel_id, load_previous_messages(team_id, channel_id, message_id, CONTEXT_LENGTH))
embeddings = replace_ids_with_names(embeddings, team_id)
embeddings = enrich_with_datetime(embeddings)
embeddings = enrich_with_adjacent_messages(embeddings)
Expand Down

0 comments on commit 1961f5e

Please sign in to comment.