Skip to content

Commit

Permalink
Remote TGI/TGI services with OAuth Client Credentials authentication (o…
Browse files Browse the repository at this point in the history
…pea-project#836)

* Add get_access_token method in utils.py to get the token for OAuth protected tgi and tei remote endpoints

* Update embedding_tei.py to support authentication for tei endpoints. Uses get_access_token from utils

* Update llm.py under llms faq-generation to support authentication for tgi endpoints. Uses get_access_token from utils

* Update llm.py under llms summarization to support authentication for tgi endpoints. Uses get_access_token from utils

* Update llm.py under llms text-generation to support authentication for tgi endpoints. Uses get_access_token from utils

* Update reranking_tei to support authentication for tgi endpoints. Uses get_access_token from utils

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sgurunat and pre-commit-ci[bot] authored Oct 31, 2024
1 parent d2e9c0a commit 74df6bb
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 16 deletions.
22 changes: 22 additions & 0 deletions comps/cores/mega/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from socket import AF_INET, SOCK_STREAM, socket
from typing import List, Optional, Union

import requests

from .logger import CustomLogger


def is_port_free(host: str, port: int) -> bool:
"""Check if a given port on a host is free.
Expand Down Expand Up @@ -183,6 +187,24 @@ def _check_bind(port):
return _random_port()


def get_access_token(token_url: str, client_id: str, client_secret: str) -> str:
"""Get access token using OAuth client credentials flow."""
logger = CustomLogger("tgi_or_tei_service_auth")
data = {
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "client_credentials",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = requests.post(token_url, data=data, headers=headers)
if response.status_code == 200:
token_info = response.json()
return token_info.get("access_token", "")
else:
logger.error(f"Failed to retrieve access token: {response.status_code}, {response.text}")
return ""


class SafeContextManager:
"""This context manager ensures that the `__exit__` method of the
sub context is called, even when there is an Exception in the
Expand Down
42 changes: 36 additions & 6 deletions comps/embeddings/tei/langchain/embedding_tei.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os
import time
from typing import Union
from typing import List, Union

from langchain_huggingface import HuggingFaceEndpointEmbeddings
from huggingface_hub import AsyncInferenceClient

from comps import (
CustomLogger,
Expand All @@ -17,6 +18,7 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.utils import get_access_token
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
EmbeddingRequest,
Expand All @@ -27,6 +29,13 @@
logger = CustomLogger("embedding_tei_langchain")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
TEI_EMBEDDING_ENDPOINT = os.getenv("TEI_EMBEDDING_ENDPOINT", "http://localhost:8080")


@register_microservice(
name="opea_service@embedding_tei_langchain",
Expand All @@ -40,13 +49,17 @@ async def embedding(
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
start = time.time()
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
async_client = get_async_inference_client(access_token)
if logflag:
logger.info(input)
if isinstance(input, TextDoc):
embed_vector = await embeddings.aembed_query(input.text)
embed_vector = await aembed_query(input.text, async_client)
res = EmbedDoc(text=input.text, embedding=embed_vector)
else:
embed_vector = await embeddings.aembed_query(input.input)
embed_vector = await aembed_query(input.input, async_client)
if input.dimensions is not None:
embed_vector = embed_vector[: input.dimensions]

Expand All @@ -64,8 +77,25 @@ async def embedding(
return res


async def aembed_query(text: str, async_client: AsyncInferenceClient, model_kwargs=None, task=None) -> List[float]:
response = (await aembed_documents([text], async_client, model_kwargs=model_kwargs, task=task))[0]
return response


async def aembed_documents(
texts: List[str], async_client: AsyncInferenceClient, model_kwargs=None, task=None
) -> List[List[float]]:
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = model_kwargs or {}
responses = await async_client.post(json={"inputs": texts, **_model_kwargs}, task=task)
return json.loads(responses.decode())


def get_async_inference_client(access_token: str) -> AsyncInferenceClient:
headers = {"Authorization": f"Bearer {access_token}"} if access_token else {}
return AsyncInferenceClient(model=TEI_EMBEDDING_ENDPOINT, token=HUGGINGFACEHUB_API_TOKEN, headers=headers)


if __name__ == "__main__":
tei_embedding_endpoint = os.getenv("TEI_EMBEDDING_ENDPOINT", "http://localhost:8080")
embeddings = HuggingFaceEndpointEmbeddings(model=tei_embedding_endpoint)
logger.info("TEI Gaudi Embedding initialized.")
opea_microservices["opea_service@embedding_tei_langchain"].start()
13 changes: 13 additions & 0 deletions comps/llms/faq-generation/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
from langchain_community.llms import HuggingFaceEndpoint

from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps.cores.mega.utils import get_access_token

logger = CustomLogger("llm_faqgen")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")


def post_process_text(text: str):
if text == " ":
Expand All @@ -37,6 +43,12 @@ def post_process_text(text: str):
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
server_kwargs = {}
if access_token:
server_kwargs["headers"] = {"Authorization": f"Bearer {access_token}"}
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_tokens,
Expand All @@ -46,6 +58,7 @@ async def llm_generate(input: LLMParamsDoc):
temperature=input.temperature,
repetition_penalty=input.repetition_penalty,
streaming=input.streaming,
server_kwargs=server_kwargs,
)
templ = """Create a concise FAQs (frequently asked questions and answers) for following text:
TEXT: {text}
Expand Down
20 changes: 14 additions & 6 deletions comps/llms/summarization/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from langchain.prompts import PromptTemplate

from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps.cores.mega.utils import get_access_token

logger = CustomLogger("llm_docsum")
logflag = os.getenv("LOGFLAG", False)

llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = AsyncInferenceClient(
model=llm_endpoint,
timeout=600,
)
# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")

templ_en = """Write a concise summary of the following:
Expand Down Expand Up @@ -45,7 +45,6 @@
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)

if input.language in ["en", "auto"]:
templ = templ_en
elif input.language in ["zh"]:
Expand All @@ -60,6 +59,15 @@ async def llm_generate(input: LLMParamsDoc):
logger.info("After prompting:")
logger.info(prompt)

access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
headers = {}
if access_token:
headers = {"Authorization": f"Bearer {access_token}"}
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = AsyncInferenceClient(model=llm_endpoint, timeout=600, headers=headers)

text_generation = await llm.text_generation(
prompt=prompt,
stream=input.streaming,
Expand Down
20 changes: 16 additions & 4 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.utils import get_access_token
from comps.cores.proto.api_protocol import ChatCompletionRequest

logger = CustomLogger("llm_tgi")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")

llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = AsyncInferenceClient(
model=llm_endpoint,
timeout=600,
)


@register_microservice(
Expand All @@ -45,6 +47,16 @@
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
if logflag:
logger.info(input)

access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
headers = {}
if access_token:
headers = {"Authorization": f"Bearer {access_token}"}

llm = AsyncInferenceClient(model=llm_endpoint, timeout=600, headers=headers)

prompt_template = None
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
Expand Down
11 changes: 11 additions & 0 deletions comps/reranks/tei/reranking_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.utils import get_access_token
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
RerankingRequest,
Expand All @@ -28,6 +29,11 @@
logger = CustomLogger("reranking_tei")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")


@register_microservice(
name="opea_service@reranking_tei",
Expand All @@ -46,6 +52,9 @@ async def reranking(
logger.info(input)
start = time.time()
reranking_results = []
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
if input.retrieved_docs:
docs = [doc.text for doc in input.retrieved_docs]
url = tei_reranking_endpoint + "/rerank"
Expand All @@ -56,6 +65,8 @@ async def reranking(
query = input.input
data = {"query": query, "texts": docs}
headers = {"Content-Type": "application/json"}
if access_token:
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.post(url, data=json.dumps(data), headers=headers) as response:
response_data = await response.json()
Expand Down

0 comments on commit 74df6bb

Please sign in to comment.