Skip to content

Commit

Permalink
fix index out of bounds for large bvector
Browse files Browse the repository at this point in the history
Signed-off-by: Mingzhuo Yin <[email protected]>
  • Loading branch information
silver-ymz committed Feb 21, 2024
1 parent 16c8f2c commit 44f295c
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 19 deletions.
12 changes: 7 additions & 5 deletions crates/base/src/vector/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use super::Vector;
use crate::scalar::F32;
use serde::{Deserialize, Serialize};

pub const BVEC_WIDTH: usize = std::mem::size_of::<usize>() * 8;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BinaryVec {
pub dims: u16,
Expand Down Expand Up @@ -73,7 +75,7 @@ impl<'a> PartialOrd for BinaryVecRef<'a> {

impl BinaryVec {
pub fn new(dims: u16) -> Self {
let size = (dims as usize).div_ceil(std::mem::size_of::<usize>() * 8);
let size = (dims as usize).div_ceil(BVEC_WIDTH);
Self {
dims,
data: vec![0; size],
Expand All @@ -83,9 +85,9 @@ impl BinaryVec {
pub fn set(&mut self, index: usize, value: bool) {
assert!(index < self.dims as usize);
if value {
self.data[index / 32] |= 1 << (index % 32);
self.data[index / BVEC_WIDTH] |= 1 << (index % BVEC_WIDTH);
} else {
self.data[index / 32] &= !(1 << (index % 32));
self.data[index / BVEC_WIDTH] &= !(1 << (index % BVEC_WIDTH));
}
}
}
Expand All @@ -95,7 +97,7 @@ impl<'a> BinaryVecRef<'a> {
let mut index = 0;
std::iter::from_fn(move || {
if index < self.dims as usize {
let result = self.data[index / 32] & (1 << (index % 32)) != 0;
let result = self.data[index / BVEC_WIDTH] & (1 << (index % BVEC_WIDTH)) != 0;
index += 1;
Some(result)
} else {
Expand All @@ -106,6 +108,6 @@ impl<'a> BinaryVecRef<'a> {

pub fn get(&self, index: usize) -> bool {
assert!(index < self.dims as usize);
self.data[index / 32] & (1 << (index % 32)) != 0
self.data[index / BVEC_WIDTH] & (1 << (index % BVEC_WIDTH)) != 0
}
}
2 changes: 1 addition & 1 deletion crates/base/src/vector/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod binary;
mod sparse_f32;

pub use binary::{BinaryVec, BinaryVecRef};
pub use binary::{BinaryVec, BinaryVecRef, BVEC_WIDTH};
pub use sparse_f32::{SparseF32, SparseF32Ref};

pub trait Vector {
Expand Down
2 changes: 1 addition & 1 deletion crates/service/src/prelude/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pub use base::error::*;
pub use base::scalar::{F16, F32};
pub use base::search::{Element, Filter, Payload};
pub use base::sys::{Handle, Pointer};
pub use base::vector::{BinaryVec, BinaryVecRef, SparseF32, SparseF32Ref, Vector};
pub use base::vector::{BinaryVec, BinaryVecRef, SparseF32, SparseF32Ref, Vector, BVEC_WIDTH};

pub use num_traits::{Float, Zero};
2 changes: 1 addition & 1 deletion crates/service/src/prelude/storage/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Storage for BinaryMmap {
}

fn vector(&self, i: u32) -> BinaryVecRef<'_> {
let size = (self.dims as usize).div_ceil(std::mem::size_of::<usize>() * 8);
let size = (self.dims as usize).div_ceil(BVEC_WIDTH);
let s = i as usize * size;
let e = (i + 1) as usize * size;
BinaryVecRef {
Expand Down
19 changes: 8 additions & 11 deletions src/datatype/bvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ impl BVector {
fn layout(len: usize) -> Layout {
u16::try_from(len).expect("Vector is too large.");
let layout_alpha = Layout::new::<BVector>();
let layout_beta =
Layout::array::<usize>(len.div_ceil(std::mem::size_of::<usize>() * 8)).unwrap();
let layout_beta = Layout::array::<usize>(len.div_ceil(BVEC_WIDTH)).unwrap();
let layout = layout_alpha.extend(layout_beta).unwrap().0;
layout.pad_to_align()
}
Expand All @@ -53,7 +52,7 @@ impl BVector {
std::ptr::copy_nonoverlapping(
vector.data.as_ptr(),
(*ptr).phantom.as_mut_ptr(),
dims.div_ceil(std::mem::size_of::<usize>() * 8),
dims.div_ceil(BVEC_WIDTH),
);
BVectorOutput(NonNull::new(ptr).unwrap())
}
Expand Down Expand Up @@ -82,7 +81,7 @@ impl BVector {
data: unsafe {
std::slice::from_raw_parts(
self.phantom.as_ptr(),
self.dims.div_ceil(std::mem::size_of::<usize>() as u16 * 8) as usize,
(self.dims as usize).div_ceil(BVEC_WIDTH),
)
},
}
Expand Down Expand Up @@ -418,9 +417,9 @@ fn _vectors_bvector_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum
}
let dims = end - start;
let mut values = BinaryVec::new(dims);
if start % (std::mem::size_of::<usize>() as u16 * 8) == 0 {
let start_idx = start as usize / (std::mem::size_of::<usize>() * 8);
let end_idx = (end as usize).div_ceil(std::mem::size_of::<usize>() * 8);
if start % BVEC_WIDTH as u16 == 0 {
let start_idx = start as usize / BVEC_WIDTH;
let end_idx = (end as usize).div_ceil(BVEC_WIDTH);
values
.data
.copy_from_slice(&input.data().data[start_idx..end_idx]);
Expand Down Expand Up @@ -469,8 +468,7 @@ fn _vectors_bvector_send(vector: BVectorInput<'_>) -> Datum {
unsafe {
let mut buf = StringInfoData::default();
let len = vector.dims;
let bytes = (len as usize).div_ceil(std::mem::size_of::<usize>() * 8)
* std::mem::size_of::<usize>();
let bytes = (len as usize).div_ceil(BVEC_WIDTH) * std::mem::size_of::<usize>();
pgrx::pg_sys::pq_begintypsend(&mut buf);
pgrx::pg_sys::pq_sendbytes(&mut buf, (&len) as *const u16 as _, 2);
pgrx::pg_sys::pq_sendbytes(&mut buf, vector.phantom.as_ptr() as _, bytes as _);
Expand All @@ -489,8 +487,7 @@ fn _vectors_bvector_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> B
if len == 0 {
pgrx::error!("data corruption is detected");
}
let bytes = (len as usize).div_ceil(std::mem::size_of::<usize>() * 8)
* std::mem::size_of::<usize>();
let bytes = (len as usize).div_ceil(BVEC_WIDTH) * std::mem::size_of::<usize>();
let ptr = pgrx::pg_sys::pq_getmsgbytes(buf, bytes as _);
let mut output = BVector::new_zeroed_in_postgres(len as usize);
std::ptr::copy(ptr, output.phantom.as_mut_ptr() as _, bytes);
Expand Down
3 changes: 3 additions & 0 deletions tests/sqllogictest/bvector.slt
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0,1,0]'::bvector limit

statement ok
DROP TABLE t;

statement ok
SELECT array_agg(1)::real[]::vector::bvector FROM generate_series(1, 100);

0 comments on commit 44f295c

Please sign in to comment.