Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support distance pushdown #510

Merged
merged 12 commits into from
Jul 9, 2024
2 changes: 1 addition & 1 deletion crates/base/src/vector/bvecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
pub const BVEC_WIDTH: usize = usize::BITS as usize;

// When using binary vector, please ensure that the padding bits are always zero.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BVecf32Owned {
dims: u16,
data: Vec<usize>,
Expand Down
2 changes: 1 addition & 1 deletion crates/base/src/vector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub trait VectorBorrowed: Copy {
fn normalize(&self) -> Self::Owned;
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum OwnedVector {
Vecf32(Vecf32Owned),
Vecf16(Vecf16Owned),
Expand Down
2 changes: 1 addition & 1 deletion crates/base/src/vector/svecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::vector::{VectorBorrowed, VectorKind, VectorOwned};
use num_traits::{Float, Zero};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SVecf32Owned {
dims: u32,
indexes: Vec<u32>,
Expand Down
2 changes: 1 addition & 1 deletion crates/base/src/vector/vecf16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::scalar::{ScalarLike, F16, F32};
use num_traits::{Float, Zero};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[repr(transparent)]
pub struct Vecf16Owned(Vec<F16>);

Expand Down
2 changes: 1 addition & 1 deletion crates/base/src/vector/vecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::scalar::F32;
use num_traits::{Float, Zero};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[repr(transparent)]
pub struct Vecf32Owned(Vec<F32>);

Expand Down
2 changes: 1 addition & 1 deletion crates/base/src/vector/veci8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::scalar::{F32, I8};
use num_traits::Float;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Veci8Owned {
dims: u32,
data: Vec<I8>,
Expand Down
5 changes: 3 additions & 2 deletions crates/base/src/worker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::index::*;
use crate::scalar::F32;
use crate::search::*;
use crate::vector::*;

Expand Down Expand Up @@ -32,15 +33,15 @@ pub trait ViewBasicOperations {
&'a self,
vector: &'a OwnedVector,
opts: &'a SearchOptions,
) -> Result<Box<dyn Iterator<Item = Pointer> + 'a>, BasicError>;
) -> Result<Box<dyn Iterator<Item = (F32, Pointer)> + 'a>, BasicError>;
}

pub trait ViewVbaseOperations {
fn vbase<'a>(
&'a self,
vector: &'a OwnedVector,
opts: &'a SearchOptions,
) -> Result<Box<dyn Iterator<Item = Pointer> + 'a>, VbaseError>;
) -> Result<Box<dyn Iterator<Item = (F32, Pointer)> + 'a>, VbaseError>;
}

pub trait ViewListOperations {
Expand Down
9 changes: 5 additions & 4 deletions crates/index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::utils::tournament_tree::LoserTree;
use arc_swap::ArcSwap;
use base::index::*;
use base::operator::*;
use base::scalar::F32;
use base::search::*;
use base::vector::*;
use common::clean::clean;
Expand Down Expand Up @@ -334,7 +335,7 @@ impl<O: Op> IndexView<O> {
&'a self,
vector: Borrowed<'_, O>,
opts: &'a SearchOptions,
) -> Result<impl Iterator<Item = Pointer> + 'a, BasicError> {
) -> Result<impl Iterator<Item = (F32, Pointer)> + 'a, BasicError> {
if self.options.vector.dims != vector.dims() {
return Err(BasicError::InvalidVector);
}
Expand Down Expand Up @@ -371,7 +372,7 @@ impl<O: Op> IndexView<O> {
let loser = LoserTree::new(heaps);
Ok(loser.filter_map(|x| {
if self.delete.check(x.payload).is_some() {
Some(x.payload.pointer())
Some((x.distance, x.payload.pointer()))
} else {
None
}
Expand All @@ -381,7 +382,7 @@ impl<O: Op> IndexView<O> {
&'a self,
vector: Borrowed<'a, O>,
opts: &'a SearchOptions,
) -> Result<impl Iterator<Item = Pointer> + 'a, VbaseError> {
) -> Result<impl Iterator<Item = (F32, Pointer)> + 'a, VbaseError> {
if self.options.vector.dims != vector.dims() {
return Err(VbaseError::InvalidVector);
}
Expand Down Expand Up @@ -414,7 +415,7 @@ impl<O: Op> IndexView<O> {
let loser = LoserTree::new(beta);
Ok(loser.filter_map(|x| {
if self.delete.check(x.payload).is_some() {
Some(x.payload.pointer())
Some((x.distance, x.payload.pointer()))
} else {
None
}
Expand Down
9 changes: 5 additions & 4 deletions crates/service/src/instance.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use base::distance::*;
use base::index::*;
use base::operator::*;
use base::scalar::F32;
use base::search::*;
use base::vector::*;
use base::worker::*;
Expand Down Expand Up @@ -317,11 +318,11 @@ impl ViewBasicOperations for InstanceView {
&'a self,
vector: &'a OwnedVector,
opts: &'a SearchOptions,
) -> Result<Box<dyn Iterator<Item = Pointer> + 'a>, BasicError> {
) -> Result<Box<dyn Iterator<Item = (F32, Pointer)> + 'a>, BasicError> {
match (self, vector) {
(InstanceView::Vecf32Cos(x), OwnedVector::Vecf32(vector)) => {
Ok(Box::new(x.basic(vector.for_borrow(), opts)?)
as Box<dyn Iterator<Item = Pointer>>)
as Box<dyn Iterator<Item = (F32, Pointer)>>)
}
(InstanceView::Vecf32Dot(x), OwnedVector::Vecf32(vector)) => {
Ok(Box::new(x.basic(vector.for_borrow(), opts)?))
Expand Down Expand Up @@ -378,11 +379,11 @@ impl ViewVbaseOperations for InstanceView {
&'a self,
vector: &'a OwnedVector,
opts: &'a SearchOptions,
) -> Result<Box<dyn Iterator<Item = Pointer> + 'a>, VbaseError> {
) -> Result<Box<dyn Iterator<Item = (F32, Pointer)> + 'a>, VbaseError> {
match (self, vector) {
(InstanceView::Vecf32Cos(x), OwnedVector::Vecf32(vector)) => {
Ok(Box::new(x.vbase(vector.for_borrow(), opts)?)
as Box<dyn Iterator<Item = Pointer>>)
as Box<dyn Iterator<Item = (F32, Pointer)>>)
}
(InstanceView::Vecf32Dot(x), OwnedVector::Vecf32(vector)) => {
Ok(Box::new(x.vbase(vector.for_borrow(), opts)?))
Expand Down
41 changes: 41 additions & 0 deletions src/datatype/memory_bvecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::FromDatum;
use pgrx::IntoDatum;
use pgrx::UnboxDatum;
use std::alloc::Layout;
use std::ops::Deref;
use std::ops::DerefMut;
Expand Down Expand Up @@ -152,6 +153,46 @@ impl IntoDatum for BVecf32Output {
}
}

impl FromDatum for BVecf32Output {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let p = NonNull::new(datum.cast_mut_ptr::<BVecf32Header>())?;
let q =
unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? };
if p != q {
Some(BVecf32Output(q))
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).for_borrow() };
Some(BVecf32Output::new(vector))
}
}
}
}

unsafe impl UnboxDatum for BVecf32Output {
type As<'src> = BVecf32Output;
#[inline]
unsafe fn unbox<'src>(d: pgrx::Datum<'src>) -> Self::As<'src>
where
Self: 'src,
{
let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::<BVecf32Header>()).unwrap();
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
BVecf32Output(q)
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).for_borrow() };
BVecf32Output::new(vector)
cutecutecat marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

unsafe impl SqlTranslatable for BVecf32Input<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("bvector")))
Expand Down
41 changes: 41 additions & 0 deletions src/datatype/memory_svecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::FromDatum;
use pgrx::IntoDatum;
use pgrx::UnboxDatum;
use std::alloc::Layout;
use std::ops::Deref;
use std::ptr::NonNull;
Expand Down Expand Up @@ -168,6 +169,46 @@ impl IntoDatum for SVecf32Output {
}
}

impl FromDatum for SVecf32Output {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let p = NonNull::new(datum.cast_mut_ptr::<SVecf32Header>())?;
let q =
unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? };
if p != q {
Some(SVecf32Output(q))
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).for_borrow() };
Some(SVecf32Output::new(vector))
}
}
}
}

unsafe impl UnboxDatum for SVecf32Output {
type As<'src> = SVecf32Output;
#[inline]
unsafe fn unbox<'src>(d: pgrx::Datum<'src>) -> Self::As<'src>
where
Self: 'src,
{
let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::<SVecf32Header>()).unwrap();
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
SVecf32Output(q)
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).for_borrow() };
SVecf32Output::new(vector)
}
}
}

unsafe impl SqlTranslatable for SVecf32Input<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("svector")))
Expand Down
41 changes: 41 additions & 0 deletions src/datatype/memory_vecf16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::FromDatum;
use pgrx::IntoDatum;
use pgrx::UnboxDatum;
use std::alloc::Layout;
use std::ops::Deref;
use std::ptr::NonNull;
Expand Down Expand Up @@ -145,6 +146,46 @@ impl IntoDatum for Vecf16Output {
}
}

impl FromDatum for Vecf16Output {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let p = NonNull::new(datum.cast_mut_ptr::<Vecf16Header>())?;
let q =
unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? };
if p != q {
Some(Vecf16Output(q))
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).for_borrow() };
Some(Vecf16Output::new(vector))
}
}
}
}

unsafe impl UnboxDatum for Vecf16Output {
type As<'src> = Vecf16Output;
#[inline]
unsafe fn unbox<'src>(d: pgrx::Datum<'src>) -> Self::As<'src>
where
Self: 'src,
{
let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::<Vecf16Header>()).unwrap();
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
Vecf16Output(q)
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).for_borrow() };
Vecf16Output::new(vector)
}
}
}

unsafe impl SqlTranslatable for Vecf16Input<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vecf16")))
Expand Down
41 changes: 41 additions & 0 deletions src/datatype/memory_vecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::FromDatum;
use pgrx::IntoDatum;
use pgrx::UnboxDatum;
use std::alloc::Layout;
use std::ops::Deref;
use std::ptr::NonNull;
Expand Down Expand Up @@ -145,6 +146,46 @@ impl IntoDatum for Vecf32Output {
}
}

impl FromDatum for Vecf32Output {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let p = NonNull::new(datum.cast_mut_ptr::<Vecf32Header>())?;
let q =
unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? };
if p != q {
Some(Vecf32Output(q))
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).for_borrow() };
Some(Vecf32Output::new(vector))
}
}
}
}

unsafe impl UnboxDatum for Vecf32Output {
type As<'src> = Vecf32Output;
#[inline]
unsafe fn unbox<'src>(d: pgrx::Datum<'src>) -> Self::As<'src>
where
Self: 'src,
{
let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::<Vecf32Header>()).unwrap();
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
Vecf32Output(q)
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).for_borrow() };
Vecf32Output::new(vector)
}
}
}

unsafe impl SqlTranslatable for Vecf32Input<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vector")))
Expand Down
Loading