-
Notifications
You must be signed in to change notification settings - Fork 251
/
Copy path_static_disk_index.py
197 lines (178 loc) · 9.21 KB
/
_static_disk_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import os
import warnings
from typing import Optional
import numpy as np
from . import (
DistanceMetric,
QueryResponse,
QueryResponseBatch,
VectorDType,
VectorLike,
VectorLikeBatch,
)
from . import _diskannpy as _native_dap
from ._common import (
_assert,
_assert_2d,
_assert_is_nonnegative_uint32,
_assert_is_positive_uint32,
_castable_dtype_or_raise,
_ensure_index_metadata,
_valid_index_prefix,
_valid_metric,
)
__ALL__ = ["StaticDiskIndex"]
class StaticDiskIndex:
"""
A StaticDiskIndex is a disk-backed index that is not mutable.
"""
def __init__(
self,
index_directory: str,
num_threads: int,
num_nodes_to_cache: int,
cache_mechanism: int = 1,
distance_metric: Optional[DistanceMetric] = None,
vector_dtype: Optional[VectorDType] = None,
dimensions: Optional[int] = None,
index_prefix: str = "ann",
):
"""
### Parameters
- **index_directory**: The directory containing the index files. This directory must contain the following
files:
- `{index_prefix}_sample_data.bin`
- `{index_prefix}_mem.index.data`
- `{index_prefix}_pq_compressed.bin`
- `{index_prefix}_pq_pivots.bin`
- `{index_prefix}_sample_ids.bin`
- `{index_prefix}_disk.index`
It may also include the following optional files:
- `{index_prefix}_vectors.bin`: Optional. `diskannpy` builder functions may create this file in the
`index_directory` if the index was created from a numpy array
- `{index_prefix}_metadata.bin`: Optional. `diskannpy` builder functions create this file to store metadata
about the index, such as vector dtype, distance metric, number of vectors and vector dimensionality.
If an index is built from the `diskann` cli tools, this file will not exist.
- **num_threads**: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system
- **num_nodes_to_cache**: Number of nodes to cache in memory (> -1)
- **cache_mechanism**: 1 -> use the generated sample_data.bin file for
the index to initialize a set of cached nodes, up to `num_nodes_to_cache`, 2 -> ready the cache for up to
`num_nodes_to_cache`, but do not initialize it with any nodes. Any other value disables node caching.
- **distance_metric**: A `str`, strictly one of {"l2", "mips", "cosine"}. `l2` and `cosine` are supported for all 3
vector dtypes, but `mips` is only available for single precision floats. Default is `None`. **This
value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it does not exist,
you are required to provide it.
- **vector_dtype**: The vector dtype this index has been built with. **This value is only used if a
`{index_prefix}_metadata.bin` file does not exist.** If it does not exist, you are required to provide it.
- **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same
dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it
does not exist, you are required to provide it.
- **index_prefix**: The prefix of the index files. Defaults to "ann".
"""
index_prefix = _valid_index_prefix(index_directory, index_prefix)
vector_dtype, metric, _, _ = _ensure_index_metadata(
index_prefix,
vector_dtype,
distance_metric,
1, # it doesn't matter because we don't need it in this context anyway
dimensions,
)
dap_metric = _valid_metric(metric)
_assert_is_nonnegative_uint32(num_threads, "num_threads")
_assert_is_nonnegative_uint32(num_nodes_to_cache, "num_nodes_to_cache")
self._vector_dtype = vector_dtype
if vector_dtype == np.uint8:
_index = _native_dap.StaticDiskUInt8Index
elif vector_dtype == np.int8:
_index = _native_dap.StaticDiskInt8Index
else:
_index = _native_dap.StaticDiskFloatIndex
self._index = _index(
distance_metric=dap_metric,
index_path_prefix=os.path.join(index_directory, index_prefix),
num_threads=num_threads,
num_nodes_to_cache=num_nodes_to_cache,
cache_mechanism=cache_mechanism,
)
def search(
self, query: VectorLike, k_neighbors: int, complexity: int, beam_width: int = 2
) -> QueryResponse:
"""
Searches the index by a single query vector.
### Parameters
- **query**: 1d numpy array of the same dimensionality and dtype of the index.
- **k_neighbors**: Number of neighbors to be returned. If query vector exists in index, it almost definitely
will be returned as well, so adjust your ``k_neighbors`` as appropriate. Must be > 0.
- **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size
increases accuracy at the cost of latency. Must be at least k_neighbors in size.
- **beam_width**: The beamwidth to be used for search. This is the maximum number of IO requests each query
will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query,
but might result in slightly higher total number of IO requests to SSD per query. For the highest query
throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search.
Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will
involve some tuning overhead.
"""
_query = _castable_dtype_or_raise(query, expected=self._vector_dtype)
_assert(len(_query.shape) == 1, "query vector must be 1-d")
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
_assert_is_positive_uint32(complexity, "complexity")
_assert_is_positive_uint32(beam_width, "beam_width")
if k_neighbors > complexity:
warnings.warn(
f"{k_neighbors=} asked for, but {complexity=} was smaller. Increasing {complexity} to {k_neighbors}"
)
complexity = k_neighbors
return self._index.search(
query=_query,
knn=k_neighbors,
complexity=complexity,
beam_width=beam_width,
)
def batch_search(
self,
queries: VectorLikeBatch,
k_neighbors: int,
complexity: int,
num_threads: int,
beam_width: int = 2,
) -> QueryResponseBatch:
"""
Searches the index by a batch of query vectors.
This search is parallelized and far more efficient than searching for each vector individually.
### Parameters
- **queries**: 2d numpy array, with column dimensionality matching the index and row dimensionality being the
number of queries intended to search for in parallel. Dtype must match dtype of the index.
- **k_neighbors**: Number of neighbors to be returned. If query vector exists in index, it almost definitely
will be returned as well, so adjust your ``k_neighbors`` as appropriate. Must be > 0.
- **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size
increases accuracy at the cost of latency. Must be at least k_neighbors in size.
- **num_threads**: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system
- **beam_width**: The beamwidth to be used for search. This is the maximum number of IO requests each query
will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query,
but might result in slightly higher total number of IO requests to SSD per query. For the highest query
throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search.
Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will
involve some tuning overhead.
"""
_queries = _castable_dtype_or_raise(queries, expected=self._vector_dtype)
_assert_2d(_queries, "queries")
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
_assert_is_positive_uint32(complexity, "complexity")
_assert_is_nonnegative_uint32(num_threads, "num_threads")
_assert_is_positive_uint32(beam_width, "beam_width")
if k_neighbors > complexity:
warnings.warn(
f"{k_neighbors=} asked for, but {complexity=} was smaller. Increasing {complexity} to {k_neighbors}"
)
complexity = k_neighbors
num_queries, dim = _queries.shape
return self._index.batch_search(
queries=_queries,
num_queries=num_queries,
knn=k_neighbors,
complexity=complexity,
beam_width=beam_width,
num_threads=num_threads,
)