diff --git a/.gitignore b/.gitignore index eeb8a6e..1af3908 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ **/__pycache__ +**/tmp diff --git a/OWNERS b/OWNERS index 2ed47b1..d971342 100644 --- a/OWNERS +++ b/OWNERS @@ -3,8 +3,10 @@ filters: reviewers: - jaelgu - zc277584121 + - junjiejiangjjj approvers: - jaelgu - zc277584121 + - junjiejiangjjj - codingjaguar - xiaofan-luan diff --git a/config.py b/config.py index 683b73c..1f81635 100644 --- a/config.py +++ b/config.py @@ -88,14 +88,19 @@ # Scalar db configs SCALARDB_CONFIG = { 'connection_args': { - # 'hosts': os.getenv('ES_HOSTS', 'https://localhost:9200'), - 'cloud_id': os.getenv('ES_CLOUD_ID'), - 'ca_certs': os.getenv('ES_CA_CERTS', None), - 'basic_auth': (os.getenv('ES_USER', 'user_name'), os.getenv('ES_PASSWORD', 'es_password')) + 'hosts': os.getenv('ES_HOSTS', 'http://localhost:9200'), }, 'top_k': 3 } +for arg in ['ES_CLOUD_ID', 'ES_CA_CERTS', 'ES_CA_CERTS']: + arg_k = arg.replace('ES_', '').lower() + arg_v = os.getenv(arg) + if arg_v: + SCALARDB_CONFIG['connection_args'][arg_k] = arg_v +if os.getenv('ES_USER'): + SCALARDB_CONFIG['connection_args']['basic_auth'] = (os.getenv('ES_USER'), os.getenv('ES_PASSWORD')) + # Memory db configs MEMORYDB_CONFIG = { 'connect_str': os.getenv('SQL_URI', 'postgresql://postgres:postgres@localhost/chat_history') diff --git a/main.py b/main.py index 4c640ab..be8aa46 100644 --- a/main.py +++ b/main.py @@ -24,9 +24,9 @@ 'The service should start with either "--langchain" or "--towhee".' if USE_LANGCHAIN: - from src_langchain.operations import chat, insert, drop, check, get_history, clear_history # pylint: disable=C0413 + from src_langchain.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413 if USE_TOWHEE: - from src_towhee.operations import chat, insert, drop, check, get_history, clear_history # pylint: disable=C0413 + from src_towhee.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413 app = FastAPI() origins = ['*'] @@ -90,6 +90,15 @@ def do_project_check_api(project: str): return jsonable_encoder({'status': False, 'msg': f'Failed to check project:\n{e}'}), 400 +@app.get('/project/count') +def do_project_count_api(project: str): + try: + counts = count(project) + return jsonable_encoder({'status': True, 'msg': counts}), 200 + except Exception as e: # pylint: disable=W0703 + return jsonable_encoder({'status': False, 'msg': f'Failed to count entities:\n{e}'}), 400 + + @app.get('/history/get') def do_history_get_api(project: str, session_id: str = None): try: diff --git a/src_langchain/operations.py b/src_langchain/operations.py index 925be58..2b9eb5f 100644 --- a/src_langchain/operations.py +++ b/src_langchain/operations.py @@ -79,13 +79,13 @@ def drop(project): def check(project): - '''Check existences of project tables in both vector and memory stores.''' + '''Check existences of project tables in both doc stores and memory stores.''' try: doc_check = DocStore.has_project(project) except Exception as e: - logger.error('Failed to check table in vector db:\n%s', e) + logger.error('Failed to check doc stores:\n%s', e) raise RuntimeError from e - # Clear memory + # Check memory try: memory_check = MemoryStore.check(project) except Exception as e: @@ -93,6 +93,15 @@ def check(project): raise RuntimeError from e return {'store': doc_check, 'memory': memory_check} +def count(project): + '''Count entities.''' + try: + counts = DocStore.count_entities(project=project) + return counts + except Exception as e: + logger.error('Failed to count entities:\n%s', e) + raise RuntimeError from e + def get_history(project, session_id): '''Get conversation history from memory store.''' diff --git a/src_langchain/store/__init__.py b/src_langchain/store/__init__.py index 30caeed..df7d31d 100644 --- a/src_langchain/store/__init__.py +++ b/src_langchain/store/__init__.py @@ -97,3 +97,15 @@ def has_project(project): if USE_SCALAR: assert ScalarStore.has_project(project) == status return status + + @staticmethod + def count_entities(project): + if not VectorStore.has_project(project): + milvus_count = es_count = None + else: + milvus_count = VectorStore.count_entities(project) + if USE_SCALAR: + es_count = ScalarStore.count_entities(project) + else: + es_count = None + return {'vector store': milvus_count, 'scalar store': es_count} diff --git a/src_langchain/store/scalar_store/es.py b/src_langchain/store/scalar_store/es.py index 56a40d1..ec63402 100644 --- a/src_langchain/store/scalar_store/es.py +++ b/src_langchain/store/scalar_store/es.py @@ -48,3 +48,8 @@ def drop(project: str, connection_args: dict = CONNECTION_ARGS): def has_project(project: str, connection_args: dict = CONNECTION_ARGS): client = ScalarStore.connect(connection_args) return client.indices.exists(index=project) + + @staticmethod + def count_entities(project: str, connection_args: dict = CONNECTION_ARGS): + client = ScalarStore.connect(connection_args) + return client.count(index=project)['count'] diff --git a/src_langchain/store/vector_store/milvus.py b/src_langchain/store/vector_store/milvus.py index 9a65da3..958dcd7 100644 --- a/src_langchain/store/vector_store/milvus.py +++ b/src_langchain/store/vector_store/milvus.py @@ -195,3 +195,12 @@ def has_project(project: str, connection_args: dict = CONNECTION_ARGS): VectorStore.connect(connection_args) return utility.has_collection(project) + + + @staticmethod + def count_entities(project: str, connection_args: dict = CONNECTION_ARGS): + from pymilvus import Collection # pylint: disable=C0415 + + VectorStore.connect(connection_args) + collection = Collection(project) + return collection.num_entities diff --git a/src_towhee/operations.py b/src_towhee/operations.py index 88b2495..49b2029 100644 --- a/src_towhee/operations.py +++ b/src_towhee/operations.py @@ -46,7 +46,7 @@ def insert(data_src, project, source_type: str = 'file'): # pylint: disable=W061 if not towhee_pipelines.check(project): towhee_pipelines.create(project) res = insert_pipeline(data_src, project).to_list() - num = towhee_pipelines.count_entities(project) + num = towhee_pipelines.count_entities(project)['vector store'] assert len(res) <= num, 'Failed to insert data.' return len(res) @@ -79,7 +79,7 @@ def check(project): try: doc_check = towhee_pipelines.check(project) except Exception as e: - logger.error('Failed to check table in vector db:\n%s', e) + logger.error('Failed to check doc stores:\n%s', e) raise RuntimeError from e # Check memory try: @@ -90,6 +90,16 @@ def check(project): return {'store': doc_check, 'memory': memory_check} +def count(project): + '''Count entities.''' + try: + counts = towhee_pipelines.count_entities(project) + return counts + except Exception as e: + logger.error('Failed to count entities:\n%s', e) + raise RuntimeError from e + + def get_history(project, session_id): '''Get conversation history from memory store.''' try: diff --git a/src_towhee/pipelines/__init__.py b/src_towhee/pipelines/__init__.py index a599289..afee7ea 100644 --- a/src_towhee/pipelines/__init__.py +++ b/src_towhee/pipelines/__init__.py @@ -190,10 +190,14 @@ def check(self, project): return status def count_entities(self, project): - collection = Collection(project) - collection.flush() - milvus_count = collection.num_entities - if self.use_scalar: - es_count = self.es_client.count(index=project)['count'] - assert es_count == milvus_count, 'Mismatched data count in Milvus vs Elastic.' - return milvus_count + if not self.check(project): + milvus_count = es_count = None + else: + collection = Collection(project) + collection.flush() + milvus_count = collection.num_entities + if self.use_scalar: + es_count = self.es_client.count(index=project)['count'] + else: + es_count = None + return {'vector store': milvus_count, 'scalar store': es_count} diff --git a/tests/unit_tests/src_towhee/pipelines/test_pipelines.py b/tests/unit_tests/src_towhee/pipelines/test_pipelines.py index 4240693..7570ddb 100644 --- a/tests/unit_tests/src_towhee/pipelines/test_pipelines.py +++ b/tests/unit_tests/src_towhee/pipelines/test_pipelines.py @@ -81,7 +81,7 @@ def test_openai(self): insert_pipeline = pipelines.insert_pipeline res = insert_pipeline(self.data_src, self.project).to_list() - num = pipelines.count_entities(self.project) + num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num # Check search @@ -114,7 +114,7 @@ def test_chatglm(self): insert_pipeline = pipelines.insert_pipeline res = insert_pipeline(self.data_src, self.project).to_list() - num = pipelines.count_entities(self.project) + num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num # Check search @@ -149,7 +149,7 @@ def json(self): insert_pipeline = pipelines.insert_pipeline res = insert_pipeline(self.data_src, self.project).to_list() - num = pipelines.count_entities(self.project) + num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num # Check search @@ -190,7 +190,7 @@ def output(self): insert_pipeline = pipelines.insert_pipeline res = insert_pipeline(self.data_src, self.project).to_list() - num = pipelines.count_entities(self.project) + num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num # Check search @@ -225,7 +225,7 @@ def json(self): insert_pipeline = pipelines.insert_pipeline res = insert_pipeline(self.data_src, self.project).to_list() - num = pipelines.count_entities(self.project) + num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num # Check search @@ -262,7 +262,7 @@ def iter_lines(self): insert_pipeline = pipelines.insert_pipeline res = insert_pipeline(self.data_src, self.project).to_list() - num = pipelines.count_entities(self.project) + num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num # Check search @@ -300,7 +300,7 @@ def __call__(self, *args, **kwargs): insert_pipeline = pipelines.insert_pipeline res = insert_pipeline(self.data_src, self.project).to_list() - num = pipelines.count_entities(self.project) + num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num # Check search diff --git a/tests/unit_tests/src_towhee/test_operations.py b/tests/unit_tests/src_towhee/test_operations.py index 41473c5..1a27cfc 100644 --- a/tests/unit_tests/src_towhee/test_operations.py +++ b/tests/unit_tests/src_towhee/test_operations.py @@ -15,7 +15,7 @@ def __init__(self, *args, **kwargs): def get_history(self, *args, **kwargs): return self.memory - + def add_history(self, project, session_id, messages: list): self.memory += messages @@ -25,6 +25,7 @@ def drop(self, *args, **kwargs): def check(self, *args, **kwargs): return len(self.memory) > 0 + class MockPipeline: def __init__(self, *args, **kwargs): self.search_que = DataQueue([('answer', ColumnType.SCALAR)]) @@ -44,14 +45,14 @@ def insert_pipeline(self, data_src, project, source_type='file'): def check(self, project): return project in self.projects - + def create(self, project): if not self.check(project): self.projects[project] = '' def count_entities(self, project): - return len(self.projects[project]) - + return {'vector store': len(self.projects[project]), 'scalar store': None} + def drop(self, project): del self.projects[project] @@ -64,26 +65,26 @@ class TestOperations(unittest.TestCase): expect_len = 1 question = 'the first question' expect_answer = 'mock answer' - + def test_chat(self): with patch('src_towhee.pipelines.TowheePipelines') as mock_pipelines, \ - patch('src_towhee.memory.MemoryStore') as mock_memory: + patch('src_towhee.memory.MemoryStore') as mock_memory: mock_pipelines.return_value = MockPipeline() mock_memory.return_value = MockStore() from src_towhee.pipelines import TowheePipelines from src_towhee.memory import MemoryStore - with patch.object(TowheePipelines, 'search_pipeline', mock_pipelines.search_pipeline), \ - patch.object(MemoryStore, 'add_history', mock_memory.add_history), \ - patch.object(MemoryStore, 'get_history', mock_memory.get_history), \ - patch.object(MemoryStore, 'drop', mock_memory.drop): - + patch.object(MemoryStore, 'add_history', mock_memory.add_history), \ + patch.object(MemoryStore, 'get_history', mock_memory.get_history), \ + patch.object(MemoryStore, 'drop', mock_memory.drop): + from src_towhee.operations import chat, get_history, clear_history - question, answer = chat(self.session_id, self.project, self.question) + question, answer = chat( + self.session_id, self.project, self.question) assert answer == self.expect_answer history = get_history(self.project, self.session_id) @@ -93,11 +94,10 @@ def test_chat(self): clean_history = get_history(self.project, self.session_id) assert clean_history == [] - def test_insert(self): with patch('src_towhee.pipelines.TowheePipelines') as mock_pipelines, \ - patch('src_towhee.memory.MemoryStore') as mock_memory: + patch('src_towhee.memory.MemoryStore') as mock_memory: mock_pipelines.return_value = MockPipeline() mock_memory.return_value = MockStore() @@ -105,12 +105,12 @@ def test_insert(self): from src_towhee.memory import MemoryStore with patch.object(TowheePipelines, 'insert_pipeline', mock_pipelines.insert_pipeline), \ - patch.object(TowheePipelines, 'count_entities', mock_pipelines.count_entities), \ - patch.object(TowheePipelines, 'check', mock_pipelines.check), \ - patch.object(TowheePipelines, 'drop', mock_pipelines.drop), \ - patch.object(MemoryStore, 'check', mock_memory.check), \ - patch.object(MemoryStore, 'drop', mock_memory.drop): - + patch.object(TowheePipelines, 'count_entities', mock_pipelines.count_entities), \ + patch.object(TowheePipelines, 'check', mock_pipelines.check), \ + patch.object(TowheePipelines, 'drop', mock_pipelines.drop), \ + patch.object(MemoryStore, 'check', mock_memory.check), \ + patch.object(MemoryStore, 'drop', mock_memory.drop): + from src_towhee.operations import insert, check, drop count = insert(self.test_src, self.project) @@ -124,5 +124,5 @@ def test_insert(self): assert not status['memory'] -if __name__== '__main__': - unittest.main() \ No newline at end of file +if __name__ == '__main__': + unittest.main()