Skip to content

Commit

Permalink
feat: add more ruff rules (#138)
Browse files Browse the repository at this point in the history
* feat: add more ruff rules

Signed-off-by: 盐粒 Yanli <[email protected]>

* chore: modified readme

Signed-off-by: 盐粒 Yanli <[email protected]>

* rename error class

Signed-off-by: 盐粒 Yanli <[email protected]>

---------

Signed-off-by: 盐粒 Yanli <[email protected]>
  • Loading branch information
BeautyyuYanli authored Nov 17, 2023
1 parent f8344dd commit f6e382d
Show file tree
Hide file tree
Showing 16 changed files with 138 additions and 86 deletions.
1 change: 1 addition & 0 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pdm sync
Run lint:
```bash
pdm run format
pdm run fix
pdm run check
```

Expand Down
7 changes: 4 additions & 3 deletions bindings/python/examples/psycopg_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pgvecto_rs.psycopg import register_vector

URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
port=os.getenv("DB_PORT", 5432),
port=os.getenv("DB_PORT", "5432"),
host=os.getenv("DB_HOST", "localhost"),
username=os.getenv("DB_USER", "postgres"),
password=os.getenv("DB_PASS", "mysecretpassword"),
Expand All @@ -18,7 +18,7 @@
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
register_vector(conn)
conn.execute(
"CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);"
"CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);",
)
conn.commit()
try:
Expand All @@ -39,7 +39,8 @@

# Select the row "hello pgvecto.rs"
cur = conn.execute(
"SELECT * FROM documents WHERE text = %s;", ("hello pgvecto.rs",)
"SELECT * FROM documents WHERE text = %s;",
("hello pgvecto.rs",),
)
target = cur.fetchone()[2]

Expand Down
8 changes: 5 additions & 3 deletions bindings/python/examples/sdk_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pgvecto_rs.sdk import PGVectoRs, Record, filters

URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format(
port=os.getenv("DB_PORT", 5432),
port=os.getenv("DB_PORT", "5432"),
host=os.getenv("DB_HOST", "localhost"),
username=os.getenv("DB_USER", "postgres"),
password=os.getenv("DB_PASS", "mysecretpassword"),
Expand Down Expand Up @@ -43,15 +43,17 @@ def embed(text: str):
# Query (With a filter from the filters module)
print("#################### First Query ####################")
for record, dis in client.search(
target, filter=filters.meta_contains({"src": "one"})
target,
filter=filters.meta_contains({"src": "one"}),
):
print(f"DISTANCE SCORE: {dis}")
print(record)

# Another Query (Equivalent to the first one, but with a lambda filter written by hand)
print("#################### Second Query ####################")
for record, dis in client.search(
target, filter=lambda r: r.meta.contains({"src": "one"})
target,
filter=lambda r: r.meta.contains({"src": "one"}),
):
print(f"DISTANCE SCORE: {dis}")
print(record)
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/examples/sqlalchemy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pgvecto_rs.sqlalchemy import Vector

URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format(
port=os.getenv("DB_PORT", 5432),
port=os.getenv("DB_PORT", "5432"),
host=os.getenv("DB_HOST", "localhost"),
username=os.getenv("DB_USER", "postgres"),
password=os.getenv("DB_PASS", "mysecretpassword"),
Expand Down Expand Up @@ -53,7 +53,7 @@ def __repr__(self) -> str:
stmt = select(
Document.text,
Document.embedding.squared_euclidean_distance(target.embedding).label(
"distance"
"distance",
),
).order_by("distance")
for doc in session.execute(stmt):
Expand Down
40 changes: 20 additions & 20 deletions bindings/python/pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 27 additions & 12 deletions bindings/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name = "pgvecto-rs"
version = "0.1.3"
description = "Python binding for pgvecto.rs"
authors = [
{ name = "TensorChord", email = "[email protected]" },
{ name = "盐粒 Yanli", email = "[email protected]" },
{ name = "TensorChord", email = "[email protected]" },
{ name = "盐粒 Yanli", email = "[email protected]" },
]
dependencies = [
"numpy>=1.23",
Expand All @@ -23,15 +23,15 @@ classifiers = [

[build-system]
build-backend = "pdm.backend"
requires = [
requires = [
"pdm-backend",
]

[project.optional-dependencies]
psycopg3 = [
psycopg3 = [
"psycopg[binary]>=3.1.12",
]
sdk = [
sdk = [
"openai>=1.2.2",
"pgvecto_rs[sqlalchemy]",
]
Expand All @@ -40,19 +40,34 @@ sqlalchemy = [
"SQLAlchemy>=2.0.23",
]
[tool.pdm.dev-dependencies]
lint = ["ruff>=0.1.1"]
lint = ["ruff>=0.1.5"]
test = ["pytest>=7.4.3"]

[tool.pdm.scripts]
test = "pytest tests/"
test = "pytest tests/"
format = "ruff format ."
fix = "ruff --fix ."
check = { composite = ["ruff format . --check", "ruff ."] }
fix = "ruff --fix ."
check = { composite = ["ruff format . --check", "ruff ."] }

[tool.ruff]
select = ["E", "F", "I", "TID"]
ignore = ["E731", "E501"]
src = ["src"]
select = [
"E", #https://docs.astral.sh/ruff/rules/#error-e
"F", #https://docs.astral.sh/ruff/rules/#pyflakes-f
"I", #https://docs.astral.sh/ruff/rules/#isort-i
"TID", #https://docs.astral.sh/ruff/rules/#flake8-tidy-imports-tid
"S", #https://docs.astral.sh/ruff/rules/#flake8-bandit-s
"B", #https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"SIM", #https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
"N", #https://docs.astral.sh/ruff/rules/#pep8-naming-n
"PT", #https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
"TRY", #https://docs.astral.sh/ruff/rules/#tryceratops-try
"FLY", #https://docs.astral.sh/ruff/rules/#flynt-fly
"PL", #https://docs.astral.sh/ruff/rules/#pylint-pl
"NPY", #https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy
"RUF", #https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
]
ignore = ["S101", "E731", "E501"]
src = ["src"]

[tool.pytest.ini_options]
addopts = "-r aR"
25 changes: 25 additions & 0 deletions bindings/python/src/pgvecto_rs/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np


class PGVectoRsError(ValueError):
pass


class NDArrayDimensionError(PGVectoRsError):
def __init__(self, dim: int) -> None:
super().__init__(f"ndarray must be 1D for vector, got {dim}D")


class NDArrayDtypeError(PGVectoRsError):
def __init__(self, dtype: np.dtype) -> None:
super().__init__(f"ndarray data type must be numeric for vector, got {dtype}")


class BuiltinListTypeError(PGVectoRsError):
def __init__(self) -> None:
super().__init__("list data type must be numeric for vector")


class VectorDimensionError(PGVectoRsError):
def __init__(self, dim: int) -> None:
super().__init__(f"vector dimension must be > 0, got {dim}")
2 changes: 1 addition & 1 deletion bindings/python/src/pgvecto_rs/psycopg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def register_vector_async(context: Connection):

def register_vector_info(context: Connection, info: TypeInfo):
if info is None:
raise ProgrammingError("vector type not found in the database")
raise ProgrammingError(info="vector type not found in the database")
info.register(context)

class VectorTextDumper(VectorDumper):
Expand Down
10 changes: 7 additions & 3 deletions bindings/python/src/pgvecto_rs/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
"""Connect to an existing table or create a new empty one.
Args:
----
db_url (str): url to the database.
table_name (str): name of the table.
dimension (int): dimension of the embeddings.
Expand All @@ -36,7 +37,8 @@ def __init__(
class _Table(RecordORM):
__tablename__ = f"collection_{collection_name}"
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True), primary_key=True
postgresql.UUID(as_uuid=True),
primary_key=True,
)
text: Mapped[str] = mapped_column(String)
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
Expand All @@ -59,7 +61,7 @@ def insert(self, records: List[Record]) -> None:
text=record.text,
meta=record.meta,
embedding=record.embedding,
)
),
)
session.commit()

Expand All @@ -73,13 +75,15 @@ def search(
"""Search for the nearest records.
Args:
----
embedding : Target embedding.
distance_op : Distance op.
top_k : Max records to return. Defaults to 4.
filter : Read our document. Defaults to None.
order_by_dis : Order by distance. Defaults to True.
Returns:
-------
List of records and coresponding distances.
"""
Expand All @@ -88,7 +92,7 @@ def search(
select(
self._table,
self._table.embedding.op(distance_op, return_type=Float)(
embedding
embedding,
).label("distance"),
)
.limit(top_k)
Expand Down
9 changes: 5 additions & 4 deletions bindings/python/src/pgvecto_rs/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sqlalchemy.types as types
from sqlalchemy import types

from pgvecto_rs.errors import VectorDimensionError
from pgvecto_rs.utils import serializer


Expand All @@ -8,13 +9,13 @@ class Vector(types.UserDefinedType):

def __init__(self, dim):
if dim < 0:
raise ValueError("negative dim is not allowed")
raise VectorDimensionError(dim)
self.dim = dim

def get_col_spec(self, **kw):
if self.dim is None or self.dim == 0:
return "VECTOR"
return "VECTOR({})".format(self.dim)
return f"VECTOR({self.dim})"

def bind_processor(self, dialect):
def _processor(value):
Expand All @@ -28,7 +29,7 @@ def _processor(value):

return _processor

class comparator_factory(types.UserDefinedType.Comparator):
class comparator_factory(types.UserDefinedType.Comparator): # noqa: N801
def squared_euclidean_distance(self, other):
return self.op("<->", return_type=types.Float)(other)

Expand Down
12 changes: 9 additions & 3 deletions bindings/python/src/pgvecto_rs/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

import numpy as np

from pgvecto_rs.errors import (
BuiltinListTypeError,
NDArrayDimensionError,
NDArrayDtypeError,
)


def ignore_none(func):
@wraps(func)
Expand All @@ -26,9 +32,9 @@ def validate_ndarray(func):
def _func(value: np.ndarray, *args, **kwargs):
if isinstance(value, np.ndarray):
if value.ndim != 1:
raise ValueError("ndarray must be 1D for vector")
raise NDArrayDimensionError(value.ndim)
if not np.issubdtype(value.dtype, np.number):
raise ValueError("ndarray data type must be numeric for vector")
raise NDArrayDtypeError(value.dtype)
return func(value, *args, **kwargs)

return _func
Expand All @@ -41,7 +47,7 @@ def validate_builtin_list(func):
def _func(value: list, *args, **kwargs):
if isinstance(value, list):
if not all(isinstance(x, (int, float)) for x in value):
raise ValueError("list data type must be numeric for vector")
raise BuiltinListTypeError()
value = np.array(value, dtype=np.float32)
return func(value, *args, **kwargs)

Expand Down
Loading

0 comments on commit f6e382d

Please sign in to comment.