Skip to content

Commit

Permalink
Merge pull request #1165 from geo-ant/feature/parallel-column-iterators
Browse files Browse the repository at this point in the history
Parallel Column Iterators with Rayon
  • Loading branch information
sebcrozet authored Jan 14, 2023
2 parents 9e58540 + 3a8c1bf commit 731fd0e
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nalgebra-ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: test
run: cargo test --features arbitrary,rand,serde-serialize,sparse,debug,io,compare,libm,proptest-support,slow-tests,rkyv-safe-deser;
run: cargo test --features arbitrary,rand,serde-serialize,sparse,debug,io,compare,libm,proptest-support,slow-tests,rkyv-safe-deser,rayon;
test-nalgebra-glm:
runs-on: ubuntu-latest
steps:
Expand Down
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ libm-force = [ "simba/libm_force" ]
macros = [ "nalgebra-macros" ]
cuda = [ "cust_core", "simba/cuda" ]


# Conversion
convert-mint = [ "mint" ]
convert-bytemuck = [ "bytemuck" ]
Expand Down Expand Up @@ -101,7 +102,7 @@ glam020 = { package = "glam", version = "0.20", optional = true }
glam021 = { package = "glam", version = "0.21", optional = true }
glam022 = { package = "glam", version = "0.22", optional = true }
cust_core = { version = "0.1", optional = true }

rayon = { version = "1.6", optional = true }

[dev-dependencies]
serde_json = "1.0"
Expand Down Expand Up @@ -137,3 +138,5 @@ lto = true
[package.metadata.docs.rs]
# Enable all the features when building the docs on docs.rs
all-features = true
# define the configuration attribute `docsrs`
rustdoc-args = ["--cfg", "docsrs"]
123 changes: 101 additions & 22 deletions src/base/iter.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
//! Matrix iterators.
// only enables the `doc_cfg` feature when
// the `docsrs` configuration attribute is defined
#![cfg_attr(docsrs, feature(doc_cfg))]

use core::fmt::Debug;
use core::ops::Range;
use std::iter::FusedIterator;
use std::marker::PhantomData;
use std::mem;
Expand Down Expand Up @@ -288,20 +294,40 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ExactSizeIte
}

/*
*
* Column iterators.
*
*/
#[derive(Clone, Debug)]
/// An iterator through the columns of a matrix.
pub struct ColumnIter<'a, T, R: Dim, C: Dim, S: RawStorage<T, R, C>> {
mat: &'a Matrix<T, R, C, S>,
curr: usize,
range: Range<usize>,
}

impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> ColumnIter<'a, T, R, C, S> {
/// a new column iterator covering all columns of the matrix
pub(crate) fn new(mat: &'a Matrix<T, R, C, S>) -> Self {
ColumnIter { mat, curr: 0 }
ColumnIter {
mat,
range: 0..mat.ncols(),
}
}

pub(crate) fn split_at(self, index: usize) -> (Self, Self) {
// SAFETY: this makes sur the generated ranges are valid.
let split_pos = (self.range.start + index).min(self.range.end);

let left_iter = ColumnIter {
mat: self.mat,
range: self.range.start..split_pos,
};

let right_iter = ColumnIter {
mat: self.mat,
range: split_pos..self.range.end,
};

(left_iter, right_iter)
}
}

Expand All @@ -310,9 +336,10 @@ impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> Iterator for ColumnIter

#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.curr < self.mat.ncols() {
let res = self.mat.column(self.curr);
self.curr += 1;
debug_assert!(self.range.start <= self.range.end);
if self.range.start < self.range.end {
let res = self.mat.column(self.range.start);
self.range.start += 1;
Some(res)
} else {
None
Expand All @@ -321,15 +348,29 @@ impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> Iterator for ColumnIter

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(
self.mat.ncols() - self.curr,
Some(self.mat.ncols() - self.curr),
)
let hint = self.range.len();
(hint, Some(hint))
}

#[inline]
fn count(self) -> usize {
self.mat.ncols() - self.curr
self.range.len()
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> DoubleEndedIterator
for ColumnIter<'a, T, R, C, S>
{
fn next_back(&mut self) -> Option<Self::Item> {
debug_assert!(self.range.start <= self.range.end);
if !self.range.is_empty() {
self.range.end -= 1;
debug_assert!(self.range.end < self.mat.ncols());
debug_assert!(self.range.end >= self.range.start);
Some(self.mat.column(self.range.end))
} else {
None
}
}
}

Expand All @@ -338,27 +379,47 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorage<T, R, C>> ExactSizeIterat
{
#[inline]
fn len(&self) -> usize {
self.mat.ncols() - self.curr
self.range.end - self.range.start
}
}

/// An iterator through the mutable columns of a matrix.
#[derive(Debug)]
pub struct ColumnIterMut<'a, T, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> {
mat: *mut Matrix<T, R, C, S>,
curr: usize,
range: Range<usize>,
phantom: PhantomData<&'a mut Matrix<T, R, C, S>>,
}

impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ColumnIterMut<'a, T, R, C, S> {
pub(crate) fn new(mat: &'a mut Matrix<T, R, C, S>) -> Self {
let range = 0..mat.ncols();
ColumnIterMut {
mat,
curr: 0,
phantom: PhantomData,
range,
phantom: Default::default(),
}
}

pub(crate) fn split_at(self, index: usize) -> (Self, Self) {
// SAFETY: this makes sur the generated ranges are valid.
let split_pos = (self.range.start + index).min(self.range.end);

let left_iter = ColumnIterMut {
mat: self.mat,
range: self.range.start..split_pos,
phantom: Default::default(),
};

let right_iter = ColumnIterMut {
mat: self.mat,
range: split_pos..self.range.end,
phantom: Default::default(),
};

(left_iter, right_iter)
}

fn ncols(&self) -> usize {
unsafe { (*self.mat).ncols() }
}
Expand All @@ -370,10 +431,11 @@ impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> Iterator
type Item = MatrixViewMut<'a, T, R, U1, S::RStride, S::CStride>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.curr < self.ncols() {
let res = unsafe { (*self.mat).column_mut(self.curr) };
self.curr += 1;
fn next(&'_ mut self) -> Option<Self::Item> {
debug_assert!(self.range.start <= self.range.end);
if self.range.start < self.range.end {
let res = unsafe { (*self.mat).column_mut(self.range.start) };
self.range.start += 1;
Some(res)
} else {
None
Expand All @@ -382,12 +444,13 @@ impl<'a, T, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> Iterator

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.ncols() - self.curr, Some(self.ncols() - self.curr))
let hint = self.range.len();
(hint, Some(hint))
}

#[inline]
fn count(self) -> usize {
self.ncols() - self.curr
self.range.len()
}
}

Expand All @@ -396,6 +459,22 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> ExactSizeIte
{
#[inline]
fn len(&self) -> usize {
self.ncols() - self.curr
self.range.len()
}
}

impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> DoubleEndedIterator
for ColumnIterMut<'a, T, R, C, S>
{
fn next_back(&mut self) -> Option<Self::Item> {
debug_assert!(self.range.start <= self.range.end);
if !self.range.is_empty() {
self.range.end -= 1;
debug_assert!(self.range.end < self.ncols());
debug_assert!(self.range.end >= self.range.start);
Some(unsafe { (*self.mat).column_mut(self.range.end) })
} else {
None
}
}
}
1 change: 1 addition & 0 deletions src/base/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ pub type MatrixCross<T, R1, C1, R2, C2> =
///
/// #### Iteration, map, and fold
/// - [Iteration on components, rows, and columns <span style="float:right;">`iter`, `column_iter`…</span>](#iteration-on-components-rows-and-columns)
/// - [Parallel iterators using rayon <span style="float:right;">`par_column_iter`, `par_column_iter_mut`…</span>](#parallel-iterators-using-rayon)
/// - [Elementwise mapping and folding <span style="float:right;">`map`, `fold`, `zip_map`…</span>](#elementwise-mapping-and-folding)
/// - [Folding or columns and rows <span style="float:right;">`compress_rows`, `compress_columns`…</span>](#folding-on-columns-and-rows)
///
Expand Down
3 changes: 3 additions & 0 deletions src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ mod min_max;
/// Mechanisms for working with values that may not be initialized.
pub mod uninit;

#[cfg(feature = "rayon")]
pub mod par_iter;

#[cfg(feature = "rkyv-serialize-no-std")]
mod rkyv_wrappers;

Expand Down
Loading

0 comments on commit 731fd0e

Please sign in to comment.