diff --git a/python/pylibraft/pylibraft/neighbors/__init__.pxd b/python/pylibraft/pylibraft/neighbors/__init__.pxd index 273b4497cc..ac9c3224ed 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.pxd +++ b/python/pylibraft/pylibraft/neighbors/__init__.pxd @@ -12,3 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from c_ivf_pq cimport cudaDataType_t.CUDA_R_32F as CUDA_R_32F +from c_ivf_pq cimport cudaDataType_t.CUDA_R_16F as CUDA_R_16F +from c_ivf_pq cimport cudaDataType_t.CUDA_R_8U as CUDA_R_8U + +from c_ivf_pq cimport codebook_gen.PER_SUBSPACE as PER_SUBSPACE +from c_ivf_pq cimport codebook_gen.PER_CLUSTER as PER_CLUSTER + +from c_ivf_pq cimport index_params +from c_ivf_pq cimport search_params \ No newline at end of file diff --git a/python/pylibraft/pylibraft/neighbors/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/c_ivf_pq.pxd index 1556c25e62..91e1eae836 100644 --- a/python/pylibraft/pylibraft/neighbors/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/c_ivf_pq.pxd @@ -59,7 +59,7 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \ PER_CLUSTER "raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER" - cdef cppclass index_params(ann_index_params): + cpdef cppclass index_params(ann_index_params): uint32_t n_lists uint32_t kmeans_n_iters double kmeans_trainset_fraction @@ -78,7 +78,7 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \ uint32_t pq_dim, uint32_t n_nonempty_lists) - cdef cppclass search_params(ann_search_params): + cpdef cppclass search_params(ann_search_params): uint32_t n_probes cudaDataType_t lut_dtype cudaDataType_t internal_distance_dtype diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq.pyx index d6553f3cd7..5ca75da794 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq.pyx @@ -32,6 +32,9 @@ from rmm._lib.memory_resource cimport device_memory_resource cimport pylibraft.neighbors.c_ivf_pq as c_ivf_pq +from pylibraft.neighbors.c_ivf_pq cimport index_params +from pylibraft.neighbors.c_ivf_pq cimport search_params + def is_c_cont(cai): dt = np.dtype(cai["typestr"]) return "strides" not in cai or \ diff --git a/python/pylibraft/pylibraft/test/test_ivf_pq.py b/python/pylibraft/pylibraft/test/test_ivf_pq.py index 7201d23ffa..c94e2e44c3 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_pq.py +++ b/python/pylibraft/pylibraft/test/test_ivf_pq.py @@ -403,3 +403,8 @@ def test_search_inputs(params): with pytest.raises(Exception): nn.search(queries_device, k, out_idx_device, out_dist_device, n_probes=50) + + +def test_new_api(): + params = IvfPq.index_params + assert params.n_litst > 0