Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added classification_location parameter in PebbloSafeLoader. #47

159 changes: 104 additions & 55 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import datetime
import inspect
import json
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -72,7 +73,9 @@ class PebbloRetrievalQA(Chain):
"""Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private:
"""Classifier endpoint."""
_discover_sent: bool = False #: :meta private:
classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent = False #: :meta private:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the use of _discover_sent and _prompt_sent? They are only set and never used in the code. Additionally, since these are class-level variables, how will they work with multiple prompt requests?

Copy link
Author

@rahul-trip rahul-trip Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These flags are not intended to be used within the code, but by the end user.
The end user user can check the status of discovery or prompt/document payload sent successfully or not.
Currently they are only for pebblo-server, can be enhanced for pebblo-cloud as well.

To fix _prompt_sent, now I have intentionally set it to False in the beginning of _call method.

"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
Expand Down Expand Up @@ -115,7 +118,9 @@ def _call(
"name": self.app_name,
"context": [
{
"retrieved_from": doc.metadata.get("source"),
"retrieved_from": doc.metadata.get(
"full_path", doc.metadata.get("source")
),
"doc": doc.page_content,
"vector_db": self.retriever.vectorstore.__class__.__name__,
}
Expand All @@ -133,6 +138,7 @@ def _call(
else []
if auth_context
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
Expand Down Expand Up @@ -222,6 +228,7 @@ def from_chain_type(
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL,
classifier_location: str = "local",
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
Expand All @@ -242,7 +249,10 @@ def from_chain_type(
)

PebbloRetrievalQA._send_discover(
app, api_key=api_key, classifier_url=classifier_url
app,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
)

return cls(
Expand All @@ -252,6 +262,7 @@ def from_chain_type(
description=description,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
**kwargs,
)

Expand Down Expand Up @@ -321,40 +332,46 @@ def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # typ
return app

@staticmethod
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
def _send_discover(app, api_key, classifier_url, classifier_location) -> None: # type: ignore
"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
if classifier_location == "local":
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

if api_key:
if api_key and classifier_location == "pebblo-cloud":
try:
headers.update({"x-api-key": api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
Expand Down Expand Up @@ -396,39 +413,71 @@ def _send_prompt(self, qa_payload: Qa) -> None:
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
)
logger.debug("prompt-payload: %s", qa_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
pebblo_resp = None
payload = qa_payload.dict(exclude_unset=True)
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
app_discover_url,
headers=headers,
json=payload,
timeout=20,
)
logger.debug("prompt-payload: %s", payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

if self.api_key:
if self.classifier_location == "local":
if pebblo_resp:
payload["response"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("response", {})
)
payload["context"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("context", [])
)
payload["prompt"] = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("prompt", {})
)
else:
payload["response"] = None
payload["context"] = None
payload["prompt"] = None
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
try:
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
json=qa_payload.dict(),
json=payload,
timeout=20,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models for the PebbloRetrievalQA chain."""

from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from langchain_core.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -137,9 +137,10 @@ class Prompt(BaseModel):

class Qa(BaseModel):
name: str
context: List[Optional[Context]]
prompt: Prompt
response: Prompt
context: Union[List[Optional[Context]], Optional[Context]]
prompt: Optional[Prompt]
response: Optional[Prompt]
prompt_time: str
user: str
user_identities: Optional[List[str]]
classifier_location: str
Loading
Loading