Skip to content

Commit

Permalink
Add data count
Browse files Browse the repository at this point in the history
Signed-off-by: Mengjia Gu <[email protected]>
  • Loading branch information
jaelgu committed Aug 23, 2023
1 parent 62c94a6 commit d5053b9
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
**/__pycache__
**/tmp
2 changes: 2 additions & 0 deletions OWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ filters:
reviewers:
- jaelgu
- zc277584121
- junjiejiangjjj
approvers:
- jaelgu
- zc277584121
- junjiejiangjjj
- codingjaguar
- xiaofan-luan
13 changes: 9 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,19 @@
# Scalar db configs
SCALARDB_CONFIG = {
'connection_args': {
# 'hosts': os.getenv('ES_HOSTS', 'https://localhost:9200'),
'cloud_id': os.getenv('ES_CLOUD_ID'),
'ca_certs': os.getenv('ES_CA_CERTS', None),
'basic_auth': (os.getenv('ES_USER', 'user_name'), os.getenv('ES_PASSWORD', 'es_password'))
'hosts': os.getenv('ES_HOSTS', 'http://localhost:9200'),
},
'top_k': 3
}

for arg in ['ES_CLOUD_ID', 'ES_CA_CERTS', 'ES_CA_CERTS']:
arg_k = arg.replace('ES_', '').lower()
arg_v = os.getenv(arg)
if arg_v:
SCALARDB_CONFIG['connection_args'][arg_k] = arg_v
if os.getenv('ES_USER'):
SCALARDB_CONFIG['connection_args']['basic_auth'] = (os.getenv('ES_USER'), os.getenv('ES_PASSWORD'))

# Memory db configs
MEMORYDB_CONFIG = {
'connect_str': os.getenv('SQL_URI', 'postgresql://postgres:postgres@localhost/chat_history')
Expand Down
13 changes: 11 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
'The service should start with either "--langchain" or "--towhee".'

if USE_LANGCHAIN:
from src_langchain.operations import chat, insert, drop, check, get_history, clear_history # pylint: disable=C0413
from src_langchain.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413
if USE_TOWHEE:
from src_towhee.operations import chat, insert, drop, check, get_history, clear_history # pylint: disable=C0413
from src_towhee.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413

app = FastAPI()
origins = ['*']
Expand Down Expand Up @@ -90,6 +90,15 @@ def do_project_check_api(project: str):
return jsonable_encoder({'status': False, 'msg': f'Failed to check project:\n{e}'}), 400


@app.get('/project/count')
def do_project_count_api(project: str):
try:
counts = count(project)
return jsonable_encoder({'status': True, 'msg': counts}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to count entities:\n{e}'}), 400


@app.get('/history/get')
def do_history_get_api(project: str, session_id: str = None):
try:
Expand Down
15 changes: 12 additions & 3 deletions src_langchain/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,29 @@ def drop(project):


def check(project):
'''Check existences of project tables in both vector and memory stores.'''
'''Check existences of project tables in both doc stores and memory stores.'''
try:
doc_check = DocStore.has_project(project)
except Exception as e:
logger.error('Failed to check table in vector db:\n%s', e)
logger.error('Failed to check doc stores:\n%s', e)
raise RuntimeError from e
# Clear memory
# Check memory
try:
memory_check = MemoryStore.check(project)
except Exception as e:
logger.error('Failed to clean memory for the project:\n%s', e)
raise RuntimeError from e
return {'store': doc_check, 'memory': memory_check}

def count(project):
'''Count entities.'''
try:
counts = DocStore.count_entities(project=project)
return counts
except Exception as e:
logger.error('Failed to count entities:\n%s', e)
raise RuntimeError from e


def get_history(project, session_id):
'''Get conversation history from memory store.'''
Expand Down
12 changes: 12 additions & 0 deletions src_langchain/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,15 @@ def has_project(project):
if USE_SCALAR:
assert ScalarStore.has_project(project) == status
return status

@staticmethod
def count_entities(project):
if not VectorStore.has_project(project):
milvus_count = es_count = None
else:
milvus_count = VectorStore.count_entities(project)
if USE_SCALAR:
es_count = ScalarStore.count_entities(project)
else:
es_count = None
return {'vector store': milvus_count, 'scalar store': es_count}
5 changes: 5 additions & 0 deletions src_langchain/store/scalar_store/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ def drop(project: str, connection_args: dict = CONNECTION_ARGS):
def has_project(project: str, connection_args: dict = CONNECTION_ARGS):
client = ScalarStore.connect(connection_args)
return client.indices.exists(index=project)

@staticmethod
def count_entities(project: str, connection_args: dict = CONNECTION_ARGS):
client = ScalarStore.connect(connection_args)
return client.count(index=project)['count']
9 changes: 9 additions & 0 deletions src_langchain/store/vector_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,12 @@ def has_project(project: str, connection_args: dict = CONNECTION_ARGS):

VectorStore.connect(connection_args)
return utility.has_collection(project)


@staticmethod
def count_entities(project: str, connection_args: dict = CONNECTION_ARGS):
from pymilvus import Collection # pylint: disable=C0415

VectorStore.connect(connection_args)
collection = Collection(project)
return collection.num_entities
14 changes: 12 additions & 2 deletions src_towhee/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def insert(data_src, project, source_type: str = 'file'): # pylint: disable=W061
if not towhee_pipelines.check(project):
towhee_pipelines.create(project)
res = insert_pipeline(data_src, project).to_list()
num = towhee_pipelines.count_entities(project)
num = towhee_pipelines.count_entities(project)['vector store']
assert len(res) <= num, 'Failed to insert data.'
return len(res)

Expand Down Expand Up @@ -79,7 +79,7 @@ def check(project):
try:
doc_check = towhee_pipelines.check(project)
except Exception as e:
logger.error('Failed to check table in vector db:\n%s', e)
logger.error('Failed to check doc stores:\n%s', e)
raise RuntimeError from e
# Check memory
try:
Expand All @@ -90,6 +90,16 @@ def check(project):
return {'store': doc_check, 'memory': memory_check}


def count(project):
'''Count entities.'''
try:
counts = towhee_pipelines.count_entities(project)
return counts
except Exception as e:
logger.error('Failed to count entities:\n%s', e)
raise RuntimeError from e


def get_history(project, session_id):
'''Get conversation history from memory store.'''
try:
Expand Down
18 changes: 11 additions & 7 deletions src_towhee/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,14 @@ def check(self, project):
return status

def count_entities(self, project):
collection = Collection(project)
collection.flush()
milvus_count = collection.num_entities
if self.use_scalar:
es_count = self.es_client.count(index=project)['count']
assert es_count == milvus_count, 'Mismatched data count in Milvus vs Elastic.'
return milvus_count
if not self.check(project):
milvus_count = es_count = None
else:
collection = Collection(project)
collection.flush()
milvus_count = collection.num_entities
if self.use_scalar:
es_count = self.es_client.count(index=project)['count']
else:
es_count = None
return {'vector store': milvus_count, 'scalar store': es_count}
14 changes: 7 additions & 7 deletions tests/unit_tests/src_towhee/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_openai(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
num = pipelines.count_entities(self.project)
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

# Check search
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_chatglm(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
num = pipelines.count_entities(self.project)
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

# Check search
Expand Down Expand Up @@ -149,7 +149,7 @@ def json(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
num = pipelines.count_entities(self.project)
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

# Check search
Expand Down Expand Up @@ -190,7 +190,7 @@ def output(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
num = pipelines.count_entities(self.project)
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

# Check search
Expand Down Expand Up @@ -225,7 +225,7 @@ def json(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
num = pipelines.count_entities(self.project)
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

# Check search
Expand Down Expand Up @@ -262,7 +262,7 @@ def iter_lines(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
num = pipelines.count_entities(self.project)
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

# Check search
Expand Down Expand Up @@ -300,7 +300,7 @@ def __call__(self, *args, **kwargs):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
num = pipelines.count_entities(self.project)
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

# Check search
Expand Down
44 changes: 22 additions & 22 deletions tests/unit_tests/src_towhee/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, *args, **kwargs):

def get_history(self, *args, **kwargs):
return self.memory

def add_history(self, project, session_id, messages: list):
self.memory += messages

Expand All @@ -25,6 +25,7 @@ def drop(self, *args, **kwargs):
def check(self, *args, **kwargs):
return len(self.memory) > 0


class MockPipeline:
def __init__(self, *args, **kwargs):
self.search_que = DataQueue([('answer', ColumnType.SCALAR)])
Expand All @@ -44,14 +45,14 @@ def insert_pipeline(self, data_src, project, source_type='file'):

def check(self, project):
return project in self.projects

def create(self, project):
if not self.check(project):
self.projects[project] = ''

def count_entities(self, project):
return len(self.projects[project])
return {'vector store': len(self.projects[project]), 'scalar store': None}

def drop(self, project):
del self.projects[project]

Expand All @@ -64,26 +65,26 @@ class TestOperations(unittest.TestCase):
expect_len = 1
question = 'the first question'
expect_answer = 'mock answer'

def test_chat(self):

with patch('src_towhee.pipelines.TowheePipelines') as mock_pipelines, \
patch('src_towhee.memory.MemoryStore') as mock_memory:
patch('src_towhee.memory.MemoryStore') as mock_memory:
mock_pipelines.return_value = MockPipeline()
mock_memory.return_value = MockStore()

from src_towhee.pipelines import TowheePipelines
from src_towhee.memory import MemoryStore


with patch.object(TowheePipelines, 'search_pipeline', mock_pipelines.search_pipeline), \
patch.object(MemoryStore, 'add_history', mock_memory.add_history), \
patch.object(MemoryStore, 'get_history', mock_memory.get_history), \
patch.object(MemoryStore, 'drop', mock_memory.drop):
patch.object(MemoryStore, 'add_history', mock_memory.add_history), \
patch.object(MemoryStore, 'get_history', mock_memory.get_history), \
patch.object(MemoryStore, 'drop', mock_memory.drop):

from src_towhee.operations import chat, get_history, clear_history

question, answer = chat(self.session_id, self.project, self.question)
question, answer = chat(
self.session_id, self.project, self.question)
assert answer == self.expect_answer

history = get_history(self.project, self.session_id)
Expand All @@ -93,24 +94,23 @@ def test_chat(self):
clean_history = get_history(self.project, self.session_id)
assert clean_history == []


def test_insert(self):

with patch('src_towhee.pipelines.TowheePipelines') as mock_pipelines, \
patch('src_towhee.memory.MemoryStore') as mock_memory:
patch('src_towhee.memory.MemoryStore') as mock_memory:
mock_pipelines.return_value = MockPipeline()
mock_memory.return_value = MockStore()

from src_towhee.pipelines import TowheePipelines
from src_towhee.memory import MemoryStore

with patch.object(TowheePipelines, 'insert_pipeline', mock_pipelines.insert_pipeline), \
patch.object(TowheePipelines, 'count_entities', mock_pipelines.count_entities), \
patch.object(TowheePipelines, 'check', mock_pipelines.check), \
patch.object(TowheePipelines, 'drop', mock_pipelines.drop), \
patch.object(MemoryStore, 'check', mock_memory.check), \
patch.object(MemoryStore, 'drop', mock_memory.drop):
patch.object(TowheePipelines, 'count_entities', mock_pipelines.count_entities), \
patch.object(TowheePipelines, 'check', mock_pipelines.check), \
patch.object(TowheePipelines, 'drop', mock_pipelines.drop), \
patch.object(MemoryStore, 'check', mock_memory.check), \
patch.object(MemoryStore, 'drop', mock_memory.drop):

from src_towhee.operations import insert, check, drop

count = insert(self.test_src, self.project)
Expand All @@ -124,5 +124,5 @@ def test_insert(self):
assert not status['memory']


if __name__== '__main__':
unittest.main()
if __name__ == '__main__':
unittest.main()

0 comments on commit d5053b9

Please sign in to comment.