Skip to content

Commit

Permalink
Bug Fix: chat completions API calls need model_id (opea-project#114)
Browse files Browse the repository at this point in the history
* added default model

Signed-off-by: Tyler Wilbers <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added instructions for enviornment variable

Signed-off-by: Tyler Wilbers <[email protected]>

* added bash to codeblock

Signed-off-by: Tyler Wilbers <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed typo

Signed-off-by: Tyler Wilbers <[email protected]>

---------

Signed-off-by: Tyler Wilbers <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Tyler W and pre-commit-ci[bot] authored Jun 21, 2024
1 parent df0c119 commit 88a147d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
6 changes: 6 additions & 0 deletions comps/guardrails/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ curl 127.0.0.1:8088/generate \

## 1.4 Start Guardrails Service

Optional: If you have deployed a Guardrails model with TGI Gaudi Service other than default model (i.e., `meta-llama/LlamaGuard-7b`) [from section 1.2](## 1.2 Start TGI Gaudi Service), you will need to add the eviornment variable `SAFETY_GUARD_MODEL_ID` containing the model id. For example, the following informs the Guardrails Service the deployed model used LlamaGuard2:

```bash
export SAFETY_GUARD_MODEL_ID="meta-llama/Meta-Llama-Guard-2-8"
```

```bash
export SAFETY_GUARD_ENDPOINT="http://${your_ip}:8088"
python langchain/guardrails_tgi_gaudi.py
Expand Down
18 changes: 17 additions & 1 deletion comps/guardrails/langchain/guardrails_tgi_gaudi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@

import os

from langchain_community.utilities.requests import JsonRequestsWrapper
from langchain_huggingface import ChatHuggingFace
from langchain_huggingface.llms import HuggingFaceEndpoint
from langsmith import traceable

from comps import ServiceType, TextDoc, opea_microservices, register_microservice

DEFAULT_MODEL = "meta-llama/LlamaGuard-7b"

def get_unsafe_dict(model_id="meta-llama/LlamaGuard-7b"):

def get_unsafe_dict(model_id=DEFAULT_MODEL):
if model_id == "meta-llama/LlamaGuard-7b":
return {
"O1": "Violence and Hate",
Expand Down Expand Up @@ -38,6 +41,18 @@ def get_unsafe_dict(model_id="meta-llama/LlamaGuard-7b"):
}


def get_tgi_service_model_id(endpoint_url, default=DEFAULT_MODEL):
"""Returns Hugging Face repo id for deployed service's info endpoint
otherwise return default model."""
try:
requests = JsonRequestsWrapper()
info_endpoint = os.path.join(endpoint_url, "info")
model_info = requests.get(info_endpoint)
return model_info["model_id"]
except Exception as e:
return default


@register_microservice(
name="opea_service@guardrails_tgi_gaudi",
service_type=ServiceType.GUARDRAIL,
Expand All @@ -64,6 +79,7 @@ def safety_guard(input: TextDoc) -> TextDoc:

if __name__ == "__main__":
safety_guard_endpoint = os.getenv("SAFETY_GUARD_ENDPOINT", "http://localhost:8080")
safety_guard_model = os.getenv("SAFETY_GUARD_MODEL_ID", get_tgi_service_model_id(safety_guard_endpoint))
llm_guard = HuggingFaceEndpoint(
endpoint_url=safety_guard_endpoint,
max_new_tokens=100,
Expand Down
1 change: 1 addition & 0 deletions comps/guardrails/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
docarray[full]
fastapi
huggingface_hub
langchain-community
langchain-huggingface
langsmith
opentelemetry-api
Expand Down

0 comments on commit 88a147d

Please sign in to comment.