-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update CAGRA serialization #1755
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,6 +162,31 @@ cdef class IndexFloat(Index): | |
attr_str = [m_str] + attr_str | ||
return "Index(type=CAGRA, " + (", ".join(attr_str)) + ")" | ||
|
||
@auto_sync_handle | ||
def update_dataset(self, dataset, handle=None): | ||
""" Replace the dataset with a new dataset. | ||
|
||
Parameters | ||
---------- | ||
dataset : array interface compliant matrix shape (n_samples, dim) | ||
{handle_docstring} | ||
""" | ||
cdef device_resources* handle_ = \ | ||
<device_resources*><size_t>handle.getHandle() | ||
|
||
dataset_ai = wrap_array(dataset) | ||
dataset_dt = dataset_ai.dtype | ||
_check_input_array(dataset_ai, [np.dtype("float32")]) | ||
|
||
if dataset_ai.from_cai: | ||
self.index[0].update_dataset(deref(handle_), | ||
get_dmv_float(dataset_ai, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is where we could use make_const_mdspan. It would simplify things so that we don't need to make non-const functions everywhere (which kind of circumvents the const functions). |
||
check_shape=True)) | ||
else: | ||
self.index[0].update_dataset(deref(handle_), | ||
get_hmv_float(dataset_ai, | ||
check_shape=True)) | ||
|
||
@property | ||
def metric(self): | ||
return self.index[0].metric() | ||
|
@@ -195,6 +220,31 @@ cdef class IndexInt8(Index): | |
self.index = new c_cagra.index[int8_t, uint32_t]( | ||
deref(handle_)) | ||
|
||
@auto_sync_handle | ||
def update_dataset(self, dataset, handle=None): | ||
""" Replace the dataset with a new dataset. | ||
|
||
Parameters | ||
---------- | ||
dataset : array interface compliant matrix shape (n_samples, dim) | ||
{handle_docstring} | ||
""" | ||
cdef device_resources* handle_ = \ | ||
<device_resources*><size_t>handle.getHandle() | ||
|
||
dataset_ai = wrap_array(dataset) | ||
dataset_dt = dataset_ai.dtype | ||
_check_input_array(dataset_ai, [np.dtype("byte")]) | ||
|
||
if dataset_ai.from_cai: | ||
self.index[0].update_dataset(deref(handle_), | ||
get_dmv_int8(dataset_ai, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make_const_mdspan here too. |
||
check_shape=True)) | ||
else: | ||
self.index[0].update_dataset(deref(handle_), | ||
get_hmv_int8(dataset_ai, | ||
check_shape=True)) | ||
|
||
def __repr__(self): | ||
m_str = "metric=" + _get_metric_string(self.index.metric()) | ||
attr_str = [attr + "=" + str(getattr(self, attr)) | ||
|
@@ -235,6 +285,31 @@ cdef class IndexUint8(Index): | |
self.index = new c_cagra.index[uint8_t, uint32_t]( | ||
deref(handle_)) | ||
|
||
@auto_sync_handle | ||
def update_dataset(self, dataset, handle=None): | ||
""" Replace the dataset with a new dataset. | ||
|
||
Parameters | ||
---------- | ||
dataset : array interface compliant matrix shape (n_samples, dim) | ||
{handle_docstring} | ||
""" | ||
cdef device_resources* handle_ = \ | ||
<device_resources*><size_t>handle.getHandle() | ||
|
||
dataset_ai = wrap_array(dataset) | ||
dataset_dt = dataset_ai.dtype | ||
_check_input_array(dataset_ai, [np.dtype("ubyte")]) | ||
|
||
if dataset_ai.from_cai: | ||
self.index[0].update_dataset(deref(handle_), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make_const_mdspan |
||
get_dmv_uint8(dataset_ai, | ||
check_shape=True)) | ||
else: | ||
self.index[0].update_dataset(deref(handle_), | ||
get_hmv_uint8(dataset_ai, | ||
check_shape=True)) | ||
|
||
def __repr__(self): | ||
m_str = "metric=" + _get_metric_string(self.index.metric()) | ||
attr_str = [attr + "=" + str(getattr(self, attr)) | ||
|
@@ -693,7 +768,7 @@ def search(SearchParams search_params, | |
|
||
|
||
@auto_sync_handle | ||
def save(filename, Index index, handle=None): | ||
def save(filename, Index index, bool include_dataset=True, handle=None): | ||
""" | ||
Saves the index to a file. | ||
|
||
|
@@ -706,6 +781,8 @@ def save(filename, Index index, handle=None): | |
Name of the file. | ||
index : Index | ||
Trained CAGRA index. | ||
include_dataset : bool | ||
Whether or not to write out the dataset along with the index | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be useful to mention the implication here just to make it more obvious for the uninformed- like a warning that a dataset can get quite large so it's advisable to set this to false to shrink the size of the serialized index. |
||
{handle_docstring} | ||
|
||
Examples | ||
|
@@ -741,15 +818,17 @@ def save(filename, Index index, handle=None): | |
if index.active_index_type == "float32": | ||
idx_float = index | ||
c_cagra.serialize_file( | ||
deref(handle_), c_filename, deref(idx_float.index)) | ||
deref(handle_), c_filename, deref(idx_float.index), | ||
include_dataset) | ||
elif index.active_index_type == "byte": | ||
idx_int8 = index | ||
c_cagra.serialize_file( | ||
deref(handle_), c_filename, deref(idx_int8.index)) | ||
deref(handle_), c_filename, deref(idx_int8.index), include_dataset) | ||
elif index.active_index_type == "ubyte": | ||
idx_uint8 = index | ||
c_cagra.serialize_file( | ||
deref(handle_), c_filename, deref(idx_uint8.index)) | ||
deref(handle_), c_filename, deref(idx_uint8.index), | ||
include_dataset) | ||
else: | ||
raise ValueError( | ||
"Index dtype %s not supported" % index.active_index_type) | ||
|
@@ -785,12 +864,9 @@ def load(filename, handle=None): | |
cdef IndexInt8 idx_int8 | ||
cdef IndexUint8 idx_uint8 | ||
|
||
# we extract the dtype from the array interfaces in the file | ||
with open(filename, 'rb') as f: | ||
type_str = f.read(700).decode("utf-8", errors='ignore') | ||
|
||
# Read description of the 6th element to get the datatype | ||
dataset_dt = np.dtype(type_str.split('descr')[6][5:7]) | ||
with open(filename, "rb") as f: | ||
type_str = f.read(3).decode("utf8") | ||
dataset_dt = np.dtype(type_str) | ||
|
||
if dataset_dt == np.float32: | ||
idx_float = IndexFloat(handle) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably want to keep these const mdspans. If this is because of python, can we use make_const_mdspan() in that layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that automatically discarding
const
would be bad - but this is doing the opposite and is automatically adding it (like this is converting a non-const mdspan to a const msdpan), which I feel like is something that should be allowed with our API's.The issue I have is that Cython kinda sucks with respecting
const
identifiers, which is why all our Cython api's use non-const mdspans right now. Like if I try to add aget_const_hmv_float
(to parallel the non-constget_hmv_float
we have now) - I get an error message from Cython, where it doesn't recognizeconst float
as a type inside template parameters:I can get around this by adding a Cython typedef (like
ctypedef const float const_float
) - but that introduces the need for other hacks later on (like cython will treatconst_float
andconst float
as separate types - meaning that when we define theupdate_dataset
for Cython in c_cagra.pxd I can't just goconst T
as the type, and have to introduce a new template param =(. I've done this in the last commit - let me know what you think