Skip to content

Commit

Permalink
Replace postgresql with sql in langchain
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <[email protected]>
  • Loading branch information
jaelgu committed Oct 25, 2023
1 parent 37db7d6 commit 2d4fc9c
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 32 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,21 @@ The option using LangChain employs the use of [Agent](https://python.langchain.c

- Vector Store: You need to prepare the service of vector database in advance. For example, you can refer to [Milvus Documents](https://milvus.io/docs) or [Zilliz Cloud](https://zilliz.com/doc/quick_start) to learn about how to start a Milvus service.
- Scalar Store (Optional): This is optional, only work when `USE_SCALAR` is true in [configuration](config.py). If this is enabled (i.e. USE_SCALAR=True), the default scalar store will use [Elastic](https://www.elastic.co/). In this case, you need to prepare the Elasticsearch service in advance.
- Memory Store: You need to prepare the database for memory storage as well. By default, LangChain mode supports [Postgresql](https://www.postgresql.org/) and Towhee mode allows interaction with any database supported by [SQLAlchemy 2.0](https://docs.sqlalchemy.org/en/20/dialects/).
- Memory Store: You need to prepare the database for memory storage as well. By default, both LangChain and Towhee mode allow interaction with any database supported by [SQLAlchemy 2.0](https://docs.sqlalchemy.org/en/20/dialects/).

The system will use default store configs.
To set up your special connections for each database, you can also export environment variables instead of modifying the configuration file.

For the Vector Store, set **MILVUS_URI**:
For the Vector Store, set **ZILLIZ_URI**:
```shell
$ export MILVUS_URI=https://localhost:19530
$ export ZILLIZ_URI=your_zilliz_cloud_endpoint
$ export ZILLIZ_TOKEN=your_zilliz_cloud_api_key # skip this if using Milvus instance
```

For the Memory Store, set **SQL_URI**:
```shell
$ export SQL_URI={database_type}://{user}:{password}@{host}/{database_name}
```
> LangChain mode only supports [Postgresql](https://www.postgresql.org/) as database type.

```

<details>
<summary>By default, scalar store (elastic) is disabled.
Expand Down
8 changes: 3 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@
# Vector db configs
VECTORDB_CONFIG = {
'connection_args': {
'uri': os.getenv('MILVUS_URI', 'http://localhost:19530'),
'user': os.getenv('MILVUS_USER', ''),
'password': os.getenv('MILVUS_PASSWORD', ''),
'secure': True if os.getenv('MILVUS_SECURE', 'False').lower() == 'true' else False
'uri': os.getenv('ZILLIZ_URI', 'http://localhost:19530'),
'token': os.getenv('ZILLIZ_TOKEN')
},
'top_k': 5,
'threshold': 0,
Expand Down Expand Up @@ -104,7 +102,7 @@

# Memory db configs
MEMORYDB_CONFIG = {
'connect_str': os.getenv('SQL_URI', 'postgresql://postgres:postgres@localhost/chat_history')
'connect_str': os.getenv('SQL_URI', 'sqlite:///./sqlite.db')
}


Expand Down
2 changes: 1 addition & 1 deletion src_langchain/store/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The default module also works with [Zilliz Cloud](https://zilliz.com) by setting
# Vector db configs
VECTORDB_CONFIG = {
'connection_args': {
'uri': os.getenv('MILVUS_URI', 'your_endpoint'),
'uri': os.getenv('ZILLIZ_URI', 'your_endpoint'),
'user': os.getenv('MILVUS_USER', 'user_name'),
'password': os.getenv('MILVUS_PASSWORD', 'password_goes_here'),
'secure': True
Expand Down
2 changes: 1 addition & 1 deletion src_langchain/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, List

from .vector_store.milvus import VectorStore, Embeddings
from .memory_store.pg import MemoryStore
from .memory_store.sql import MemoryStore

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

Expand Down
84 changes: 84 additions & 0 deletions src_langchain/store/memory_store/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import sys
from typing import List

from sqlalchemy import create_engine, inspect, MetaData, Table

from langchain.schema import HumanMessage, AIMessage
from langchain.memory import SQLChatMessageHistory, ConversationBufferMemory

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from config import MEMORYDB_CONFIG # pylint: disable=C0413


CONNECT_STR = MEMORYDB_CONFIG.get(
'connect_str', 'sqlite:///./sqlite.db')


class MemoryStore:
'''Memory database APIs: add_history, get_history'''

def __init__(self, table_name: str, session_id: str):
'''Initialize memory storage: e.g. history_db'''
self.table_name = table_name
self.session_id = session_id

self.history_db = SQLChatMessageHistory(
table_name=self.table_name,
session_id=self.session_id,
connection_string=CONNECT_STR,
)
self.memory = ConversationBufferMemory(
memory_key='chat_history',
chat_memory=self.history_db,
return_messages=True
)

def add_history(self, messages: List[dict]):
for qa in messages:
if 'question' in qa:
self.history_db.add_user_message(qa['question'])
if 'answer' in qa:
self.history_db.add_ai_message(qa['answer'])

def get_history(self):
history = self.history_db.messages
messages = []
for x in history:
if isinstance(x, HumanMessage):
if len(messages) > 0 and messages[-1][0] is None:
a = messages[-1][-1]
del messages[-1]
else:
a = None
messages.append((x.content, a))
if isinstance(x, AIMessage):
if len(messages) > 0 and messages[-1][-1] is None:
q = messages[-1][0]
del messages[-1]
else:
q = None
messages.append((q, x.content))
return messages

@staticmethod
def drop(table_name, connect_str: str = CONNECT_STR, session_id: str = None):
engine = create_engine(connect_str, echo=False)
existence = MemoryStore.check(table_name)

if existence:
project_table = Table(table_name, MetaData(),
autoload_with=engine, extend_existing=True)
if session_id and len(session_id) > 0:
query = project_table.delete().where(project_table.c.session_id == session_id)
with engine.connect() as conn:
conn.execute(query)
conn.commit()
else:
query = project_table.drop(engine)

@staticmethod
def check(table_name, connect_str: str = CONNECT_STR):
engine = create_engine(connect_str, echo=False)
return inspect(engine).has_table(table_name)
67 changes: 64 additions & 3 deletions src_langchain/store/vector_store/milvus.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import sys
import logging
from typing import Optional, Any, Tuple, List, Dict
from typing import Optional, Any, Tuple, List, Dict, Union
from uuid import uuid4

from langchain.vectorstores import Milvus
from langchain.embeddings.base import Embeddings
Expand All @@ -14,7 +15,7 @@

logger = logging.getLogger('vector_store')

CONNECTION_ARGS = VECTORDB_CONFIG.get('connection_args', {'host': 'localhost', 'port': 19530})
CONNECTION_ARGS = VECTORDB_CONFIG.get('connection_args', {'uri': 'http://localhost:19530'})
TOP_K = VECTORDB_CONFIG.get('top_k', 3)
INDEX_PARAMS = VECTORDB_CONFIG.get('index_params', None)
SEARCH_PARAMS = VECTORDB_CONFIG.get('search_params', None)
Expand All @@ -26,7 +27,12 @@ class VectorStore(Milvus):
'''

def __init__(self, table_name: str, embedding_func: Embeddings = None, connection_args: dict = CONNECTION_ARGS):
'''Initialize vector db'''
'''Initialize vector db
connection_args:
uri: milvus or zilliz uri
token: zilliz token
'''
# assert isinstance(
# embedding_func, Embeddings), 'Invalid embedding function. Only accept langchain.embeddings.'
self.embedding_func = embedding_func
Expand All @@ -40,6 +46,61 @@ def __init__(self, table_name: str, embedding_func: Embeddings = None, connectio
search_params=SEARCH_PARAMS
)

def _create_connection_alias(self, connection_args: dict) -> str:
"""Create the connection to the Milvus server."""
from pymilvus import MilvusException, connections # pylint: disable = C0415

# Grab the connection arguments that are used for checking existing connection
host: str = connection_args.get('host', None)
port: Union[str, int] = connection_args.get('port', None)
uri: str = connection_args.get('uri', None)
user = connection_args.get('user', None)
password = connection_args.get('password', None)
token = connection_args.get('token', None)


_connection_args = {} # pylint: disable = C0103
# Order of use is uri > host/port
if uri is not None:
_connection_args['uri'] = uri
given_address = uri.split('://')[1]
elif host is not None and port is not None:
_connection_args['host'] = host
_connection_args['port'] = port
given_address = f'{host}:{port}'
else:
logger.debug('Missing standard address type for reuse attempt')
given_address = None

# Order of use is token > user/password
if token is not None:
_connection_args['token'] = token
_connection_args['secure'] = True
elif user is not None and password is not None:
_connection_args['user'] = user
_connection_args['password'] = password
_connection_args['secure'] = True
else:
_connection_args['secure'] = False

# If a valid address was given, then check if a connection exists
if given_address is not None:
for con in connections.list_connections():
addr = connections.get_connection_addr(con[0])
if addr == given_address:
logger.debug('Using previous connection: %s', con[0])
return con[0]

# Generate a new connection if one doesn't exist
alias = uuid4().hex
try:
connections.connect(alias=alias, **_connection_args)
logger.debug('Created new connection using: %s', alias)
return alias
except MilvusException as e:
logger.error('Failed to create new connection using: %s', alias)
raise e

def similarity_search_with_score_by_vector(
self,
embedding: List[float],
Expand Down
51 changes: 36 additions & 15 deletions src_towhee/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,42 @@ def __init__(self,
self.rerank_config = rerank_config
self.chunk_size = chunk_size

self.milvus_uri = vectordb_config['connection_args']['uri']
self.milvus_host = self.milvus_uri.split('https://')[1].split(':')[0]
self.milvus_port = self.milvus_uri.split('https://')[1].split(':')[1]
milvus_user = vectordb_config['connection_args'].get('user')
self.milvus_secure = vectordb_config['connection_args'].get('secure', False)
self.milvus_user = None if milvus_user == '' else milvus_user
milvus_password = vectordb_config['connection_args'].get('password')
self.milvus_password = None if milvus_password == '' else milvus_password
self.milvus_topk = vectordb_config.get('top_k', 5)
self.milvus_threshold = vectordb_config.get('threshold', 0)
self.milvus_index_params = vectordb_config.get('index_params', {})

connections.connect(
host=self.milvus_host,
port=self.milvus_port,
user=self.milvus_user,
secure=self.milvus_secure,
password=self.milvus_password
)
self.connection_args = vectordb_config['connection_args']
for k, v in self.connection_args.items():
if v is None:
del self.connection_args[k]
if isinstance(v, str) and len(v) == 0:
del self.connection_args[k]

if 'uri' in self.connection_args:
self.milvus_uri = self.connection_args['uri']
self.milvus_host = self.connection_args.pop('host', None)
self.milvus_port = self.connection_args.pop('port', None)
elif 'host' in self.connection_args and 'port' in self.connection_args:
self.milvus_uri = None
self.milvus_host = self.connection_args.get('host')
self.milvus_port = self.connection_args.get('port')
else:
raise AttributeError('Invalid connection args for milvus.')

if 'token' in self.connection_args:
self.milvus_token = self.connection_args.get('token', None)
self.milvus_user = self.connection_args.pop('user', None)
self.milvus_password = self.connection_args.pop('password', None)
else:
self.milvus_token = None
self.milvus_user = self.connection_args.get('user')
self.milvus_password = self.connection_args.get('password')

if self.milvus_token or self.milvus_user:
self.connection_args['secure'] = True
self.milvus_secure = True

connections.connect(**self.connection_args)

if self.use_scalar:
from elasticsearch import Elasticsearch # pylint: disable=C0415
Expand Down Expand Up @@ -98,6 +115,8 @@ def search_config(self):


# Configure vector store (Milvus/Zilliz)
search_config.milvus_uri = self.milvus_uri
search_config.milvus_token = self.milvus_token
search_config.milvus_host = self.milvus_host
search_config.milvus_port = self.milvus_port
search_config.milvus_user = self.milvus_user
Expand Down Expand Up @@ -130,6 +149,8 @@ def insert_config(self):
insert_config.embedding_device = self.textencoder_config['device']

# Configure vector store (Milvus/Zilliz)
insert_config.milvus_uri = self.milvus_uri
insert_config.milvus_token = self.milvus_token
insert_config.milvus_host = self.milvus_host
insert_config.milvus_port = self.milvus_port
insert_config.milvus_user = self.milvus_user
Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/src_towhee/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
# 'uri': 'https://localhost:19530',
'user': None,
'password': None,
'secure': False
}

RERANK_CONFIG['rerank'] = False
Expand Down

0 comments on commit 2d4fc9c

Please sign in to comment.