Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Nov 8, 2022
1 parent 989dfb9 commit 94e1adf
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
10 changes: 10 additions & 0 deletions python/pylibraft/pylibraft/neighbors/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from c_ivf_pq cimport cudaDataType_t.CUDA_R_32F as CUDA_R_32F
from c_ivf_pq cimport cudaDataType_t.CUDA_R_16F as CUDA_R_16F
from c_ivf_pq cimport cudaDataType_t.CUDA_R_8U as CUDA_R_8U

from c_ivf_pq cimport codebook_gen.PER_SUBSPACE as PER_SUBSPACE
from c_ivf_pq cimport codebook_gen.PER_CLUSTER as PER_CLUSTER

from c_ivf_pq cimport index_params
from c_ivf_pq cimport search_params
4 changes: 2 additions & 2 deletions python/pylibraft/pylibraft/neighbors/c_ivf_pq.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \
PER_CLUSTER "raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER"


cdef cppclass index_params(ann_index_params):
cpdef cppclass index_params(ann_index_params):
uint32_t n_lists
uint32_t kmeans_n_iters
double kmeans_trainset_fraction
Expand All @@ -78,7 +78,7 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \
uint32_t pq_dim,
uint32_t n_nonempty_lists)

cdef cppclass search_params(ann_search_params):
cpdef cppclass search_params(ann_search_params):
uint32_t n_probes
cudaDataType_t lut_dtype
cudaDataType_t internal_distance_dtype
Expand Down
3 changes: 3 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ from rmm._lib.memory_resource cimport device_memory_resource

cimport pylibraft.neighbors.c_ivf_pq as c_ivf_pq

from pylibraft.neighbors.c_ivf_pq cimport index_params
from pylibraft.neighbors.c_ivf_pq cimport search_params

def is_c_cont(cai):
dt = np.dtype(cai["typestr"])
return "strides" not in cai or \
Expand Down
5 changes: 5 additions & 0 deletions python/pylibraft/pylibraft/test/test_ivf_pq.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,8 @@ def test_search_inputs(params):

with pytest.raises(Exception):
nn.search(queries_device, k, out_idx_device, out_dist_device, n_probes=50)


def test_new_api():
params = IvfPq.index_params
assert params.n_litst > 0

0 comments on commit 94e1adf

Please sign in to comment.