From 358dbd665cdccc7734e572c6b78db312690b499a Mon Sep 17 00:00:00 2001 From: Liangyx2 Date: Wed, 19 Jun 2024 09:51:25 +0800 Subject: [PATCH] Use parameter for retriever (#159) * fix Signed-off-by: Liangyx2 for more information, see https://pre-commit.ci --------- Signed-off-by: Liangyx2 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/cores/proto/docarray.py | 8 +++++ comps/dataprep/milvus/README.md | 6 ++++ comps/dataprep/milvus/prepare_doc_milvus.py | 10 +++--- comps/dataprep/qdrant/README.md | 6 ++++ comps/dataprep/qdrant/prepare_doc_qdrant.py | 10 +++--- comps/dataprep/redis/README.md | 11 ++++++ .../redis/langchain/prepare_doc_redis.py | 17 ++++++---- comps/retrievers/langchain/milvus/README.md | 34 +++++++++++++++++++ .../langchain/milvus/retriever_milvus.py | 18 +++++++++- comps/retrievers/langchain/redis/README.md | 34 +++++++++++++++++++ .../langchain/redis/retriever_redis.py | 18 +++++++++- 11 files changed, 156 insertions(+), 16 deletions(-) diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index c657d0c2e..9854cfff7 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -20,11 +20,19 @@ class Base64ByteStrDoc(BaseDoc): class DocPath(BaseDoc): path: str + chunk_size: int = 1500 + chunk_overlap: int = 100 class EmbedDoc768(BaseDoc): text: str embedding: conlist(float, min_length=768, max_length=768) + search_type: str = "similarity" + k: int = 4 + distance_threshold: Optional[float] = None + fetch_k: int = 20 + lambda_mult: float = 0.5 + score_threshold: float = 0.2 class Audio2TextDoc(AudioDoc): diff --git a/comps/dataprep/milvus/README.md b/comps/dataprep/milvus/README.md index d57b35cc6..e698c400e 100644 --- a/comps/dataprep/milvus/README.md +++ b/comps/dataprep/milvus/README.md @@ -53,3 +53,9 @@ Once document preparation microservice for Qdrant is started, user can use below ```bash curl -X POST -H "Content-Type: application/json" -d '{"path":"/home/user/doc/your_document_name"}' http://localhost:6010/v1/dataprep ``` + +You can specify chunk_size and chunk_size by the following commands. + +```bash +curl -X POST -H "Content-Type: application/json" -d '{"path":"/home/user/doc/your_document_name","chunk_size":1500,"chunk_overlap":100}' http://localhost:6010/v1/dataprep +``` diff --git a/comps/dataprep/milvus/prepare_doc_milvus.py b/comps/dataprep/milvus/prepare_doc_milvus.py index 2f7068712..7f8028e00 100644 --- a/comps/dataprep/milvus/prepare_doc_milvus.py +++ b/comps/dataprep/milvus/prepare_doc_milvus.py @@ -31,11 +31,13 @@ # @opea_telemetry def ingest_documents(doc_path: DocPath): """Ingest document to Milvus.""" - doc_path = doc_path.path - print(f"Parsing document {doc_path}.") + path = doc_path.path + print(f"Parsing document {path}.") - text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100, add_start_index=True) - content = document_loader(doc_path) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=doc_path.chunk_size, chunk_overlap=doc_path.chunk_size, add_start_index=True + ) + content = document_loader(path) chunks = text_splitter.split_text(content) print("Done preprocessing. Created ", len(chunks), " chunks of the original pdf") diff --git a/comps/dataprep/qdrant/README.md b/comps/dataprep/qdrant/README.md index 9e4cd63e7..108a4c2fa 100644 --- a/comps/dataprep/qdrant/README.md +++ b/comps/dataprep/qdrant/README.md @@ -69,3 +69,9 @@ Once document preparation microservice for Qdrant is started, user can use below ```bash curl -X POST -H "Content-Type: application/json" -d '{"path":"/path/to/document"}' http://localhost:6000/v1/dataprep ``` + +You can specify chunk_size and chunk_size by the following commands. + +```bash +curl -X POST -H "Content-Type: application/json" -d '{"path":"/path/to/document","chunk_size":1500,"chunk_overlap":100}' http://localhost:6000/v1/dataprep +``` diff --git a/comps/dataprep/qdrant/prepare_doc_qdrant.py b/comps/dataprep/qdrant/prepare_doc_qdrant.py index 7949b5814..ba379a71e 100644 --- a/comps/dataprep/qdrant/prepare_doc_qdrant.py +++ b/comps/dataprep/qdrant/prepare_doc_qdrant.py @@ -25,11 +25,13 @@ @opea_telemetry def ingest_documents(doc_path: DocPath): """Ingest document to Qdrant.""" - doc_path = doc_path.path - print(f"Parsing document {doc_path}.") + path = doc_path.path + print(f"Parsing document {path}.") - text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100, add_start_index=True) - content = document_loader(doc_path) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=doc_path.chunk_size, chunk_overlap=doc_path.chunk_size, add_start_index=True + ) + content = document_loader(path) chunks = text_splitter.split_text(content) print("Done preprocessing. Created ", len(chunks), " chunks of the original pdf") diff --git a/comps/dataprep/redis/README.md b/comps/dataprep/redis/README.md index d6265db30..b548af29b 100644 --- a/comps/dataprep/redis/README.md +++ b/comps/dataprep/redis/README.md @@ -140,6 +140,17 @@ curl -X POST \ http://localhost:6007/v1/dataprep ``` +You can specify chunk_size and chunk_size by the following commands. + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@/home/sdp/yuxiang/opea_intent/GenAIComps4/comps/table_extraction/LLAMA2_page6.pdf" \ + -F "chunk_size=1500" \ + -F "chunk_overlap=100" \ + http://localhost:6007/v1/dataprep +``` + - Multiple file upload ```bash diff --git a/comps/dataprep/redis/langchain/prepare_doc_redis.py b/comps/dataprep/redis/langchain/prepare_doc_redis.py index 321a3fb21..717d48be6 100644 --- a/comps/dataprep/redis/langchain/prepare_doc_redis.py +++ b/comps/dataprep/redis/langchain/prepare_doc_redis.py @@ -33,11 +33,13 @@ async def save_file_to_local_disk(save_path: str, file): def ingest_data_to_redis(doc_path: DocPath): """Ingest document to Redis.""" - doc_path = doc_path.path - print(f"Parsing document {doc_path}.") + path = doc_path.path + print(f"Parsing document {path}.") - text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100, add_start_index=True) - content = document_loader(doc_path) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=doc_path.chunk_size, chunk_overlap=doc_path.chunk_size, add_start_index=True + ) + content = document_loader(path) chunks = text_splitter.split_text(content) print("Done preprocessing. Created ", len(chunks), " chunks of the original pdf") @@ -99,7 +101,10 @@ def ingest_link_to_redis(link_list: List[str]): @register_microservice(name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep", host="0.0.0.0", port=6007) @traceable(run_type="tool") async def ingest_documents( - files: Optional[Union[UploadFile, List[UploadFile]]] = File(None), link_list: Optional[str] = Form(None) + files: Optional[Union[UploadFile, List[UploadFile]]] = File(None), + link_list: Optional[str] = Form(None), + chunk_size: int = Form(1500), + chunk_overlap: int = Form(100), ): print(f"files:{files}") print(f"link_list:{link_list}") @@ -115,7 +120,7 @@ async def ingest_documents( for file in files: save_path = upload_folder + file.filename await save_file_to_local_disk(save_path, file) - ingest_data_to_redis(DocPath(path=save_path)) + ingest_data_to_redis(DocPath(path=save_path, chunk_size=chunk_size, chunk_overlap=chunk_overlap)) print(f"Successfully saved file {save_path}") return {"status": 200, "message": "Data preparation succeeded"} diff --git a/comps/retrievers/langchain/milvus/README.md b/comps/retrievers/langchain/milvus/README.md index ccba5b0e7..3df06b2be 100644 --- a/comps/retrievers/langchain/milvus/README.md +++ b/comps/retrievers/langchain/milvus/README.md @@ -66,3 +66,37 @@ curl http://${your_ip}:7000/v1/retrieval \ -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding}}" \ -H 'Content-Type: application/json' ``` + +You can set the parameters for the retriever. + +```bash +your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") +curl http://localhost:7000/v1/retrieval \ + -X POST \ + -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity\", \"k\":4}" \ + -H 'Content-Type: application/json' +``` + +```bash +your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") +curl http://localhost:7000/v1/retrieval \ + -X POST \ + -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity_distance_threshold\", \"k\":4, \"distance_threshold\":1.0}" \ + -H 'Content-Type: application/json' +``` + +```bash +your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") +curl http://localhost:7000/v1/retrieval \ + -X POST \ + -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity_score_threshold\", \"k\":4, \"score_threshold\":0.2}" \ + -H 'Content-Type: application/json' +``` + +```bash +your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") +curl http://localhost:7000/v1/retrieval \ + -X POST \ + -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"mmr\", \"k\":4, \"fetch_k\":20, \"lambda_mult\":0.5}" \ + -H 'Content-Type: application/json' +``` diff --git a/comps/retrievers/langchain/milvus/retriever_milvus.py b/comps/retrievers/langchain/milvus/retriever_milvus.py index 8fed84306..4f0b14b8a 100644 --- a/comps/retrievers/langchain/milvus/retriever_milvus.py +++ b/comps/retrievers/langchain/milvus/retriever_milvus.py @@ -38,7 +38,23 @@ def retrieve(input: EmbedDoc768) -> SearchedDoc: collection_name=COLLECTION_NAME, ) start = time.time() - search_res = vector_db.similarity_search_by_vector(embedding=input.embedding) + if input.search_type == "similarity": + search_res = vector_db.similarity_search_by_vector(embedding=input.embedding, k=input.k) + elif input.search_type == "similarity_distance_threshold": + if input.distance_threshold is None: + raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever") + search_res = vector_db.similarity_search_by_vector( + embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold + ) + elif input.search_type == "similarity_score_threshold": + docs_and_similarities = vector_db.similarity_search_with_relevance_scores( + query=input.text, k=input.k, score_threshold=input.score_threshold + ) + search_res = [doc for doc, _ in docs_and_similarities] + elif input.search_type == "mmr": + search_res = vector_db.max_marginal_relevance_search( + query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult + ) searched_docs = [] for r in search_res: searched_docs.append(TextDoc(text=r.page_content)) diff --git a/comps/retrievers/langchain/redis/README.md b/comps/retrievers/langchain/redis/README.md index 38f64f33e..d9c1916b2 100644 --- a/comps/retrievers/langchain/redis/README.md +++ b/comps/retrievers/langchain/redis/README.md @@ -117,3 +117,37 @@ curl http://${your_ip}:7000/v1/retrieval \ -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding}}" \ -H 'Content-Type: application/json' ``` + +You can set the parameters for the retriever. + +```bash +your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") +curl http://localhost:7000/v1/retrieval \ + -X POST \ + -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity\", \"k\":4}" \ + -H 'Content-Type: application/json' +``` + +```bash +your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") +curl http://localhost:7000/v1/retrieval \ + -X POST \ + -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity_distance_threshold\", \"k\":4, \"distance_threshold\":1.0}" \ + -H 'Content-Type: application/json' +``` + +```bash +your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") +curl http://localhost:7000/v1/retrieval \ + -X POST \ + -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"similarity_score_threshold\", \"k\":4, \"score_threshold\":0.2}" \ + -H 'Content-Type: application/json' +``` + +```bash +your_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") +curl http://localhost:7000/v1/retrieval \ + -X POST \ + -d "{\"text\":\"What is the revenue of Nike in 2023?\",\"embedding\":${your_embedding},\"search_type\":\"mmr\", \"k\":4, \"fetch_k\":20, \"lambda_mult\":0.5}" \ + -H 'Content-Type: application/json' +``` diff --git a/comps/retrievers/langchain/redis/retriever_redis.py b/comps/retrievers/langchain/redis/retriever_redis.py index 50b461d34..0474741da 100644 --- a/comps/retrievers/langchain/redis/retriever_redis.py +++ b/comps/retrievers/langchain/redis/retriever_redis.py @@ -34,7 +34,23 @@ @register_statistics(names=["opea_service@retriever_redis"]) def retrieve(input: EmbedDoc768) -> SearchedDoc: start = time.time() - search_res = vector_db.similarity_search_by_vector(embedding=input.embedding) + if input.search_type == "similarity": + search_res = vector_db.similarity_search_by_vector(embedding=input.embedding, k=input.k) + elif input.search_type == "similarity_distance_threshold": + if input.distance_threshold is None: + raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever") + search_res = vector_db.similarity_search_by_vector( + embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold + ) + elif input.search_type == "similarity_score_threshold": + docs_and_similarities = vector_db.similarity_search_with_relevance_scores( + query=input.text, k=input.k, score_threshold=input.score_threshold + ) + search_res = [doc for doc, _ in docs_and_similarities] + elif input.search_type == "mmr": + search_res = vector_db.max_marginal_relevance_search( + query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult + ) searched_docs = [] for r in search_res: searched_docs.append(TextDoc(text=r.page_content))