From 527e8a5b91e3435d1b118bb3c8e7dcc4ec3f3f0f Mon Sep 17 00:00:00 2001 From: "hs.zhang" <22708345+cangfengzhs@users.noreply.github.com> Date: Mon, 1 Aug 2022 12:18:12 +0800 Subject: [PATCH] add point store and candidate set --- example/random/large_schema.ngql | 0 example/random/middle_schema.ngql | 0 example/random/small_schema.ngql | 6 + merak/candidate_set.py | 64 +++++++ merak/client.py | 14 +- merak/hnsw.py | 267 ++++++++++++++++-------------- merak/point.py | 104 ++---------- merak/point_store.py | 123 ++++++++++++++ tests/gen_data.py | 25 +++ tests/small_dataset_test.py | 78 ++++++--- 10 files changed, 436 insertions(+), 245 deletions(-) create mode 100644 example/random/large_schema.ngql create mode 100644 example/random/middle_schema.ngql create mode 100644 example/random/small_schema.ngql create mode 100644 merak/candidate_set.py create mode 100644 merak/point_store.py create mode 100644 tests/gen_data.py diff --git a/example/random/large_schema.ngql b/example/random/large_schema.ngql new file mode 100644 index 0000000..e69de29 diff --git a/example/random/middle_schema.ngql b/example/random/middle_schema.ngql new file mode 100644 index 0000000..e69de29 diff --git a/example/random/small_schema.ngql b/example/random/small_schema.ngql new file mode 100644 index 0000000..ef033c9 --- /dev/null +++ b/example/random/small_schema.ngql @@ -0,0 +1,6 @@ +CREATE SPACE `random_small`( partition_num=8, replica_factor=1, vid_type=INT64); + +create tag point(vector string); +create edge e(); + +insert vertex point(vector) VALUES 0:("")) diff --git a/merak/candidate_set.py b/merak/candidate_set.py new file mode 100644 index 0000000..f59446a --- /dev/null +++ b/merak/candidate_set.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 + +from asyncio import futures +from concurrent.futures.thread import _worker +from merak.point import Point +from typing import List, Tuple +from merak.point_store import PointStore +import heapq +from concurrent.futures import ThreadPoolExecutor,ProcessPoolExecutor +import concurrent + +executor = ThreadPoolExecutor(max_workers=40) + +class CandidateSet(object): + def __init__(self, target: Point, points: List[Point], max_size: int, point_store: PointStore) -> None: + global executor + self.target = target + self.max_size = max_size + self.point_store = point_store + self.executor = executor + self.points: List[Tuple(float, Point)] = [( + p.distance(self.target), p) for p in points] + self.points.sort(key=lambda x: x[0]) + self.visited = set([p.id for p in points]) + self.furthest = self.points[-1] + + def pop(self) -> Point: + + # if len(self.futures) != 0: + # ret = concurrent.futures.wait( + # self.futures, return_when=concurrent.futures.FIRST_COMPLETED) + # done = ret.done + # self.futures = ret.not_done + # for f in done: + # point: Point = f.result() + # if point.id in self.visited: + # continue + # if len(self.points) < self.max_size: + # heapq.heappush( + # self.points, (point.distance(self.target), point)) + # else: + # heapq.heappushpop( + # self.points, point.distance(self.target), point) + if len(self.points) == 0: + return None + ret = heapq.heappop(self.points) + return ret[1] + + def add(self, id_list: List[int]): + points = self.executor.map(self.worker,id_list) + for point in points: + if point.id in self.visited: + return + self.visited.add(point.id) + x = (point.distance(self.target), point) + self.points.append(x) + self.points.sort(key=lambda x:x[0]) + self.points = self.points[0:self.max_size] + + + + def worker(self, id: int) -> Point: + point = self.point_store.get_point(id, True) + return point diff --git a/merak/client.py b/merak/client.py index bb08f35..7f72890 100644 --- a/merak/client.py +++ b/merak/client.py @@ -43,7 +43,8 @@ def get_neighbors(self, vid) -> Tuple[str, Dict]: ''' # todo: replace t1 as tag, col1 as property # todo: use int id? - query = "FETCH PROP ON t1 \'{}\' YIELD properties(vertex).col1".format(vid) + query = "FETCH PROP ON t1 \'{}\' YIELD properties(vertex).col1".format( + vid) result = self.session.execute(query) if not result.is_succeeded(): raise RuntimeError("fetch failed") @@ -54,7 +55,8 @@ def get_neighbors(self, vid) -> Tuple[str, Dict]: raise RuntimeError("fetch no result") # todo: replace e1 as edge - query = "GO FROM \'{}\' OVER e1 YIELD rank(edge) as rank, dst(edge) as dst".format(vid) + query = "GO FROM \'{}\' OVER e1 YIELD rank(edge) as rank, dst(edge) as dst".format( + vid) result = self.session.execute(query) if not result.is_succeeded(): raise RuntimeError("go failed") @@ -76,7 +78,8 @@ def get_neighbors(self, vid) -> Tuple[str, Dict]: return (vec, neighbors) def insert_vertex(self, vid, vector): - query = "INSERT VERTEX t1(col1) VALUES \'{}\': (\'{}\')".format(vid, vector) + query = "INSERT VERTEX t1(col1) VALUES \'{}\': (\'{}\')".format( + vid, vector) result = self.session.execute(query) if not result.is_succeeded(): raise RuntimeError("insert vertex failed") @@ -97,7 +100,10 @@ def insert(self, batch: InsertBatch): raise RuntimeError(f"insert batch {batchStr} failed") def execute(self, query): - return self.session.execute(query) + session = self.pool.get_session('root', 'nebula') + result = session.execute(query) + session.release() + return result def close(self): self.session.release() diff --git a/merak/hnsw.py b/merak/hnsw.py index 0ecc94a..457113a 100644 --- a/merak/hnsw.py +++ b/merak/hnsw.py @@ -1,17 +1,36 @@ #!/usr/bin/env python3 +import pprint import random -from typing import List, Dict - -from merak.graph import LayeredGraph -from merak.point import Point, Points +from typing import List, Tuple +import heapq +from merak.point import Point +from merak.point_store import PointStore +from merak.candidate_set import CandidateSet +import logging + + +class HNSWConfig: + def __init__(self) -> None: + self.root_point = 0 + self.degree = 5 + self.max_degree = 10 + self.layer_factor = 4 + self.max_layer = 5 + self.candidate_set_size = 100 + self.insert_candidate_set_size = 1 + self.enable_heuristic = False class HNSW: - def __init__(self, max_top_layer: int) -> None: - self._graph = LayeredGraph(max_top_layer) + def __init__(self, config: HNSWConfig, point_store: PointStore) -> None: + self.config_ = config + self.point_store = point_store + + # def __init__(self, max_top_layer: int) -> None: + # self._graph = LayeredGraph(max_top_layer) - def __search_layer(self, q: Point, ep: List[int], ef: int, l: int) -> List[Point]: + def __search_layer(self, q: Point, candidates: List[Point], candidate_count: int, layer: int) -> List[Point]: ''' Search closest ef points in layer l, with ep as the entry point set Args: @@ -23,36 +42,43 @@ def __search_layer(self, q: Point, ep: List[int], ef: int, l: int) -> List[Point Returns: ef closest neighbors to q ''' - assert isinstance(ep, List) - - visited = {point_id for point_id in ep} - ep = [self._graph.get_point(id) for id in ep] # transform from id to point - result = Points(q, False, ep) - candidates = Points(q, True, ep) + # assert isinstance(ep, List) - while len(candidates) > 0: - curr = candidates.pop_nearest() - furthest_res = result.furthest() - - if curr.distance(q) > furthest_res.distance(q): - break + visited = set([p.id for p in candidates]) - for next_id in self._graph.get_neighbor_ids(l, curr.id): - if next_id in visited: - continue - visited.add(next_id) + result: List[Tuple(float, Point)] = [] - furthest_res = result.furthest() - next_point = self._graph.get_point(next_id) - if next_point.distance(q) < furthest_res.distance(q) or len(result) < ef: - candidates.push(next_point) - result.push(next_point) - while len(result) > ef: - result.pop_furthest() + # ep = [self._graph.get_point(id) + # for id in ep] # transform from id to point + # result = Points(q, False, ep) + # candidates = Points(q, True, ep) - return result.values - - def __select_neighbors_simple(self, q: Point, candidates: List[Point], m: int) -> List[Point]: + candidate_set = CandidateSet( + q, candidates, candidate_count, self.point_store) + while True: + p = candidate_set.pop() + if p is None: + break + pushed = False + pair = (-p.distance(q), p) + if len(result) < candidate_count: + heapq.heappush(result, pair) + pushed = True + else: + x = heapq.heappushpop(result, pair) + pushed = x[1]!=p + if pushed: + new_point_ids = [] + for n in p.neighbors[layer]: + if n in visited: + continue + visited.add(n) + new_point_ids.append(n) + candidate_set.add(new_point_ids) + result.sort(key=lambda x: -x[0]) + return [x[1] for x in result[:candidate_count]] + + def __select_neighbors_simple(self, q: Point, points: List[Point], m: int) -> List[Point]: ''' Select m nearest points from candidates to q Args: @@ -63,59 +89,56 @@ def __select_neighbors_simple(self, q: Point, candidates: List[Point], m: int) - Returns: m nearest points to q ''' - assert q is not None - assert len(candidates) >= m - - points = Points(q, True, candidates) - return [points.pop_nearest() for i in range(0, m)] - - def __select_neighbors_heuristic(self, q: Point, c: List[int], - m: int, l: int, extend: bool = True, keep: bool = True) -> List[Point]: - ''' Select nearest neighbors heuristically - - Args: - q: base element - c: candidate points - m: number of neighbors to return - l: layer number - extend: flag indicating whether or not to extend candidate list - keep: flag indicating whether or not to add discarded points - - Returns: - m points selected by the heuristic - ''' - - candidate_points = [self._graph.get_point(id) for id in c] - result = Points(q, nearest=True) - candidates = Points(q, nearest=True, points=candidate_points) - - if extend: - for p in candidate_points: - for next_id in self._graph.get_neighbor_ids(l, p.id): - if next_id not in candidates: - next_point = self._graph.get_point(next_id) - candidates.push(next_point) - - to_discard = Points(q, nearest=True) - while len(candidates) > 0 and len(result) < m: - curr = candidates.pop_nearest() - # TODO(spw): is this condition right? - if len(result) == 0 or curr.distance(q) < result.nearest().distance(q): - if curr != q: - result.push(curr) - else: - to_discard.push(curr) - - if keep: - while len(to_discard) > 0 and len(result) < m: - curr = to_discard.pop_nearest() - if curr != 1: - result.push(curr) - - return result.values - - def knn_search(self, q: Point, k: int, ef: int) -> List[Point]: + points.sort(key=lambda x: x.distance(q)) + return points[:m] + + # def __select_neighbors_heuristic(self, q: Point, c: List[int], + # m: int, l: int, extend: bool = True, keep: bool = True) -> List[Point]: + # ''' Select nearest neighbors heuristically + + # Args: + # q: base element + # c: candidate points + # m: number of neighbors to return + # l: layer number + # extend: flag indicating whether or not to extend candidate list + # keep: flag indicating whether or not to add discarded points + + # Returns: + # m points selected by the heuristic + # ''' + + # candidate_points = [self._graph.get_point(id) for id in c] + # result = Points(q, nearest=True) + # candidates = Points(q, nearest=True, points=candidate_points) + + # if extend: + # for p in candidate_points: + # for next_id in self._graph.get_neighbor_ids(l, p.id): + # if next_id not in candidates: + # next_point = self._graph.get_point(next_id) + # candidates.push(next_point) + + # to_discard = Points(q, nearest=True) + # while len(candidates) > 0 and len(result) < m: + # curr = candidates.pop_nearest() + # # TODO(spw): is this condition right? + # if len(result) == 0 or curr.distance(q) < result.nearest().distance(q): + # if curr != q: + # result.push(curr) + # else: + # to_discard.push(curr) + + # if keep: + # while len(to_discard) > 0 and len(result) < m: + # curr = to_discard.pop_nearest() + # if curr != 1: + # result.push(curr) + + # return result.values + + def knn_search(self, q: Point, k: int, candidate_count: int = None, high_layer_condidate_count: int = 1) -> List[Point]: ''' Search the nearest k points for q Args: @@ -125,57 +148,51 @@ def knn_search(self, q: Point, k: int, ef: int) -> List[Point]: Returns: K nearest elements to q ''' - entry_points = [] if self._graph.entry_point is None else [self._graph.entry_point] - for l in range(self._graph.top_layer, 0, -1): - nearest_points = self.__search_layer(q, entry_points, 1, l) - entry_points = [Points(q, True, nearest_points).pop_nearest()] - nearest_points = self.__search_layer(q, entry_points, ef, 0) + if candidate_count is None: + candidate_count = k*2 - num = min(len(nearest_points), k) - points = Points(q, True, nearest_points) - return [points.pop_nearest() for _ in range(num)] + root: Point = self.point_store.get_point(0, True) + entry_points: List[Point] = [root] + for l in range(self.config_.max_layer, 0, -1): + nearest_points = self.__search_layer( + q, entry_points, high_layer_condidate_count, l) + entry_points = nearest_points + nearest_points = self.__search_layer( + q, entry_points, candidate_count, 0) - def insert(self, q: Point, m: int, m_max: int, ef: int, ml: int): - ''' Insert element to graph with + return self.__select_neighbors_simple(q, nearest_points, k) - Args: - q: new element - m: number of established connections - m_max: maximum number of connections for each element per layer - ef: size of the dynamic candidate list - ml: normalization factor for level generation + def insert(self, q: Point,): + ''' Insert element to graph with ''' - add_batch = self._graph.add_batch() - - entry_points: List[Point] = [] if self._graph.entry_point is None else [ - self._graph.entry_point] new_layer = 0 - while new_layer < ml and random.randint(0, 10000) % 2 == 1: + + while new_layer < self.config_.max_layer and random.randint(0, self.config_.layer_factor-1) == 0: new_layer += 1 - add_batch.add_point(p) # l in [new_layer+1, top_layer], from top to bottom. # only find one entry point for next layer - for l in range(self._graph.top_layer, new_layer, -1): - nearest_points = self.__search_layer(q, entry_points, 1, l) - entry_points = [Points(q, nearest=True, points=nearest_points).nearest()] + root: Point = self.point_store.get_point(self.config_.root_point, True) + entry_points: List[Point] = [root] + + for layer in range(self.config_.max_layer, new_layer, -1): + nearest_points: List[Point] = self.__search_layer( + q, entry_points, self.config_.insert_candidate_set_size, layer) + entry_points = nearest_points # l in [0, new_layer], from top to bottom. # Find a entry point set for next layer - for l in range(min(self._graph.top_layer, new_layer), -1, -1): - nearest_points = self.__search_layer(q, entry_points, ef, l) - neighbors = self.__select_neighbors_heuristic(q, nearest_points, m, l) - for e in neighbors: - add_batch.add_edge(l, q, e) - - # TODO(spw): remove this temp for concurrent - # shrink connections if old element's edges greater than m_max - # for e in neighbors: - # curr_neighbors = self._graph.get_neighbors(l, e) - # if len(curr_neighbors) > m_max: - # new_neighbors = self.__select_neighbors_heuristic( - # e, curr_neighbors, m_max, l) - # self._graph.set_neighbors(l, e, new_neighbors) + for layer in range(min(self.config_.max_layer, new_layer), -1, -1): + nearest_points: List[Point] = self.__search_layer( + q, entry_points, self.config_.candidate_set_size, layer) + if self.config_.enable_heuristic: + neighbors: List[Point] = self.__select_neighbors_heuristic( + q, nearest_points, self.config_.degree, layer) + else: + neighbors: List[Point] = self.__select_neighbors_simple( + q, nearest_points, self.config_.degree) + for n in neighbors: + q.neighbors[layer].append(n.id) entry_points = nearest_points - self._graph.add(add_batch) + self.point_store.save_point(q) diff --git a/merak/point.py b/merak/point.py index 8254972..eb10c1d 100644 --- a/merak/point.py +++ b/merak/point.py @@ -3,109 +3,33 @@ import heapq import numpy as np from typing import List, Set, Union +from collections import defaultdict class Point(object): - def __init__(self, id: int, vec: np.ndarray, *args, **kwargs) -> None: - self._id = id - self._vec = vec - self.__dict__.update(kwargs) + def __init__(self, id: int, vec: np.ndarray = []) -> None: + self.id = id + self.vector = vec + self.neighbors = defaultdict(list) def __hash__(self) -> int: - return self._id + return self.id def __eq__(self, other: 'Point') -> bool: - if type(self) != type(other): - return False + assert type(self) == type(other) return self.id == other.id def __gt__(self, other: 'Point') -> bool: - if type(self) != type(other): - return False + assert type(self) == type(other) return self.id > other.id def __str__(self) -> str: - return f'point-{self._id}' + return f'point-{self.id}' - @property - def id(self): - return self._id - - @property - def vec(self): - return self._vec - - @property - def vec_str(self): - ''' just simple encoding now - ''' - return np.array2string(self._vec) + def __repr__(self) -> str: + return f'point-{self.id}' def distance(self, other: 'Point') -> float: - return np.linalg.norm(self.vec - other.vec) - - -class Points: - ''' Helper class to get the nearest and furthest element in an element vector. - ''' - - def __init__(self, base: Point, nearest: bool = True, points: List[Point] = []): - self._points_pair: List[(int, Point)] = [] - self._points_set: Set[int] = set() - self._base = base - self._nearest = nearest - for p in points: - self.push(p) - - def __len__(self) -> int: - return len(self._points_pair) - - def __contains__(self, p: Union[Point, int]) -> bool: - if type(p) == Point: - return p.id in self._points_set - if type(p) == int: - return p in self._points_set - return False - - @property - def base(self) -> Point: - return self._base - - @property - def values(self) -> List[Point]: - return [pair[1] for pair in self._points_pair] - - def push(self, p: Point): - if p.id in self._points_set: - return - - if self._nearest: - heapq.heappush(self._points_pair, (p.distance(self._base), p)) - else: - heapq.heappush(self._points_pair, (-p.distance(self._base), p)) - - self._points_set.add(p.id) - - def pop_nearest(self) -> Point: - assert self._nearest is True - _, p = heapq.heappop(self._points_pair) - self._points_set.remove(p.id) - return p - - def nearest(self) -> Point: - if self._nearest: - return self._points_pair[0][1] - else: - return max(self._points_pair)[1] - - def pop_furthest(self) -> Point: - assert self._nearest is False - _, p = heapq.heappop(self._points_pair) - self._points_set.remove(p.id) - return p - - def furthest(self) -> Point: - if self._nearest is False: - return self._points_pair[0][1] - else: - return max(self._points_pair)[1] + if self.id ==0 or other.id == 0: + return float('inf') + return np.linalg.norm(self.vector - other.vector) diff --git a/merak/point_store.py b/merak/point_store.py new file mode 100644 index 0000000..0dc0c50 --- /dev/null +++ b/merak/point_store.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +import numpy as np +from transformers import RetriBertConfig +from merak.client import Client +from merak.point import Point +from typing import List, Dict +from collections import defaultdict +import pprint +import json +from abc import ABC, abstractmethod + + +class PointStore(ABC): + def __init__(self) -> None: + pass + + @abstractmethod + def get_point(self, id: int, get_neighbors=False) -> Point: + pass + + @abstractmethod + def save_point(self, point: Point): + pass + + +class MemoryPointStore(PointStore): + def __init__(self) -> None: + super().__init__() + self.points: Dict[int, Point] = dict() + self.points[0] = Point(0, None) + + def get_point(self, id: int, get_neighbors=False) -> Point: + return self.points[id] + + def save_point(self, point: Point): + self.points[point.id] = point + for layer, neighbors in point.neighbors.items(): + for n in neighbors: + self.points[n].neighbors[layer].append(point.id) + + +class NebulaPointStore(PointStore): + def __init__(self, client: Client, space: str) -> None: + super().__init__() + self.client_ = client + self.space = space + result = self.client_.execute("use {}".format(space)) + + def __encode_vector(self, vector: np.ndarray) -> str: + return json.dumps(vector.tolist()) + + def __decode_vector(self, string: str) -> np.ndarray: + return np.array(json.loads(string)) + + def save_point(self, point: Point) -> bool: + # insert vertex + vector_string = self.__encode_vector(point.vector) + ngql = "USE {} ;INSERT VERTEX point(vector) VALUES {}:(\"{}\")".format(self.space, + point.id, vector_string) + result = self.client_.execute(ngql) + + assert result.is_succeeded() + + # insert edges + ngql = "USE {} ;INSERT EDGE e() VALUES ".format(self.space) + edges = [] + for layer, neighbors in point.neighbors.items(): + for n in neighbors: + edges.append("{}->{}@{}:()".format(point.id, n, layer)) + ngql += ",".join(edges) + result = self.client_.execute(ngql) + assert result.is_succeeded() + + def get_point(self, id: int, get_neighbors=False) -> Point: + point: Point = Point(id, None) + # get vertex + ngql = "USE {} ;FETCH PROP ON point {} YIELD properties(vertex).vector".format(self.space, + id) + result = self.client_.execute(ngql) + if not result.is_succeeded(): + + raise RuntimeError("Fetch vector of {} failed".format(id)) + + vec = result.row_values(0)[0].as_string() + if len(vec) == 0: + point.vector = None + else: + point.vector = self.__decode_vector(vec) + + # get neighbors + if not get_neighbors: + return point + + neighbors = defaultdict(set) + + ngql = "USE {} ;GO FROM {} OVER e YIELD rank(edge) as layer,dst(edge) as dst".format(self.space, + id) + result = self.client_.execute(ngql) + + if not result.is_succeeded(): + raise RuntimeError("Get neighbors failed") + + for i in range(result.row_size()): + rank = result.row_values(i)[0].as_int() + dst = result.row_values(i)[1].as_int() + neighbors[rank].add(dst) + + ngql = "USE {} ;GO FROM {} OVER e REVERSELY YIELD rank(edge) as layer,src(edge) as dst".format(self.space, + id) + result = self.client_.execute(ngql) + + if not result.is_succeeded(): + raise RuntimeError("Get neighbors failed") + + for i in range(result.row_size()): + rank = result.row_values(i)[0].as_int() + dst = result.row_values(i)[1].as_int() + neighbors[rank].add(dst) + + point.neighbors = neighbors + + return point diff --git a/tests/gen_data.py b/tests/gen_data.py new file mode 100644 index 0000000..b964e31 --- /dev/null +++ b/tests/gen_data.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +import numpy as np + + + +count = 10000 +dim = 100 + + +arr = np.random.random((count, dim)) + +nearest=[] +for i in range(len(arr)): + vec = arr[i] + distance = np.linalg.norm(vec-arr, axis=1) + index = np.array([k for k in range(len(arr))]) + dis_pairs = list(zip(distance, index)) + dis_pairs.sort(key=lambda x: x[0]) + nearest.append([x[1] for x in dis_pairs[1:21]]) +nearest = np.array(nearest) + +np.savez("data_{}_{}.npz".format(count,dim),arr=arr,nearest=nearest) + + + \ No newline at end of file diff --git a/tests/small_dataset_test.py b/tests/small_dataset_test.py index c18be72..561793f 100644 --- a/tests/small_dataset_test.py +++ b/tests/small_dataset_test.py @@ -2,51 +2,77 @@ import unittest from typing import Dict, List - +from tqdm import tqdm from merak.point import Point -from merak.hnsw import HNSW +from merak.hnsw import HNSW,HNSWConfig +from merak.point_store import MemoryPointStore,NebulaPointStore +from merak.client import Client +import time +import ipdb +point_count=1000 +dim = 20 -class TestHNSWSmallData(unittest.TestCase): - def setUp(self) -> None: - self._ml = 4 - self._ef = 100 - self._m = 3 - self._m_max = 8 - self._hnsw = HNSW(self._ml) +data = np.load("data_{}_{}.npz".format(point_count,dim)) + +client = Client("192.168.8.212",9669) + +point_store = MemoryPointStore() + +#point_store = NebulaPointStore(client,"random_small") +config = HNSWConfig() + +hnsw = HNSW(config,point_store) - self._point_num = 1000 - self._point_dim = 5 +def import_data(array): + #ipdb.set_trace() + with tqdm(desc="Import",total=len(array)) as pbar: + for i in range(len(array)): + p = Point(i+1,array[i]) + hnsw.insert(p) + pbar.update(1) + print("Finish import data") + +def search(): + pass + + +class TestHNSWSmallData(unittest.TestCase): + def setUp(self) -> None: # cook data - arr = np.random.random((self._point_num, self._point_dim)) - self._points = [] + arr = data["arr"] + self._points = {} self._nearest = dict() - self._k = 10 for i in range(len(arr)): + id = i+1 vec = arr[i] - distance = np.linalg.norm(vec-arr, axis=1) - index = np.array([k for k in range(len(arr))]) - dis_pairs = list(zip(distance, index)) - dis_pairs.sort(key=lambda x: x[0]) + self._nearest[id] =[Point(x+1,[]) for x in data["nearest"][i].tolist()] + p = Point(id, vec) + self._points[id] = p - self._nearest[i] = [Point(x[1], arr[x[1]]) - for x in dis_pairs[1:self._k+1]] # first is distance with itself - p = Point(i, vec) - self._hnsw.insert(p, self._m, self._m_max, self._ef, self._ml) - self._points.append(p) def test_search(self): - idx = 1 - knns = self._hnsw.knn_search(self._points[idx], self._k, self._ef) + start_time = time.time() + id= 1 + knns = hnsw.knn_search(self._points[id], 20+1) + end_time = time.time() for p in knns: print(p, sep=" ") print() - for p in self._nearest[idx]: + for p in self._nearest[id]: print(p, sep=" ") print() + count = 0 + for p in knns: + if p in self._nearest[id]: + count+=1 + print("use time: {}".format(end_time-start_time)) + print("acc: {}".format(count/len(self._nearest[id]))) + if __name__ == '__main__': + import_data(data["arr"]) unittest.main()