From b93083bf7958f4bf3129237a11af6a149f14b9ff Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 18 Sep 2022 09:05:56 +0200 Subject: [PATCH] feat[rust, python]: allow setting by multiple values for boolean and utf8 physical types (#4889) --- polars/polars-ops/src/chunked_array/set.rs | 142 ++++++++++++++++++-- py-polars/polars/internals/series/series.py | 2 + py-polars/src/set.rs | 10 +- py-polars/tests/unit/test_series.py | 6 + 4 files changed, 140 insertions(+), 20 deletions(-) diff --git a/polars/polars-ops/src/chunked_array/set.rs b/polars/polars-ops/src/chunked_array/set.rs index 68ae5f920e5..57a5b41968f 100644 --- a/polars/polars-ops/src/chunked_array/set.rs +++ b/polars/polars-ops/src/chunked_array/set.rs @@ -5,12 +5,64 @@ use polars_core::utils::arrow::bitmap::MutableBitmap; use polars_core::utils::arrow::types::NativeType; pub trait ChunkedSet { - fn set_at_idx2(self, idx: &[IdxSize], values: V) -> Series + fn set_at_idx2(self, idx: &[IdxSize], values: V) -> PolarsResult where V: IntoIterator>; } +fn check_sorted(idx: &[IdxSize]) -> PolarsResult<()> { + if idx.is_empty() { + return Ok(()); + } + let mut sorted = true; + let mut previous = idx[0]; + for &i in &idx[1..] { + if i < previous { + // we will not break here as that prevents SIMD + sorted = false; + } + previous = i; + } + if sorted { + Ok(()) + } else { + Err(PolarsError::ComputeError( + "set indices must be sorted".into(), + )) + } +} + +fn check_bounds(idx: &[IdxSize], len: IdxSize) -> PolarsResult<()> { + let mut inbounds = true; + + for &i in idx { + if i >= len { + // we will not break here as that prevents SIMD + inbounds = false; + } + } + if inbounds { + Ok(()) + } else { + Err(PolarsError::ComputeError( + "set indices are out of bounds".into(), + )) + } +} + +trait PolarsOpsNumericType: PolarsNumericType {} -fn set_at_idx_impl( +impl PolarsOpsNumericType for UInt8Type {} +impl PolarsOpsNumericType for UInt16Type {} +impl PolarsOpsNumericType for UInt32Type {} +impl PolarsOpsNumericType for UInt64Type {} +impl PolarsOpsNumericType for Int8Type {} +impl PolarsOpsNumericType for Int16Type {} +impl PolarsOpsNumericType for Int32Type {} +impl PolarsOpsNumericType for Int64Type {} +impl PolarsOpsNumericType for Float32Type {} +impl PolarsOpsNumericType for Float64Type {} + +unsafe fn set_at_idx_impl( new_values_slice: &mut [T], set_values: V, arr: &mut PrimitiveArray, @@ -28,10 +80,10 @@ fn set_at_idx_impl( for (idx, val) in idx.iter().zip(&mut values_iter) { match val { Some(value) => { - mut_validity.set(*idx as usize, true); - new_values_slice[*idx as usize] = value + mut_validity.set_unchecked(*idx as usize, true); + *new_values_slice.get_unchecked_mut(*idx as usize) = value } - None => mut_validity.set(*idx as usize, false), + None => mut_validity.set_unchecked(*idx as usize, false), } } mut_validity.into() @@ -44,14 +96,14 @@ fn set_at_idx_impl( if validity.is_empty() { validity.extend_constant(len, true); } - validity.set(*idx as usize, true); - new_values_slice[*idx as usize] = value + validity.set_unchecked(*idx as usize, true); + *new_values_slice.get_unchecked_mut(*idx as usize) = value } None => { if validity.is_empty() { validity.extend_constant(len, true); } - validity.set(*idx as usize, false) + validity.set_unchecked(*idx as usize, false) } } } @@ -61,14 +113,15 @@ fn set_at_idx_impl( } } -impl ChunkedSet for ChunkedArray +impl ChunkedSet for ChunkedArray where ChunkedArray: IntoSeries, { - fn set_at_idx2(self, idx: &[IdxSize], values: V) -> Series + fn set_at_idx2(self, idx: &[IdxSize], values: V) -> PolarsResult where V: IntoIterator>, { + check_bounds(idx, self.len() as IdxSize)?; let mut ca = self.rechunk(); drop(self); @@ -83,14 +136,77 @@ where // reborrow because the bck does not allow it let current_values = unsafe { &mut *std::slice::from_raw_parts_mut(ptr, len) }; - set_at_idx_impl(current_values, values, arr, idx, len) + // Safety: + // we checked bounds + unsafe { set_at_idx_impl(current_values, values, arr, idx, len) }; } None => { let mut new_values = arr.values().as_slice().to_vec(); - set_at_idx_impl(&mut new_values, values, arr, idx, len); + // Safety: + // we checked bounds + unsafe { set_at_idx_impl(&mut new_values, values, arr, idx, len) }; arr.set_values(new_values.into()); } }; - ca.into_series() + Ok(ca.into_series()) + } +} + +impl<'a> ChunkedSet<&'a str> for &'a Utf8Chunked { + fn set_at_idx2(self, idx: &[IdxSize], values: V) -> PolarsResult + where + V: IntoIterator>, + { + check_bounds(idx, self.len() as IdxSize)?; + check_sorted(idx)?; + let mut ca_iter = self.into_iter().enumerate(); + let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len(), self.get_values_size()); + + for (current_idx, current_value) in idx.iter().zip(values) { + for (cnt_idx, opt_val_self) in &mut ca_iter { + if cnt_idx == *current_idx as usize { + builder.append_option(current_value); + break; + } else { + builder.append_option(opt_val_self); + } + } + } + // the last idx is probably not the last value so we finish the iterator + for (_, opt_val_self) in ca_iter { + builder.append_option(opt_val_self); + } + + let ca = builder.finish(); + Ok(ca.into_series()) + } +} +impl ChunkedSet for &BooleanChunked { + fn set_at_idx2(self, idx: &[IdxSize], values: V) -> PolarsResult + where + V: IntoIterator>, + { + check_bounds(idx, self.len() as IdxSize)?; + check_sorted(idx)?; + let mut ca_iter = self.into_iter().enumerate(); + let mut builder = BooleanChunkedBuilder::new(self.name(), self.len()); + + for (current_idx, current_value) in idx.iter().zip(values) { + for (cnt_idx, opt_val_self) in &mut ca_iter { + if cnt_idx == *current_idx as usize { + builder.append_option(current_value); + break; + } else { + builder.append_option(opt_val_self); + } + } + } + // the last idx is probably not the last value so we finish the iterator + for (_, opt_val_self) in ca_iter { + builder.append_option(opt_val_self); + } + + let ca = builder.finish(); + Ok(ca.into_series()) } } diff --git a/py-polars/polars/internals/series/series.py b/py-polars/polars/internals/series/series.py index 09ea3c20e45..820bac8dd58 100644 --- a/py-polars/polars/internals/series/series.py +++ b/py-polars/polars/internals/series/series.py @@ -2490,6 +2490,8 @@ def set_at_idx( | bool | Sequence[int] | Sequence[float] + | Sequence[bool] + | Sequence[str] | Sequence[date] | Sequence[datetime] | date diff --git a/py-polars/src/set.rs b/py-polars/src/set.rs index 3e1d4917803..fcc081afa73 100644 --- a/py-polars/src/set.rs +++ b/py-polars/src/set.rs @@ -77,19 +77,15 @@ pub(crate) fn set_at_idx(mut s: Series, idx: &Series, values: &Series) -> Polars DataType::Boolean => { let ca = s.bool()?; let values = values.bool()?; - let value = values.get(0); - ca.set_at_idx(idx.iter().copied(), value) - .map(|ca| ca.into_series())? + ca.set_at_idx2(idx, values) } DataType::Utf8 => { let ca = s.utf8()?; let values = values.utf8()?; - let value = values.get(0); - ca.set_at_idx(idx.iter().copied(), value) - .map(|ca| ca.into_series())? + ca.set_at_idx2(idx, values) } _ => panic!("not yet implemented for dtype: {}", logical_dtype), }; - s.cast(&logical_dtype) + s.and_then(|s| s.cast(&logical_dtype)) } diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 125bef666ff..6c9d3a69f62 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -1996,6 +1996,12 @@ def test_set_at_idx() -> None: assert s.to_list() == ["a", "x", "x"] assert s.set_at_idx([0, 2], 0.12345).to_list() == ["0.12345", "x", "0.12345"] + # set multiple values values + s = pl.Series(["z", "z", "z"]) + assert s.set_at_idx([0, 1], ["a", "b"]).to_list() == ["a", "b", "z"] + s = pl.Series([True, False, True]) + assert s.set_at_idx([0, 1], [False, True]).to_list() == [False, True, True] + def test_repr() -> None: s = pl.Series("ints", [1001, 2002, 3003])