diff --git a/validator/main.py b/validator/main.py index fa73c25..5e3e988 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,14 +1,16 @@ import os +import re import itertools import warnings from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn +from transformers import pipeline, AutoTokenizer +import torch import nltk import numpy as np from guardrails.utils.docs_utils import get_chunks_from_text -from guardrails.utils.validator_utils import PROVENANCE_V1_PROMPT from guardrails.validator_base import ( FailResult, PassResult, @@ -21,6 +23,47 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from sentence_transformers import SentenceTransformer +PROVENANCE_V1_PROMPT = """Instruction: +As an Attribution Validator, your task is to determine if the given contexts provide irrefutable evidence to support the claim. Follow these strict guidelines: + +Respond "Yes" ONLY if ALL of the following conditions are met: + +The contexts explicitly and unambiguously state information that fully confirms ALL aspects of the claim. +There is NO room for alternative interpretations or assumptions. +The support is direct and doesn't require complex inference chains. +If numbers or specific details are mentioned in the claim, they MUST be exactly matched in the contexts. + + +Respond "No" if ANY of the following are true: + +The contexts do not provide explicit information that fully substantiates every part of the claim. +The claim requires any degree of inference or assumption not directly stated in the contexts. +The contexts only partially support the claim or support it with slight differences in details. +There is any ambiguity, vagueness, or room for interpretation in how the contexts relate to the claim. +The claim includes any information not present in the contexts, even if it seems common knowledge. +The contexts contradict any part of the claim, no matter how minor. + + +Treat the contexts as the ONLY source of truth. Do not use any outside knowledge or assumptions. +For multi-part claims, EVERY single part must be explicitly supported by the contexts for a "Yes" response. +If there is ANY doubt whatsoever, respond with "No". +Be extremely literal in your interpretation. Do not extrapolate or generalize from the given information. + +Provide your analysis in this format: + + +Point 1 +Point 2 +Point 3 (if needed) + + + +Yes OR No +Claim: +{} +Contexts: +{} +Response:""" @register_validator(name="guardrails/provenance_llm", data_type="string") class ProvenanceLLM(Validator): @@ -125,8 +168,32 @@ def call_llm(self, prompt: str) -> str: response (str): String representing the LLM response. """ return self._llm_callable(prompt) + + def evaluate_with_vectara(self, text:str, pass_on_invalid:bool) -> bool: + classifier = pipeline( + "text-classification", + model="vectara/hallucination_evaluation_model", + tokenizer=AutoTokenizer.from_pretrained("google/flan-t5-base"), + trust_remote_code=True, + device="cpu" if torch.cuda.is_available() else "cpu", + ) + result = classifier(text, batch_size=1) + if result[0]['label'] == 'consistent': + return True + if result[0]['label'] == 'hallucinated': + return False + if pass_on_invalid: + warn( + "The Vectara returned an invalid response. Considering the sentence as supported..." + ) + return True + else: + warn( + "The Vectara returned an invalid response. Considering the sentence as unsupported..." + ) + return False - def evaluate_with_llm(self, text: str, query_function: Callable) -> str: + def evaluate_with_llm(self, text: str, query_function: Callable, pass_on_invalid: bool) -> bool: """Validate that the LLM-generated text is supported by the provided contexts. @@ -145,35 +212,28 @@ def evaluate_with_llm(self, text: str, query_function: Callable) -> str: # Get evaluation response eval_response = self.call_llm(prompt) - return eval_response + return self.parse_response(eval_response, pass_on_invalid=pass_on_invalid) def validate_each_sentence( self, value: Any, query_function: Callable, metadata: Dict[str, Any] ) -> ValidationResult: """Validate each sentence in the response.""" pass_on_invalid = metadata.get("pass_on_invalid", False) # Default to False + use_vectara = metadata.get("use_vectara", False) # Split the value into sentences using nltk sentence tokenizer. sentences = nltk.sent_tokenize(value) unsupported_sentences, supported_sentences = [], [] for sentence in sentences: - eval_response = self.evaluate_with_llm(sentence, query_function) - if eval_response == "yes": + if use_vectara: + eval_response = self.evaluate_with_vectara(sentence, pass_on_invalid=pass_on_invalid) + else: + eval_response = self.evaluate_with_llm(sentence, query_function, pass_on_invalid=pass_on_invalid) + if eval_response == True: supported_sentences.append(sentence) - elif eval_response == "no": + elif eval_response == False: unsupported_sentences.append(sentence) - else: - if pass_on_invalid: - warn( - "The LLM returned an invalid response. Considering the sentence as supported..." - ) - supported_sentences.append(sentence) - else: - warn( - "The LLM returned an invalid response. Considering the sentence as unsupported..." - ) - unsupported_sentences.append(sentence) if unsupported_sentences: unsupported_sentences = "- " + "\n- ".join(unsupported_sentences) @@ -187,18 +247,42 @@ def validate_each_sentence( fix_value="\n".join(supported_sentences), ) return PassResult(metadata=metadata) + + def parse_response(self, response:str, pass_on_invalid:bool) -> bool: + response = response.lower() + # Extract decision + decision_match = re.search(r'(yes|no)', response) + decision = decision_match.group(1) if decision_match else None + if decision is None or decision == 'no': + return False + elif decision == 'yes': + return True + else: + if pass_on_invalid: + warn( + "The LLM returned an invalid response. Considering the sentence as supported..." + ) + return True + else: + warn( + "The LLM returned an invalid response. Considering the sentence as unsupported..." + ) + return False def validate_full_text( self, value: Any, query_function: Callable, metadata: Dict[str, Any] ) -> ValidationResult: """Validate the entire LLM text.""" pass_on_invalid = metadata.get("pass_on_invalid", False) # Default to False - + use_vectara = metadata.get("use_vectara", False) # Self-evaluate LLM with entire text - eval_response = self.evaluate_with_llm(value, query_function) - if eval_response == "yes": + if use_vectara: + passed = self.evaluate_with_vectara(value, pass_on_invalid=pass_on_invalid) + else: + passed = self.evaluate_with_llm(value, query_function, pass_on_invalid=pass_on_invalid) + if passed == True: return PassResult(metadata=metadata) - if eval_response == "no": + if passed == False: return FailResult( metadata=metadata, error_message=(