Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make superbooga & superboogav2 functional again #5656

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 12 additions & 40 deletions extensions/superbooga/chromadb.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,24 @@
import random

import chromadb
import posthog
import torch
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer

from modules.logging_colors import logger
from chromadb.utils import embedding_functions

logger.info('Intercepting all calls to posthog :)')
# Intercept calls to posthog
posthog.capture = lambda *args, **kwargs: None


class Collecter():
def __init__(self):
pass

def add(self, texts: list[str]):
pass

def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")

def clear(self):
pass


class Embedder():
class ChromaCollector():
def __init__(self):
pass

def embed(self, text: str) -> list[torch.Tensor]:
pass
name = ''.join(random.choice('ab') for _ in range(10))


class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder):
super().__init__()
self.name = name
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
self.ids = []

def add(self, texts: list[str]):
Expand Down Expand Up @@ -102,24 +83,15 @@ def get_ids_sorted(self, search_strings: list[str], n_results: int, n_initial: i
return sorted(ids)

def clear(self):
self.collection.delete(ids=self.ids)
self.ids = []


class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
self.chroma_client.delete_collection(name=self.name)
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)


def make_collector():
global embedder
return ChromaCollector(embedder)
return ChromaCollector()


def add_chunks_to_collector(chunks, collector):
collector.clear()
collector.add(chunks)


embedder = SentenceTransformerEmbedder()
2 changes: 1 addition & 1 deletion extensions/superbooga/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
beautifulsoup4==4.12.2
chromadb==0.3.18
chromadb==0.4.24
pandas==2.0.3
posthog==2.4.2
sentence_transformers==2.2.2
Expand Down
33 changes: 11 additions & 22 deletions extensions/superboogav2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@

import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlparse, parse_qs
from threading import Thread
from urllib.parse import parse_qs, urlparse

import extensions.superboogav2.parameters as parameters
from modules import shared
from modules.logging_colors import logger

from .chromadb import ChromaCollector
from .data_processor import process_and_add_to_collector

import extensions.superboogav2.parameters as parameters


class CustomThreadingHTTPServer(ThreadingHTTPServer):
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
Expand All @@ -38,30 +37,26 @@ def __init__(self, request, client_address, server, collector: ChromaCollector):
self.collector = collector
super().__init__(request, client_address, server)


def _send_412_error(self, message):
self.send_response(412)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": message})
self.wfile.write(response.encode('utf-8'))


def _send_404_error(self):
self.send_response(404)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": "Resource not found"})
self.wfile.write(response.encode('utf-8'))


def _send_400_error(self, error_message: str):
self.send_response(400)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": error_message})
self.wfile.write(response.encode('utf-8'))


def _send_200_response(self, message: str):
self.send_response(200)
Expand All @@ -75,24 +70,21 @@ def _send_200_response(self, message: str):

self.wfile.write(response.encode('utf-8'))


def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
if sort_param == parameters.SORT_DISTANCE:
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
elif sort_param == parameters.SORT_ID:
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
else: # Default is dist
else: # Default is dist
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)

return {
"results": results
}


def do_GET(self):
self._send_404_error()


def do_POST(self):
try:
content_length = int(self.headers['Content-Length'])
Expand All @@ -107,7 +99,7 @@ def do_POST(self):
if corpus is None:
self._send_412_error("Missing parameter 'corpus'")
return

clear_before_adding = body.get('clear_before_adding', False)
metadata = body.get('metadata')
process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata)
Expand All @@ -118,7 +110,7 @@ def do_POST(self):
if corpus is None:
self._send_412_error("Missing parameter 'metadata'")
return

self.collector.delete(ids_to_delete=None, where=metadata)
self._send_200_response("Data successfully deleted")

Expand All @@ -127,15 +119,15 @@ def do_POST(self):
if search_strings is None:
self._send_412_error("Missing parameter 'search_strings'")
return

n_results = body.get('n_results')
if n_results is None:
n_results = parameters.get_chunk_count()

max_token_count = body.get('max_token_count')
if max_token_count is None:
max_token_count = parameters.get_max_token_count()

sort_param = query_params.get('sort', ['distance'])[0]

results = self._handle_get(search_strings, n_results, max_token_count, sort_param)
Expand All @@ -146,7 +138,6 @@ def do_POST(self):
except Exception as e:
self._send_400_error(str(e))


def do_DELETE(self):
try:
parsed_path = urlparse(self.path)
Expand All @@ -161,12 +152,10 @@ def do_DELETE(self):
except Exception as e:
self._send_400_error(str(e))


def do_OPTIONS(self):
self.send_response(200)
self.end_headers()


def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
Expand Down Expand Up @@ -197,11 +186,11 @@ def start_server(self, port: int):

def stop_server(self):
if self.server is not None:
logger.info(f'Stopping chromaDB API.')
logger.info('Stopping chromaDB API.')
self.server.shutdown()
self.server.server_close()
self.server = None
self.is_running = False

def is_server_running(self):
return self.is_running
return self.is_running
14 changes: 7 additions & 7 deletions extensions/superboogav2/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@
import datetime
import json
import os

from pathlib import Path

from .data_processor import process_and_add_to_collector, preprocess_text
from .data_processor import preprocess_text, process_and_add_to_collector
from .parameters import get_chunk_count, get_max_token_count
from .utils import create_metadata_source


def benchmark(config_path, collector):
# Get the current system date
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"benchmark_{sysdate}.txt"

# Open the log file in append mode
with open(filename, 'a') as log:
with open(config_path, 'r') as f:
data = json.load(f)

total_points = 0
max_points = 0

Expand All @@ -45,7 +45,7 @@ def benchmark(config_path, collector):
for question_group in item["questions"]:
question_variants = question_group["question_variants"]
criteria = question_group["criteria"]

for q in question_variants:
max_points += len(criteria)
processed_text = preprocess_text(q)
Expand All @@ -54,7 +54,7 @@ def benchmark(config_path, collector):
results = collector.get_sorted_by_dist(processed_text, n_results=get_chunk_count(), max_token_count=get_max_token_count())

points = 0

for c in criteria:
for p in results:
if c in p:
Expand All @@ -69,4 +69,4 @@ def benchmark(config_path, collector):

print(f'##Total points:\n\n{total_points}/{max_points}', file=log)

return total_points, max_points
return total_points, max_points
31 changes: 16 additions & 15 deletions extensions/superboogav2/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,35 @@
import re

import extensions.superboogav2.parameters as parameters

from extensions.superboogav2.utils import (
create_context_text,
create_metadata_source
)
from modules import chat, shared
from modules.text_generation import get_encoded_length
from modules.logging_colors import logger
from modules.chat import load_character_memoized
from extensions.superboogav2.utils import create_context_text, create_metadata_source
from modules.logging_colors import logger
from modules.text_generation import get_encoded_length

from .data_processor import process_and_add_to_collector
from .chromadb import ChromaCollector

from .data_processor import process_and_add_to_collector

CHAT_METADATA = create_metadata_source('automatic-chat-insert')


def _remove_tag_if_necessary(user_input: str):
if not parameters.get_is_manual():
return user_input

return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input)


def _should_query(input: str):
if not parameters.get_is_manual():
return True

if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE):
return True

return False


Expand Down Expand Up @@ -69,7 +70,7 @@ def _concatinate_history(history: dict, state: dict):
if len(exchange) >= 2:
full_history_text += _format_single_exchange(bot_name, exchange[1])

return full_history_text[:-1] # Remove the last new line.
return full_history_text[:-1] # Remove the last new line.


def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
Expand All @@ -82,20 +83,20 @@ def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
for i, messages in enumerate(reversed(history['internal'])):
for j, message in enumerate(reversed(messages)):
num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message))

# TODO: This is an extremely naive solution. A more robust implementation must be made.
if history_tokens + num_context_tokens <= max_len:
# This message can be replaced
replace_position = (i, j)

history_tokens += num_message_tokens

if replace_position is None:
logger.warn("The provided context_text is too long to replace any message in the history.")
else:
# replace the message at replace_position with context_text
i, j = replace_position
history['internal'][-i-1][-j-1] = context_text
history['internal'][-i - 1][-j - 1] = context_text


def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
Expand All @@ -120,5 +121,5 @@ def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector
user_input = create_context_text(results) + user_input
elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT:
_hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state)

return chat.generate_chat_prompt(user_input, state, **kwargs)
Loading