From 88a147da7d7faf381f407cfdd3e4b03b5016d9f6 Mon Sep 17 00:00:00 2001 From: Tyler W Date: Fri, 21 Jun 2024 03:49:11 +0000 Subject: [PATCH] Bug Fix: chat completions API calls need `model_id` (#114) * added default model Signed-off-by: Tyler Wilbers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added instructions for enviornment variable Signed-off-by: Tyler Wilbers * added bash to codeblock Signed-off-by: Tyler Wilbers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed typo Signed-off-by: Tyler Wilbers --------- Signed-off-by: Tyler Wilbers Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/guardrails/README.md | 6 ++++++ .../langchain/guardrails_tgi_gaudi.py | 18 +++++++++++++++++- comps/guardrails/requirements.txt | 1 + 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/comps/guardrails/README.md b/comps/guardrails/README.md index 6b0a72ab3..1107e7fe4 100644 --- a/comps/guardrails/README.md +++ b/comps/guardrails/README.md @@ -54,6 +54,12 @@ curl 127.0.0.1:8088/generate \ ## 1.4 Start Guardrails Service +Optional: If you have deployed a Guardrails model with TGI Gaudi Service other than default model (i.e., `meta-llama/LlamaGuard-7b`) [from section 1.2](## 1.2 Start TGI Gaudi Service), you will need to add the eviornment variable `SAFETY_GUARD_MODEL_ID` containing the model id. For example, the following informs the Guardrails Service the deployed model used LlamaGuard2: + +```bash +export SAFETY_GUARD_MODEL_ID="meta-llama/Meta-Llama-Guard-2-8" +``` + ```bash export SAFETY_GUARD_ENDPOINT="http://${your_ip}:8088" python langchain/guardrails_tgi_gaudi.py diff --git a/comps/guardrails/langchain/guardrails_tgi_gaudi.py b/comps/guardrails/langchain/guardrails_tgi_gaudi.py index 002cbc6ce..03d193505 100644 --- a/comps/guardrails/langchain/guardrails_tgi_gaudi.py +++ b/comps/guardrails/langchain/guardrails_tgi_gaudi.py @@ -3,14 +3,17 @@ import os +from langchain_community.utilities.requests import JsonRequestsWrapper from langchain_huggingface import ChatHuggingFace from langchain_huggingface.llms import HuggingFaceEndpoint from langsmith import traceable from comps import ServiceType, TextDoc, opea_microservices, register_microservice +DEFAULT_MODEL = "meta-llama/LlamaGuard-7b" -def get_unsafe_dict(model_id="meta-llama/LlamaGuard-7b"): + +def get_unsafe_dict(model_id=DEFAULT_MODEL): if model_id == "meta-llama/LlamaGuard-7b": return { "O1": "Violence and Hate", @@ -38,6 +41,18 @@ def get_unsafe_dict(model_id="meta-llama/LlamaGuard-7b"): } +def get_tgi_service_model_id(endpoint_url, default=DEFAULT_MODEL): + """Returns Hugging Face repo id for deployed service's info endpoint + otherwise return default model.""" + try: + requests = JsonRequestsWrapper() + info_endpoint = os.path.join(endpoint_url, "info") + model_info = requests.get(info_endpoint) + return model_info["model_id"] + except Exception as e: + return default + + @register_microservice( name="opea_service@guardrails_tgi_gaudi", service_type=ServiceType.GUARDRAIL, @@ -64,6 +79,7 @@ def safety_guard(input: TextDoc) -> TextDoc: if __name__ == "__main__": safety_guard_endpoint = os.getenv("SAFETY_GUARD_ENDPOINT", "http://localhost:8080") + safety_guard_model = os.getenv("SAFETY_GUARD_MODEL_ID", get_tgi_service_model_id(safety_guard_endpoint)) llm_guard = HuggingFaceEndpoint( endpoint_url=safety_guard_endpoint, max_new_tokens=100, diff --git a/comps/guardrails/requirements.txt b/comps/guardrails/requirements.txt index 38eff5b6c..6d44ec4a4 100644 --- a/comps/guardrails/requirements.txt +++ b/comps/guardrails/requirements.txt @@ -1,6 +1,7 @@ docarray[full] fastapi huggingface_hub +langchain-community langchain-huggingface langsmith opentelemetry-api