diff --git a/nano_vectordb/dbs.py b/nano_vectordb/dbs.py index 77915e6..17822a8 100644 --- a/nano_vectordb/dbs.py +++ b/nano_vectordb/dbs.py @@ -138,19 +138,33 @@ def __len__(self): return len(self.__storage["data"]) def query( - self, query: np.ndarray, top_k: int = 10, better_than_threshold: float = None - ): + self, + query: np.ndarray, + top_k: int = 10, + better_than_threshold: float = None, + filter_lambda: callable = None, + ) -> list[dict]: return self.usable_metrics[self.metric](query, top_k, better_than_threshold) def _cosine_query( - self, query: np.ndarray, top_k: int, better_than_threshold: float + self, + query: np.ndarray, + top_k: int, + better_than_threshold: float, + filter_lambda: callable = None, ): query = normalize(query) - scores = np.dot(self.__storage["matrix"], query) + if filter_lambda is None: + use_matrix = self.__storage["matrix"] + filter_index = np.arange(len(self.__storage["data"])) + else: + raise NotImplementedError("Filter lambda not implemented") + scores = np.dot(use_matrix, query) sort_index = np.argsort(scores)[-top_k:] sort_index = sort_index[::-1] + sort_abs_index = filter_index[sort_index] results = [] - for i in sort_index: + for i in sort_abs_index: if better_than_threshold is not None and scores[i] < better_than_threshold: break results.append({**self.__storage["data"][i], f_METRICS: scores[i]}) diff --git a/tests/test_init.py b/tests/test_init.py index c5c82bd..56c50de 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,6 +1,7 @@ import os import numpy as np from nano_vectordb import NanoVectorDB +from nano_vectordb.dbs import f_METRICS, f_ID, f_VECTOR def test_init(): @@ -14,7 +15,8 @@ def test_init(): print("Load", time() - start) fake_embeds = np.random.rand(data_len, fake_dim) - fakes_data = [{"__vector__": fake_embeds[i]} for i in range(data_len)] + fakes_data = [{f_VECTOR: fake_embeds[i], f_ID: i} for i in range(data_len)] + query_data = fake_embeds[data_len // 2] start = time() r = a.upsert(fakes_data) print("Upsert", time() - start) @@ -23,9 +25,12 @@ def test_init(): a = NanoVectorDB(fake_dim) start = time() - r = a.query(np.random.rand(fake_dim), 10, better_than_threshold=0.01) - print("Query", time() - start) + r = a.query(query_data, 10, better_than_threshold=0.01) + assert r[0][f_ID] == data_len // 2 print(r) + assert len(r) <= 10 + for d in r: + assert d[f_METRICS] >= 0.01 os.remove("nano-vectordb.json") @@ -40,10 +45,10 @@ def test_same_upsert(): print("Load", time() - start) fake_embeds = np.random.rand(data_len, fake_dim) - fakes_data = [{"__vector__": fake_embeds[i]} for i in range(data_len)] + fakes_data = [{f_VECTOR: fake_embeds[i]} for i in range(data_len)] r1 = a.upsert(fakes_data) assert len(r1["insert"]) == len(fakes_data) - fakes_data = [{"__vector__": fake_embeds[i]} for i in range(data_len)] + fakes_data = [{f_VECTOR: fake_embeds[i]} for i in range(data_len)] r2 = a.upsert(fakes_data) assert r2["update"] == r1["insert"] @@ -52,7 +57,7 @@ def test_get(): a = NanoVectorDB(1024) a.upsert( [ - {"__vector__": np.random.rand(1024), "__id__": str(i), "content": i} + {f_VECTOR: np.random.rand(1024), f_ID: str(i), "content": i} for i in range(100) ] ) @@ -67,7 +72,7 @@ def test_delete(): a = NanoVectorDB(1024) a.upsert( [ - {"__vector__": np.random.rand(1024), "__id__": str(i), "content": i} + {f_VECTOR: np.random.rand(1024), f_ID: str(i), "content": i} for i in range(100) ] )