Skip to content

Commit

Permalink
Merge pull request #10 from ACMILabs/add/9-ollama-embeddings
Browse files Browse the repository at this point in the history
#9 Add the ability to use Ollama Embeddings models
  • Loading branch information
sighmon authored May 21, 2024
2 parents 948738b + 0223607 commit 679dd3c
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 47 deletions.
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ disable=raw-checker-failed,
use-symbolic-message-instead,
# ACMI entries
missing-docstring,
no-name-in-module
no-name-in-module,
duplicate-code,
import-error

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
13 changes: 9 additions & 4 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from langchain.schema import format_document
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableMap, RunnablePassthrough
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
Expand All @@ -30,7 +31,8 @@
DATABASE_PATH = os.getenv('DATABASE_PATH', '')
COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'works')
PERSIST_DIRECTORY = os.getenv('PERSIST_DIRECTORY', 'works_db')
MODEL = os.getenv('MODEL', 'gpt-4-turbo-2024-04-09')
MODEL = os.getenv('MODEL', 'gpt-4o')
EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', None)
LLM_BASE_URL = os.getenv('LLM_BASE_URL', None)
LANGCHAIN_TRACING_V2 = os.getenv('LANGCHAIN_TRACING_V2', 'false').lower() == 'true'

Expand Down Expand Up @@ -72,18 +74,21 @@ def _format_chat_history(chat_history: List[Tuple]) -> str:
buffer = ""
for dialogue_turn in chat_history:
human = 'Human: ' + dialogue_turn[0]
ai = 'Assistant: ' + dialogue_turn[1]
buffer += '\n' + '\n'.join([human, ai])
assistant = 'Assistant: ' + dialogue_turn[1]
buffer += '\n' + '\n'.join([human, assistant])
return buffer


embeddings = OpenAIEmbeddings()
if MODEL.startswith('gpt'):
llm = ChatOpenAI(temperature=0, model=MODEL)
embeddings = OpenAIEmbeddings(model=EMBEDDINGS_MODEL or 'text-embedding-ada-002')
else:
llm = Ollama(model=MODEL)
embeddings = OllamaEmbeddings(model=EMBEDDINGS_MODEL or MODEL)
if LLM_BASE_URL:
llm.base_url = LLM_BASE_URL
embeddings.base_url = LLM_BASE_URL

docsearch = Chroma(
collection_name=COLLECTION_NAME,
embedding_function=embeddings,
Expand Down
96 changes: 54 additions & 42 deletions chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,34 @@
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import JSONLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

DATABASE_PATH = os.getenv('DATABASE_PATH', '')
COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'works')
PERSIST_DIRECTORY = os.getenv('PERSIST_DIRECTORY', 'works_db')
MODEL = os.getenv('MODEL', 'gpt-4-turbo-2024-04-09')
MODEL = os.getenv('MODEL', 'gpt-4o')
EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', None)
LLM_BASE_URL = os.getenv('LLM_BASE_URL', None)
REBUILD = os.getenv('REBUILD', 'false').lower() == 'true'
HISTORY = os.getenv('HISTORY', 'true').lower() == 'true'
ALL = os.getenv('ALL', 'false').lower() == 'true'

# Set true if you'd like langchain tracing via LangSmith https://smith.langchain.com
os.environ['LANGCHAIN_TRACING_V2'] = 'false'

embeddings = OpenAIEmbeddings()
if MODEL.startswith('gpt'):
llm = ChatOpenAI(temperature=0, model=MODEL)
embeddings = OpenAIEmbeddings(model=EMBEDDINGS_MODEL or 'text-embedding-ada-002')
else:
llm = Ollama(model=MODEL)
embeddings = OllamaEmbeddings(model=EMBEDDINGS_MODEL or MODEL)
if LLM_BASE_URL:
llm.base_url = LLM_BASE_URL
embeddings.base_url = LLM_BASE_URL

docsearch = Chroma(
collection_name=COLLECTION_NAME,
embedding_function=embeddings,
Expand All @@ -45,40 +58,47 @@
'results': [],
}
params = {'page': ''}
if ALL:
print('Loading all of the works from the ACMI Public API')
while True:
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
).json()
json_data['results'].extend(page_data['results'])
if not page_data.get('next'):
break
params['page'] = furl(page_data.get('next')).args.get('page')
if len(json_data['results']) % 1000 == 0:
print(f'Downloaded {len(json_data["results"])}...')
TMP_FILE_PATH = 'data.json'

if os.path.isfile(TMP_FILE_PATH):
print('Loading works from the ACMI Public API data.json file you have already created...')
with open(TMP_FILE_PATH, 'r', encoding='utf-8') as tmp_file:
json_data = json.load(tmp_file)
else:
print('Loading the first ten pages of works from the ACMI Public API')
PAGES = 10
json_data = {
'results': [],
}
for index in range(1, (PAGES + 1)):
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
)
json_data['results'].extend(page_data.json()['results'])
print(f'Downloaded {page_data.request.url}')
params['page'] = furl(page_data.json().get('next')).args.get('page')
print(f'Finished downloading {len(json_data["results"])} works.')
if ALL:
print('Loading all of the works from the ACMI Public API')
while True:
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
).json()
json_data['results'].extend(page_data['results'])
if not page_data.get('next'):
break
params['page'] = furl(page_data.get('next')).args.get('page')
if len(json_data['results']) % 1000 == 0:
print(f'Downloaded {len(json_data["results"])}...')
else:
print('Loading the first ten pages of works from the ACMI Public API')
PAGES = 10
json_data = {
'results': [],
}
for index in range(1, (PAGES + 1)):
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
)
json_data['results'].extend(page_data.json()['results'])
print(f'Downloaded {page_data.request.url}')
params['page'] = furl(page_data.json().get('next')).args.get('page')
print(f'Finished downloading {len(json_data["results"])} works.')

with open(TMP_FILE_PATH, 'w', encoding='utf-8') as json_file:
json.dump(json_data, json_file)

TMP_FILE_PATH = 'data.json'
with open(TMP_FILE_PATH, 'w', encoding='utf-8') as json_file:
json.dump(json_data, json_file)
json_loader = JSONLoader(
file_path=TMP_FILE_PATH,
jq_schema='.results[]',
Expand All @@ -103,14 +123,6 @@ def chunks(input_list, number_per_chunk):
print(f'Added {len(sublist)} items to the database... total {(i + 1) * len(sublist)}')
print(f'Finished adding {len(data)} items to the database')

docsearch = Chroma.from_documents(
data,
embeddings,
collection_name=COLLECTION_NAME,
persist_directory=PERSIST_DIRECTORY,
)

llm = ChatOpenAI(temperature=0, model=MODEL)
qa_chain = create_qa_with_sources_chain(llm)
doc_prompt = PromptTemplate(
template='Content: {page_content}\nSource: {source}',
Expand Down
1 change: 1 addition & 0 deletions config.tmpl.env
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ OPENAI_API_KEY=your-openai-api-key
LANGCHAIN_API_KEY=your-langchain-api-key
MODEL=
LLM_BASE_URL=
EMBEDDINGS_MODEL=

0 comments on commit 679dd3c

Please sign in to comment.