Skip to content

Commit

Permalink
Added classification_location parameter in PebbloSafeLoader.
Browse files Browse the repository at this point in the history
Signed-off-by: Rahul Tripathi <[email protected]>
  • Loading branch information
Rahul Tripathi committed Jun 7, 2024
1 parent debf821 commit 8ca58df
Showing 1 changed file with 98 additions and 53 deletions.
151 changes: 98 additions & 53 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import datetime
import inspect
import json
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -72,6 +73,8 @@ class PebbloRetrievalQA(Chain):
"""Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private:
"""Classifier endpoint."""
classifier_location: Optional[str] = None #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent = False #: :meta private:
Expand Down Expand Up @@ -222,6 +225,7 @@ def from_chain_type(
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL,
classifier_location: Optional[str] = None,
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
Expand All @@ -242,7 +246,10 @@ def from_chain_type(
)

PebbloRetrievalQA._send_discover(
app, api_key=api_key, classifier_url=classifier_url
app,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
)

return cls(
Expand All @@ -252,6 +259,7 @@ def from_chain_type(
description=description,
api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location,
**kwargs,
)

Expand Down Expand Up @@ -321,40 +329,46 @@ def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # typ
return app

@staticmethod
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
def _send_discover(app, api_key, classifier_url, classifier_location) -> None: # type: ignore
"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
if classifier_location == "local":
payload = app.dict(exclude_unset=True)
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
"Received unexpected HTTP response code:" +
f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

if api_key:
if api_key and classifier_location == "pebblo-cloud":
try:
headers.update({"x-api-key": api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
Expand Down Expand Up @@ -396,35 +410,66 @@ def _send_prompt(self, qa_payload: Qa) -> None:
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
)
logger.debug("prompt-payload: %s", qa_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
pebblo_resp = None
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
app_discover_url,
headers=headers,
json=qa_payload.dict(),
timeout=20,
)
logger.debug("prompt-payload: %s", qa_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
"Received unexpected HTTP response code:" +
f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)

if self.api_key:
if self.classifier_location == "local":
if pebblo_resp:
qa_payload.response = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("response", {})
)
qa_payload.context = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("context", [])
)
qa_payload.prompt = (
json.loads(pebblo_resp.text)
.get("retrieval_data", {})
.get("prompt", {})
)
else:
qa_payload.response = {}
qa_payload.context = []
qa_payload.prompt = {}
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
try:
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
Expand Down

0 comments on commit 8ca58df

Please sign in to comment.