Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community]: Added classification_location parameter in PebbloSafeLoader. #22565

Merged
167 changes: 111 additions & 56 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import datetime
import inspect
import json
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -72,7 +73,9 @@ class PebbloRetrievalQA(Chain):
"""Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private:
"""Classifier endpoint."""
_discover_sent: bool = False #: :meta private:
classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
Expand All @@ -94,6 +97,7 @@ def _call(
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
PebbloRetrievalQA.set_prompt_sent(value=False)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {})
Expand All @@ -115,7 +119,9 @@ def _call(
"name": self.app_name,
"context": [
{
"retrieved_from": doc.metadata.get("source"),
"retrieved_from": doc.metadata.get(
"full_path", doc.metadata.get("source")
),
"doc": doc.page_content,
"vector_db": self.retriever.vectorstore.__class__.__name__,
}
Expand All @@ -133,6 +139,7 @@ def _call(
else []
if auth_context
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
Expand Down Expand Up @@ -222,6 +229,7 @@ def from_chain_type(
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL,
classifier_location: str = "local",
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
Expand All @@ -242,7 +250,10 @@ def from_chain_type(
)

PebbloRetrievalQA._send_discover(
app, api_key=api_key, classifier_url=classifier_url
app,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
)

return cls(
Expand All @@ -252,6 +263,7 @@ def from_chain_type(
description=description,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
**kwargs,
)

Expand Down Expand Up @@ -321,38 +333,44 @@ def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # typ
return app

@staticmethod
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
def _send_discover(app, api_key, classifier_url, classifier_location) -> None: # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotations missing here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added missing type annotations throughout file.

"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
if classifier_location == "local":
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

if api_key:
try:
Expand Down Expand Up @@ -387,48 +405,82 @@ def set_discover_sent(cls) -> None:
cls._discover_sent = True

@classmethod
def set_prompt_sent(cls) -> None:
cls._prompt_sent = True
def set_prompt_sent(cls, value: bool = True) -> None:
cls._prompt_sent = value

def _send_prompt(self, qa_payload: Qa) -> None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
)
logger.debug("prompt-payload: %s", qa_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
pebblo_resp = None
payload = qa_payload.dict(exclude_unset=True)
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
app_discover_url,
headers=headers,
json=payload,
timeout=20,
)
logger.debug("prompt-payload: %s", payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
if self.api_key:
if self.classifier_location == "local":
if pebblo_resp:
payload["response"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("response", {})
)
payload["context"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("context", [])
)
payload["prompt"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("prompt", {})
)
else:
payload["response"] = None
payload["context"] = None
payload["prompt"] = None
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
try:
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
json=qa_payload.dict(),
json=payload,
timeout=20,
)

Expand All @@ -451,6 +503,9 @@ def _send_prompt(self, qa_payload: Qa) -> None:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise ValueError("API key is missing for sending prompt to Pebblo cloud.")

@classmethod
def get_chain_details(cls, llm, **kwargs): # type: ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models for the PebbloRetrievalQA chain."""

from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from langchain_core.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -137,9 +137,10 @@ class Prompt(BaseModel):

class Qa(BaseModel):
name: str
context: List[Optional[Context]]
prompt: Prompt
response: Prompt
context: Union[List[Optional[Context]], Optional[Context]]
prompt: Optional[Prompt]
response: Optional[Prompt]
prompt_time: str
user: str
user_identities: Optional[List[str]]
classifier_location: str
Loading
Loading