Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve RaBitQ performance #576

Merged
merged 15 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ memmap2 = "0.9.4"
parking_lot = "0.12.1"
paste = "1.0.14"
rand = "0.8.5"
rand_distr = "0.4.3"
rustix = { version = "0.38.31", features = ["fs", "mm", "net"] }
serde = "1"
serde_json = "1"
Expand Down
34 changes: 25 additions & 9 deletions crates/base/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,43 @@ pub enum DistanceKind {
pub struct Distance(i32);

impl Distance {
pub const ZERO: Self = Self(0);
pub const INFINITY: Self = Self(2139095040);
pub const NEG_INFINITY: Self = Self(-2139095041);
pub const ZERO: Self = Distance::from_f32(0.0f32);
pub const INFINITY: Self = Distance::from_f32(f32::INFINITY);
pub const NEG_INFINITY: Self = Distance::from_f32(f32::NEG_INFINITY);

pub fn to_f32(self) -> f32 {
self.into()
#[inline(always)]
pub const fn from_f32(value: f32) -> Self {
let bits = value.to_bits() as i32;
let mask = ((bits >> 31) as u32) >> 1;
let res = bits ^ (mask as i32);
Self(res)
}

#[inline(always)]
pub const fn to_f32(self) -> f32 {
let bits = self.0;
let mask = ((bits >> 31) as u32) >> 1;
let res = bits ^ (mask as i32);
f32::from_bits(res as u32)
}

#[inline(always)]
pub const fn to_i32(self) -> i32 {
self.0
}
}

impl From<f32> for Distance {
#[inline(always)]
fn from(value: f32) -> Self {
let bits = value.to_bits() as i32;
Self(bits ^ (((bits >> 31) as u32) >> 1) as i32)
Distance::from_f32(value)
}
}

impl From<Distance> for f32 {
#[inline(always)]
fn from(Distance(bits): Distance) -> Self {
f32::from_bits((bits ^ (((bits >> 31) as u32) >> 1) as i32) as u32)
fn from(value: Distance) -> Self {
Distance::to_f32(value)
}
}

Expand Down
18 changes: 16 additions & 2 deletions crates/base/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ impl IndexOptions {
IndexingOptions::Rabitq(_) => {
if !matches!(self.vector.d, DistanceKind::L2) {
return Err(ValidationError::new(
"inverted_index is not support for distance that is not l2",
"rabitq is not support for distance that is not l2",
));
}
if !matches!(self.vector.v, VectorKind::Vecf32) {
return Err(ValidationError::new(
"inverted_index is not support for vectors that are not vector",
"rabitq is not support for vectors that are not vector",
));
}
}
Expand Down Expand Up @@ -446,18 +446,24 @@ pub struct RabitqIndexingOptions {
#[serde(default = "RabitqIndexingOptions::default_nlist")]
#[validate(range(min = 1, max = 1_000_000))]
pub nlist: u32,
#[serde(default = "IvfIndexingOptions::default_spherical_centroids")]
pub spherical_centroids: bool,
}

impl RabitqIndexingOptions {
fn default_nlist() -> u32 {
1000
}
fn default_spherical_centroids() -> bool {
false
}
}

impl Default for RabitqIndexingOptions {
fn default() -> Self {
Self {
nlist: Self::default_nlist(),
spherical_centroids: Self::default_spherical_centroids(),
}
}
}
Expand Down Expand Up @@ -561,6 +567,7 @@ impl Default for ProductQuantizationOptions {
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate, Alter)]
#[serde(deny_unknown_fields)]
pub struct SearchOptions {
#[serde(default = "SearchOptions::default_flat_sq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
Expand Down Expand Up @@ -591,6 +598,9 @@ pub struct SearchOptions {
#[serde(default = "SearchOptions::default_rabitq_nprobe")]
#[validate(range(min = 1, max = 65535))]
pub rabitq_nprobe: u32,
#[serde(default = "SearchOptions::default_rabitq_epsilon")]
#[validate(range(min = 1.0, max = 4.0))]
pub rabitq_epsilon: f32,
#[serde(default = "SearchOptions::default_rabitq_fast_scan")]
pub rabitq_fast_scan: bool,
#[serde(default = "SearchOptions::default_diskann_ef_search")]
Expand Down Expand Up @@ -632,6 +642,9 @@ impl SearchOptions {
pub const fn default_rabitq_nprobe() -> u32 {
10
}
pub const fn default_rabitq_epsilon() -> f32 {
1.9
}
pub const fn default_rabitq_fast_scan() -> bool {
true
}
Expand All @@ -654,6 +667,7 @@ impl Default for SearchOptions {
ivf_nprobe: Self::default_ivf_nprobe(),
hnsw_ef_search: Self::default_hnsw_ef_search(),
rabitq_nprobe: Self::default_rabitq_nprobe(),
rabitq_epsilon: Self::default_rabitq_epsilon(),
rabitq_fast_scan: Self::default_rabitq_fast_scan(),
diskann_ef_search: Self::default_diskann_ef_search(),
}
Expand Down
2 changes: 2 additions & 0 deletions crates/base/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![feature(const_float_bits_conv)]
#![feature(avx512_target_feature)]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))]
Expand All @@ -13,6 +14,7 @@ pub mod distance;
pub mod index;
pub mod operator;
pub mod pod;
pub mod rand;
pub mod scalar;
pub mod search;
pub mod vector;
Expand Down
9 changes: 9 additions & 0 deletions crates/common/src/rand.rs → crates/base/src/rand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,12 @@ where
_ => unreachable!(),
}
}

pub fn sample_u32_sorted<R>(rng: &mut R, length: u32, amount: u32) -> Vec<u32>
where
R: Rng + ?Sized,
{
let mut x = sample_u32(rng, length, amount);
x.sort();
x
}
Loading