diff --git a/.pylintrc b/.pylintrc index da8e250..a1cfee7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -70,7 +70,9 @@ disable=raw-checker-failed, use-symbolic-message-instead, # ACMI entries missing-docstring, - no-name-in-module + no-name-in-module, + duplicate-code, + import-error # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/api/server.py b/api/server.py index 17716ea..02c671f 100644 --- a/api/server.py +++ b/api/server.py @@ -21,6 +21,7 @@ from langchain.schema import format_document from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableMap, RunnablePassthrough +from langchain_community.embeddings import OllamaEmbeddings from langchain_community.llms import Ollama from langchain_community.vectorstores import Chroma from langchain_openai import ChatOpenAI, OpenAIEmbeddings @@ -30,7 +31,8 @@ DATABASE_PATH = os.getenv('DATABASE_PATH', '') COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'works') PERSIST_DIRECTORY = os.getenv('PERSIST_DIRECTORY', 'works_db') -MODEL = os.getenv('MODEL', 'gpt-4-turbo-2024-04-09') +MODEL = os.getenv('MODEL', 'gpt-4o') +EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', None) LLM_BASE_URL = os.getenv('LLM_BASE_URL', None) LANGCHAIN_TRACING_V2 = os.getenv('LANGCHAIN_TRACING_V2', 'false').lower() == 'true' @@ -72,18 +74,21 @@ def _format_chat_history(chat_history: List[Tuple]) -> str: buffer = "" for dialogue_turn in chat_history: human = 'Human: ' + dialogue_turn[0] - ai = 'Assistant: ' + dialogue_turn[1] - buffer += '\n' + '\n'.join([human, ai]) + assistant = 'Assistant: ' + dialogue_turn[1] + buffer += '\n' + '\n'.join([human, assistant]) return buffer -embeddings = OpenAIEmbeddings() if MODEL.startswith('gpt'): llm = ChatOpenAI(temperature=0, model=MODEL) + embeddings = OpenAIEmbeddings(model=EMBEDDINGS_MODEL or 'text-embedding-ada-002') else: llm = Ollama(model=MODEL) + embeddings = OllamaEmbeddings(model=EMBEDDINGS_MODEL or MODEL) if LLM_BASE_URL: llm.base_url = LLM_BASE_URL + embeddings.base_url = LLM_BASE_URL + docsearch = Chroma( collection_name=COLLECTION_NAME, embedding_function=embeddings, diff --git a/chat.py b/chat.py index dfde666..15cc3dd 100644 --- a/chat.py +++ b/chat.py @@ -19,13 +19,17 @@ from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate from langchain_community.document_loaders import JSONLoader +from langchain_community.embeddings import OllamaEmbeddings +from langchain_community.llms import Ollama from langchain_community.vectorstores import Chroma from langchain_openai import ChatOpenAI, OpenAIEmbeddings DATABASE_PATH = os.getenv('DATABASE_PATH', '') COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'works') PERSIST_DIRECTORY = os.getenv('PERSIST_DIRECTORY', 'works_db') -MODEL = os.getenv('MODEL', 'gpt-4-turbo-2024-04-09') +MODEL = os.getenv('MODEL', 'gpt-4o') +EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', None) +LLM_BASE_URL = os.getenv('LLM_BASE_URL', None) REBUILD = os.getenv('REBUILD', 'false').lower() == 'true' HISTORY = os.getenv('HISTORY', 'true').lower() == 'true' ALL = os.getenv('ALL', 'false').lower() == 'true' @@ -33,7 +37,16 @@ # Set true if you'd like langchain tracing via LangSmith https://smith.langchain.com os.environ['LANGCHAIN_TRACING_V2'] = 'false' -embeddings = OpenAIEmbeddings() +if MODEL.startswith('gpt'): + llm = ChatOpenAI(temperature=0, model=MODEL) + embeddings = OpenAIEmbeddings(model=EMBEDDINGS_MODEL or 'text-embedding-ada-002') +else: + llm = Ollama(model=MODEL) + embeddings = OllamaEmbeddings(model=EMBEDDINGS_MODEL or MODEL) + if LLM_BASE_URL: + llm.base_url = LLM_BASE_URL + embeddings.base_url = LLM_BASE_URL + docsearch = Chroma( collection_name=COLLECTION_NAME, embedding_function=embeddings, @@ -45,40 +58,47 @@ 'results': [], } params = {'page': ''} - if ALL: - print('Loading all of the works from the ACMI Public API') - while True: - page_data = requests.get( - 'https://api.acmi.net.au/works/', - params=params, - timeout=10, - ).json() - json_data['results'].extend(page_data['results']) - if not page_data.get('next'): - break - params['page'] = furl(page_data.get('next')).args.get('page') - if len(json_data['results']) % 1000 == 0: - print(f'Downloaded {len(json_data["results"])}...') + TMP_FILE_PATH = 'data.json' + + if os.path.isfile(TMP_FILE_PATH): + print('Loading works from the ACMI Public API data.json file you have already created...') + with open(TMP_FILE_PATH, 'r', encoding='utf-8') as tmp_file: + json_data = json.load(tmp_file) else: - print('Loading the first ten pages of works from the ACMI Public API') - PAGES = 10 - json_data = { - 'results': [], - } - for index in range(1, (PAGES + 1)): - page_data = requests.get( - 'https://api.acmi.net.au/works/', - params=params, - timeout=10, - ) - json_data['results'].extend(page_data.json()['results']) - print(f'Downloaded {page_data.request.url}') - params['page'] = furl(page_data.json().get('next')).args.get('page') - print(f'Finished downloading {len(json_data["results"])} works.') + if ALL: + print('Loading all of the works from the ACMI Public API') + while True: + page_data = requests.get( + 'https://api.acmi.net.au/works/', + params=params, + timeout=10, + ).json() + json_data['results'].extend(page_data['results']) + if not page_data.get('next'): + break + params['page'] = furl(page_data.get('next')).args.get('page') + if len(json_data['results']) % 1000 == 0: + print(f'Downloaded {len(json_data["results"])}...') + else: + print('Loading the first ten pages of works from the ACMI Public API') + PAGES = 10 + json_data = { + 'results': [], + } + for index in range(1, (PAGES + 1)): + page_data = requests.get( + 'https://api.acmi.net.au/works/', + params=params, + timeout=10, + ) + json_data['results'].extend(page_data.json()['results']) + print(f'Downloaded {page_data.request.url}') + params['page'] = furl(page_data.json().get('next')).args.get('page') + print(f'Finished downloading {len(json_data["results"])} works.') + + with open(TMP_FILE_PATH, 'w', encoding='utf-8') as json_file: + json.dump(json_data, json_file) - TMP_FILE_PATH = 'data.json' - with open(TMP_FILE_PATH, 'w', encoding='utf-8') as json_file: - json.dump(json_data, json_file) json_loader = JSONLoader( file_path=TMP_FILE_PATH, jq_schema='.results[]', @@ -103,14 +123,6 @@ def chunks(input_list, number_per_chunk): print(f'Added {len(sublist)} items to the database... total {(i + 1) * len(sublist)}') print(f'Finished adding {len(data)} items to the database') - docsearch = Chroma.from_documents( - data, - embeddings, - collection_name=COLLECTION_NAME, - persist_directory=PERSIST_DIRECTORY, - ) - -llm = ChatOpenAI(temperature=0, model=MODEL) qa_chain = create_qa_with_sources_chain(llm) doc_prompt = PromptTemplate( template='Content: {page_content}\nSource: {source}', diff --git a/config.tmpl.env b/config.tmpl.env index 24298fc..b7ecaaf 100644 --- a/config.tmpl.env +++ b/config.tmpl.env @@ -2,3 +2,4 @@ OPENAI_API_KEY=your-openai-api-key LANGCHAIN_API_KEY=your-langchain-api-key MODEL= LLM_BASE_URL= +EMBEDDINGS_MODEL=