Skip to content

Commit

Permalink
Model updates, Code & Log cleanup, Fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Raj725 committed Aug 21, 2024
1 parent 3303a25 commit 330d10e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 67 deletions.
33 changes: 18 additions & 15 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
App,
AuthContext,
ChainInfo,
Model,
SemanticContext,
VectorDB,
)
from langchain_community.chains.pebblo_retrieval.utilities import (
PLUGIN_VERSION,
Expand Down Expand Up @@ -318,7 +320,9 @@ def set_discover_sent(cls) -> None:
cls._discover_sent = True

@classmethod
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs) -> List[ChainInfo]:
def get_chain_details(
cls, llm: BaseLanguageModel, **kwargs: Any
) -> List[ChainInfo]:
"""
Get chain details.
Expand All @@ -327,20 +331,20 @@ def get_chain_details(cls, llm: BaseLanguageModel, **kwargs) -> List[ChainInfo]:
**kwargs: Additional keyword arguments.
Returns:
List[Dict[str, Any]]: Chain details.
List[ChainInfo]: Chain details.
"""
llm_dict = llm.__dict__
chains = [
{
"name": cls.__name__,
"model": {
"name": llm_dict.get("model_name", llm_dict.get("model")),
"vendor": llm.__class__.__name__,
},
"vector_dbs": [
{
"name": kwargs["retriever"].vectorstore.__class__.__name__,
"embedding_model": str(
ChainInfo(
name=cls.__name__,
model=Model(
name=llm_dict.get("model_name", llm_dict.get("model")),
vendor=llm.__class__.__name__,
),
vector_dbs=[
VectorDB(
name=kwargs["retriever"].vectorstore.__class__.__name__,
embedding_model=str(
kwargs["retriever"].vectorstore._embeddings.model
)
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
Expand All @@ -349,9 +353,8 @@ def get_chain_details(cls, llm: BaseLanguageModel, **kwargs) -> List[ChainInfo]:
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
else None
),
}
)
],
},
),
]
chains = [ChainInfo(**chain) for chain in chains]
return chains
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ class Context(BaseModel):

class Prompt(BaseModel):
data: Optional[Union[list, str]]
entityCount: Optional[int]
entities: Optional[dict]
prompt_gov_enabled: Optional[bool]
entityCount: Optional[int] = None
entities: Optional[dict] = None
prompt_gov_enabled: Optional[bool] = None


class Qa(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from langchain_community.chains.pebblo_retrieval.models import (
App,
AuthContext,
Context,
Framework,
Prompt,
Qa,
Runtime,
)
Expand Down Expand Up @@ -45,7 +47,9 @@ def get_runtime() -> Tuple[Framework, Runtime]:
Tuple[Framework, Runtime]: Framework and Runtime for the current app instance.
"""
runtime_env = get_runtime_environment()
framework = Framework(name="langchain", version=runtime_env.get("library_version"))
framework = Framework(
name="langchain", version=runtime_env.get("library_version", None)
)
uname = platform.uname()
runtime = Runtime(
host=uname.node,
Expand Down Expand Up @@ -188,6 +192,7 @@ def send_prompt(
if self.classifier_location == "local":
# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
pebblo_resp = pebblo_resp.json() if pebblo_resp else None
self.update_cloud_payload(payload, pebblo_resp)

headers = self._make_headers(cloud_request=True)
Expand Down Expand Up @@ -278,13 +283,11 @@ def make_request(
method=method, url=url, headers=headers, json=payload, timeout=timeout
)
logger.debug(
"Request: method %s, url %s, body %s len %s response status %s body %s",
"Request: method %s, url %s, len %s response status %s",
method,
response.request.url,
str(response.request.body),
str(len(response.request.body if response.request.body else [])),
str(response.status_code),
response.json(),
)

if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
Expand All @@ -305,29 +308,25 @@ def make_request(
return None

@staticmethod
def update_cloud_payload(payload: dict, pebblo_resp: Optional[Response]) -> None:
def update_cloud_payload(payload: dict, pebblo_resp: Optional[dict]) -> None:
"""
Update the payload with response, prompt and context from Pebblo response.
Args:
payload (dict): Payload to be updated.
pebblo_resp (Optional[Response]): Response from Pebblo server.
pebblo_resp (Optional[dict]): Response from Pebblo server.
"""
if pebblo_resp:
resp = json.loads(pebblo_resp.text)
if resp:
payload["response"].update(
resp.get("retrieval_data", {}).get("response", {})
)
payload["response"].pop("data")
payload["prompt"].update(
resp.get("retrieval_data", {}).get("prompt", {})
)
payload["prompt"].pop("data")
context = payload["context"]
for context_data in context:
context_data.pop("doc")
payload["context"] = context
# Update response, prompt and context from pebblo response
response = payload.get("response", {})
response.update(pebblo_resp.get("retrieval_data", {}).get("response", {}))
response.pop("data", None)
prompt = payload.get("prompt", {})
prompt.update(pebblo_resp.get("retrieval_data", {}).get("prompt", {}))
prompt.pop("data", None)
context = payload.get("context", [])
for context_data in context:
context_data.pop("doc", None)
else:
payload["response"] = {}
payload["prompt"] = {}
Expand Down Expand Up @@ -362,39 +361,32 @@ def build_prompt_qa_payload(
Returns:
dict: The QA payload for the prompt.
"""
qa = {
"name": app_name,
"context": [
{
"retrieved_from": doc.metadata.get(
qa = Qa(
name=app_name,
context=[
Context(
retrieved_from=doc.metadata.get(
"full_path", doc.metadata.get("source")
),
"doc": doc.page_content,
"vector_db": retriever.vectorstore.__class__.__name__,
**(
{"pb_checksum": doc.metadata.get("pb_checksum")}
if doc.metadata.get("pb_checksum")
else {}
),
}
doc=doc.page_content,
vector_db=retriever.vectorstore.__class__.__name__,
pb_checksum=doc.metadata.get("pb_checksum"),
)
for doc in docs
if isinstance(doc, Document)
],
"prompt": {
"data": question,
"entities": prompt_entities.get("entities", {}),
"entityCount": prompt_entities.get("entityCount", 0),
"prompt_gov_enabled": prompt_gov_enabled,
},
"response": {
"data": answer,
},
"prompt_time": prompt_time,
"user": auth_context.user_id if auth_context else "unknown",
"user_identities": auth_context.user_auth
prompt=Prompt(
data=question,
entities=prompt_entities.get("entities", {}),
entityCount=prompt_entities.get("entityCount", 0),
prompt_gov_enabled=prompt_gov_enabled,
),
response=Prompt(data=answer),
prompt_time=prompt_time,
user=auth_context.user_id if auth_context else "unknown",
user_identities=auth_context.user_auth
if auth_context and hasattr(auth_context, "user_auth")
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
return qa_payload.dict(exclude_unset=True)
classifier_location=self.classifier_location,
)
return qa.dict(exclude_unset=True)

0 comments on commit 330d10e

Please sign in to comment.