Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RAG performance issues #132

Merged
merged 11 commits into from
Jun 8, 2024
2 changes: 1 addition & 1 deletion comps/embeddings/langchain/local_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
)
@opea_telemetry
def embedding(input: TextDoc) -> EmbedDoc1024:
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-large-en-v1.5")
embed_vector = embeddings.embed_query(input.text)
res = EmbedDoc1024(text=input.text, embedding=embed_vector)
return res


if __name__ == "__main__":
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-large-en-v1.5")
opea_microservices["opea_service@local_embedding"].start()
4 changes: 2 additions & 2 deletions comps/guardrails/langchain/guardrails_tgi_gaudi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def get_unsafe_dict(model_id="meta-llama/LlamaGuard-7b"):
)
@traceable(run_type="llm")
def safety_guard(input: TextDoc) -> TextDoc:
# chat engine for server-side prompt templating
llm_engine_hf = ChatHuggingFace(llm=llm_guard)
response_input_guard = llm_engine_hf.invoke([{"role": "user", "content": input.text}]).content
if "unsafe" in response_input_guard:
unsafe_dict = get_unsafe_dict(llm_engine_hf.model_id)
Expand All @@ -75,5 +73,7 @@ def safety_guard(input: TextDoc) -> TextDoc:
temperature=0.01,
repetition_penalty=1.03,
)
# chat engine for server-side prompt templating
llm_engine_hf = ChatHuggingFace(llm=llm_guard)
print("guardrails - router] LLM initialized.")
opea_microservices["opea_service@guardrails_tgi_gaudi"].start()
18 changes: 9 additions & 9 deletions comps/retrievers/langchain/retriever_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
)
@traceable(run_type="retriever")
def retrieve(input: EmbedDoc768) -> SearchedDoc:
search_res = vector_db.similarity_search_by_vector(embedding=input.embedding)
searched_docs = []
for r in search_res:
searched_docs.append(TextDoc(text=r.page_content))
result = SearchedDoc(retrieved_docs=searched_docs, initial_query=input.text)
return result


if __name__ == "__main__":
# Create vectorstore
if tei_embedding_endpoint:
# create embeddings using TEI endpoint service
Expand All @@ -36,13 +45,4 @@ def retrieve(input: EmbedDoc768) -> SearchedDoc:
redis_url=REDIS_URL,
schema=INDEX_SCHEMA,
)
search_res = vector_db.similarity_search_by_vector(embedding=input.embedding)
searched_docs = []
for r in search_res:
searched_docs.append(TextDoc(text=r.page_content))
result = SearchedDoc(retrieved_docs=searched_docs, initial_query=input.text)
return result


if __name__ == "__main__":
opea_microservices["opea_service@retriever_redis"].start()
8 changes: 5 additions & 3 deletions tests/test_llms_text-generation_tgi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ function start_service() {

# check whether tgi is fully ready
n=0
until [[ "$n" -ge 100 ]]; do
docker logs test-comps-llm-tgi-endpoint > test-comps-llm-tgi-endpoint.log
until [[ "$n" -ge 100 ]] || [[ $ready == true ]]; do
docker logs test-comps-llm-tgi-endpoint > ${WORKPATH}/tests/test-comps-llm-tgi-endpoint.log
n=$((n+1))
if grep -q Connected test-comps-llm-tgi-endpoint.log; then
if grep -q Connected ${WORKPATH}/tests/test-comps-llm-tgi-endpoint.log; then
break
fi
sleep 5s
Expand All @@ -43,6 +43,8 @@ function validate_microservice() {
-X POST \
-d '{"query":"What is Deep Learning?"}' \
-H 'Content-Type: application/json'
docker logs test-comps-llm-tgi-endpoint
docker logs test-comps-llm-tgi-server
}

function stop_docker() {
Expand Down
2 changes: 2 additions & 0 deletions tests/test_reranks_langchain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ function validate_microservice() {
-X POST \
-d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}]}' \
-H 'Content-Type: application/json'
docker logs test-comps-reranking-tei-server
docker logs test-comps-reranking-tei-endpoint
}

function stop_docker() {
Expand Down
27 changes: 20 additions & 7 deletions tests/test_retrievers_langchain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@ WORKPATH=$(dirname "$PWD")
ip_address=$(hostname -I | awk '{print $1}')
function build_docker_images() {
cd $WORKPATH
docker build --no-cache -t opea/retriever-redis:comps -f comps/retrievers/langchain/docker/Dockerfile .
docker build --no-cache -t opea/retriever-redis:comps --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/retrievers/langchain/docker/Dockerfile .
}

function start_service() {
# redis
docker run -d --name test-redis-vector-db -p 5010:6379 -p 5011:8001 -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy redis/redis-stack:7.2.0-v9
sleep 10s

# tei endpoint
tei_endpoint=5008
model="BAAI/bge-large-en-v1.5"
revision="refs/pr/5"
docker run -d --name="test-comps-retriever-tei-endpoint" -p $tei_endpoint:80 -v ./data:/data --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.2 --model-id $model --revision $revision
model="BAAI/bge-base-en-v1.5"
docker run -d --name="test-comps-retriever-tei-endpoint" -p $tei_endpoint:80 -v ./data:/data --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.2 --model-id $model
sleep 30s
export TEI_EMBEDDING_ENDPOINT="http://${ip_address}:${tei_endpoint}"

# redis retriever
export REDIS_URL="redis://${ip_address}:6379"
export REDIS_URL="redis://${ip_address}:5010"
export INDEX_NAME="rag-redis"
retriever_port=5009
unset http_proxy
Expand All @@ -38,11 +42,20 @@ function validate_microservice() {
-X POST \
-d "{\"text\":\"test\",\"embedding\":${test_embedding}}" \
-H 'Content-Type: application/json'
docker logs test-comps-retriever-redis-server
docker logs test-comps-retriever-tei-endpoint
}

function stop_docker() {
cid=$(docker ps -aq --filter "name=test-comps-retrievers*")
if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid && sleep 1s; fi
cid_retrievers=$(docker ps -aq --filter "name=test-comps-retrievers*")
if [[ ! -z "$cid_retrievers" ]]; then
docker stop $cid_retrievers && docker rm $cid_retrievers && sleep 1s
fi

cid_redis=$(docker ps -aq --filter "name=test-redis-vector-db")
if [[ ! -z "$cid_redis" ]]; then
docker stop $cid_redis && docker rm $cid_redis && sleep 1s
fi
}

function main() {
Expand Down
Loading