diff --git a/validator/main.py b/validator/main.py
index fa73c25..5e3e988 100644
--- a/validator/main.py
+++ b/validator/main.py
@@ -1,14 +1,16 @@
import os
+import re
import itertools
import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn
+from transformers import pipeline, AutoTokenizer
+import torch
import nltk
import numpy as np
from guardrails.utils.docs_utils import get_chunks_from_text
-from guardrails.utils.validator_utils import PROVENANCE_V1_PROMPT
from guardrails.validator_base import (
FailResult,
PassResult,
@@ -21,6 +23,47 @@
from tenacity import retry, stop_after_attempt, wait_random_exponential
from sentence_transformers import SentenceTransformer
+PROVENANCE_V1_PROMPT = """Instruction:
+As an Attribution Validator, your task is to determine if the given contexts provide irrefutable evidence to support the claim. Follow these strict guidelines:
+
+Respond "Yes" ONLY if ALL of the following conditions are met:
+
+The contexts explicitly and unambiguously state information that fully confirms ALL aspects of the claim.
+There is NO room for alternative interpretations or assumptions.
+The support is direct and doesn't require complex inference chains.
+If numbers or specific details are mentioned in the claim, they MUST be exactly matched in the contexts.
+
+
+Respond "No" if ANY of the following are true:
+
+The contexts do not provide explicit information that fully substantiates every part of the claim.
+The claim requires any degree of inference or assumption not directly stated in the contexts.
+The contexts only partially support the claim or support it with slight differences in details.
+There is any ambiguity, vagueness, or room for interpretation in how the contexts relate to the claim.
+The claim includes any information not present in the contexts, even if it seems common knowledge.
+The contexts contradict any part of the claim, no matter how minor.
+
+
+Treat the contexts as the ONLY source of truth. Do not use any outside knowledge or assumptions.
+For multi-part claims, EVERY single part must be explicitly supported by the contexts for a "Yes" response.
+If there is ANY doubt whatsoever, respond with "No".
+Be extremely literal in your interpretation. Do not extrapolate or generalize from the given information.
+
+Provide your analysis in this format:
+
+
+Point 1
+Point 2
+Point 3 (if needed)
+
+
+
+Yes OR No
+Claim:
+{}
+Contexts:
+{}
+Response:"""
@register_validator(name="guardrails/provenance_llm", data_type="string")
class ProvenanceLLM(Validator):
@@ -125,8 +168,32 @@ def call_llm(self, prompt: str) -> str:
response (str): String representing the LLM response.
"""
return self._llm_callable(prompt)
+
+ def evaluate_with_vectara(self, text:str, pass_on_invalid:bool) -> bool:
+ classifier = pipeline(
+ "text-classification",
+ model="vectara/hallucination_evaluation_model",
+ tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"),
+ trust_remote_code=True,
+ device="cpu" if torch.cuda.is_available() else "cpu",
+ )
+ result = classifier(text, batch_size=1)
+ if result[0]['label'] == 'consistent':
+ return True
+ if result[0]['label'] == 'hallucinated':
+ return False
+ if pass_on_invalid:
+ warn(
+ "The Vectara returned an invalid response. Considering the sentence as supported..."
+ )
+ return True
+ else:
+ warn(
+ "The Vectara returned an invalid response. Considering the sentence as unsupported..."
+ )
+ return False
- def evaluate_with_llm(self, text: str, query_function: Callable) -> str:
+ def evaluate_with_llm(self, text: str, query_function: Callable, pass_on_invalid: bool) -> bool:
"""Validate that the LLM-generated text is supported by the provided
contexts.
@@ -145,35 +212,28 @@ def evaluate_with_llm(self, text: str, query_function: Callable) -> str:
# Get evaluation response
eval_response = self.call_llm(prompt)
- return eval_response
+ return self.parse_response(eval_response, pass_on_invalid=pass_on_invalid)
def validate_each_sentence(
self, value: Any, query_function: Callable, metadata: Dict[str, Any]
) -> ValidationResult:
"""Validate each sentence in the response."""
pass_on_invalid = metadata.get("pass_on_invalid", False) # Default to False
+ use_vectara = metadata.get("use_vectara", False)
# Split the value into sentences using nltk sentence tokenizer.
sentences = nltk.sent_tokenize(value)
unsupported_sentences, supported_sentences = [], []
for sentence in sentences:
- eval_response = self.evaluate_with_llm(sentence, query_function)
- if eval_response == "yes":
+ if use_vectara:
+ eval_response = self.evaluate_with_vectara(sentence, pass_on_invalid=pass_on_invalid)
+ else:
+ eval_response = self.evaluate_with_llm(sentence, query_function, pass_on_invalid=pass_on_invalid)
+ if eval_response == True:
supported_sentences.append(sentence)
- elif eval_response == "no":
+ elif eval_response == False:
unsupported_sentences.append(sentence)
- else:
- if pass_on_invalid:
- warn(
- "The LLM returned an invalid response. Considering the sentence as supported..."
- )
- supported_sentences.append(sentence)
- else:
- warn(
- "The LLM returned an invalid response. Considering the sentence as unsupported..."
- )
- unsupported_sentences.append(sentence)
if unsupported_sentences:
unsupported_sentences = "- " + "\n- ".join(unsupported_sentences)
@@ -187,18 +247,42 @@ def validate_each_sentence(
fix_value="\n".join(supported_sentences),
)
return PassResult(metadata=metadata)
+
+ def parse_response(self, response:str, pass_on_invalid:bool) -> bool:
+ response = response.lower()
+ # Extract decision
+ decision_match = re.search(r'(yes|no)', response)
+ decision = decision_match.group(1) if decision_match else None
+ if decision is None or decision == 'no':
+ return False
+ elif decision == 'yes':
+ return True
+ else:
+ if pass_on_invalid:
+ warn(
+ "The LLM returned an invalid response. Considering the sentence as supported..."
+ )
+ return True
+ else:
+ warn(
+ "The LLM returned an invalid response. Considering the sentence as unsupported..."
+ )
+ return False
def validate_full_text(
self, value: Any, query_function: Callable, metadata: Dict[str, Any]
) -> ValidationResult:
"""Validate the entire LLM text."""
pass_on_invalid = metadata.get("pass_on_invalid", False) # Default to False
-
+ use_vectara = metadata.get("use_vectara", False)
# Self-evaluate LLM with entire text
- eval_response = self.evaluate_with_llm(value, query_function)
- if eval_response == "yes":
+ if use_vectara:
+ passed = self.evaluate_with_vectara(value, pass_on_invalid=pass_on_invalid)
+ else:
+ passed = self.evaluate_with_llm(value, query_function, pass_on_invalid=pass_on_invalid)
+ if passed == True:
return PassResult(metadata=metadata)
- if eval_response == "no":
+ if passed == False:
return FailResult(
metadata=metadata,
error_message=(