Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote TGI/TGI services with OAuth Client Credentials authentication #836

Merged
merged 7 commits into from
Oct 31, 2024
22 changes: 22 additions & 0 deletions comps/cores/mega/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from socket import AF_INET, SOCK_STREAM, socket
from typing import List, Optional, Union

import requests

from .logger import CustomLogger


def is_port_free(host: str, port: int) -> bool:
"""Check if a given port on a host is free.
Expand Down Expand Up @@ -183,6 +187,24 @@
return _random_port()


def get_access_token(token_url: str, client_id: str, client_secret: str) -> str:
"""Get access token using OAuth client credentials flow."""
logger = CustomLogger("tgi_or_tei_service_auth")
data = {

Check warning on line 193 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L192-L193

Added lines #L192 - L193 were not covered by tests
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "client_credentials",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = requests.post(token_url, data=data, headers=headers)
if response.status_code == 200:
token_info = response.json()
return token_info.get("access_token", "")

Check warning on line 202 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L198-L202

Added lines #L198 - L202 were not covered by tests
else:
logger.error(f"Failed to retrieve access token: {response.status_code}, {response.text}")
return ""

Check warning on line 205 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L204-L205

Added lines #L204 - L205 were not covered by tests


class SafeContextManager:
"""This context manager ensures that the `__exit__` method of the
sub context is called, even when there is an Exception in the
Expand Down
42 changes: 36 additions & 6 deletions comps/embeddings/tei/langchain/embedding_tei.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os
import time
from typing import Union
from typing import List, Union

from langchain_huggingface import HuggingFaceEndpointEmbeddings
from huggingface_hub import AsyncInferenceClient

from comps import (
CustomLogger,
Expand All @@ -17,6 +18,7 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.utils import get_access_token
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
EmbeddingRequest,
Expand All @@ -27,6 +29,13 @@
logger = CustomLogger("embedding_tei_langchain")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
TEI_EMBEDDING_ENDPOINT = os.getenv("TEI_EMBEDDING_ENDPOINT", "http://localhost:8080")


@register_microservice(
name="opea_service@embedding_tei_langchain",
Expand All @@ -40,13 +49,17 @@ async def embedding(
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
start = time.time()
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
async_client = get_async_inference_client(access_token)
if logflag:
logger.info(input)
if isinstance(input, TextDoc):
embed_vector = await embeddings.aembed_query(input.text)
embed_vector = await aembed_query(input.text, async_client)
res = EmbedDoc(text=input.text, embedding=embed_vector)
else:
embed_vector = await embeddings.aembed_query(input.input)
embed_vector = await aembed_query(input.input, async_client)
if input.dimensions is not None:
embed_vector = embed_vector[: input.dimensions]

Expand All @@ -64,8 +77,25 @@ async def embedding(
return res


async def aembed_query(text: str, async_client: AsyncInferenceClient, model_kwargs=None, task=None) -> List[float]:
response = (await aembed_documents([text], async_client, model_kwargs=model_kwargs, task=task))[0]
return response


async def aembed_documents(
texts: List[str], async_client: AsyncInferenceClient, model_kwargs=None, task=None
) -> List[List[float]]:
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = model_kwargs or {}
responses = await async_client.post(json={"inputs": texts, **_model_kwargs}, task=task)
return json.loads(responses.decode())


def get_async_inference_client(access_token: str) -> AsyncInferenceClient:
headers = {"Authorization": f"Bearer {access_token}"} if access_token else {}
return AsyncInferenceClient(model=TEI_EMBEDDING_ENDPOINT, token=HUGGINGFACEHUB_API_TOKEN, headers=headers)


if __name__ == "__main__":
tei_embedding_endpoint = os.getenv("TEI_EMBEDDING_ENDPOINT", "http://localhost:8080")
embeddings = HuggingFaceEndpointEmbeddings(model=tei_embedding_endpoint)
logger.info("TEI Gaudi Embedding initialized.")
opea_microservices["opea_service@embedding_tei_langchain"].start()
13 changes: 13 additions & 0 deletions comps/llms/faq-generation/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
from langchain_community.llms import HuggingFaceEndpoint

from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps.cores.mega.utils import get_access_token

logger = CustomLogger("llm_faqgen")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")


def post_process_text(text: str):
if text == " ":
Expand All @@ -37,6 +43,12 @@ def post_process_text(text: str):
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
server_kwargs = {}
if access_token:
server_kwargs["headers"] = {"Authorization": f"Bearer {access_token}"}
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_tokens,
Expand All @@ -46,6 +58,7 @@ async def llm_generate(input: LLMParamsDoc):
temperature=input.temperature,
repetition_penalty=input.repetition_penalty,
streaming=input.streaming,
server_kwargs=server_kwargs,
)
templ = """Create a concise FAQs (frequently asked questions and answers) for following text:
TEXT: {text}
Expand Down
20 changes: 14 additions & 6 deletions comps/llms/summarization/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from langchain.prompts import PromptTemplate

from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps.cores.mega.utils import get_access_token

logger = CustomLogger("llm_docsum")
logflag = os.getenv("LOGFLAG", False)

llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = AsyncInferenceClient(
model=llm_endpoint,
timeout=600,
)
# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")

templ_en = """Write a concise summary of the following:

Expand Down Expand Up @@ -45,7 +45,6 @@
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)

if input.language in ["en", "auto"]:
templ = templ_en
elif input.language in ["zh"]:
Expand All @@ -60,6 +59,15 @@ async def llm_generate(input: LLMParamsDoc):
logger.info("After prompting:")
logger.info(prompt)

access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
headers = {}
if access_token:
headers = {"Authorization": f"Bearer {access_token}"}
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = AsyncInferenceClient(model=llm_endpoint, timeout=600, headers=headers)

text_generation = await llm.text_generation(
prompt=prompt,
stream=input.streaming,
Expand Down
20 changes: 16 additions & 4 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.utils import get_access_token
from comps.cores.proto.api_protocol import ChatCompletionRequest

logger = CustomLogger("llm_tgi")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")

llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = AsyncInferenceClient(
model=llm_endpoint,
timeout=600,
)


@register_microservice(
Expand All @@ -45,6 +47,16 @@
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
if logflag:
logger.info(input)

access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
headers = {}
if access_token:
headers = {"Authorization": f"Bearer {access_token}"}

llm = AsyncInferenceClient(model=llm_endpoint, timeout=600, headers=headers)

prompt_template = None
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
Expand Down
11 changes: 11 additions & 0 deletions comps/reranks/tei/reranking_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.utils import get_access_token
from comps.cores.proto.api_protocol import (
ChatCompletionRequest,
RerankingRequest,
Expand All @@ -28,6 +29,11 @@
logger = CustomLogger("reranking_tei")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")


@register_microservice(
name="opea_service@reranking_tei",
Expand All @@ -46,6 +52,9 @@ async def reranking(
logger.info(input)
start = time.time()
reranking_results = []
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
if input.retrieved_docs:
docs = [doc.text for doc in input.retrieved_docs]
url = tei_reranking_endpoint + "/rerank"
Expand All @@ -56,6 +65,8 @@ async def reranking(
query = input.input
data = {"query": query, "texts": docs}
headers = {"Content-Type": "application/json"}
if access_token:
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.post(url, data=json.dumps(data), headers=headers) as response:
response_data = await response.json()
Expand Down
Loading