Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add __cointains__ to Span #3446

Merged
merged 2 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions haystack/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,47 @@ class Span:
:param end: Position where the spand ends
"""

def __contains__(self, value):
"""
Checks for inclusion of the given value into the interval defined by Span.
```
assert 10 in Span(5, 15) # True
assert 20 in Span(1, 15) # False
```
Includes the left edge, but not the right edge.
```
assert 5 in Span(5, 15) # True
assert 15 in Span(5, 15) # False
```
Works for numbers and all values that can be safely converted into floats.
```
assert 10.0 in Span(5, 15) # True
assert "10" in Span(5, 15) # True
```
It also works for Span objects, returning True only if the given
Span is fully contained into the original Span.
As for numerical values, the left edge is included, the right edge is not.
```
assert Span(10, 11) in Span(5, 15) # True
assert Span(5, 10) in Span(5, 15) # True
assert Span(10, 15) in Span(5, 15) # False
assert Span(5, 15) in Span(5, 15) # False
assert Span(5, 14) in Span(5, 15) # True
assert Span(0, 1) in Span(5, 15) # False
assert Span(0, 10) in Span(5, 15) # False
assert Span(10, 20) in Span(5, 15) # False
```
"""
if isinstance(value, Span):
return self.start <= value.start and self.end > value.end
try:
value = float(value)
return self.start <= value < self.end
except Exception as e:
raise ValueError(
f"Cannot use 'in' with a value of type {type(value)}. Use numeric values or Span objects."
) from e


@dataclass
class Answer:
Expand Down
29 changes: 29 additions & 0 deletions test/others/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,32 @@ def test_deserialize_speech_answer():
context_audio=SAMPLES_PATH / "audio" / "the context for this answer is here.wav",
)
assert speech_answer == SpeechAnswer.from_dict(speech_answer.to_dict())


def test_span_in():
assert 10 in Span(5, 15)
assert not 20 in Span(1, 15)


def test_span_in_edges():
assert 5 in Span(5, 15)
assert not 15 in Span(5, 15)


def test_span_in_other_values():
assert 10.0 in Span(5, 15)
assert "10" in Span(5, 15)
with pytest.raises(ValueError):
"hello" in Span(5, 15)


def test_assert_span_vs_span():
assert Span(10, 11) in Span(5, 15)
assert Span(5, 10) in Span(5, 15)
assert not Span(10, 15) in Span(5, 15)
assert not Span(5, 15) in Span(5, 15)
assert Span(5, 14) in Span(5, 15)

assert not Span(0, 1) in Span(5, 15)
assert not Span(0, 10) in Span(5, 15)
assert not Span(10, 20) in Span(5, 15)