Skip to content

Commit

Permalink
Make superbooga & superboogav2 functional again (oobabooga#5656)
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga authored and bartowski1182 committed Mar 23, 2024
1 parent 1c8330e commit 2cbfd5a
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 253 deletions.
52 changes: 12 additions & 40 deletions extensions/superbooga/chromadb.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,24 @@
import random

import chromadb
import posthog
import torch
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer

from modules.logging_colors import logger
from chromadb.utils import embedding_functions

logger.info('Intercepting all calls to posthog :)')
# Intercept calls to posthog
posthog.capture = lambda *args, **kwargs: None


class Collecter():
def __init__(self):
pass

def add(self, texts: list[str]):
pass

def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")

def clear(self):
pass


class Embedder():
class ChromaCollector():
def __init__(self):
pass

def embed(self, text: str) -> list[torch.Tensor]:
pass
name = ''.join(random.choice('ab') for _ in range(10))


class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder):
super().__init__()
self.name = name
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
self.ids = []

def add(self, texts: list[str]):
Expand Down Expand Up @@ -102,24 +83,15 @@ def get_ids_sorted(self, search_strings: list[str], n_results: int, n_initial: i
return sorted(ids)

def clear(self):
self.collection.delete(ids=self.ids)
self.ids = []


class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
self.chroma_client.delete_collection(name=self.name)
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)


def make_collector():
global embedder
return ChromaCollector(embedder)
return ChromaCollector()


def add_chunks_to_collector(chunks, collector):
collector.clear()
collector.add(chunks)


embedder = SentenceTransformerEmbedder()
2 changes: 1 addition & 1 deletion extensions/superbooga/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
beautifulsoup4==4.12.2
chromadb==0.3.18
chromadb==0.4.24
pandas==2.0.3
posthog==2.4.2
sentence_transformers==2.2.2
Expand Down
33 changes: 11 additions & 22 deletions extensions/superboogav2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@

import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlparse, parse_qs
from threading import Thread
from urllib.parse import parse_qs, urlparse

import extensions.superboogav2.parameters as parameters
from modules import shared
from modules.logging_colors import logger

from .chromadb import ChromaCollector
from .data_processor import process_and_add_to_collector

import extensions.superboogav2.parameters as parameters


class CustomThreadingHTTPServer(ThreadingHTTPServer):
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
Expand All @@ -38,30 +37,26 @@ def __init__(self, request, client_address, server, collector: ChromaCollector):
self.collector = collector
super().__init__(request, client_address, server)


def _send_412_error(self, message):
self.send_response(412)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": message})
self.wfile.write(response.encode('utf-8'))


def _send_404_error(self):
self.send_response(404)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": "Resource not found"})
self.wfile.write(response.encode('utf-8'))


def _send_400_error(self, error_message: str):
self.send_response(400)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": error_message})
self.wfile.write(response.encode('utf-8'))


def _send_200_response(self, message: str):
self.send_response(200)
Expand All @@ -75,24 +70,21 @@ def _send_200_response(self, message: str):

self.wfile.write(response.encode('utf-8'))


def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
if sort_param == parameters.SORT_DISTANCE:
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
elif sort_param == parameters.SORT_ID:
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
else: # Default is dist
else: # Default is dist
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)

return {
"results": results
}


def do_GET(self):
self._send_404_error()


def do_POST(self):
try:
content_length = int(self.headers['Content-Length'])
Expand All @@ -107,7 +99,7 @@ def do_POST(self):
if corpus is None:
self._send_412_error("Missing parameter 'corpus'")
return

clear_before_adding = body.get('clear_before_adding', False)
metadata = body.get('metadata')
process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata)
Expand All @@ -118,7 +110,7 @@ def do_POST(self):
if corpus is None:
self._send_412_error("Missing parameter 'metadata'")
return

self.collector.delete(ids_to_delete=None, where=metadata)
self._send_200_response("Data successfully deleted")

Expand All @@ -127,15 +119,15 @@ def do_POST(self):
if search_strings is None:
self._send_412_error("Missing parameter 'search_strings'")
return

n_results = body.get('n_results')
if n_results is None:
n_results = parameters.get_chunk_count()

max_token_count = body.get('max_token_count')
if max_token_count is None:
max_token_count = parameters.get_max_token_count()

sort_param = query_params.get('sort', ['distance'])[0]

results = self._handle_get(search_strings, n_results, max_token_count, sort_param)
Expand All @@ -146,7 +138,6 @@ def do_POST(self):
except Exception as e:
self._send_400_error(str(e))


def do_DELETE(self):
try:
parsed_path = urlparse(self.path)
Expand All @@ -161,12 +152,10 @@ def do_DELETE(self):
except Exception as e:
self._send_400_error(str(e))


def do_OPTIONS(self):
self.send_response(200)
self.end_headers()


def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
Expand Down Expand Up @@ -197,11 +186,11 @@ def start_server(self, port: int):

def stop_server(self):
if self.server is not None:
logger.info(f'Stopping chromaDB API.')
logger.info('Stopping chromaDB API.')
self.server.shutdown()
self.server.server_close()
self.server = None
self.is_running = False

def is_server_running(self):
return self.is_running
return self.is_running
14 changes: 7 additions & 7 deletions extensions/superboogav2/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@
import datetime
import json
import os

from pathlib import Path

from .data_processor import process_and_add_to_collector, preprocess_text
from .data_processor import preprocess_text, process_and_add_to_collector
from .parameters import get_chunk_count, get_max_token_count
from .utils import create_metadata_source


def benchmark(config_path, collector):
# Get the current system date
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"benchmark_{sysdate}.txt"

# Open the log file in append mode
with open(filename, 'a') as log:
with open(config_path, 'r') as f:
data = json.load(f)

total_points = 0
max_points = 0

Expand All @@ -45,7 +45,7 @@ def benchmark(config_path, collector):
for question_group in item["questions"]:
question_variants = question_group["question_variants"]
criteria = question_group["criteria"]

for q in question_variants:
max_points += len(criteria)
processed_text = preprocess_text(q)
Expand All @@ -54,7 +54,7 @@ def benchmark(config_path, collector):
results = collector.get_sorted_by_dist(processed_text, n_results=get_chunk_count(), max_token_count=get_max_token_count())

points = 0

for c in criteria:
for p in results:
if c in p:
Expand All @@ -69,4 +69,4 @@ def benchmark(config_path, collector):

print(f'##Total points:\n\n{total_points}/{max_points}', file=log)

return total_points, max_points
return total_points, max_points
31 changes: 16 additions & 15 deletions extensions/superboogav2/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,35 @@
import re

import extensions.superboogav2.parameters as parameters

from extensions.superboogav2.utils import (
create_context_text,
create_metadata_source
)
from modules import chat, shared
from modules.text_generation import get_encoded_length
from modules.logging_colors import logger
from modules.chat import load_character_memoized
from extensions.superboogav2.utils import create_context_text, create_metadata_source
from modules.logging_colors import logger
from modules.text_generation import get_encoded_length

from .data_processor import process_and_add_to_collector
from .chromadb import ChromaCollector

from .data_processor import process_and_add_to_collector

CHAT_METADATA = create_metadata_source('automatic-chat-insert')


def _remove_tag_if_necessary(user_input: str):
if not parameters.get_is_manual():
return user_input

return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input)


def _should_query(input: str):
if not parameters.get_is_manual():
return True

if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE):
return True

return False


Expand Down Expand Up @@ -69,7 +70,7 @@ def _concatinate_history(history: dict, state: dict):
if len(exchange) >= 2:
full_history_text += _format_single_exchange(bot_name, exchange[1])

return full_history_text[:-1] # Remove the last new line.
return full_history_text[:-1] # Remove the last new line.


def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
Expand All @@ -82,20 +83,20 @@ def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
for i, messages in enumerate(reversed(history['internal'])):
for j, message in enumerate(reversed(messages)):
num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message))

# TODO: This is an extremely naive solution. A more robust implementation must be made.
if history_tokens + num_context_tokens <= max_len:
# This message can be replaced
replace_position = (i, j)

history_tokens += num_message_tokens

if replace_position is None:
logger.warn("The provided context_text is too long to replace any message in the history.")
else:
# replace the message at replace_position with context_text
i, j = replace_position
history['internal'][-i-1][-j-1] = context_text
history['internal'][-i - 1][-j - 1] = context_text


def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
Expand All @@ -120,5 +121,5 @@ def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector
user_input = create_context_text(results) + user_input
elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT:
_hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state)

return chat.generate_chat_prompt(user_input, state, **kwargs)
Loading

0 comments on commit 2cbfd5a

Please sign in to comment.