Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Context Matching #2293

Merged
merged 22 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
convert_labels_to_squad,
)
from haystack.utils.squad_data import SquadData
from haystack.utils.context_matching import calculate_context_similarity, match_context, match_contexts
213 changes: 213 additions & 0 deletions haystack/utils/context_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from collections import namedtuple
import multiprocessing
from typing import Generator, Iterable, Optional, Tuple, List, Union
import re
from rapidfuzz import fuzz
from multiprocessing import Pool
from tqdm import tqdm
from itertools import groupby


_CandidateScore = namedtuple("_CandidateScore", ["context_id", "candidate_id", "score"])


def _score_candidate(args: Tuple[Union[str, Tuple[object, str]], Tuple[object, str], int, bool]):
context, candidate, min_length, boost_split_overlaps = args
candidate_id, candidate_text = candidate
context_id, context_text = (None, context) if isinstance(context, str) else context
score = calculate_context_similarity(
context=context_text, candidate=candidate_text, min_length=min_length, boost_split_overlaps=boost_split_overlaps
)
return _CandidateScore(context_id=context_id, candidate_id=candidate_id, score=score)


def normalize_white_space_and_case(str: str) -> str:
return re.sub(r"\s+", " ", str).lower().strip()


def _no_processor(str: str) -> str:
return str


def calculate_context_similarity(
context: str, candidate: str, min_length: int = 100, boost_split_overlaps: bool = True
) -> float:
"""
Calculates the text similarity score of context and candidate.
The score's value ranges between 0.0 and 100.0.

:param context: The context to match.
:param candidate: The candidate to match the context.
:param min_length: The minimum string length context and candidate need to have in order to be scored.
Returns 0.0 otherwise.
:param boost_split_overlaps: Whether to boost split overlaps (e.g. [AB] <-> [BC]) that result from different preprocessing params.
If we detect that the score is near a half match and the matching part of the candidate is at its boundaries
we cut the context on the same side, recalculate the score and take the mean of both.
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
"""
# we need to handle short contexts/contents (e.g single word)
# as they produce high scores by matching if the chars of the word are contained in the other one
# this has to be done after normalizing
context = normalize_white_space_and_case(context)
candidate = normalize_white_space_and_case(candidate)
context_len = len(context)
candidate_len = len(candidate)
if candidate_len < min_length or context_len < min_length:
return 0.0

if context_len < candidate_len:
shorter = context
longer = candidate
shorter_len = context_len
longer_len = candidate_len
else:
shorter = candidate
longer = context
shorter_len = candidate_len
longer_len = context_len

score_alignment = fuzz.partial_ratio_alignment(shorter, longer, processor=_no_processor)
score = score_alignment.score

# Special handling for split overlaps (e.g. [AB] <-> [BC]):
# If we detect that the score is near a half match and the best fitting part of longer is at its boundaries
# we cut the shorter on the same side, recalculate the score and take the mean of both.
# Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total
if boost_split_overlaps and 40 <= score < 65:
cut_shorter_left = score_alignment.dest_start == 0
cut_shorter_right = score_alignment.dest_end == longer_len
cut_len = shorter_len // 2

if cut_shorter_left:
cut_score = fuzz.partial_ratio(shorter[cut_len:], longer, processor=_no_processor)
if cut_score > score:
score = (score + cut_score) / 2
if cut_shorter_right:
cut_score = fuzz.partial_ratio(shorter[:-cut_len], longer, processor=_no_processor)
if cut_score > score:
score = (score + cut_score) / 2

return score


def match_context(
context: str,
candidates: Generator[Tuple[str, str], None, None],
threshold: float = 65.0,
show_progress: bool = False,
num_processes: int = None,
chunksize: int = 1,
min_length: int = 100,
boost_split_overlaps: bool = True,
) -> List[Tuple[str, float]]:
"""
Matches the context against multiple candidates. Candidates consist of a tuple of an id and its text.

Returns a sorted list of the candidate ids and its scores filtered by the threshold in descending order.

:param context: The context to match.
:param candidates: The candidates to match the context.
A candidate consists of a tuple of candidate id and candidate text.
:param threshold: Score threshold that candidates must surpass to be included into the result list.
:param show_progress: Whether to show the progress of matching all candidates.
:param num_processes: The number of processes to be used for matching in parallel.
:param chunksize: The chunksize used during parallel processing.
If not specified chunksize is 1.
For very long iterables using a large value for chunksize can make the job complete much faster than using the default value of 1.
:param min_length: The minimum string length context and candidate need to have in order to be scored.
Returns 0.0 otherwise.
:param boost_split_overlaps: Whether to boost split overlaps (e.g. [AB] <-> [BC]) that result from different preprocessing params.
If we detect that the score is near a half match and the matching part of the candidate is at its boundaries
we cut the context on the same side, recalculate the score and take the mean of both.
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
"""
pool: Optional[multiprocessing.pool.Pool] = None
try:
score_candidate_args = ((context, candidate, min_length, boost_split_overlaps) for candidate in candidates)
if num_processes is None or num_processes > 1:
pool = Pool(processes=num_processes)
candidate_scores: Iterable = pool.imap_unordered(
_score_candidate, score_candidate_args, chunksize=chunksize
)
else:
candidate_scores = map(_score_candidate, score_candidate_args)

if show_progress:
candidate_scores = tqdm(candidate_scores)

matches = (candidate for candidate in candidate_scores if candidate.score > threshold)
sorted_matches = sorted(matches, key=lambda candidate: candidate.score, reverse=True)
match_list = list((candidate_score.candidate_id, candidate_score.score) for candidate_score in sorted_matches)

return match_list

finally:
if pool:
pool.close()


def match_contexts(
contexts: List[str],
candidates: Generator[Tuple[str, str], None, None],
threshold: float = 65.0,
show_progress: bool = False,
num_processes: int = None,
chunksize: int = 1,
min_length: int = 100,
boost_split_overlaps: bool = True,
) -> List[List[Tuple[str, float]]]:
"""
Matches the contexts against multiple candidates. Candidates consist of a tuple of an id and its string text.
This method iterates over candidates only once.

Returns for each context a sorted list of the candidate ids and its scores filtered by the threshold in descending order.

:param contexts: The contexts to match.
:param candidates: The candidates to match the context.
A candidate consists of a tuple of candidate id and candidate text.
:param threshold: Score threshold that candidates must surpass to be included into the result list.
:param show_progress: Whether to show the progress of matching all candidates.
:param num_processes: The number of processes to be used for matching in parallel.
:param chunksize: The chunksize used during parallel processing.
If not specified chunksize is 1.
For very long iterables using a large value for chunksize can make the job complete much faster than using the default value of 1.
:param min_length: The minimum string length context and candidate need to have in order to be scored.
Returns 0.0 otherwise.
:param boost_split_overlaps: Whether to boost split overlaps (e.g. [AB] <-> [BC]) that result from different preprocessing params.
If we detect that the score is near a half match and the matching part of the candidate is at its boundaries
we cut the context on the same side, recalculate the score and take the mean of both.
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
"""
pool: Optional[multiprocessing.pool.Pool] = None
try:
score_candidate_args = (
(context, candidate, min_length, boost_split_overlaps)
for candidate in candidates
for context in enumerate(contexts)
)

if num_processes is None or num_processes > 1:
pool = Pool(processes=num_processes)
candidate_scores: Iterable = pool.imap_unordered(
_score_candidate, score_candidate_args, chunksize=chunksize
)
else:
candidate_scores = map(_score_candidate, score_candidate_args)

if show_progress:
candidate_scores = tqdm(candidate_scores)

match_lists: List[List[Tuple[str, float]]] = list()
matches = (candidate for candidate in candidate_scores if candidate.score > threshold)
group_sorted_matches = sorted(matches, key=lambda candidate: candidate.context_id)
grouped_matches = groupby(group_sorted_matches, key=lambda candidate: candidate.context_id)
for context_id, group in grouped_matches:
sorted_group = sorted(group, key=lambda candidate: candidate.score, reverse=True)
match_list = list((candiate_score.candidate_id, candiate_score.score) for candiate_score in sorted_group)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in candiate_score

match_lists.insert(context_id, match_list)

return match_lists

finally:
if pool:
pool.close()
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ install_requires =
elasticsearch>=7.7,<=7.10
elastic-apm

# context matching
rapidfuzz

# Schema validation
jsonschema

Expand Down
Loading