Skip to content

Commit

Permalink
Fix search sorted casting (#1579)
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn authored Dec 6, 2024
1 parent 84e1a6f commit bdfe74c
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 55 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/bench-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ jobs:
BENCH_VORTEX_RATIOS: '.*'
RUSTFLAGS: '-C target-cpu=native'
run: |
cargo run --bin tpch_benchmark --release -- -d gh-json -t 1 | tee tpch.json
cargo run --bin tpch_benchmark --release -- --only-vortex -d gh-json -t 1 | tee tpch.json
- name: Store benchmark result
if: '!cancelled()'
uses: benchmark-action/github-action-benchmark@v1
Expand Down Expand Up @@ -179,7 +179,7 @@ jobs:
RUSTFLAGS: '-C target-cpu=native'
HOME: /home/ci-runner
run: |
cargo run --bin clickbench --release -- -d gh-json | tee clickbench.json
cargo run --bin clickbench --release -- --only-vortex -d gh-json | tee clickbench.json
- name: Store benchmark result
if: '!cancelled()'
uses: benchmark-action/github-action-benchmark@v1
Expand Down
2 changes: 1 addition & 1 deletion bench-vortex/src/bin/clickbench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use vortex::error::vortex_panic;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value = "8")]
#[arg(short, long, default_value = "5")]
iterations: usize,
#[arg(short, long)]
threads: Option<usize>,
Expand Down
2 changes: 1 addition & 1 deletion bench-vortex/src/bin/tpch_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct Args {
threads: Option<usize>,
#[arg(short, long, default_value_t = true, default_missing_value = "true", action = ArgAction::Set)]
warmup: bool,
#[arg(short, long, default_value = "8")]
#[arg(short, long, default_value = "5")]
iterations: usize,
#[arg(long)]
only_vortex: bool,
Expand Down
44 changes: 24 additions & 20 deletions encodings/fastlanes/src/bitpacking/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use num_traits::AsPrimitive;
use vortex_array::array::SparseArray;
use vortex_array::compute::{
search_sorted_usize, IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn,
SearchSortedSide,
SearchSortedSide, SearchSortedUsizeFn,
};
use vortex_array::stats::ArrayStatistics;
use vortex_array::validity::Validity;
Expand All @@ -32,24 +32,6 @@ impl SearchSortedFn<BitPackedArray> for BitPackedEncoding {
})
}

fn search_sorted_usize(
&self,
array: &BitPackedArray,
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
match_each_unsigned_integer_ptype!(array.ptype(), |$P| {
// NOTE: conversion may truncate silently.
if let Some(pvalue) = num_traits::cast::<usize, $P>(value) {
search_sorted_native(array, pvalue, side)
} else {
// provided u64 is too large to fit in the provided PType, value must be off
// the right end of the array.
Ok(SearchResult::NotFound(array.len()))
}
})
}

fn search_sorted_many(
&self,
array: &BitPackedArray,
Expand All @@ -69,6 +51,26 @@ impl SearchSortedFn<BitPackedArray> for BitPackedEncoding {
.try_collect()
})
}
}

impl SearchSortedUsizeFn<BitPackedArray> for BitPackedEncoding {
fn search_sorted_usize(
&self,
array: &BitPackedArray,
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
match_each_unsigned_integer_ptype!(array.ptype(), |$P| {
// NOTE: conversion may truncate silently.
if let Some(pvalue) = num_traits::cast::<usize, $P>(value) {
search_sorted_native(array, pvalue, side)
} else {
// provided u64 is too large to fit in the provided PType, value must be off
// the right end of the array.
Ok(SearchResult::NotFound(array.len()))
}
})
}

fn search_sorted_usize_many(
&self,
Expand Down Expand Up @@ -121,7 +123,9 @@ where
// max packed value just search the patches
let usize_value: usize = value.as_();
if usize_value > array.max_packed_value() {
search_sorted_usize(&patches_array, value.as_(), side)
// FIXME(ngates): this is broken. Patches _aren't_ sorted because they're sparse and
// interspersed with nulls...
search_sorted_usize(&patches_array, usize_value, side)
} else {
Ok(BitPackedSearch::<'_, T>::new(array).search_sorted(&value, side))
}
Expand Down
8 changes: 6 additions & 2 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::array::PrimitiveEncoding;
use crate::compute::{
CastFn, ComputeVTable, FillForwardFn, FilterFn, ScalarAtFn, SearchSortedFn, SliceFn,
SubtractScalarFn, TakeFn,
CastFn, ComputeVTable, FillForwardFn, FilterFn, ScalarAtFn, SearchSortedFn,
SearchSortedUsizeFn, SliceFn, SubtractScalarFn, TakeFn,
};
use crate::ArrayData;

Expand Down Expand Up @@ -35,6 +35,10 @@ impl ComputeVTable for PrimitiveEncoding {
Some(self)
}

fn search_sorted_usize_fn(&self) -> Option<&dyn SearchSortedUsizeFn<ArrayData>> {
Some(self)
}

fn slice_fn(&self) -> Option<&dyn SliceFn<ArrayData>> {
Some(self)
}
Expand Down
7 changes: 6 additions & 1 deletion vortex-array/src/array/primitive/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use vortex_scalar::Scalar;

use crate::array::primitive::PrimitiveArray;
use crate::array::PrimitiveEncoding;
use crate::compute::{IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide};
use crate::compute::{
IndexOrd, Len, SearchResult, SearchSorted, SearchSortedFn, SearchSortedSide,
SearchSortedUsizeFn,
};
use crate::validity::Validity;
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayDType, ArrayLen};
Expand All @@ -33,7 +36,9 @@ impl SearchSortedFn<PrimitiveArray> for PrimitiveEncoding {
}
})
}
}

impl SearchSortedUsizeFn<PrimitiveArray> for PrimitiveEncoding {
#[allow(clippy::cognitive_complexity)]
fn search_sorted_usize(
&self,
Expand Down
26 changes: 24 additions & 2 deletions vortex-array/src/array/sparse/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use crate::array::sparse::SparseArray;
use crate::array::{PrimitiveArray, SparseEncoding};
use crate::compute::{
scalar_at, search_sorted, take, ComputeVTable, FilterFn, FilterMask, InvertFn, ScalarAtFn,
SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, TakeOptions,
SearchResult, SearchSortedFn, SearchSortedSide, SearchSortedUsizeFn, SliceFn, TakeFn,
TakeOptions,
};
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayData, IntoArrayData, IntoArrayVariant};
use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};

mod invert;
mod slice;
Expand All @@ -32,6 +33,10 @@ impl ComputeVTable for SparseEncoding {
Some(self)
}

fn search_sorted_usize_fn(&self) -> Option<&dyn SearchSortedUsizeFn<ArrayData>> {
Some(self)
}

fn slice_fn(&self) -> Option<&dyn SliceFn<ArrayData>> {
Some(self)
}
Expand All @@ -50,6 +55,7 @@ impl ScalarAtFn<SparseArray> for SparseEncoding {
}
}

// FIXME(ngates): these are broken in a way that works for array patches, this will be fixed soon.
impl SearchSortedFn<SparseArray> for SparseEncoding {
fn search_sorted(
&self,
Expand All @@ -76,6 +82,22 @@ impl SearchSortedFn<SparseArray> for SparseEncoding {
}
}

// FIXME(ngates): these are broken in a way that works for array patches, this will be fixed soon.
impl SearchSortedUsizeFn<SparseArray> for SparseEncoding {
fn search_sorted_usize(
&self,
array: &SparseArray,
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
let Ok(target) = Scalar::from(value).cast(array.dtype()) else {
// If the downcast fails, then the target is too large for the dtype.
return Ok(SearchResult::NotFound(array.len()));
};
SearchSortedFn::search_sorted(self, array, &target, side)
}
}

impl FilterFn<SparseArray> for SparseEncoding {
fn filter(&self, array: &SparseArray, mask: FilterMask) -> VortexResult<ArrayData> {
let buffer = mask.to_boolean_buffer()?;
Expand Down
7 changes: 7 additions & 0 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ pub trait ComputeVTable {
None
}

/// Perform a search over an ordered array.
///
/// See: [SearchSortedUsizeFn].
fn search_sorted_usize_fn(&self) -> Option<&dyn SearchSortedUsizeFn<ArrayData>> {
None
}

/// Perform zero-copy slicing of an array.
///
/// See: [SliceFn].
Expand Down
73 changes: 47 additions & 26 deletions vortex-array/src/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ pub trait SearchSortedFn<Array> {
side: SearchSortedSide,
) -> VortexResult<SearchResult>;

fn search_sorted_usize(
&self,
array: &Array,
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
let usize_scalar = Scalar::from(value);
self.search_sorted(array, &usize_scalar, side)
}

/// Bulk search for many values.
fn search_sorted_many(
&self,
Expand All @@ -128,6 +118,15 @@ pub trait SearchSortedFn<Array> {
.map(|value| self.search_sorted(array, value, side))
.try_collect()
}
}

pub trait SearchSortedUsizeFn<Array> {
fn search_sorted_usize(
&self,
array: &Array,
value: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult>;

fn search_sorted_usize_many(
&self,
Expand Down Expand Up @@ -162,34 +161,40 @@ where
SearchSortedFn::search_sorted(encoding, array_ref, value, side)
}

fn search_sorted_usize(
fn search_sorted_many(
&self,
array: &ArrayData,
value: usize,
values: &[Scalar],
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
) -> VortexResult<Vec<SearchResult>> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
SearchSortedFn::search_sorted_usize(encoding, array_ref, value, side)
SearchSortedFn::search_sorted_many(encoding, array_ref, values, side)
}
}

fn search_sorted_many(
impl<E: Encoding> SearchSortedUsizeFn<ArrayData> for E
where
E: SearchSortedUsizeFn<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn search_sorted_usize(
&self,
array: &ArrayData,
values: &[Scalar],
value: usize,
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
) -> VortexResult<SearchResult> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
SearchSortedFn::search_sorted_many(encoding, array_ref, values, side)
SearchSortedUsizeFn::search_sorted_usize(encoding, array_ref, value, side)
}

fn search_sorted_usize_many(
Expand All @@ -204,7 +209,7 @@ where
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
SearchSortedFn::search_sorted_usize_many(encoding, array_ref, values, side)
SearchSortedUsizeFn::search_sorted_usize_many(encoding, array_ref, values, side)
}
}

Expand All @@ -214,8 +219,8 @@ pub fn search_sorted<T: Into<Scalar>>(
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
let Ok(scalar) = target.into().cast(array.dtype()) else {
// If the cast fails, then the search value must be higher than the highest value in
// the array.
// Try to downcast the usize ot the array type, if the downcast fails, then we know the
// usize is too large and the value is greater than the highest value in the array.
return Ok(SearchResult::NotFound(array.len()));
};

Expand Down Expand Up @@ -243,14 +248,28 @@ pub fn search_sorted_usize(
target: usize,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
if let Some(f) = array.encoding().search_sorted_fn() {
if let Some(f) = array.encoding().search_sorted_usize_fn() {
return f.search_sorted_usize(array, target, side);
}

// Fallback to a generic search_sorted using scalar_at
// Otherwise, convert the target into a scalar to try the search_sorted_fn
let Ok(target) = Scalar::from(target).cast(array.dtype()) else {
return Ok(SearchResult::NotFound(array.len()));
};

// Try the non-usize search sorted
if let Some(f) = array.encoding().search_sorted_fn() {
return f.search_sorted(array, &target, side);
}

// Or fallback all the way to a generic search_sorted using scalar_at
if array.encoding().scalar_at_fn().is_some() {
let scalar = Scalar::primitive(target as u64, array.dtype().nullability());
return Ok(SearchSorted::search_sorted(array, &scalar, side));
// Try to downcast the usize to the array type, if the downcast fails, then we know the
// usize is too large and the value is greater than the highest value in the array.
let Ok(target) = target.cast(array.dtype()) else {
return Ok(SearchResult::NotFound(array.len()));
};
return Ok(SearchSorted::search_sorted(array, &target, side));
}

vortex_bail!(
Expand Down Expand Up @@ -287,7 +306,7 @@ pub fn search_sorted_usize_many(
targets: &[usize],
side: SearchSortedSide,
) -> VortexResult<Vec<SearchResult>> {
if let Some(f) = array.encoding().search_sorted_fn() {
if let Some(f) = array.encoding().search_sorted_usize_fn() {
return f.search_sorted_usize_many(array, targets, side);
}

Expand All @@ -299,6 +318,8 @@ pub fn search_sorted_usize_many(
}

pub trait IndexOrd<V> {
/// PartialOrd of the value at index `idx` with `elem`.
/// For example, if self\[idx\] > elem, return Some(Greater).
fn index_cmp(&self, idx: usize, elem: &V) -> Option<Ordering>;

fn index_lt(&self, idx: usize, elem: &V) -> bool {
Expand Down

0 comments on commit bdfe74c

Please sign in to comment.