Skip to content

Commit

Permalink
perf: add optimized BinaryViewArray comparison kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp committed Jan 19, 2024
1 parent df718d1 commit 5185091
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 108 deletions.
1 change: 1 addition & 0 deletions crates/polars-compute/src/comparisons/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl NotSimdPrimitive for u128 {}
impl NotSimdPrimitive for i128 {}

mod scalar;
mod view;

#[cfg(feature = "simd")]
mod simd;
Expand Down
109 changes: 1 addition & 108 deletions crates/polars-compute/src/comparisons/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use arrow::array::{
BinaryArray, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8Array, Utf8ViewArray,
};
use arrow::array::{BinaryArray, BooleanArray, PrimitiveArray, Utf8Array};
use arrow::bitmap::{self, Bitmap};
use arrow::types::NativeType;
use polars_utils::total_ord::{TotalEq, TotalOrd};
Expand Down Expand Up @@ -71,111 +69,6 @@ impl<T: NativeType + NotSimdPrimitive + TotalOrd> TotalOrdKernel for PrimitiveAr
}
}

impl TotalOrdKernel for BinaryViewArray {
type Scalar = [u8];

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());
// TODO! speed-up by first comparing views
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_eq(&r))
.collect()
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_ne(&r))
.collect()
}

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_lt(&r))
.collect()
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_le(&r))
.collect()
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_eq(&other)).collect()
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_ne(&other)).collect()
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_lt(&other)).collect()
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_le(&other)).collect()
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_gt(&other)).collect()
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_ge(&other)).collect()
}
}

impl TotalOrdKernel for Utf8ViewArray {
type Scalar = str;

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_eq_kernel(&other.to_binview())
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_ne_kernel(&other.to_binview())
}

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_lt_kernel(&other.to_binview())
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_le_kernel(&other.to_binview())
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_eq_kernel_broadcast(other.as_bytes())
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_ne_kernel_broadcast(other.as_bytes())
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_lt_kernel_broadcast(other.as_bytes())
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_le_kernel_broadcast(other.as_bytes())
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_gt_kernel_broadcast(other.as_bytes())
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_ge_kernel_broadcast(other.as_bytes())
}
}

impl TotalOrdKernel for BinaryArray<i64> {
type Scalar = [u8];

Expand Down
244 changes: 244 additions & 0 deletions crates/polars-compute/src/comparisons/view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use arrow::array::{BinaryViewArray, Utf8ViewArray};
use arrow::bitmap::Bitmap;

use crate::comparisons::TotalOrdKernel;

// If s fits in 12 bytes, returns the view encoding it would have in a
// BinaryViewArray.
fn small_view_encoding(s: &[u8]) -> Option<u128> {
if s.len() > 12 {
return None;
}

let mut tmp = [0u8; 16];
tmp[0] = s.len() as u8;
tmp[4..4 + s.len()].copy_from_slice(s);
Some(u128::from_le_bytes(tmp))
}

// Loads (up to) the first 4 bytes of s as little-endian, padded with zeros.
fn load_prefix(s: &[u8]) -> u32 {
let start = &s[..s.len().min(4)];
let mut tmp = [0u8; 4];
tmp[..start.len()].copy_from_slice(start);
u32::from_le_bytes(tmp)
}

fn broadcast_inequality(
arr: &BinaryViewArray,
scalar: &[u8],
cmp_prefix: impl Fn(u32, u32) -> bool,
cmp_str: impl Fn(&[u8], &[u8]) -> bool,
) -> Bitmap {
let views = arr.views().as_slice();
let prefix = load_prefix(scalar);
let be_prefix = prefix.to_be();
Bitmap::from_trusted_len_iter((0..arr.len()).map(|i| unsafe {
let v_prefix = (*views.get_unchecked(i) >> 32) as u32;
if v_prefix != prefix {
cmp_prefix(v_prefix.to_be(), be_prefix)
} else {
cmp_str(arr.value_unchecked(i), scalar)
}
}))
}

impl TotalOrdKernel for BinaryViewArray {
type Scalar = [u8];

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());

let slf_views = self.views().as_slice();
let other_views = other.views().as_slice();

Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let av = *slf_views.get_unchecked(i);
let bv = *other_views.get_unchecked(i);

// First 64 bits contain length and prefix.
let a_len_prefix = av as u64;
let b_len_prefix = bv as u64;
if a_len_prefix != b_len_prefix {
return false;
}

let alen = av as u32;
if alen <= 12 {
// String is fully inlined, compare top 64 bits. Bottom bits were
// tested equal before, which also ensures the lengths are equal.
(av >> 64) as u64 == (bv >> 64) as u64
} else {
self.value_unchecked(i) == other.value_unchecked(i)
}
}))
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());

let slf_views = self.views().as_slice();
let other_views = other.views().as_slice();

Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let av = *slf_views.get_unchecked(i);
let bv = *other_views.get_unchecked(i);

// First 64 bits contain length and prefix.
let a_len_prefix = av as u64;
let b_len_prefix = bv as u64;
if a_len_prefix != b_len_prefix {
return true;
}

let alen = av as u32;
if alen <= 12 {
// String is fully inlined, compare top 64 bits. Bottom bits were
// tested equal before, which also ensures the lengths are equal.
(av >> 64) as u64 != (bv >> 64) as u64
} else {
self.value_unchecked(i) != other.value_unchecked(i)
}
}))
}

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());

let slf_views = self.views().as_slice();
let other_views = other.views().as_slice();

Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let av = *slf_views.get_unchecked(i);
let bv = *other_views.get_unchecked(i);

// First 64 bits contain length and prefix.
// Only check prefix.
let a_prefix = (av >> 32) as u32;
let b_prefix = (bv >> 32) as u32;
if a_prefix != b_prefix {
a_prefix.to_be() < b_prefix.to_be()
} else {
self.value_unchecked(i) < other.value_unchecked(i)
}
}))
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
debug_assert!(self.len() == other.len());

let slf_views = self.views().as_slice();
let other_views = other.views().as_slice();

Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let av = *slf_views.get_unchecked(i);
let bv = *other_views.get_unchecked(i);

// First 64 bits contain length and prefix.
// Only check prefix.
let a_prefix = (av >> 32) as u32;
let b_prefix = (bv >> 32) as u32;
if a_prefix != b_prefix {
a_prefix.to_be() < b_prefix.to_be()
} else {
self.value_unchecked(i) <= other.value_unchecked(i)
}
}))
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
if let Some(val) = small_view_encoding(other) {
Bitmap::from_trusted_len_iter(self.views().iter().map(|v| *v == val))
} else {
let slf_views = self.views().as_slice();
let prefix = u32::from_le_bytes(other[..4].try_into().unwrap());
let prefix_len = ((prefix as u64) << 32) | other.len() as u64;
Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let v_prefix_len = *slf_views.get_unchecked(i) as u64;
if v_prefix_len != prefix_len {
false
} else {
self.value_unchecked(i) == other
}
}))
}
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
if let Some(val) = small_view_encoding(other) {
Bitmap::from_trusted_len_iter(self.views().iter().map(|v| *v != val))
} else {
let slf_views = self.views().as_slice();
let prefix = u32::from_le_bytes(other[..4].try_into().unwrap());
let prefix_len = ((prefix as u64) << 32) | other.len() as u64;
Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe {
let v_prefix_len = *slf_views.get_unchecked(i) as u64;
if v_prefix_len != prefix_len {
true
} else {
self.value_unchecked(i) != other
}
}))
}
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
broadcast_inequality(self, other, |a, b| a < b, |a, b| a < b)
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
broadcast_inequality(self, other, |a, b| a <= b, |a, b| a <= b)
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
broadcast_inequality(self, other, |a, b| a > b, |a, b| a > b)
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
broadcast_inequality(self, other, |a, b| a >= b, |a, b| a >= b)
}
}

impl TotalOrdKernel for Utf8ViewArray {
type Scalar = str;

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_eq_kernel(&other.to_binview())
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_ne_kernel(&other.to_binview())
}

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_lt_kernel(&other.to_binview())
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
self.to_binview().tot_le_kernel(&other.to_binview())
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_eq_kernel_broadcast(other.as_bytes())
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_ne_kernel_broadcast(other.as_bytes())
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_lt_kernel_broadcast(other.as_bytes())
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_le_kernel_broadcast(other.as_bytes())
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_gt_kernel_broadcast(other.as_bytes())
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.to_binview().tot_ge_kernel_broadcast(other.as_bytes())
}
}
3 changes: 3 additions & 0 deletions py-polars/tests/unit/operations/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> No
"",
"foo",
"bar",
"fooo",
"fooooooooooo",
"foooooooooooo",
"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooom",
"foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo",
"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooop",
Expand Down

0 comments on commit 5185091

Please sign in to comment.