diff --git a/.github/workflows/psql.yml b/.github/workflows/psql.yml index d3728767a..3f39ec428 100644 --- a/.github/workflows/psql.yml +++ b/.github/workflows/psql.yml @@ -80,6 +80,7 @@ jobs: wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - sudo apt-get update sudo apt-get install -y clang-16 + sudo update-alternatives --install /usr/bin/clang clang /usr/bin/clang-16 128 - name: Set up Pgrx run: | # pg_config diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 93faa3277..df7070ff8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -77,6 +77,7 @@ jobs: wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - sudo apt-get update sudo apt-get install -y clang-16 + sudo update-alternatives --install /usr/bin/clang clang /usr/bin/clang-16 128 - name: Set up Pgrx run: | # pg_config diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2e8eb7aee..b46f1bed0 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -84,6 +84,7 @@ jobs: wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - sudo apt-get update sudo apt-get install -y clang-16 + sudo update-alternatives --install /usr/bin/clang clang /usr/bin/clang-16 128 - name: Set up Pgrx run: | # pg_config @@ -149,6 +150,7 @@ jobs: wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - sudo apt-get update sudo apt-get install -y clang-16 + sudo update-alternatives --install /usr/bin/clang clang /usr/bin/clang-16 128 - name: Set up Pgrx run: | # pg_config diff --git a/Cargo.lock b/Cargo.lock index 991898ca2..bdf2c536d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1462,6 +1462,7 @@ dependencies = [ "common", "num-traits", "rand", + "smawk", "stoppable_rayon", ] @@ -2449,6 +2450,12 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smawk" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" + [[package]] name = "socket2" version = "0.4.10" diff --git a/crates/base/src/scalar/mod.rs b/crates/base/src/scalar/mod.rs index ce466c9d6..af881929a 100644 --- a/crates/base/src/scalar/mod.rs +++ b/crates/base/src/scalar/mod.rs @@ -2,6 +2,8 @@ mod f32; mod half_f16; mod i8; +use std::iter::Sum; + pub use f32::F32; pub use half_f16::F16; pub use i8::I8; @@ -19,7 +21,9 @@ pub trait ScalarLike: + num_traits::Zero + num_traits::NumOps + num_traits::NumAssignOps + + Default + crate::pod::Pod + + Sum { fn from_f32(x: f32) -> Self; fn to_f32(self) -> f32; diff --git a/crates/c/build.rs b/crates/c/build.rs index c3104f333..0e948e751 100644 --- a/crates/c/build.rs +++ b/crates/c/build.rs @@ -2,7 +2,7 @@ fn main() { println!("cargo:rerun-if-changed=src/f16.h"); println!("cargo:rerun-if-changed=src/f16.c"); cc::Build::new() - .compiler("clang-16") + .compiler("clang") .file("./src/f16.c") .opt_level(3) .flag("-fassociative-math") diff --git a/crates/c/src/f16.c b/crates/c/src/f16.c index c0b42d4a9..4c0571a1f 100644 --- a/crates/c/src/f16.c +++ b/crates/c/src/f16.c @@ -1,3 +1,7 @@ +#if !(__clang_major__ >= 16) +#error "clang version must be >= 16" +#endif + #include "f16.h" #include diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 581fc680e..4917bc514 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -9,4 +9,3 @@ pub mod remap; pub mod sample; pub mod variants; pub mod vec2; -pub mod vec3; diff --git a/crates/common/src/sample.rs b/crates/common/src/sample.rs index 8d6e7039f..5b6b0cd31 100644 --- a/crates/common/src/sample.rs +++ b/crates/common/src/sample.rs @@ -10,10 +10,10 @@ pub fn sample(vectors: &impl Vectors) -> Vec2> { let n = vectors.len(); let m = std::cmp::min(SAMPLES as u32, n); let f = super::rand::sample_u32(&mut rand::thread_rng(), n, m); - let mut samples = Vec2::new(vectors.dims(), m as usize); + let mut samples = Vec2::zeros((m as usize, vectors.dims() as usize)); for i in 0..m { let v = vectors.vector(f[i as usize] as u32).to_vec(); - samples[i as usize].copy_from_slice(&v); + samples[(i as usize,)].copy_from_slice(&v); } samples } @@ -27,12 +27,12 @@ pub fn sample_subvector_transform( let n = vectors.len(); let m = std::cmp::min(SAMPLES as u32, n); let f = super::rand::sample_u32(&mut rand::thread_rng(), n, m); - let mut samples = Vec2::new((e - s) as u32, m as usize); + let mut samples = Vec2::zeros((m as usize, e - s)); for i in 0..m { let v = transform(vectors.vector(f[i as usize] as u32)) .as_borrowed() .to_vec(); - samples[i as usize].copy_from_slice(&v[s..e]); + samples[(i as usize,)].copy_from_slice(&v[s..e]); } samples } diff --git a/crates/common/src/vec2.rs b/crates/common/src/vec2.rs index 3925f0eb5..60d245310 100644 --- a/crates/common/src/vec2.rs +++ b/crates/common/src/vec2.rs @@ -1,70 +1,75 @@ -use base::pod::Pod; use serde::{Deserialize, Serialize}; -use std::ops::{Deref, DerefMut, Index, IndexMut}; +use std::ops::{Index, IndexMut}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Vec2 { - dims: u32, - v: Vec, + shape: (usize, usize), + base: Vec, } -impl Vec2 { - pub fn new(dims: u32, n: usize) -> Self { +impl Vec2 { + pub fn zeros(shape: (usize, usize)) -> Self { Self { - dims, - v: base::pod::zeroed_vec(dims as usize * n), + shape, + base: vec![T::default(); shape.0 * shape.1], } } - pub fn dims(&self) -> u32 { - self.dims + pub fn from_vec(shape: (usize, usize), base: Vec) -> Self { + assert_eq!(shape.0 * shape.1, base.len()); + Self { shape, base } } - pub fn len(&self) -> usize { - self.v.len() / self.dims as usize +} + +impl Vec2 { + pub fn copy_within(&mut self, (l_i,): (usize,), (r_i,): (usize,)) { + assert!(l_i < self.shape.0); + assert!(r_i < self.shape.0); + let src_from = l_i * self.shape.1; + let src_to = src_from + self.shape.1; + let dest = r_i * self.shape.1; + self.base.copy_within(src_from..src_to, dest); + } +} + +impl Vec2 { + pub fn shape_0(&self) -> usize { + self.shape.0 } - pub fn is_empty(&self) -> bool { - self.len() == 0 + pub fn shape_1(&self) -> usize { + self.shape.1 } - pub fn argsort(&self) -> Vec { - let mut index: Vec = (0..self.len()).collect(); - index.sort_by_key(|i| &self[*i]); - index + pub fn as_slice(&self) -> &[T] { + self.base.as_slice() } - pub fn copy_within(&mut self, i: usize, j: usize) { - assert!(i < self.len() && j < self.len()); - unsafe { - if i != j { - let src = self.v.as_ptr().add(self.dims as usize * i); - let dst = self.v.as_mut_ptr().add(self.dims as usize * j); - std::ptr::copy_nonoverlapping(src, dst, self.dims as usize); - } - } + pub fn as_mut_slice(&mut self) -> &mut [T] { + self.base.as_mut_slice() } } -impl Index for Vec2 { +impl Index<(usize,)> for Vec2 { type Output = [T]; - fn index(&self, index: usize) -> &Self::Output { - &self.v[self.dims as usize * index..][..self.dims as usize] + fn index(&self, (i,): (usize,)) -> &Self::Output { + &self.base[i * self.shape.1..][..self.shape.1] } } -impl IndexMut for Vec2 { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - &mut self.v[self.dims as usize * index..][..self.dims as usize] +impl IndexMut<(usize,)> for Vec2 { + fn index_mut(&mut self, (i,): (usize,)) -> &mut Self::Output { + &mut self.base[i * self.shape.1..][..self.shape.1] } } -impl Deref for Vec2 { - type Target = [T]; +impl Index<(usize, usize)> for Vec2 { + type Output = T; - fn deref(&self) -> &Self::Target { - self.v.deref() + fn index(&self, (i, j): (usize, usize)) -> &Self::Output { + &self.base[i * self.shape.1..][j] } } -impl DerefMut for Vec2 { - fn deref_mut(&mut self) -> &mut Self::Target { - self.v.deref_mut() +impl IndexMut<(usize, usize)> for Vec2 { + fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output { + &mut self.base[i * self.shape.1..][j] } } diff --git a/crates/common/src/vec3.rs b/crates/common/src/vec3.rs deleted file mode 100644 index b4ff53041..000000000 --- a/crates/common/src/vec3.rs +++ /dev/null @@ -1,101 +0,0 @@ -use base::pod::Pod; -use serde::{Deserialize, Serialize}; -use std::ops::{Deref, DerefMut, Index, IndexMut}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Vec3 { - x: usize, - y: usize, - z: usize, - v: Vec, -} - -impl Vec3 { - pub fn new(x: usize, y: usize, z: usize) -> Self { - Self { - x, - y, - z, - v: base::pod::zeroed_vec(x * y * z), - } - } - pub fn x(&self) -> usize { - self.x - } - pub fn y(&self) -> usize { - self.y - } - pub fn z(&self) -> usize { - self.z - } -} - -impl Index<()> for Vec3 { - type Output = [T]; - - fn index(&self, (): ()) -> &Self::Output { - &self.v[..] - } -} - -impl IndexMut<()> for Vec3 { - fn index_mut(&mut self, (): ()) -> &mut Self::Output { - &mut self.v[..] - } -} - -impl Index<(usize,)> for Vec3 { - type Output = [T]; - - fn index(&self, (x,): (usize,)) -> &Self::Output { - &self.v[x * self.y * self.z..][..self.y * self.z] - } -} - -impl IndexMut<(usize,)> for Vec3 { - fn index_mut(&mut self, (x,): (usize,)) -> &mut Self::Output { - &mut self.v[x * self.y * self.z..][..self.y * self.z] - } -} - -impl Index<(usize, usize)> for Vec3 { - type Output = [T]; - - fn index(&self, (x, y): (usize, usize)) -> &Self::Output { - &self.v[x * self.y * self.z + y * self.z..][..self.z] - } -} - -impl IndexMut<(usize, usize)> for Vec3 { - fn index_mut(&mut self, (x, y): (usize, usize)) -> &mut Self::Output { - &mut self.v[x * self.y * self.z + y * self.z..][..self.z] - } -} - -impl Index<(usize, usize, usize)> for Vec3 { - type Output = T; - - fn index(&self, (x, y, z): (usize, usize, usize)) -> &Self::Output { - &self.v[x * self.y * self.z + y * self.z + z] - } -} - -impl IndexMut<(usize, usize, usize)> for Vec3 { - fn index_mut(&mut self, (x, y, z): (usize, usize, usize)) -> &mut Self::Output { - &mut self.v[x * self.y * self.z + y * self.z + z] - } -} - -impl Deref for Vec3 { - type Target = [T]; - - fn deref(&self) -> &Self::Target { - self.v.deref() - } -} - -impl DerefMut for Vec3 { - fn deref_mut(&mut self) -> &mut Self::Target { - self.v.deref_mut() - } -} diff --git a/crates/ivf/src/ivf_naive.rs b/crates/ivf/src/ivf_naive.rs index 9599a8f06..3793631ac 100644 --- a/crates/ivf/src/ivf_naive.rs +++ b/crates/ivf/src/ivf_naive.rs @@ -98,8 +98,8 @@ fn from_nothing( rayon::check(); let centroids = { let mut samples = samples; - for i in 0..samples.len() { - O::elkan_k_means_normalize(&mut samples[i]); + for i in 0..samples.shape_0() { + O::elkan_k_means_normalize(&mut samples[(i,)]); } k_means(nlist as usize, samples) }; diff --git a/crates/ivf/src/ivf_residual.rs b/crates/ivf/src/ivf_residual.rs index 1e9657c23..b878d1fcd 100644 --- a/crates/ivf/src/ivf_residual.rs +++ b/crates/ivf/src/ivf_residual.rs @@ -61,7 +61,7 @@ impl IvfResidual { ); let vectors = lists .iter() - .map(|&(_, i)| O::vector_sub(vector, &self.centroids[i])) + .map(|&(_, i)| O::vector_sub(vector, &self.centroids[(i,)])) .collect::>(); let mut reranker = self .quantization @@ -104,8 +104,8 @@ fn from_nothing( rayon::check(); let centroids = { let mut samples = samples; - for i in 0..samples.len() { - O::elkan_k_means_normalize(&mut samples[i]); + for i in 0..samples.shape_0() { + O::elkan_k_means_normalize(&mut samples[(i,)]); } k_means(nlist as usize, samples) }; @@ -141,7 +141,7 @@ fn from_nothing( O::elkan_k_means_normalize(&mut vector); k_means_lookup(&vector, ¢roids) }; - O::vector_sub(vector, ¢roids[target]) + O::vector_sub(vector, ¢roids[(target,)]) }, ); let payloads = MmapArray::create( diff --git a/crates/k_means/Cargo.toml b/crates/k_means/Cargo.toml index b3d2d14d5..b3fd862e3 100644 --- a/crates/k_means/Cargo.toml +++ b/crates/k_means/Cargo.toml @@ -9,6 +9,7 @@ rand.workspace = true base = { path = "../base" } common = { path = "../common" } +smawk = "0.3.2" stoppable_rayon = { path = "../stoppable_rayon" } [lints] diff --git a/crates/k_means/src/elkan.rs b/crates/k_means/src/elkan.rs index e9fa85889..161d3b0b3 100644 --- a/crates/k_means/src/elkan.rs +++ b/crates/k_means/src/elkan.rs @@ -6,7 +6,7 @@ use rand::{Rng, SeedableRng}; use std::ops::{Index, IndexMut}; pub struct ElkanKMeans { - dims: u32, + dims: usize, c: usize, centroids: Vec2, lowerbound: Square, @@ -21,23 +21,23 @@ const DELTA: f32 = 1.0 / 1024.0; impl ElkanKMeans { pub fn new(c: usize, samples: Vec2) -> Self { - let n = samples.len(); - let dims = samples.dims(); + let n = samples.shape_0(); + let dims = samples.shape_1(); let mut rand = StdRng::from_entropy(); - let mut centroids = Vec2::new(dims, c); + let mut centroids = Vec2::zeros((c, dims)); let mut lowerbound = Square::new(n, c); let mut upperbound = vec![F32::zero(); n]; let mut assign = vec![0usize; n]; - centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]); + centroids[(0,)].copy_from_slice(&samples[(rand.gen_range(0..n),)]); let mut weight = vec![F32::infinity(); n]; let mut dis = vec![F32::zero(); n]; for i in 0..c { let mut sum = F32::zero(); for j in 0..n { - dis[j] = S::euclid_distance(&samples[j], ¢roids[i]); + dis[j] = S::euclid_distance(&samples[(j,)], ¢roids[(i,)]); } for j in 0..n { lowerbound[(j, i)] = dis[j]; @@ -59,7 +59,7 @@ impl ElkanKMeans { } n - 1 }; - centroids[i + 1].copy_from_slice(&samples[index]); + centroids[(i + 1,)].copy_from_slice(&samples[(index,)]); } for i in 0..n { @@ -99,42 +99,13 @@ impl ElkanKMeans { let lowerbound = &mut self.lowerbound; let upperbound = &mut self.upperbound; let mut change = 0; - let n = samples.len(); - if n <= c { - let c = self.c; - let samples = &self.samples; - let rand = &mut self.rand; - let centroids = &mut self.centroids; - let n = samples.len(); - let dims = samples.dims(); - let sorted_index = samples.argsort(); - for i in 0..n { - let index = sorted_index.get(i).unwrap(); - let last = sorted_index.get(std::cmp::max(i, 1) - 1).unwrap(); - if *index == 0 || samples[*last] != samples[*index] { - centroids[i].copy_from_slice(&samples[*index]); - } else { - let rand_centroids: Vec<_> = (0..dims) - .map(|_| S::from_f32(rand.gen_range(0.0..1.0f32))) - .collect(); - centroids[i].copy_from_slice(rand_centroids.as_slice()); - } - } - for i in n..c { - let rand_centroids: Vec<_> = (0..dims) - .map(|_| S::from_f32(rand.gen_range(0.0..1.0f32))) - .collect(); - centroids[i].copy_from_slice(rand_centroids.as_slice()); - } - return true; - } - + let n = samples.shape_0(); // Step 1 let mut dist0 = Square::new(c, c); let mut sp = vec![F32::zero(); c]; for i in 0..c { for j in 0..c { - dist0[(i, j)] = S::euclid_distance(¢roids[i], ¢roids[j]) * 0.5; + dist0[(i, j)] = S::euclid_distance(¢roids[(i,)], ¢roids[(j,)]) * 0.5; } } for i in 0..c { @@ -153,7 +124,7 @@ impl ElkanKMeans { let mut dis = vec![F32::zero(); n]; for i in 0..n { if upperbound[i] > sp[assign[i]] { - dis[i] = S::euclid_distance(&samples[i], ¢roids[assign[i]]); + dis[i] = S::euclid_distance(&samples[(i,)], ¢roids[(assign[i],)]); } } for i in 0..n { @@ -176,7 +147,7 @@ impl ElkanKMeans { continue; } if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] { - let dis = S::euclid_distance(&samples[i], ¢roids[j]); + let dis = S::euclid_distance(&samples[(i,)], ¢roids[(j,)]); lowerbound[(i, j)] = dis; if dis < minimal { minimal = dis; @@ -189,12 +160,11 @@ impl ElkanKMeans { } // Step 4, 7 - let old = std::mem::replace(centroids, Vec2::new(dims, c)); + let old_centroids = std::mem::replace(centroids, Vec2::zeros((c, dims))); let mut count = vec![F32::zero(); c]; - centroids.fill(S::zero()); for i in 0..n { - for j in 0..dims as usize { - centroids[self.assign[i]][j] += samples[i][j]; + for j in 0..dims { + centroids[(self.assign[i], j)] += samples[(i, j)]; } count[self.assign[i]] += 1.0; } @@ -202,8 +172,8 @@ impl ElkanKMeans { if count[i] == F32::zero() { continue; } - for dim in 0..dims as usize { - centroids[i][dim] /= S::from_f32(count[i].into()); + for dim in 0..dims { + centroids[(i, dim)] /= S::from_f32(count[i].into()); } } for i in 0..c { @@ -219,27 +189,27 @@ impl ElkanKMeans { } o = (o + 1) % c; } - centroids.copy_within(o, i); - for dim in 0..dims as usize { + centroids.copy_within((o,), (i,)); + for dim in 0..dims { if dim % 2 == 0 { - centroids[i][dim] *= S::from_f32(1.0 + DELTA); - centroids[o][dim] *= S::from_f32(1.0 - DELTA); + centroids[(i, dim)] *= S::from_f32(1.0 + DELTA); + centroids[(o, dim)] *= S::from_f32(1.0 - DELTA); } else { - centroids[i][dim] *= S::from_f32(1.0 - DELTA); - centroids[o][dim] *= S::from_f32(1.0 + DELTA); + centroids[(i, dim)] *= S::from_f32(1.0 - DELTA); + centroids[(o, dim)] *= S::from_f32(1.0 + DELTA); } } count[i] = count[o] / 2.0; count[o] = count[o] - count[i]; } for i in 0..c { - spherical_normalize(&mut centroids[i]); + spherical_normalize(&mut centroids[(i,)]); } // Step 5, 6 let mut dist1 = vec![F32::zero(); c]; for i in 0..c { - dist1[i] = S::euclid_distance(&old[i], ¢roids[i]); + dist1[i] = S::euclid_distance(&old_centroids[(i,)], ¢roids[(i,)]); } for i in 0..n { for j in 0..c { diff --git a/crates/k_means/src/kmeans1d.rs b/crates/k_means/src/kmeans1d.rs new file mode 100644 index 000000000..df9e48d1f --- /dev/null +++ b/crates/k_means/src/kmeans1d.rs @@ -0,0 +1,100 @@ +use base::scalar::ScalarLike; +use common::vec2::Vec2; + +pub fn kmeans1d(c: usize, a: &[S]) -> Vec { + assert!(0 < c && c < a.len()); + let a = { + let mut x = a.to_vec(); + x.sort(); + x + }; + let n = a.len(); + // h(i, j), i <= j is cost of grouping [i, j] points into a cluster + let h = { + let mut sum_y = 0.0f64; + let mut sum_y2 = 0.0f64; + let mut prefix_y = vec![0.0f64]; + let mut prefix_y2 = vec![0.0f64]; + for y in a.iter().map(|y| y.to_f().to_f32() as f64) { + sum_y += y; + sum_y2 += y * y; + prefix_y.push(sum_y); + prefix_y2.push(sum_y2); + } + move |i, j| { + let sum_y = prefix_y[j + 1] - prefix_y[i]; + let sum_y2 = prefix_y2[j + 1] - prefix_y2[i]; + let mu = sum_y / (j + 1 - i) as f64; + let result = sum_y2 + (j + 1 - i) as f64 * mu * mu - 2.0 * mu * sum_y; + S::from_f32(result as f32) + } + }; + // f_i(j) is cost of grouping points with IDs [0, j] into clusters with IDs [0, i]. + // f_i(j) = min { f_{i - 1}(k) + h(k + 1, j) | 0 <= k < j } + let mut f = Vec2::<(S, usize)>::zeros((c, n)); + for j in 0..n { + f[(0, j)] = (h(0, j), usize::MAX); + } + for i in 1..c { + struct Question { + n: usize, + f: F, + } + impl S> smawk::Matrix for Question { + fn nrows(&self) -> usize { + self.n + } + fn ncols(&self) -> usize { + self.n + } + fn index(&self, i: usize, j: usize) -> S { + if i < j { + (self.f)(std::cmp::min(i, j), std::cmp::max(i, j)) + } else { + S::nan() + } + } + } + let minima = smawk::column_minima(&Question { + n, + f: |k, j| f[(i - 1, k)].0 + h(k + 1, j), + }); + f[(i, 0)] = (S::nan(), usize::MAX); + for j in 1..n { + let k = minima[j - 1]; + f[(i, j)] = (f[(i - 1, k)].0 + h(k + 1, j), k); + } + } + let mut centroids = vec![S::nan(); c]; + let mut i = c - 1; + let mut j = n - 1; + loop { + let k = f[(i, j)].1; + let l = if k == usize::MAX { 0 } else { k + 1 }; + centroids[i] = a[l..=j].iter().copied().sum::() / S::from_f32((j + 1 - l) as f32); + if k == usize::MAX { + break; + } + i -= 1; + j = k; + } + centroids +} + +#[cfg(test)] +mod test { + use super::*; + use base::scalar::F32; + + #[test] + fn sample_0() { + let clusters = kmeans1d( + 4, + &[ + -50.0, 4.0, 4.1, 4.2, 200.2, 200.4, 200.9, 80.0, 100.0, 102.0, + ] + .map(F32), + ); + assert_eq!(clusters, [-50.0, 4.1, 94.0, 200.5].map(F32)); + } +} diff --git a/crates/k_means/src/lib.rs b/crates/k_means/src/lib.rs index 1c277427e..c158b168a 100644 --- a/crates/k_means/src/lib.rs +++ b/crates/k_means/src/lib.rs @@ -1,9 +1,12 @@ #![allow(clippy::needless_range_loop)] pub mod elkan; +pub mod kmeans1d; +pub mod quick_centers; use base::scalar::*; use common::vec2::Vec2; +use kmeans1d::kmeans1d; use num_traits::Float; use stoppable_rayon as rayon; @@ -11,6 +14,15 @@ const ITERATIONS: usize = 400; pub fn k_means(c: usize, samples: Vec2) -> Vec2 { assert!(c > 0); + let n = samples.shape_0(); + let dims = samples.shape_1(); + assert!(dims > 0); + if n <= c { + return quick_centers::quick_centers(c, samples); + } + if dims == 1 { + return Vec2::from_vec((c, 1), kmeans1d(c, samples.as_slice())); + } let mut elkan_k_means = elkan::ElkanKMeans::::new(c, samples); for _ in 0..ITERATIONS { rayon::check(); @@ -22,20 +34,20 @@ pub fn k_means(c: usize, samples: Vec2) -> Vec2 { } pub fn k_means_lookup(vector: &[S], centroids: &Vec2) -> usize { - assert!(!centroids.is_empty()); + assert_ne!(centroids.shape_0(), 0); let mut result = (F32::infinity(), 0); - for i in 0..centroids.len() { - let dis = S::euclid_distance(vector, ¢roids[i]); + for i in 0..centroids.shape_0() { + let dis = S::euclid_distance(vector, ¢roids[(i,)]); result = std::cmp::min(result, (dis, i)); } result.1 } pub fn k_means_lookup_many(vector: &[S], centroids: &Vec2) -> Vec<(F32, usize)> { - assert!(!centroids.is_empty()); + assert_ne!(centroids.shape_0(), 0); let mut seq = Vec::new(); - for i in 0..centroids.len() { - let dis = S::euclid_distance(vector, ¢roids[i]); + for i in 0..centroids.shape_0() { + let dis = S::euclid_distance(vector, ¢roids[(i,)]); seq.push((dis, i)); } seq diff --git a/crates/k_means/src/quick_centers.rs b/crates/k_means/src/quick_centers.rs new file mode 100644 index 000000000..43cd2b912 --- /dev/null +++ b/crates/k_means/src/quick_centers.rs @@ -0,0 +1,25 @@ +use base::scalar::ScalarLike; +use common::vec2::Vec2; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +pub fn quick_centers(c: usize, samples: Vec2) -> Vec2 { + let n = samples.shape_0(); + let dims = samples.shape_1(); + assert!(c >= n); + let mut rand = StdRng::from_entropy(); + let mut centroids = Vec2::zeros((c, dims)); + centroids + .as_mut_slice() + .fill_with(|| S::from_f32(rand.gen_range(0.0..1.0f32))); + let index = { + let mut index = (0..n).collect::>(); + index.sort_by_key(|&i| &samples[(i,)]); + index.dedup_by_key(|&mut i| &samples[(i,)]); + index + }; + for i in index { + centroids[(i,)].copy_from_slice(&samples[(i,)]); + } + centroids +} diff --git a/crates/quantization/src/product/mod.rs b/crates/quantization/src/product/mod.rs index 74ddb8ed8..95fef86de 100644 --- a/crates/quantization/src/product/mod.rs +++ b/crates/quantization/src/product/mod.rs @@ -47,12 +47,12 @@ impl ProductQuantizer { k_means(256, subsamples) }) .collect::>(); - let mut centroids = Vec2::new(dims, 256); + let mut centroids = Vec2::zeros((256, dims as usize)); for i in 0..w { let subdims = std::cmp::min(ratio, dims - ratio * i); for j in 0u8..=255 { - centroids[j as usize][(i * ratio) as usize..][..subdims as usize] - .copy_from_slice(&originals[i as usize][j as usize]); + centroids[(j as usize,)][(i * ratio) as usize..][..subdims as usize] + .copy_from_slice(&originals[i as usize][(j as usize,)]); } } Self { @@ -88,7 +88,13 @@ impl ProductQuantizer { } pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> O::ProductQuantizationPreprocessed { - O::product_quantization_preprocess(self.dims, self.ratio, self.bits, &self.centroids, lhs) + O::product_quantization_preprocess( + self.dims, + self.ratio, + self.bits, + self.centroids.as_slice(), + lhs, + ) } pub fn process(&self, preprocessed: &O::ProductQuantizationPreprocessed, rhs: &[u8]) -> F32 { @@ -106,7 +112,7 @@ impl ProductQuantizer { self.dims, self.ratio, self.bits, - &self.centroids, + self.centroids.as_slice(), vector, ); if opts.flat_pq_rerank_size == 0 { @@ -138,7 +144,7 @@ impl ProductQuantizer { self.dims, self.ratio, self.bits, - &self.centroids, + self.centroids.as_slice(), vector, ); if opts.ivf_pq_rerank_size == 0 { @@ -173,7 +179,7 @@ impl ProductQuantizer { self.dims, self.ratio, self.bits, - &self.centroids, + self.centroids.as_slice(), vector.as_borrowed(), ) }) @@ -207,7 +213,7 @@ impl ProductQuantizer { self.dims, self.ratio, self.bits, - &self.centroids, + self.centroids.as_slice(), vector, ); Box::new(Window0Reranker::new(