Skip to content

Commit

Permalink
feat: added logic to handle labeled answers (#742)
Browse files Browse the repository at this point in the history
* added logic to handle labeled answers

* metrics fix

* updated is_page logic to be pair of doc + page_num. Added new versioning

* added notes for rules on what is correct
  • Loading branch information
brianshen3 authored Jan 16, 2025
1 parent 81fe74f commit acfbcad
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = "Apache-2.0"
requires-python = ">=3.8,<3.12"

dependencies = [
"kolena>=1.50.0,<2",
"kolena>=1.51.1,<2",
"s3fs>=2024.10.0",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,64 @@
from retrieval_augmented_generation.constants import ID_FIELDS


def is_doc_retrieved(retrieved_contents: list, doc_name: str) -> bool:
return any([doc_name in content.locator for content in retrieved_contents])
def is_doc_retrieved(retrieved_contents: list, doc_names: list) -> bool:
# NOTE: if one doc is retrieved, it is considered correct
return any(any(doc_name in content.locator for content in retrieved_contents) for doc_name in doc_names)


def is_page_retrieved(retrieved_contents: list, doc_name: str, relevant_pages: list) -> bool:
retrieved_pages = [content.page_number for content in retrieved_contents if doc_name in content.locator]
def is_page_retrieved(retrieved_contents: list, doc_names: list, relevant_pages: list) -> bool:
# NOTE: if one page of any doc is retrieved, it is considered correct
retrieved_pairs = [
(content.locator.split("/")[-1].replace(".pdf", ""), content.page_number) for content in retrieved_contents
]
return any(pair in retrieved_pairs for pair in relevant_pages)

# NOTE: all relevant pages must be retrieved to be considered correct.
return set(relevant_pages).issubset(retrieved_pages)

def extract_doc_names(labeling_task: dict) -> list:
if labeling_task is None:
return []
return [
content.locator.split("/")[-1].replace(".pdf", "")
for content in labeling_task.get("retrieved_contents", []) or []
]


def extract_relevant_pages(labeling_task: dict) -> list:
if labeling_task is None:
return []
contents = labeling_task.get("retrieved_contents", []) or []
# Create pairs of (doc_name, page_number)
return [(content.locator.split("/")[-1].replace(".pdf", ""), content.page_number) for content in contents]


def compute_metrics(df_dataset: pd.DataFrame, df_results: pd.DataFrame) -> pd.DataFrame:
ground_truth_columns = ["doc_name", "relevant_pages", "financebench_id"]
ground_truth_columns = ["doc_names", "relevant_pages", "financebench_id"]

is_labeled = "labeling_task" in df_dataset.columns
if is_labeled:
# Extract document names and page numbers from labeling task
df_dataset["doc_names"] = df_dataset["labeling_task"].apply(extract_doc_names)
df_dataset["relevant_pages"] = df_dataset["labeling_task"].apply(extract_relevant_pages)
else:
df_dataset["doc_names"] = df_dataset["doc_name"].apply(lambda x: [x])
# Create pairs of (doc_name, page_number) for non-labeled data
df_dataset["relevant_pages"] = df_dataset.apply(
lambda row: [(row["doc_name"], page) for page in row["relevant_pages"]],
axis=1,
)
assert set(ground_truth_columns).issubset(
df_dataset.columns,
), f"ground truth columns {ground_truth_columns} cannot be found in dataset dataframe."

df = df_results.merge(df_dataset[ground_truth_columns], on=ID_FIELDS, how="left")
# Include doc_names in the merge
columns_to_merge = ground_truth_columns + ["doc_names"]
df = df_results.merge(df_dataset[columns_to_merge], on=ID_FIELDS, how="left")

metrics = []
for record in df.itertuples():
metrics.append(
dict(
is_doc_retrieved=is_doc_retrieved(record.retrieved_contents, record.doc_name),
is_page_retrieved=is_page_retrieved(record.retrieved_contents, record.doc_name, record.relevant_pages),
is_doc_retrieved=is_doc_retrieved(record.retrieved_contents, record.doc_names),
is_page_retrieved=is_page_retrieved(record.retrieved_contents, record.doc_names, record.relevant_pages),
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ def to_documents(retrieved_contents: list[dict[str, Any]]) -> list:

def run(args: Namespace) -> None:
model_name = MODEL_NAME[args.model]
df_results = pd.read_json(f"{S3_BUCKET}/{DATASET}/results/raw/{model_name}.jsonl", lines=True)
df_results = pd.read_json(
f"{S3_BUCKET}/{DATASET}/results/raw/{model_name}.jsonl",
lines=True,
)
df_results["retrieved_contents"] = df_results["retrieved_contents"].apply(to_documents)
if args.evaluate:
df_dataset = download_dataset(args.dataset_name)
df_dataset = download_dataset(args.dataset_name, include_extracted_properties=True)
df_metrics = compute_metrics(df_dataset, df_results)
df_results = pd.concat([df_results, df_metrics], axis=1)
upload_results(args.dataset_name, model_name, df_results)
Expand Down

0 comments on commit acfbcad

Please sign in to comment.