Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community]: Added classification_location parameter in PebbloSafeLoader. #22565

Merged
182 changes: 122 additions & 60 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import datetime
import inspect
import json
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -72,7 +73,9 @@ class PebbloRetrievalQA(Chain):
"""Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private:
"""Classifier endpoint."""
_discover_sent: bool = False #: :meta private:
classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
Expand All @@ -94,6 +97,7 @@ def _call(
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
PebbloRetrievalQA.set_prompt_sent(value=False)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {})
Expand All @@ -115,7 +119,9 @@ def _call(
"name": self.app_name,
"context": [
{
"retrieved_from": doc.metadata.get("source"),
"retrieved_from": doc.metadata.get(
"full_path", doc.metadata.get("source")
),
"doc": doc.page_content,
"vector_db": self.retriever.vectorstore.__class__.__name__,
}
Expand All @@ -133,6 +139,7 @@ def _call(
else []
if auth_context
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
Expand Down Expand Up @@ -222,6 +229,7 @@ def from_chain_type(
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL,
classifier_location: str = "local",
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
Expand All @@ -233,7 +241,7 @@ def from_chain_type(
)

# generate app
app = PebbloRetrievalQA._get_app_details(
app: App = PebbloRetrievalQA._get_app_details(
app_name=app_name,
description=description,
owner=owner,
Expand All @@ -242,7 +250,10 @@ def from_chain_type(
)

PebbloRetrievalQA._send_discover(
app, api_key=api_key, classifier_url=classifier_url
app,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
)

return cls(
Expand All @@ -252,6 +263,7 @@ def from_chain_type(
description=description,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
**kwargs,
)

Expand Down Expand Up @@ -302,7 +314,9 @@ async def _aget_docs(
)

@staticmethod
def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # type: ignore
def _get_app_details( # type: ignore
app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs
) -> App:
"""Fetch app details. Internal method.
Returns:
App: App details.
Expand All @@ -321,38 +335,49 @@ def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # typ
return app

@staticmethod
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
def _send_discover(
app: App,
api_key: Optional[str],
classifier_url: str,
classifier_location: str,
) -> None: # type: ignore
"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
if classifier_location == "local":
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

if api_key:
try:
Expand Down Expand Up @@ -387,48 +412,82 @@ def set_discover_sent(cls) -> None:
cls._discover_sent = True

@classmethod
def set_prompt_sent(cls) -> None:
cls._prompt_sent = True
def set_prompt_sent(cls, value: bool = True) -> None:
cls._prompt_sent = value

def _send_prompt(self, qa_payload: Qa) -> None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
)
logger.debug("prompt-payload: %s", qa_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
pebblo_resp = None
payload = qa_payload.dict(exclude_unset=True)
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
app_discover_url,
headers=headers,
json=payload,
timeout=20,
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
logger.debug("prompt-payload: %s", payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
if self.api_key:
if self.classifier_location == "local":
if pebblo_resp:
payload["response"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("response", {})
)
payload["context"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("context", [])
)
payload["prompt"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("prompt", {})
)
else:
payload["response"] = None
payload["context"] = None
payload["prompt"] = None
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
try:
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
json=qa_payload.dict(),
json=payload,
timeout=20,
)

Expand All @@ -451,9 +510,12 @@ def _send_prompt(self, qa_payload: Qa) -> None:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")

@classmethod
def get_chain_details(cls, llm, **kwargs): # type: ignore
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
llm_dict = llm.__dict__
chain = [
{
Expand All @@ -476,6 +538,6 @@ def get_chain_details(cls, llm, **kwargs): # type: ignore
),
}
],
}
},
]
return chain
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models for the PebbloRetrievalQA chain."""

from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from langchain_core.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -137,9 +137,10 @@ class Prompt(BaseModel):

class Qa(BaseModel):
name: str
context: List[Optional[Context]]
prompt: Prompt
response: Prompt
context: Union[List[Optional[Context]], Optional[Context]]
prompt: Optional[Prompt]
response: Optional[Prompt]
prompt_time: str
user: str
user_identities: Optional[List[str]]
classifier_location: str
Loading
Loading