Skip to content

Commit

Permalink
feat[rust, python]: allow setting by multiple values for boolean and …
Browse files Browse the repository at this point in the history
…utf8 physical types (#4889)
  • Loading branch information
ritchie46 authored Sep 18, 2022
1 parent 4ec099d commit b93083b
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 20 deletions.
142 changes: 129 additions & 13 deletions polars/polars-ops/src/chunked_array/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,64 @@ use polars_core::utils::arrow::bitmap::MutableBitmap;
use polars_core::utils::arrow::types::NativeType;

pub trait ChunkedSet<T: Copy> {
fn set_at_idx2<V>(self, idx: &[IdxSize], values: V) -> Series
fn set_at_idx2<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
where
V: IntoIterator<Item = Option<T>>;
}
fn check_sorted(idx: &[IdxSize]) -> PolarsResult<()> {
if idx.is_empty() {
return Ok(());
}
let mut sorted = true;
let mut previous = idx[0];
for &i in &idx[1..] {
if i < previous {
// we will not break here as that prevents SIMD
sorted = false;
}
previous = i;
}
if sorted {
Ok(())
} else {
Err(PolarsError::ComputeError(
"set indices must be sorted".into(),
))
}
}

fn check_bounds(idx: &[IdxSize], len: IdxSize) -> PolarsResult<()> {
let mut inbounds = true;

for &i in idx {
if i >= len {
// we will not break here as that prevents SIMD
inbounds = false;
}
}
if inbounds {
Ok(())
} else {
Err(PolarsError::ComputeError(
"set indices are out of bounds".into(),
))
}
}

trait PolarsOpsNumericType: PolarsNumericType {}

fn set_at_idx_impl<V, T: NativeType>(
impl PolarsOpsNumericType for UInt8Type {}
impl PolarsOpsNumericType for UInt16Type {}
impl PolarsOpsNumericType for UInt32Type {}
impl PolarsOpsNumericType for UInt64Type {}
impl PolarsOpsNumericType for Int8Type {}
impl PolarsOpsNumericType for Int16Type {}
impl PolarsOpsNumericType for Int32Type {}
impl PolarsOpsNumericType for Int64Type {}
impl PolarsOpsNumericType for Float32Type {}
impl PolarsOpsNumericType for Float64Type {}

unsafe fn set_at_idx_impl<V, T: NativeType>(
new_values_slice: &mut [T],
set_values: V,
arr: &mut PrimitiveArray<T>,
Expand All @@ -28,10 +80,10 @@ fn set_at_idx_impl<V, T: NativeType>(
for (idx, val) in idx.iter().zip(&mut values_iter) {
match val {
Some(value) => {
mut_validity.set(*idx as usize, true);
new_values_slice[*idx as usize] = value
mut_validity.set_unchecked(*idx as usize, true);
*new_values_slice.get_unchecked_mut(*idx as usize) = value
}
None => mut_validity.set(*idx as usize, false),
None => mut_validity.set_unchecked(*idx as usize, false),
}
}
mut_validity.into()
Expand All @@ -44,14 +96,14 @@ fn set_at_idx_impl<V, T: NativeType>(
if validity.is_empty() {
validity.extend_constant(len, true);
}
validity.set(*idx as usize, true);
new_values_slice[*idx as usize] = value
validity.set_unchecked(*idx as usize, true);
*new_values_slice.get_unchecked_mut(*idx as usize) = value
}
None => {
if validity.is_empty() {
validity.extend_constant(len, true);
}
validity.set(*idx as usize, false)
validity.set_unchecked(*idx as usize, false)
}
}
}
Expand All @@ -61,14 +113,15 @@ fn set_at_idx_impl<V, T: NativeType>(
}
}

impl<T: PolarsNumericType> ChunkedSet<T::Native> for ChunkedArray<T>
impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for ChunkedArray<T>
where
ChunkedArray<T>: IntoSeries,
{
fn set_at_idx2<V>(self, idx: &[IdxSize], values: V) -> Series
fn set_at_idx2<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
where
V: IntoIterator<Item = Option<T::Native>>,
{
check_bounds(idx, self.len() as IdxSize)?;
let mut ca = self.rechunk();
drop(self);

Expand All @@ -83,14 +136,77 @@ where

// reborrow because the bck does not allow it
let current_values = unsafe { &mut *std::slice::from_raw_parts_mut(ptr, len) };
set_at_idx_impl(current_values, values, arr, idx, len)
// Safety:
// we checked bounds
unsafe { set_at_idx_impl(current_values, values, arr, idx, len) };
}
None => {
let mut new_values = arr.values().as_slice().to_vec();
set_at_idx_impl(&mut new_values, values, arr, idx, len);
// Safety:
// we checked bounds
unsafe { set_at_idx_impl(&mut new_values, values, arr, idx, len) };
arr.set_values(new_values.into());
}
};
ca.into_series()
Ok(ca.into_series())
}
}

impl<'a> ChunkedSet<&'a str> for &'a Utf8Chunked {
fn set_at_idx2<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
where
V: IntoIterator<Item = Option<&'a str>>,
{
check_bounds(idx, self.len() as IdxSize)?;
check_sorted(idx)?;
let mut ca_iter = self.into_iter().enumerate();
let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len(), self.get_values_size());

for (current_idx, current_value) in idx.iter().zip(values) {
for (cnt_idx, opt_val_self) in &mut ca_iter {
if cnt_idx == *current_idx as usize {
builder.append_option(current_value);
break;
} else {
builder.append_option(opt_val_self);
}
}
}
// the last idx is probably not the last value so we finish the iterator
for (_, opt_val_self) in ca_iter {
builder.append_option(opt_val_self);
}

let ca = builder.finish();
Ok(ca.into_series())
}
}
impl ChunkedSet<bool> for &BooleanChunked {
fn set_at_idx2<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
where
V: IntoIterator<Item = Option<bool>>,
{
check_bounds(idx, self.len() as IdxSize)?;
check_sorted(idx)?;
let mut ca_iter = self.into_iter().enumerate();
let mut builder = BooleanChunkedBuilder::new(self.name(), self.len());

for (current_idx, current_value) in idx.iter().zip(values) {
for (cnt_idx, opt_val_self) in &mut ca_iter {
if cnt_idx == *current_idx as usize {
builder.append_option(current_value);
break;
} else {
builder.append_option(opt_val_self);
}
}
}
// the last idx is probably not the last value so we finish the iterator
for (_, opt_val_self) in ca_iter {
builder.append_option(opt_val_self);
}

let ca = builder.finish();
Ok(ca.into_series())
}
}
2 changes: 2 additions & 0 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,8 @@ def set_at_idx(
| bool
| Sequence[int]
| Sequence[float]
| Sequence[bool]
| Sequence[str]
| Sequence[date]
| Sequence[datetime]
| date
Expand Down
10 changes: 3 additions & 7 deletions py-polars/src/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,15 @@ pub(crate) fn set_at_idx(mut s: Series, idx: &Series, values: &Series) -> Polars
DataType::Boolean => {
let ca = s.bool()?;
let values = values.bool()?;
let value = values.get(0);
ca.set_at_idx(idx.iter().copied(), value)
.map(|ca| ca.into_series())?
ca.set_at_idx2(idx, values)
}
DataType::Utf8 => {
let ca = s.utf8()?;
let values = values.utf8()?;
let value = values.get(0);
ca.set_at_idx(idx.iter().copied(), value)
.map(|ca| ca.into_series())?
ca.set_at_idx2(idx, values)
}
_ => panic!("not yet implemented for dtype: {}", logical_dtype),
};

s.cast(&logical_dtype)
s.and_then(|s| s.cast(&logical_dtype))
}
6 changes: 6 additions & 0 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,6 +1996,12 @@ def test_set_at_idx() -> None:
assert s.to_list() == ["a", "x", "x"]
assert s.set_at_idx([0, 2], 0.12345).to_list() == ["0.12345", "x", "0.12345"]

# set multiple values values
s = pl.Series(["z", "z", "z"])
assert s.set_at_idx([0, 1], ["a", "b"]).to_list() == ["a", "b", "z"]
s = pl.Series([True, False, True])
assert s.set_at_idx([0, 1], [False, True]).to_list() == [False, True, True]


def test_repr() -> None:
s = pl.Series("ints", [1001, 2002, 3003])
Expand Down

0 comments on commit b93083b

Please sign in to comment.