Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#9 Add the ability to use Ollama Embeddings models #10

Merged
merged 6 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ disable=raw-checker-failed,
use-symbolic-message-instead,
# ACMI entries
missing-docstring,
no-name-in-module
no-name-in-module,
duplicate-code,
import-error

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
13 changes: 9 additions & 4 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from langchain.schema import format_document
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableMap, RunnablePassthrough
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
Expand All @@ -30,7 +31,8 @@
DATABASE_PATH = os.getenv('DATABASE_PATH', '')
COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'works')
PERSIST_DIRECTORY = os.getenv('PERSIST_DIRECTORY', 'works_db')
MODEL = os.getenv('MODEL', 'gpt-4-turbo-2024-04-09')
MODEL = os.getenv('MODEL', 'gpt-4o')
EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', 'nomic-embed-text')
LLM_BASE_URL = os.getenv('LLM_BASE_URL', None)
LANGCHAIN_TRACING_V2 = os.getenv('LANGCHAIN_TRACING_V2', 'false').lower() == 'true'

Expand Down Expand Up @@ -72,18 +74,21 @@ def _format_chat_history(chat_history: List[Tuple]) -> str:
buffer = ""
for dialogue_turn in chat_history:
human = 'Human: ' + dialogue_turn[0]
ai = 'Assistant: ' + dialogue_turn[1]
buffer += '\n' + '\n'.join([human, ai])
assistant = 'Assistant: ' + dialogue_turn[1]
buffer += '\n' + '\n'.join([human, assistant])
return buffer


embeddings = OpenAIEmbeddings()
if MODEL.startswith('gpt'):
llm = ChatOpenAI(temperature=0, model=MODEL)
embeddings = OpenAIEmbeddings()
else:
llm = Ollama(model=MODEL)
embeddings = OllamaEmbeddings(model=EMBEDDINGS_MODEL)
if LLM_BASE_URL:
llm.base_url = LLM_BASE_URL
embeddings.base_url = LLM_BASE_URL

docsearch = Chroma(
collection_name=COLLECTION_NAME,
embedding_function=embeddings,
Expand Down
96 changes: 54 additions & 42 deletions chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,34 @@
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import JSONLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

DATABASE_PATH = os.getenv('DATABASE_PATH', '')
COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'works')
PERSIST_DIRECTORY = os.getenv('PERSIST_DIRECTORY', 'works_db')
MODEL = os.getenv('MODEL', 'gpt-4-turbo-2024-04-09')
MODEL = os.getenv('MODEL', 'gpt-4o')
EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', 'nomic-embed-text')
LLM_BASE_URL = os.getenv('LLM_BASE_URL', None)
REBUILD = os.getenv('REBUILD', 'false').lower() == 'true'
HISTORY = os.getenv('HISTORY', 'true').lower() == 'true'
ALL = os.getenv('ALL', 'false').lower() == 'true'

# Set true if you'd like langchain tracing via LangSmith https://smith.langchain.com
os.environ['LANGCHAIN_TRACING_V2'] = 'false'

embeddings = OpenAIEmbeddings()
if MODEL.startswith('gpt'):
llm = ChatOpenAI(temperature=0, model=MODEL)
embeddings = OpenAIEmbeddings()
else:
llm = Ollama(model=MODEL)
embeddings = OllamaEmbeddings(model=EMBEDDINGS_MODEL)
if LLM_BASE_URL:
llm.base_url = LLM_BASE_URL
embeddings.base_url = LLM_BASE_URL

docsearch = Chroma(
collection_name=COLLECTION_NAME,
embedding_function=embeddings,
Expand All @@ -45,40 +58,47 @@
'results': [],
}
params = {'page': ''}
if ALL:
print('Loading all of the works from the ACMI Public API')
while True:
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
).json()
json_data['results'].extend(page_data['results'])
if not page_data.get('next'):
break
params['page'] = furl(page_data.get('next')).args.get('page')
if len(json_data['results']) % 1000 == 0:
print(f'Downloaded {len(json_data["results"])}...')
TMP_FILE_PATH = 'data.json'

if os.path.isfile(TMP_FILE_PATH):
print('Loading works from the ACMI Public API data.json file you have already created...')
with open(TMP_FILE_PATH, 'r') as tmp_file:
json_data = json.load(tmp_file)
else:
print('Loading the first ten pages of works from the ACMI Public API')
PAGES = 10
json_data = {
'results': [],
}
for index in range(1, (PAGES + 1)):
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
)
json_data['results'].extend(page_data.json()['results'])
print(f'Downloaded {page_data.request.url}')
params['page'] = furl(page_data.json().get('next')).args.get('page')
print(f'Finished downloading {len(json_data["results"])} works.')
if ALL:
print('Loading all of the works from the ACMI Public API')
while True:
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
).json()
json_data['results'].extend(page_data['results'])
if not page_data.get('next'):
break
params['page'] = furl(page_data.get('next')).args.get('page')
if len(json_data['results']) % 1000 == 0:
print(f'Downloaded {len(json_data["results"])}...')
else:
print('Loading the first ten pages of works from the ACMI Public API')
PAGES = 10
json_data = {
'results': [],
}
for index in range(1, (PAGES + 1)):
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
)
json_data['results'].extend(page_data.json()['results'])
print(f'Downloaded {page_data.request.url}')
params['page'] = furl(page_data.json().get('next')).args.get('page')
print(f'Finished downloading {len(json_data["results"])} works.')

with open(TMP_FILE_PATH, 'w', encoding='utf-8') as json_file:
json.dump(json_data, json_file)

TMP_FILE_PATH = 'data.json'
with open(TMP_FILE_PATH, 'w', encoding='utf-8') as json_file:
json.dump(json_data, json_file)
json_loader = JSONLoader(
file_path=TMP_FILE_PATH,
jq_schema='.results[]',
Expand All @@ -103,14 +123,6 @@ def chunks(input_list, number_per_chunk):
print(f'Added {len(sublist)} items to the database... total {(i + 1) * len(sublist)}')
print(f'Finished adding {len(data)} items to the database')

docsearch = Chroma.from_documents(
data,
embeddings,
collection_name=COLLECTION_NAME,
persist_directory=PERSIST_DIRECTORY,
)

llm = ChatOpenAI(temperature=0, model=MODEL)
qa_chain = create_qa_with_sources_chain(llm)
doc_prompt = PromptTemplate(
template='Content: {page_content}\nSource: {source}',
Expand Down
1 change: 1 addition & 0 deletions config.tmpl.env
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ OPENAI_API_KEY=your-openai-api-key
LANGCHAIN_API_KEY=your-langchain-api-key
MODEL=
LLM_BASE_URL=
EMBEDDINGS_MODEL=
Loading