Skip to content

Commit

Permalink
optimize bloom filter (#620)
Browse files Browse the repository at this point in the history
Co-authored-by: zhangli20 <[email protected]>
  • Loading branch information
richox and zhangli20 authored Oct 22, 2024
1 parent 549a59b commit 009d904
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 102 deletions.
5 changes: 3 additions & 2 deletions native-engine/blaze-serde/proto/blaze.proto
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,9 @@ message RowNumExprNode {
}

message BloomFilterMightContainExprNode {
PhysicalExprNode bloom_filter_expr = 1;
PhysicalExprNode value_expr = 2;
string uuid = 1;
PhysicalExprNode bloom_filter_expr = 2;
PhysicalExprNode value_expr = 3;
}

message FilterExecNode {
Expand Down
1 change: 1 addition & 0 deletions native-engine/blaze-serde/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,7 @@ fn try_parse_physical_expr(
}
ExprType::RowNumExpr(_) => Arc::new(RowNumExpr::default()),
ExprType::BloomFilterMightContainExpr(e) => Arc::new(BloomFilterMightContainExpr::new(
e.uuid.clone(),
try_parse_physical_expr_box_required(&e.bloom_filter_expr, input_schema)?,
try_parse_physical_expr_box_required(&e.value_expr, input_schema)?,
)),
Expand Down
59 changes: 59 additions & 0 deletions native-engine/datafusion-ext-commons/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#![allow(internal_features)]
#![feature(core_intrinsics)]
#![feature(new_uninit)]
#![feature(slice_swap_unchecked)]
#![feature(vec_into_raw_parts)]
Expand All @@ -21,6 +23,7 @@ use blaze_jni_bridge::{
is_jni_bridge_inited,
};
use once_cell::sync::OnceCell;
use unchecked_index::UncheckedIndex;

pub mod array_size;
pub mod bytes_arena;
Expand Down Expand Up @@ -127,3 +130,59 @@ fn compute_batch_size_with_target_mem_size(
let est_sub_batch_size = target_mem_size / est_mem_size_per_row.max(16);
est_sub_batch_size.min(batch_size).max(batch_size_min)
}

#[macro_export]
macro_rules! unchecked {
($e:expr) => {{
// safety: bypass bounds checking, used in performance critical path
#[allow(unused_unsafe)]
unsafe {
unchecked_index::unchecked_index($e)
}
}};
}

#[macro_export]
macro_rules! assume {
($e:expr) => {{
// safety: use assume
#[allow(unused_unsafe)]
unsafe {
std::intrinsics::assume($e)
}
}};
}

#[macro_export]
macro_rules! prefetch_read_data {
($e:expr) => {{
// safety: use prefetch
let locality = 3;
#[allow(unused_unsafe)]
unsafe {
std::intrinsics::prefetch_read_data($e, locality)
}
}};
}
#[macro_export]
macro_rules! prefetch_write_data {
($e:expr) => {{
// safety: use prefetch
let locality = 3;
#[allow(unused_unsafe)]
unsafe {
std::intrinsics::prefetch_write_data($e, locality)
}
}};
}

pub trait UncheckedIndexIntoInner<T> {
fn into_inner(self) -> T;
}

impl<T: Sized> UncheckedIndexIntoInner<T> for UncheckedIndex<T> {
fn into_inner(self) -> T {
let no_drop = std::mem::ManuallyDrop::new(self);
unsafe { std::ptr::read(&**no_drop) }
}
}
13 changes: 13 additions & 0 deletions native-engine/datafusion-ext-commons/src/spark_bit_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use std::io::{Read, Write};
use byteorder::{ReadBytesExt, WriteBytesExt, BE};
use datafusion::common::Result;

use crate::assume;

// native implementation of org.apache.spark.util.sketch.BitArray
#[derive(Default, Clone)]
pub struct SparkBitArray {
Expand Down Expand Up @@ -61,12 +63,16 @@ impl SparkBitArray {
pub fn set(&mut self, index: usize) {
let data_offset = index >> 6;
let bit_offset = index & 0b00111111;

assume!(data_offset < self.data.len());
self.data[data_offset] |= 1 << bit_offset;
}

pub fn get(&self, index: usize) -> bool {
let data_offset = index >> 6;
let bit_offset = index & 0b00111111;

assume!(data_offset < self.data.len());
let datum = &self.data[data_offset];
(datum >> bit_offset) & 1 == 1
}
Expand Down Expand Up @@ -95,6 +101,13 @@ impl SparkBitArray {
pub fn bit_size(&self) -> usize {
self.data.len() * 64
}

pub fn true_count(&self) -> usize {
self.data
.iter()
.map(|&datum| datum.count_ones() as usize)
.sum()
}
}

#[cfg(test)]
Expand Down
79 changes: 63 additions & 16 deletions native-engine/datafusion-ext-commons/src/spark_bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ use std::{
io::Write,
};

use arrow::array::{BooleanArray, BooleanBufferBuilder};
use byteorder::{ReadBytesExt, WriteBytesExt, BE};
use datafusion::common::Result;

use crate::{
df_execution_err,
hash::mur::{spark_compatible_murmur3_hash, spark_compatible_murmur3_hash_long},
spark_bit_array::SparkBitArray,
unchecked,
};

#[derive(Default, Clone)]
Expand Down Expand Up @@ -89,10 +91,8 @@ impl SparkBloomFilter {
for i in 1..=self.num_hash_functions as i32 {
let mut combined_hash = h1 + i * h2;
// flip all the bits if it's negative (guaranteed positive number)
if combined_hash < 0 {
combined_hash = !combined_hash;
}
self.bits.set((combined_hash % bit_size) as usize);
combined_hash = combined_hash ^ -((combined_hash < 0) as i32);
self.bits.set((combined_hash & (bit_size - 1)) as usize);
}
}

Expand All @@ -106,10 +106,8 @@ impl SparkBloomFilter {
for i in 1..=self.num_hash_functions as i32 {
let mut combined_hash = h1 + i * h2;
// flip all the bits if it's negative (guaranteed positive number)
if combined_hash < 0 {
combined_hash = !combined_hash;
}
self.bits.set((combined_hash % bit_size) as usize);
combined_hash = combined_hash ^ -((combined_hash < 0) as i32);
self.bits.set((combined_hash & (bit_size - 1)) as usize);
}
}

Expand All @@ -121,10 +119,8 @@ impl SparkBloomFilter {
for i in 1..=self.num_hash_functions as i32 {
let mut combined_hash = h1 + i * h2;
// flip all the bits if it's negative (guaranteed positive number)
if combined_hash < 0 {
combined_hash = !combined_hash;
}
if !self.bits.get((combined_hash % bit_size) as usize) {
combined_hash = combined_hash ^ -((combined_hash < 0) as i32);
if !self.bits.get((combined_hash & (bit_size - 1)) as usize) {
return false;
}
}
Expand All @@ -140,21 +136,72 @@ impl SparkBloomFilter {
for i in 1..=self.num_hash_functions as i32 {
let mut combined_hash = h1 + i * h2;
// flip all the bits if it's negative (guaranteed positive number)
if combined_hash < 0 {
combined_hash = !combined_hash;
}
if !self.bits.get((combined_hash % bit_size) as usize) {
combined_hash = combined_hash ^ -((combined_hash < 0) as i32);
if !self.bits.get((combined_hash & (bit_size - 1)) as usize) {
return false;
}
}
true
}

#[inline]
pub fn might_contain_longs(&self, values: &[i64]) -> BooleanArray {
let mut buffer = BooleanBufferBuilder::new(0);
buffer.resize(values.len());

let h1s = values
.iter()
.map(|&v| spark_compatible_murmur3_hash_long(v, 0))
.collect::<Vec<_>>();
let h2s = values
.iter()
.zip(&h1s)
.map(|(&v, &h1)| spark_compatible_murmur3_hash_long(v, h1))
.collect::<Vec<_>>();

let bit_size = self.bits.bit_size() as i32;

'next_item: for (i, (h1, h2)) in std::iter::zip(h1s, h2s).enumerate() {
for i in 1..=self.num_hash_functions as i32 {
let mut combined_hash = h1 + i * h2;
// flip all the bits if it's negative (guaranteed positive number)
combined_hash = combined_hash ^ -((combined_hash < 0) as i32);
if !self.bits.get((combined_hash & (bit_size - 1)) as usize) {
continue 'next_item; // might not contain
}
}
unchecked!(buffer.as_slice_mut())[i / 8] |= 1 << (i % 8); // might contain
}
BooleanArray::from(buffer.finish())
}

pub fn put_all(&mut self, other: &Self) {
assert_eq!(self.num_hash_functions, other.num_hash_functions);
self.bits.put_all(&other.bits);
}

pub fn shrink_to_fit(&mut self) {
let num_bits = self.bits.bit_size();

// reduce num_bits if true count is too small
// so that we can reduce memory usage and improve performance of might_contain()
let num_trues = self.bits.true_count();
let shrinked_num_bits = (self.num_hash_functions * num_trues * 2)
.max(1)
.next_power_of_two();
if shrinked_num_bits >= num_bits {
return;
}

let mut new_bits = SparkBitArray::new_with_num_bits(shrinked_num_bits);
for i in 0..num_bits {
if self.bits.get(i) {
new_bits.set(i % shrinked_num_bits);
}
}
self.bits = new_bits;
}

fn optimal_num_of_hash_functions(n: usize, m: usize) -> usize {
let result = (m as f64 / n as f64 * 2.0_f64.ln()).round() as usize;
result.max(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

use std::{
any::Any,
collections::HashMap,
fmt::{Debug, Display, Formatter},
hash::Hasher,
sync::Arc,
sync::{Arc, Weak},
};

use arrow::{
Expand All @@ -32,26 +33,37 @@ use datafusion_ext_commons::{
cast::cast, df_execution_err, df_unimplemented_err, spark_bloom_filter::SparkBloomFilter,
};
use once_cell::sync::OnceCell;
use parking_lot::Mutex;

pub struct BloomFilterMightContainExpr {
uuid: String,
bloom_filter_expr: Arc<dyn PhysicalExpr>,
value_expr: Arc<dyn PhysicalExpr>,
bloom_filter: OnceCell<SparkBloomFilter>,
bloom_filter: OnceCell<Arc<SparkBloomFilter>>,
}

impl BloomFilterMightContainExpr {
pub fn new(
uuid: String,
bloom_filter_expr: Arc<dyn PhysicalExpr>,
value_expr: Arc<dyn PhysicalExpr>,
) -> Self {
Self {
uuid,
bloom_filter_expr,
value_expr,
bloom_filter: OnceCell::new(),
}
}
}

impl Drop for BloomFilterMightContainExpr {
fn drop(&mut self) {
drop(self.bloom_filter.take());
clear_cached_bloom_filter();
}
}

impl Display for BloomFilterMightContainExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
Expand Down Expand Up @@ -93,24 +105,24 @@ impl PhysicalExpr for BloomFilterMightContainExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
// init bloom filter
let bloom_filter = self.bloom_filter.get_or_try_init(|| {
match self.bloom_filter_expr.evaluate(batch)? {
ColumnarValue::Scalar(ScalarValue::Binary(Some(v))) => {
Ok(SparkBloomFilter::read_from(v.as_slice())?)
get_cached_bloom_filter(&self.uuid, || {
match self.bloom_filter_expr.evaluate(batch)? {
ColumnarValue::Scalar(ScalarValue::Binary(Some(v))) => {
Ok(SparkBloomFilter::read_from(v.as_slice())?)
}
_ => {
df_execution_err!("bloom_filter_arg must be valid binary scalar value")
}
}
_ => {
df_execution_err!("bloom_filter_arg must be valid binary scalar value")
}
}
})
})?;

// process with bloom filter
let values = self.value_expr.evaluate(&batch)?.into_array(1)?;
let might_contain = match values.data_type() {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
let values = cast(&values, &DataType::Int64)?;
BooleanArray::from_unary(values.as_primitive::<Int64Type>(), |v| {
bloom_filter.might_contain_long(v)
})
bloom_filter.might_contain_longs(values.as_primitive::<Int64Type>().values())
}
DataType::Utf8 => BooleanArray::from_unary(values.as_string::<i32>(), |v| {
bloom_filter.might_contain_binary(v.as_bytes())
Expand All @@ -132,6 +144,7 @@ impl PhysicalExpr for BloomFilterMightContainExpr {
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(Self::new(
self.uuid.clone(),
children[0].clone(),
children[1].clone(),
)))
Expand All @@ -143,3 +156,36 @@ impl PhysicalExpr for BloomFilterMightContainExpr {
self.value_expr.dyn_hash(state);
}
}

type Slot = Arc<Mutex<Weak<SparkBloomFilter>>>;
static CACHED_BLOOM_FILTER: OnceCell<Arc<Mutex<HashMap<String, Slot>>>> = OnceCell::new();

fn get_cached_bloom_filter(
uuid: &str,
init: impl FnOnce() -> Result<SparkBloomFilter>,
) -> Result<Arc<SparkBloomFilter>> {
// remove expire keys and insert new key
let slot = {
let cached_bloom_filter = CACHED_BLOOM_FILTER.get_or_init(|| Arc::default());
let mut cached_bloom_filter = cached_bloom_filter.lock();
cached_bloom_filter
.entry(uuid.to_string())
.or_default()
.clone()
};

let mut slot = slot.lock();
if let Some(cached) = slot.upgrade() {
Ok(cached)
} else {
let new = Arc::new(init()?);
*slot = Arc::downgrade(&new);
Ok(new)
}
}

fn clear_cached_bloom_filter() {
let cached_bloom_filter = CACHED_BLOOM_FILTER.get_or_init(|| Arc::default());
let mut cached_bloom_filter = cached_bloom_filter.lock();
cached_bloom_filter.retain(|_, v| Arc::strong_count(v) > 0);
}
Loading

0 comments on commit 009d904

Please sign in to comment.