Skip to content

Commit

Permalink
Merge pull request #502 from jturner314/refactor-iterators
Browse files Browse the repository at this point in the history
Refactor and improve iterators
  • Loading branch information
jturner314 authored Oct 24, 2018
2 parents 408f42b + 0948409 commit 7df47d2
Show file tree
Hide file tree
Showing 7 changed files with 370 additions and 369 deletions.
38 changes: 15 additions & 23 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,9 @@ use imp_prelude::*;

use arraytraits;
use dimension;
use iterators;
use error::{self, ShapeError, ErrorKind};
use dimension::IntoDimension;
use dimension::{abs_index, axes_of, Axes, do_slice, merge_axes, stride_offset};
use iterators::{
new_lanes,
new_lanes_mut,
exact_chunks_of,
exact_chunks_mut_of,
windows
};
use zip::Zip;

use {
Expand Down Expand Up @@ -676,7 +668,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
pub fn genrows(&self) -> Lanes<A, D::Smaller> {
let mut n = self.ndim();
if n == 0 { n += 1; }
new_lanes(self.view(), Axis(n - 1))
Lanes::new(self.view(), Axis(n - 1))
}

/// Return a producer and iterable that traverses over the *generalized*
Expand All @@ -688,7 +680,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
{
let mut n = self.ndim();
if n == 0 { n += 1; }
new_lanes_mut(self.view_mut(), Axis(n - 1))
LanesMut::new(self.view_mut(), Axis(n - 1))
}

/// Return a producer and iterable that traverses over the *generalized*
Expand Down Expand Up @@ -718,7 +710,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
/// }
/// ```
pub fn gencolumns(&self) -> Lanes<A, D::Smaller> {
new_lanes(self.view(), Axis(0))
Lanes::new(self.view(), Axis(0))
}

/// Return a producer and iterable that traverses over the *generalized*
Expand All @@ -728,7 +720,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
pub fn gencolumns_mut(&mut self) -> LanesMut<A, D::Smaller>
where S: DataMut
{
new_lanes_mut(self.view_mut(), Axis(0))
LanesMut::new(self.view_mut(), Axis(0))
}

/// Return a producer and iterable that traverses over all 1D lanes
Expand Down Expand Up @@ -760,7 +752,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
/// assert_eq!(inner2.into_iter().next().unwrap(), aview1(&[0, 1, 2]));
/// ```
pub fn lanes(&self, axis: Axis) -> Lanes<A, D::Smaller> {
new_lanes(self.view(), axis)
Lanes::new(self.view(), axis)
}

/// Return a producer and iterable that traverses over all 1D lanes
Expand All @@ -770,7 +762,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
pub fn lanes_mut(&mut self, axis: Axis) -> LanesMut<A, D::Smaller>
where S: DataMut
{
new_lanes_mut(self.view_mut(), axis)
LanesMut::new(self.view_mut(), axis)
}


Expand Down Expand Up @@ -819,7 +811,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
pub fn axis_iter(&self, axis: Axis) -> AxisIter<A, D::Smaller>
where D: RemoveAxis,
{
iterators::new_axis_iter(self.view(), axis.index())
AxisIter::new(self.view(), axis)
}


Expand All @@ -834,7 +826,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
where S: DataMut,
D: RemoveAxis,
{
iterators::new_axis_iter_mut(self.view_mut(), axis.index())
AxisIterMut::new(self.view_mut(), axis)
}


Expand Down Expand Up @@ -865,7 +857,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
/// [[26, 27]]]));
/// ```
pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter<A, D> {
iterators::new_chunk_iter(self.view(), axis.index(), size)
AxisChunksIter::new(self.view(), axis, size)
}

/// Return an iterator that traverses over `axis` by chunks of `size`,
Expand All @@ -878,7 +870,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
-> AxisChunksIterMut<A, D>
where S: DataMut
{
iterators::new_chunk_iter_mut(self.view_mut(), axis.index(), size)
AxisChunksIterMut::new(self.view_mut(), axis, size)
}

/// Return an exact chunks producer (and iterable).
Expand All @@ -895,7 +887,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
pub fn exact_chunks<E>(&self, chunk_size: E) -> ExactChunks<A, D>
where E: IntoDimension<Dim=D>,
{
exact_chunks_of(self.view(), chunk_size)
ExactChunks::new(self.view(), chunk_size)
}

/// Return an exact chunks producer (and iterable).
Expand Down Expand Up @@ -934,7 +926,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
where E: IntoDimension<Dim=D>,
S: DataMut
{
exact_chunks_mut_of(self.view_mut(), chunk_size)
ExactChunksMut::new(self.view_mut(), chunk_size)
}

/// Return a window producer and iterable.
Expand All @@ -954,7 +946,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
pub fn windows<E>(&self, window_size: E) -> Windows<A, D>
where E: IntoDimension<Dim=D>
{
windows(self.view(), window_size)
Windows::new(self.view(), window_size)
}

// Return (length, stride) for diagonal
Expand Down Expand Up @@ -1597,8 +1589,8 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
// break the arrays up into their inner rows
let n = self.ndim();
let dim = self.raw_dim();
Zip::from(new_lanes_mut(self.view_mut(), Axis(n - 1)))
.and(new_lanes(rhs.broadcast_assume(dim), Axis(n - 1)))
Zip::from(LanesMut::new(self.view_mut(), Axis(n - 1)))
.and(Lanes::new(rhs.broadcast_assume(dim), Axis(n - 1)))
.apply(move |s_row, r_row| {
Zip::from(s_row).and(r_row).apply(|a, b| f(a, b))
});
Expand Down
15 changes: 7 additions & 8 deletions src/impl_views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ use {
Baseiter,
};

use iter;
use iterators;
use iter::{self, AxisIter, AxisIterMut};

/// Methods for read-only array views.
impl<'a, A, D> ArrayView<'a, A, D>
Expand Down Expand Up @@ -469,15 +468,15 @@ impl<'a, A, D> ArrayView<'a, A, D>
}

#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<'a, A, D> {
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
unsafe {
Baseiter::new(self.ptr, self.dim, self.strides)
}
}

#[inline]
pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> {
ElementsBase { inner: self.into_base_iter() }
ElementsBase::new(self)
}

pub(crate) fn into_iter_(self) -> Iter<'a, A, D> {
Expand All @@ -490,7 +489,7 @@ impl<'a, A, D> ArrayView<'a, A, D>
pub fn into_outer_iter(self) -> iter::AxisIter<'a, A, D::Smaller>
where D: RemoveAxis,
{
iterators::new_outer_iter(self)
AxisIter::new(self, Axis(0))
}

}
Expand Down Expand Up @@ -519,15 +518,15 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
}

#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<'a, A, D> {
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
unsafe {
Baseiter::new(self.ptr, self.dim, self.strides)
}
}

#[inline]
pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut { inner: self.into_base_iter() }
ElementsBaseMut::new(self)
}

pub(crate) fn into_slice_(self) -> Result<&'a mut [A], Self> {
Expand All @@ -550,7 +549,7 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
pub fn into_outer_iter(self) -> iter::AxisIterMut<'a, A, D::Smaller>
where D: RemoveAxis,
{
iterators::new_outer_iter_mut(self)
AxisIterMut::new(self, Axis(0))
}
}

85 changes: 46 additions & 39 deletions src/iterators/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,30 @@ pub struct ExactChunks<'a, A: 'a, D> {
inner_strides: D,
}

/// **Panics** if any chunk dimension is zero<br>
pub fn exact_chunks_of<A, D, E>(mut a: ArrayView<A, D>, chunk: E) -> ExactChunks<A, D>
where D: Dimension,
E: IntoDimension<Dim=D>,
{
let chunk = chunk.into_dimension();
ndassert!(a.ndim() == chunk.ndim(),
concat!("Chunk dimension {} does not match array dimension {} ",
"(with array of shape {:?})"),
chunk.ndim(), a.ndim(), a.shape());
for i in 0..a.ndim() {
a.dim[i] /= chunk[i];
}
let inner_strides = a.raw_strides();
a.strides *= &chunk;
impl<'a, A, D: Dimension> ExactChunks<'a, A, D> {
/// Creates a new exact chunks producer.
///
/// **Panics** if any chunk dimension is zero
pub(crate) fn new<E>(mut a: ArrayView<'a, A, D>, chunk: E) -> Self
where
E: IntoDimension<Dim = D>,
{
let chunk = chunk.into_dimension();
ndassert!(a.ndim() == chunk.ndim(),
concat!("Chunk dimension {} does not match array dimension {} ",
"(with array of shape {:?})"),
chunk.ndim(), a.ndim(), a.shape());
for i in 0..a.ndim() {
a.dim[i] /= chunk[i];
}
let inner_strides = a.raw_strides();
a.strides *= &chunk;

ExactChunks {
base: a,
chunk: chunk,
inner_strides: inner_strides,
ExactChunks {
base: a,
chunk: chunk,
inner_strides: inner_strides,
}
}
}

Expand Down Expand Up @@ -117,27 +121,30 @@ pub struct ExactChunksMut<'a, A: 'a, D> {
inner_strides: D,
}

/// **Panics** if any chunk dimension is zero<br>
pub fn exact_chunks_mut_of<A, D, E>(mut a: ArrayViewMut<A, D>, chunk: E)
-> ExactChunksMut<A, D>
where D: Dimension,
E: IntoDimension<Dim=D>,
{
let chunk = chunk.into_dimension();
ndassert!(a.ndim() == chunk.ndim(),
concat!("Chunk dimension {} does not match array dimension {} ",
"(with array of shape {:?})"),
chunk.ndim(), a.ndim(), a.shape());
for i in 0..a.ndim() {
a.dim[i] /= chunk[i];
}
let inner_strides = a.raw_strides();
a.strides *= &chunk;
impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> {
/// Creates a new exact chunks producer.
///
/// **Panics** if any chunk dimension is zero
pub(crate) fn new<E>(mut a: ArrayViewMut<'a, A, D>, chunk: E) -> Self
where
E: IntoDimension<Dim = D>,
{
let chunk = chunk.into_dimension();
ndassert!(a.ndim() == chunk.ndim(),
concat!("Chunk dimension {} does not match array dimension {} ",
"(with array of shape {:?})"),
chunk.ndim(), a.ndim(), a.shape());
for i in 0..a.ndim() {
a.dim[i] /= chunk[i];
}
let inner_strides = a.raw_strides();
a.strides *= &chunk;

ExactChunksMut {
base: a,
chunk: chunk,
inner_strides: inner_strides,
ExactChunksMut {
base: a,
chunk: chunk,
inner_strides: inner_strides,
}
}
}

Expand Down
Loading

0 comments on commit 7df47d2

Please sign in to comment.