Skip to content

Commit

Permalink
feat: add filter lambda(50%)
Browse files Browse the repository at this point in the history
  • Loading branch information
gusye1234 committed Sep 18, 2024
1 parent 3673728 commit e8750bd
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
24 changes: 19 additions & 5 deletions nano_vectordb/dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,33 @@ def __len__(self):
return len(self.__storage["data"])

def query(
self, query: np.ndarray, top_k: int = 10, better_than_threshold: float = None
):
self,
query: np.ndarray,
top_k: int = 10,
better_than_threshold: float = None,
filter_lambda: callable = None,
) -> list[dict]:
return self.usable_metrics[self.metric](query, top_k, better_than_threshold)

def _cosine_query(
self, query: np.ndarray, top_k: int, better_than_threshold: float
self,
query: np.ndarray,
top_k: int,
better_than_threshold: float,
filter_lambda: callable = None,
):
query = normalize(query)
scores = np.dot(self.__storage["matrix"], query)
if filter_lambda is None:
use_matrix = self.__storage["matrix"]
filter_index = np.arange(len(self.__storage["data"]))
else:
raise NotImplementedError("Filter lambda not implemented")
scores = np.dot(use_matrix, query)
sort_index = np.argsort(scores)[-top_k:]
sort_index = sort_index[::-1]
sort_abs_index = filter_index[sort_index]
results = []
for i in sort_index:
for i in sort_abs_index:
if better_than_threshold is not None and scores[i] < better_than_threshold:
break
results.append({**self.__storage["data"][i], f_METRICS: scores[i]})
Expand Down
19 changes: 12 additions & 7 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import numpy as np
from nano_vectordb import NanoVectorDB
from nano_vectordb.dbs import f_METRICS, f_ID, f_VECTOR


def test_init():
Expand All @@ -14,7 +15,8 @@ def test_init():
print("Load", time() - start)

fake_embeds = np.random.rand(data_len, fake_dim)
fakes_data = [{"__vector__": fake_embeds[i]} for i in range(data_len)]
fakes_data = [{f_VECTOR: fake_embeds[i], f_ID: i} for i in range(data_len)]
query_data = fake_embeds[data_len // 2]
start = time()
r = a.upsert(fakes_data)
print("Upsert", time() - start)
Expand All @@ -23,9 +25,12 @@ def test_init():
a = NanoVectorDB(fake_dim)

start = time()
r = a.query(np.random.rand(fake_dim), 10, better_than_threshold=0.01)
print("Query", time() - start)
r = a.query(query_data, 10, better_than_threshold=0.01)
assert r[0][f_ID] == data_len // 2
print(r)
assert len(r) <= 10
for d in r:
assert d[f_METRICS] >= 0.01
os.remove("nano-vectordb.json")


Expand All @@ -40,10 +45,10 @@ def test_same_upsert():
print("Load", time() - start)

fake_embeds = np.random.rand(data_len, fake_dim)
fakes_data = [{"__vector__": fake_embeds[i]} for i in range(data_len)]
fakes_data = [{f_VECTOR: fake_embeds[i]} for i in range(data_len)]
r1 = a.upsert(fakes_data)
assert len(r1["insert"]) == len(fakes_data)
fakes_data = [{"__vector__": fake_embeds[i]} for i in range(data_len)]
fakes_data = [{f_VECTOR: fake_embeds[i]} for i in range(data_len)]
r2 = a.upsert(fakes_data)
assert r2["update"] == r1["insert"]

Expand All @@ -52,7 +57,7 @@ def test_get():
a = NanoVectorDB(1024)
a.upsert(
[
{"__vector__": np.random.rand(1024), "__id__": str(i), "content": i}
{f_VECTOR: np.random.rand(1024), f_ID: str(i), "content": i}
for i in range(100)
]
)
Expand All @@ -67,7 +72,7 @@ def test_delete():
a = NanoVectorDB(1024)
a.upsert(
[
{"__vector__": np.random.rand(1024), "__id__": str(i), "content": i}
{f_VECTOR: np.random.rand(1024), f_ID: str(i), "content": i}
for i in range(100)
]
)
Expand Down

0 comments on commit e8750bd

Please sign in to comment.