Skip to content

Commit

Permalink
get_clone impl on writer
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Nov 11, 2024
1 parent 199268d commit 552bc99
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 83 deletions.
9 changes: 9 additions & 0 deletions rust/blockstore/src/arrow/block/delta/data_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ impl DataRecordStorage {
inner.size_tracker.get_key_size()
}

pub fn get_owned_value(&self, prefix: &str, key: KeyWrapper) -> Option<DataRecordStorageEntry> {
let inner = self.inner.read();
let composite_key = CompositeKey {
prefix: prefix.to_string(),
key,
};
inner.storage.get(&composite_key).cloned()
}

pub fn add(&self, prefix: &str, key: KeyWrapper, value: &DataRecord<'_>) {
let mut inner = self.inner.write();
let composite_key = CompositeKey {
Expand Down
11 changes: 11 additions & 0 deletions rust/blockstore/src/arrow/block/delta/single_column_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,15 @@ impl<V: ArrowWriteableValue<SizeTracker = SingleColumnSizeTracker>> SingleColumn

(schema.into(), vec![prefix_arr, key_arr, value_arr])
}

pub fn get_owned_value(&self, prefix: &str, key: KeyWrapper) -> Option<RoaringBitmap> {
let inner = self.inner.read();
inner
.storage
.get(&CompositeKey {
prefix: prefix.to_string(),
key,
})
.cloned()
}
}
13 changes: 13 additions & 0 deletions rust/blockstore/src/arrow/block/delta/spann_posting_list_delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ impl SpannPostingListDelta {
self.inner.read().size_tracker.get_key_size()
}

pub fn get_owned_value(
&self,
prefix: &str,
key: KeyWrapper,
) -> Option<SpannPostingListDeltaEntry> {
let read_guard = self.inner.read();
let composite_key = CompositeKey {
prefix: prefix.to_string(),
key,
};
read_guard.storage.get(&composite_key).cloned()
}

pub fn add(&self, prefix: &str, key: KeyWrapper, value: &SpannPostingList<'_>) {
let mut lock_guard = self.inner.write();
let composite_key = CompositeKey {
Expand Down
31 changes: 12 additions & 19 deletions rust/blockstore/src/arrow/block/value/data_record_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl ArrowWriteableValue for &DataRecord<'_> {
type ArrowBuilder = ValueBuilderWrapper;
type SizeTracker = DataRecordSizeTracker;
type PreparedValue = (String, Vec<f32>, Option<Vec<u8>>, Option<String>);
type OwnedReadableValue = DataRecordStorageEntry;

fn offset_size(item_count: usize) -> usize {
let id_offset = bit_util::round_upto_multiple_of_64((item_count + 1) * 4);
Expand Down Expand Up @@ -182,11 +183,20 @@ impl ArrowWriteableValue for &DataRecord<'_> {

(struct_field, value_arr)
}

fn get_owned_value_from_delta(
prefix: &str,
key: KeyWrapper,
delta: &BlockDelta,
) -> Option<Self::OwnedReadableValue> {
match &delta.builder {
BlockStorage::DataRecord(builder) => builder.get_owned_value(prefix, key),
_ => panic!("Invalid builder type"),
}
}
}

impl<'referred_data> ArrowReadableValue<'referred_data> for DataRecord<'referred_data> {
type OwnedReadableValue = DataRecordStorageEntry;

fn get(array: &'referred_data Arc<dyn Array>, index: usize) -> Self {
let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();

Expand Down Expand Up @@ -257,21 +267,4 @@ impl<'referred_data> ArrowReadableValue<'referred_data> for DataRecord<'referred
) {
<&DataRecord>::add(prefix, key.into(), &value, storage);
}

fn to_owned(self) -> Self::OwnedReadableValue {
let metadata = match &self.metadata {
Some(metadata) => {
let metadata_proto = Into::<UpdateMetadata>::into(metadata.clone());
let metadata_as_bytes = metadata_proto.encode_to_vec();
Some(metadata_as_bytes)
}
None => None,
};
(
self.id.to_string(),
self.embedding.to_vec(),
metadata,
self.document.map(|s| s.to_string()),
)
}
}
18 changes: 12 additions & 6 deletions rust/blockstore/src/arrow/block/value/roaring_bitmap_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ impl ArrowWriteableValue for RoaringBitmap {
type ArrowBuilder = BinaryBuilder;
type SizeTracker = SingleColumnSizeTracker;
type PreparedValue = Vec<u8>;
type OwnedReadableValue = RoaringBitmap;

fn offset_size(item_count: usize) -> usize {
bit_util::round_upto_multiple_of_64((item_count + 1) * 4)
Expand Down Expand Up @@ -78,11 +79,20 @@ impl ArrowWriteableValue for RoaringBitmap {
let value_arr = (&value_arr as &dyn Array).slice(0, value_arr.len());
(value_field, value_arr)
}

fn get_owned_value_from_delta(
prefix: &str,
key: KeyWrapper,
delta: &BlockDelta,
) -> Option<Self::OwnedReadableValue> {
match &delta.builder {
BlockStorage::RoaringBitmap(builder) => builder.get_owned_value(prefix, key),
_ => panic!("Invalid builder type"),
}
}
}

impl ArrowReadableValue<'_> for RoaringBitmap {
type OwnedReadableValue = RoaringBitmap;

fn get(array: &std::sync::Arc<dyn Array>, index: usize) -> Self {
let arr = array.as_any().downcast_ref::<BinaryArray>().unwrap();
let bytes = arr.value(index);
Expand All @@ -98,8 +108,4 @@ impl ArrowReadableValue<'_> for RoaringBitmap {
) {
RoaringBitmap::add(prefix, key.into(), value, storage);
}

fn to_owned(self) -> Self::OwnedReadableValue {
self.clone()
}
}
22 changes: 12 additions & 10 deletions rust/blockstore/src/arrow/block/value/spann_posting_list_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl ArrowWriteableValue for &SpannPostingList<'_> {
type PreparedValue = SpannPostingListDeltaEntry;
type SizeTracker = SpannPostingListSizeTracker;
type ArrowBuilder = SpannPostingListBuilderWrapper;
type OwnedReadableValue = SpannPostingListDeltaEntry;

// This method is only called for SingleColumnStorage.
fn offset_size(_: usize) -> usize {
Expand Down Expand Up @@ -180,11 +181,20 @@ impl ArrowWriteableValue for &SpannPostingList<'_> {

(value_field, value_arr)
}

fn get_owned_value_from_delta(
prefix: &str,
key: KeyWrapper,
delta: &BlockDelta,
) -> Option<Self::OwnedReadableValue> {
match &delta.builder {
BlockStorage::SpannPostingListDelta(builder) => builder.get_owned_value(prefix, key),
_ => panic!("Invalid builder type"),
}
}
}

impl<'referred_data> ArrowReadableValue<'referred_data> for SpannPostingList<'referred_data> {
type OwnedReadableValue = SpannPostingListDeltaEntry;

fn get(array: &'referred_data Arc<dyn Array>, index: usize) -> Self {
let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();

Expand Down Expand Up @@ -255,12 +265,4 @@ impl<'referred_data> ArrowReadableValue<'referred_data> for SpannPostingList<'re
) {
<&SpannPostingList>::add(prefix, key.into(), &value, storage);
}

fn to_owned(self) -> Self::OwnedReadableValue {
(
self.doc_offset_ids.to_vec(),
self.doc_versions.to_vec(),
self.doc_embeddings.to_vec(),
)
}
}
18 changes: 12 additions & 6 deletions rust/blockstore/src/arrow/block/value/str_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl ArrowWriteableValue for String {
type ArrowBuilder = StringBuilder;
type SizeTracker = SingleColumnSizeTracker;
type PreparedValue = String;
type OwnedReadableValue = String;

fn offset_size(item_count: usize) -> usize {
bit_util::round_upto_multiple_of_64((item_count + 1) * 4)
Expand Down Expand Up @@ -69,11 +70,20 @@ impl ArrowWriteableValue for String {
let value_arr = (&value_arr as &dyn Array).slice(0, value_arr.len());
(value_field, value_arr)
}

fn get_owned_value_from_delta(
prefix: &str,
key: KeyWrapper,
delta: &BlockDelta,
) -> Option<Self::OwnedReadableValue> {
match &delta.builder {
BlockStorage::String(builder) => builder.get_owned_value(prefix, key),
_ => panic!("Invalid builder type"),
}
}
}

impl<'referred_data> ArrowReadableValue<'referred_data> for &'referred_data str {
type OwnedReadableValue = String;

fn get(array: &'referred_data Arc<dyn Array>, index: usize) -> &'referred_data str {
let array = array.as_any().downcast_ref::<StringArray>().unwrap();
array.value(index)
Expand All @@ -86,8 +96,4 @@ impl<'referred_data> ArrowReadableValue<'referred_data> for &'referred_data str
) {
String::add(prefix, key.into(), value.to_string(), storage);
}

fn to_owned(self) -> Self::OwnedReadableValue {
self.to_string()
}
}
16 changes: 12 additions & 4 deletions rust/blockstore/src/arrow/block/value/u32_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl ArrowWriteableValue for u32 {
type ArrowBuilder = UInt32Builder;
type SizeTracker = SingleColumnSizeTracker;
type PreparedValue = u32;
type OwnedReadableValue = u32;

fn offset_size(_item_count: usize) -> usize {
0
Expand Down Expand Up @@ -68,6 +69,17 @@ impl ArrowWriteableValue for u32 {
let value_arr = (&value_arr as &dyn Array).slice(0, value_arr.len());
(value_field, value_arr)
}

fn get_owned_value_from_delta(
prefix: &str,
key: KeyWrapper,
delta: &BlockDelta,
) -> Option<Self::OwnedReadableValue> {
match &delta.builder {
BlockStorage::UInt32(builder) => builder.get_owned_value(prefix, key),
_ => panic!("Invalid builder type: {:?}", &delta.builder),
}
}
}

impl<'a> ArrowReadableValue<'a> for u32 {
Expand All @@ -85,8 +97,4 @@ impl<'a> ArrowReadableValue<'a> for u32 {
) {
u32::add(prefix, key.into(), value, storage);
}

fn to_owned(self) -> Self::OwnedReadableValue {
self
}
}
18 changes: 12 additions & 6 deletions rust/blockstore/src/arrow/block/value/uint32array_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl ArrowWriteableValue for Vec<u32> {
type ArrowBuilder = ListBuilder<UInt32Builder>;
type SizeTracker = SingleColumnSizeTracker;
type PreparedValue = Vec<u32>;
type OwnedReadableValue = Vec<u32>;

fn offset_size(item_count: usize) -> usize {
bit_util::round_upto_multiple_of_64((item_count + 1) * size_of::<u32>())
Expand Down Expand Up @@ -86,11 +87,20 @@ impl ArrowWriteableValue for Vec<u32> {

(value_field, Arc::new(value_arr))
}

fn get_owned_value_from_delta(
prefix: &str,
key: KeyWrapper,
delta: &BlockDelta,
) -> Option<Self::OwnedReadableValue> {
match &delta.builder {
BlockStorage::VecUInt32(builder) => builder.get_owned_value(prefix, key),
_ => panic!("Invalid builder type"),
}
}
}

impl<'referred_data> ArrowReadableValue<'referred_data> for &'referred_data [u32] {
type OwnedReadableValue = Vec<u32>;

fn get(array: &'referred_data Arc<dyn Array>, index: usize) -> Self {
let list_array = array.as_any().downcast_ref::<ListArray>().unwrap();
let start = list_array.value_offsets()[index] as usize;
Expand Down Expand Up @@ -135,8 +145,4 @@ impl<'referred_data> ArrowReadableValue<'referred_data> for &'referred_data [u32
) {
<Vec<u32>>::add(prefix, key.into(), value.to_vec(), storage);
}

fn to_owned(self) -> Self::OwnedReadableValue {
self.to_vec()
}
}
51 changes: 22 additions & 29 deletions rust/blockstore/src/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,11 @@ impl ArrowUnorderedBlockfileWriter {
Ok(())
}

fn block_lifetime_scope<'new, K: ArrowReadableKey<'new>, V: ArrowReadableValue<'new>>(
pub(crate) async fn get_clone<K: ArrowWriteableKey, V: ArrowWriteableValue>(
&self,
block: &'new Block,
prefix: &str,
key: K,
) -> V::OwnedReadableValue {
let value = block.get::<K, V>(prefix, key).unwrap();
value.to_owned()
}

pub(crate) async fn get_clone<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>(
&self,
prefix: &str,
key: K,
) -> Result<V::OwnedReadableValue, Box<dyn ChromaError>> {
) -> Result<Option<V::OwnedReadableValue>, Box<dyn ChromaError>> {
// TODO: for now the BF writer locks the entire write operation
let _guard = self.write_mutex.lock().await;

Expand All @@ -257,7 +247,7 @@ impl ArrowUnorderedBlockfileWriter {
deltas.get(&target_block_id).cloned()
};

let res: Result<V::OwnedReadableValue, Box<dyn ChromaError>> = match delta {
let delta = match delta {
None => {
let block = match self.block_manager.get(&target_block_id).await {
Ok(Some(block)) => block,
Expand All @@ -268,24 +258,27 @@ impl ArrowUnorderedBlockfileWriter {
return Err(Box::new(e));
}
};
// // Lie to the compiler that the block is static even though it is not and will be
// // dropped after this method call. The only way this block is being used is
// // to read a value from it and deep copy it so even if the block is dropped,
// // the value will still be valid.
// let block: &'static Block = unsafe { transmute::<&Block, &'static Block>(&block) };
// let value = match block.get::<K, V>(prefix, key) {
// Some(value) => value.to_owned(),
// None => {
// return Err(Box::new(ArrowBlockfileError::BlockNotFound));
// }
// };
// Ok(value)
let val = self.block_lifetime_scope::<K, V>(&block, prefix, key);
Ok(val)
let new_delta = match self.block_manager.fork::<K, V>(&block.id).await {
Ok(delta) => delta,
Err(e) => {
return Err(Box::new(e));
}
};
let new_id = new_delta.id;
// Blocks can be empty.
self.root
.sparse_index
.replace_block(target_block_id, new_delta.id);
{
let mut deltas = self.block_deltas.lock();
deltas.insert(new_id, new_delta.clone());
}
new_delta
}
Some(_) => Err(Box::new(ArrowBlockfileError::BlockNotFound)),
Some(delta) => delta,
};
res

Ok(V::get_owned_value_from_delta(prefix, key.into(), &delta))
}

pub(crate) async fn delete<K: ArrowWriteableKey, V: ArrowWriteableValue>(
Expand Down
Loading

0 comments on commit 552bc99

Please sign in to comment.