From 8ca58df6099c44b4a1a3b06a1d89513367211d20 Mon Sep 17 00:00:00 2001 From: Rahul Tripathi Date: Fri, 7 Jun 2024 17:49:43 +0530 Subject: [PATCH] Added classification_location parameter in PebbloSafeLoader. Signed-off-by: Rahul Tripathi --- .../chains/pebblo_retrieval/base.py | 151 ++++++++++++------ 1 file changed, 98 insertions(+), 53 deletions(-) diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index e5ecb0f5254ec..2f96b5707a587 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -5,6 +5,7 @@ import datetime import inspect +import json import logging from http import HTTPStatus from typing import Any, Dict, List, Optional @@ -72,6 +73,8 @@ class PebbloRetrievalQA(Chain): """Pebblo cloud API key for app.""" classifier_url: str = CLASSIFIER_URL #: :meta private: """Classifier endpoint.""" + classifier_location: Optional[str] = None #: :meta private: + """Classifier location. It could be either of 'local' or 'pebblo-cloud'.""" _discover_sent = False #: :meta private: """Flag to check if discover payload has been sent.""" _prompt_sent = False #: :meta private: @@ -222,6 +225,7 @@ def from_chain_type( chain_type_kwargs: Optional[dict] = None, api_key: Optional[str] = None, classifier_url: str = CLASSIFIER_URL, + classifier_location: Optional[str] = None, **kwargs: Any, ) -> "PebbloRetrievalQA": """Load chain from chain type.""" @@ -242,7 +246,10 @@ def from_chain_type( ) PebbloRetrievalQA._send_discover( - app, api_key=api_key, classifier_url=classifier_url + app, + api_key=api_key, + classifier_url=classifier_url, + classifier_location=classifier_location, ) return cls( @@ -252,6 +259,7 @@ def from_chain_type( description=description, api_key=api_key, classifier_url=classifier_url, + classifier_location=classifier_location, **kwargs, ) @@ -321,40 +329,46 @@ def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # typ return app @staticmethod - def _send_discover(app, api_key, classifier_url) -> None: # type: ignore + def _send_discover(app, api_key, classifier_url, classifier_location) -> None: # type: ignore """Send app discovery payload to pebblo-server. Internal method.""" headers = { "Accept": "application/json", "Content-Type": "application/json", } - payload = app.dict(exclude_unset=True) - app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}" - try: - pebblo_resp = requests.post( - app_discover_url, headers=headers, json=payload, timeout=20 - ) - logger.debug("discover-payload: %s", payload) - logger.debug( - "send_discover[local]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_resp.request.url, - str(pebblo_resp.request.body), - str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])), - str(pebblo_resp.status_code), - pebblo_resp.json(), - ) - if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: - PebbloRetrievalQA.set_discover_sent() - else: - logger.warning( - f"Received unexpected HTTP response code: {pebblo_resp.status_code}" + if classifier_location == "local": + payload = app.dict(exclude_unset=True) + app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}" + try: + pebblo_resp = requests.post( + app_discover_url, headers=headers, json=payload, timeout=20 ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach pebblo server.") - except Exception as e: - logger.warning("An Exception caught in _send_discover: local %s", e) + logger.debug("discover-payload: %s", payload) + logger.debug( + "send_discover[local]: request url %s, body %s len %s\ + response status %s body %s", + pebblo_resp.request.url, + str(pebblo_resp.request.body), + str( + len( + pebblo_resp.request.body if pebblo_resp.request.body else [] + ) + ), + str(pebblo_resp.status_code), + pebblo_resp.json(), + ) + if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: + PebbloRetrievalQA.set_discover_sent() + else: + logger.warning( + "Received unexpected HTTP response code:" + + f"{pebblo_resp.status_code}" + ) + except requests.exceptions.RequestException: + logger.warning("Unable to reach pebblo server.") + except Exception as e: + logger.warning("An Exception caught in _send_discover: local %s", e) - if api_key: + if api_key and classifier_location == "pebblo-cloud": try: headers.update({"x-api-key": api_key}) pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}" @@ -396,35 +410,66 @@ def _send_prompt(self, qa_payload: Qa) -> None: "Content-Type": "application/json", } app_discover_url = f"{self.classifier_url}{PROMPT_URL}" - try: - pebblo_resp = requests.post( - app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20 - ) - logger.debug("prompt-payload: %s", qa_payload) - logger.debug( - "send_prompt[local]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_resp.request.url, - str(pebblo_resp.request.body), - str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])), - str(pebblo_resp.status_code), - pebblo_resp.json(), - ) - if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: - PebbloRetrievalQA.set_prompt_sent() - else: - logger.warning( - f"Received unexpected HTTP response code: {pebblo_resp.status_code}" + pebblo_resp = None + if self.classifier_location == "local": + try: + pebblo_resp = requests.post( + app_discover_url, + headers=headers, + json=qa_payload.dict(), + timeout=20, + ) + logger.debug("prompt-payload: %s", qa_payload) + logger.debug( + "send_prompt[local]: request url %s, body %s len %s\ + response status %s body %s", + pebblo_resp.request.url, + str(pebblo_resp.request.body), + str( + len( + pebblo_resp.request.body if pebblo_resp.request.body else [] + ) + ), + str(pebblo_resp.status_code), + pebblo_resp.json(), ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach pebblo server.") - except Exception as e: - logger.warning("An Exception caught in _send_discover: local %s", e) + if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: + PebbloRetrievalQA.set_prompt_sent() + else: + logger.warning( + "Received unexpected HTTP response code:" + + f"{pebblo_resp.status_code}" + ) + except requests.exceptions.RequestException: + logger.warning("Unable to reach pebblo server.") + except Exception as e: + logger.warning("An Exception caught in _send_discover: local %s", e) if self.api_key: + if self.classifier_location == "local": + if pebblo_resp: + qa_payload.response = ( + json.loads(pebblo_resp.text) + .get("retrieval_data", {}) + .get("response", {}) + ) + qa_payload.context = ( + json.loads(pebblo_resp.text) + .get("retrieval_data", {}) + .get("context", []) + ) + qa_payload.prompt = ( + json.loads(pebblo_resp.text) + .get("retrieval_data", {}) + .get("prompt", {}) + ) + else: + qa_payload.response = {} + qa_payload.context = [] + qa_payload.prompt = {} + headers.update({"x-api-key": self.api_key}) + pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}" try: - headers.update({"x-api-key": self.api_key}) - pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}" pebblo_cloud_response = requests.post( pebblo_cloud_url, headers=headers,