-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The `SelfQuery PGVectorTranslator` is not correct. The operator is "eq" and not "$eq". This patch implements a new version of `PGVectorTranslator`. It's necessary to release a new version before accepting [another PR](langchain-ai/langchain#23217) in langchain core.
- Loading branch information
Showing
5 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Dict, Tuple, Union | ||
|
||
from langchain_core.structured_query import ( | ||
Comparator, | ||
Comparison, | ||
Operation, | ||
Operator, | ||
StructuredQuery, | ||
Visitor, | ||
) | ||
|
||
|
||
class PGVectorTranslator(Visitor): | ||
"""Translate `PGVector` internal query language elements to valid filters.""" | ||
|
||
allowed_operators = [Operator.AND, Operator.OR] | ||
"""Subset of allowed logical operators.""" | ||
allowed_comparators = [ | ||
Comparator.EQ, | ||
Comparator.NE, | ||
Comparator.GT, | ||
Comparator.LT, | ||
Comparator.IN, | ||
Comparator.NIN, | ||
Comparator.CONTAIN, | ||
Comparator.LIKE, | ||
] | ||
"""Subset of allowed logical comparators.""" | ||
|
||
def _format_func(self, func: Union[Operator, Comparator]) -> str: | ||
self._validate_func(func) | ||
return f"${func.value}" | ||
|
||
def visit_operation(self, operation: Operation) -> Dict: | ||
args = [arg.accept(self) for arg in operation.arguments] | ||
return {self._format_func(operation.operator): args} | ||
|
||
def visit_comparison(self, comparison: Comparison) -> Dict: | ||
return { | ||
comparison.attribute: { | ||
self._format_func(comparison.comparator): comparison.value | ||
} | ||
} | ||
|
||
def visit_structured_query( | ||
self, structured_query: StructuredQuery | ||
) -> Tuple[str, dict]: | ||
if structured_query.filter is None: | ||
kwargs = {} | ||
else: | ||
kwargs = {"filter": structured_query.filter.accept(self)} | ||
return structured_query.query, kwargs |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import Dict, Tuple | ||
|
||
import pytest as pytest | ||
from langchain_core.structured_query import ( | ||
Comparator, | ||
Comparison, | ||
Operation, | ||
Operator, | ||
StructuredQuery, | ||
) | ||
|
||
from langchain_postgres import PGVectorTranslator | ||
|
||
DEFAULT_TRANSLATOR = PGVectorTranslator() | ||
|
||
|
||
def test_visit_comparison() -> None: | ||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1) | ||
expected = {"foo": {"$lt": 1}} | ||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp) | ||
assert expected == actual | ||
|
||
|
||
@pytest.mark.skip("Not implemented") | ||
def test_visit_operation() -> None: | ||
op = Operation( | ||
operator=Operator.AND, | ||
arguments=[ | ||
Comparison(comparator=Comparator.LT, attribute="foo", value=2), | ||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), | ||
Comparison(comparator=Comparator.GT, attribute="abc", value=2.0), | ||
], | ||
) | ||
expected = { | ||
"foo": {"$lt": 2}, | ||
"bar": {"$eq": "baz"}, | ||
"abc": {"$gt": 2.0}, | ||
} | ||
actual = DEFAULT_TRANSLATOR.visit_operation(op) | ||
assert expected == actual | ||
|
||
|
||
def test_visit_structured_query() -> None: | ||
query = "What is the capital of France?" | ||
structured_query = StructuredQuery( | ||
query=query, | ||
filter=None, | ||
) | ||
expected: Tuple[str, Dict] = (query, {}) | ||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) | ||
assert expected == actual | ||
|
||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1) | ||
structured_query = StructuredQuery( | ||
query=query, | ||
filter=comp, | ||
) | ||
expected = (query, {"filter": {"foo": {"$lt": 1}}}) | ||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) | ||
assert expected == actual | ||
|
||
op = Operation( | ||
operator=Operator.AND, | ||
arguments=[ | ||
Comparison(comparator=Comparator.LT, attribute="foo", value=2), | ||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), | ||
Comparison(comparator=Comparator.GT, attribute="abc", value=2.0), | ||
], | ||
) | ||
structured_query = StructuredQuery( | ||
query=query, | ||
filter=op, | ||
) | ||
expected = ( | ||
query, | ||
{ | ||
"filter": { | ||
"$and": [ | ||
{"foo": {"$lt": 2}}, | ||
{"bar": {"$eq": "baz"}}, | ||
{"abc": {"$gt": 2.0}}, | ||
] | ||
} | ||
}, | ||
) | ||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) | ||
assert expected == actual |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
EXPECTED_ALL = [ | ||
"__version__", | ||
"PGVector", | ||
"PGVectorTranslator", | ||
"PostgresChatMessageHistory", | ||
] | ||
|
||
|