Skip to content

Commit

Permalink
adding generic inference binary+idk judges (#1316)
Browse files Browse the repository at this point in the history
* generic llm as a judge for binary and idk

Signed-off-by: Roni Friedman-Melamed <[email protected]>

* unify generic inference label

Signed-off-by: Roni Friedman-Melamed <[email protected]>

* ruff

Signed-off-by: Roni Friedman-Melamed <[email protected]>

* fix processor bug
Signed-off-by: lilacheden <[email protected]>

---------

Signed-off-by: Roni Friedman-Melamed <[email protected]>
Co-authored-by: lilacheden <[email protected]>
Co-authored-by: Yotam Perlitz <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2024
1 parent 582d96f commit e5b8355
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 49 deletions.
47 changes: 30 additions & 17 deletions prepare/metrics/llm_as_judge/binary_judge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unitxt import add_to_catalog
from unitxt.inference import GenericInferenceEngine
from unitxt.llm_as_judge import (
TaskBasedLLMasJudge,
)
Expand All @@ -17,6 +18,13 @@
"answer_relevance": {"q_a": "judge_answer_relevance"},
}

generic_engine_label = "generic_inference_engine"

inference_models = {
"llama_3_1_70b_instruct_wml": "engines.classification.llama_3_1_70b_instruct_wml",
generic_engine_label: GenericInferenceEngine(),
}


def get_prediction_field(metric_type):
return None if metric_type == "context_relevance" else "answer"
Expand All @@ -27,20 +35,25 @@ def get_prediction_field(metric_type):
task_name = f"tasks.rag_eval.{metric_type}.binary"

for use_logprobs in [True, False]:
logprobs_label = "_logprobs" if use_logprobs else ""
metric_label = f"{metric_type}_{template_short_name}{logprobs_label}"
metric = TaskBasedLLMasJudge(
inference_model="engines.classification.llama_3_1_70b_instruct_wml",
template=f"templates.rag_eval.{metric_type}.{template_name}{logprobs_label}",
task=task_name,
format="formats.empty",
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
infer_log_probs=use_logprobs,
)

add_to_catalog(
metric,
f"metrics.llm_as_judge.binary.llama_3_1_70b_instruct_wml_{metric_label}",
overwrite=True,
)
for inf_label, inference_model in inference_models.items():
if (
use_logprobs and inf_label == generic_engine_label
): # engine GenericInferenceEngine does not support logprobs
continue
logprobs_label = "_logprobs" if use_logprobs else ""
metric_label = f"{metric_type}_{template_short_name}{logprobs_label}"
metric = TaskBasedLLMasJudge(
inference_model=inference_model,
template=f"templates.rag_eval.{metric_type}.{template_name}{logprobs_label}",
task=task_name,
format="formats.empty",
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
infer_log_probs=use_logprobs,
)

add_to_catalog(
metric,
f"metrics.llm_as_judge.binary.{inf_label}_{metric_label}",
overwrite=True,
)
65 changes: 37 additions & 28 deletions prepare/metrics/llm_as_judge/conversation_idk.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,46 @@
from unitxt import add_to_catalog
from unitxt.inference import (
GenericInferenceEngine,
IbmGenAiInferenceEngine,
IbmGenAiInferenceEngineParams,
)
from unitxt.llm_as_judge import LLMAsJudge

platform = "ibm_gen_ai"
gen_params = IbmGenAiInferenceEngineParams(max_new_tokens=256)

model_name = "meta-llama/llama-3-70b-instruct"
template_name = "templates.response_assessment.judges.idk.v1"

inference_model = IbmGenAiInferenceEngine(model_name=model_name, parameters=gen_params)

model_label = model_name.split("/")[1].replace("-", "")
template_label = template_name.split(".")[-1]

metric_label = (
"metrics.llm_as_judge.rating." + model_label + "_template_" + template_label
)

cur_metric = LLMAsJudge(
inference_model=inference_model,
template=template_name,
task="rating.single_turn",
main_score=metric_label,
prediction_type="str",
)

# _description__= "Does the model response say I don't know?"

add_to_catalog(
cur_metric,
"metrics.llm_as_judge.conversation_answer_idk.llama3_v1_ibmgenai_judges",
overwrite=True,
)
inference_models = {
"llama3_v1_ibmgenai": {
"model_name": "llama370binstruct",
"inference_model": IbmGenAiInferenceEngine(
model_name="meta-llama/llama-3-70b-instruct",
parameters=IbmGenAiInferenceEngineParams(max_new_tokens=256),
),
},
"generic_inference_engine": {
"model_name": "generic",
"inference_model": (GenericInferenceEngine()),
},
}

for label, inference_model in inference_models.items():
model_label = inference_model["model_name"]
template_label = template_name.split(".")[-1]
metric_label = (
"metrics.llm_as_judge.rating." + model_label + "_template_" + template_label
)

cur_metric = LLMAsJudge(
inference_model=inference_model["inference_model"],
template=template_name,
task="rating.single_turn",
main_score=metric_label,
prediction_type="str",
)

# _description__= "Does the model response say I don't know?"

add_to_catalog(
cur_metric,
f"metrics.llm_as_judge.conversation_answer_idk.{label}_judges",
overwrite=True,
)
2 changes: 1 addition & 1 deletion prepare/processors/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@
)

add_to_catalog(
PostProcess(Cast(to="float", failure_default={"float": 0.5})),
PostProcess(Cast(to="float", failure_default=0.5)),
"processors.cast_to_float_return_0_5_if_failed",
overwrite=True,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"__type__": "task_based_ll_mas_judge",
"inference_model": {
"__type__": "generic_inference_engine"
},
"template": "templates.rag_eval.answer_correctness.judge_loose_match_no_context",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"main_score": "answer_correctness_q_a_gt_loose",
"prediction_field": "answer",
"infer_log_probs": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"__type__": "task_based_ll_mas_judge",
"inference_model": {
"__type__": "generic_inference_engine"
},
"template": "templates.rag_eval.answer_correctness.judge_simplified_format",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"main_score": "answer_correctness_q_a_gt_strict",
"prediction_field": "answer",
"infer_log_probs": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"__type__": "task_based_ll_mas_judge",
"inference_model": {
"__type__": "generic_inference_engine"
},
"template": "templates.rag_eval.answer_relevance.judge_answer_relevance",
"task": "tasks.rag_eval.answer_relevance.binary",
"format": "formats.empty",
"main_score": "answer_relevance_q_a",
"prediction_field": "answer",
"infer_log_probs": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"__type__": "task_based_ll_mas_judge",
"inference_model": {
"__type__": "generic_inference_engine"
},
"template": "templates.rag_eval.context_relevance.judge_context_relevance_ares",
"task": "tasks.rag_eval.context_relevance.binary",
"format": "formats.empty",
"main_score": "context_relevance_q_c_ares",
"prediction_field": null,
"infer_log_probs": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"__type__": "task_based_ll_mas_judge",
"inference_model": {
"__type__": "generic_inference_engine"
},
"template": "templates.rag_eval.correctness_holistic.judge_correctness_simple",
"task": "tasks.rag_eval.correctness_holistic.binary",
"format": "formats.empty",
"main_score": "correctness_holistic_q_c_a",
"prediction_field": "answer",
"infer_log_probs": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"__type__": "task_based_ll_mas_judge",
"inference_model": {
"__type__": "generic_inference_engine"
},
"template": "templates.rag_eval.faithfulness.judge_no_question_simplified",
"task": "tasks.rag_eval.faithfulness.binary",
"format": "formats.empty",
"main_score": "faithfulness_c_a",
"prediction_field": "answer",
"infer_log_probs": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"__type__": "task_based_ll_mas_judge",
"inference_model": {
"__type__": "generic_inference_engine"
},
"template": "templates.rag_eval.faithfulness.judge_with_question_simplified",
"task": "tasks.rag_eval.faithfulness.binary",
"format": "formats.empty",
"main_score": "faithfulness_q_c_a",
"prediction_field": "answer",
"infer_log_probs": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"__type__": "llm_as_judge",
"inference_model": {
"__type__": "generic_inference_engine"
},
"template": "templates.response_assessment.judges.idk.v1",
"task": "rating.single_turn",
"main_score": "metrics.llm_as_judge.rating.generic_template_v1",
"prediction_type": "str"
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
"operator": {
"__type__": "cast",
"to": "float",
"failure_default": {
"float": 0.5
}
"failure_default": 0.5
}
}

0 comments on commit e5b8355

Please sign in to comment.