Skip to content

Commit

Permalink
Fix co_broadcast in operator overloading
Browse files Browse the repository at this point in the history
  • Loading branch information
SparrowLii committed Jan 20, 2021
1 parent 2af780f commit 979d6df
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 27 deletions.
105 changes: 105 additions & 0 deletions src/dimension/broadcast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use crate::error::*;
use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};

/// Calculate the co_broadcast shape of two dimensions. Return error if shapes are
/// not compatible.
fn broadcast_shape<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
where
D1: Dimension,
D2: Dimension,
Output: Dimension,
{
let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim());
// Swap the order if d2 is longer.
if overflow {
return broadcast_shape::<D2, D1, Output>(shape2, shape1);
}
// The output should be the same length as shape1.
let mut out = Output::zeros(shape1.ndim());
let out_slice = out.slice_mut();
let s1 = shape1.slice();
let s2 = shape2.slice();
// Uses the [NumPy broadcasting rules]
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
//
// Zero dimension element is not in the original rules of broadcasting.
// We currently treat it as the same as 1. Especially, when one side is
// zero with one side is empty, or both sides are zero, the result will
// remain zero.
for i in 0..shape1.ndim() {
out_slice[i] = s1[i];
}
for i in 0..shape2.ndim() {
if out_slice[i + k] != s2[i] && s2[i] != 0 {
if out_slice[i + k] <= 1 {
out_slice[i + k] = s2[i]
} else if s2[i] != 1 {
return Err(from_kind(ErrorKind::IncompatibleShape));
}
}
}
Ok(out)
}

pub trait BroadcastShape<Other: Dimension>: Dimension {
/// The resulting dimension type after broadcasting.
type BroadcastOutput: Dimension;

/// Determines the shape after broadcasting the dimensions together.
///
/// If the dimensions are not compatible, returns `Err`.
///
/// Uses the [NumPy broadcasting rules]
/// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
fn broadcast_shape(&self, other: &Other) -> Result<Self::BroadcastOutput, ShapeError> {
broadcast_shape::<Self, Other, Self::BroadcastOutput>(self, other)
}
}

/// Dimensions of the same type remain unchanged when co_broadcast.
/// So you can directly use D as the resulting type.
/// (Instead of <D as BroadcastShape<D>>::BroadcastOutput)
impl<D: Dimension> BroadcastShape<D> for D {
type BroadcastOutput = D;
}

macro_rules! impl_broadcast_distinct_fixed {
($smaller:ty, $larger:ty) => {
impl BroadcastShape<$larger> for $smaller {
type BroadcastOutput = $larger;
}

impl BroadcastShape<$smaller> for $larger {
type BroadcastOutput = $larger;
}
};
}

impl_broadcast_distinct_fixed!(Ix0, Ix1);
impl_broadcast_distinct_fixed!(Ix0, Ix2);
impl_broadcast_distinct_fixed!(Ix0, Ix3);
impl_broadcast_distinct_fixed!(Ix0, Ix4);
impl_broadcast_distinct_fixed!(Ix0, Ix5);
impl_broadcast_distinct_fixed!(Ix0, Ix6);
impl_broadcast_distinct_fixed!(Ix1, Ix2);
impl_broadcast_distinct_fixed!(Ix1, Ix3);
impl_broadcast_distinct_fixed!(Ix1, Ix4);
impl_broadcast_distinct_fixed!(Ix1, Ix5);
impl_broadcast_distinct_fixed!(Ix1, Ix6);
impl_broadcast_distinct_fixed!(Ix2, Ix3);
impl_broadcast_distinct_fixed!(Ix2, Ix4);
impl_broadcast_distinct_fixed!(Ix2, Ix5);
impl_broadcast_distinct_fixed!(Ix2, Ix6);
impl_broadcast_distinct_fixed!(Ix3, Ix4);
impl_broadcast_distinct_fixed!(Ix3, Ix5);
impl_broadcast_distinct_fixed!(Ix3, Ix6);
impl_broadcast_distinct_fixed!(Ix4, Ix5);
impl_broadcast_distinct_fixed!(Ix4, Ix6);
impl_broadcast_distinct_fixed!(Ix5, Ix6);
impl_broadcast_distinct_fixed!(Ix0, IxDyn);
impl_broadcast_distinct_fixed!(Ix1, IxDyn);
impl_broadcast_distinct_fixed!(Ix2, IxDyn);
impl_broadcast_distinct_fixed!(Ix3, IxDyn);
impl_broadcast_distinct_fixed!(Ix4, IxDyn);
impl_broadcast_distinct_fixed!(Ix5, IxDyn);
impl_broadcast_distinct_fixed!(Ix6, IxDyn);
2 changes: 2 additions & 0 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use num_integer::div_floor;

pub use self::axes::{axes_of, Axes, AxisDescription};
pub use self::axis::Axis;
pub use self::broadcast::BroadcastShape;
pub use self::conversion::IntoDimension;
pub use self::dim::*;
pub use self::dimension_trait::Dimension;
Expand All @@ -28,6 +29,7 @@ use std::mem;
mod macros;
mod axes;
mod axis;
mod broadcast;
mod conversion;
pub mod dim;
mod dimension_trait;
Expand Down
85 changes: 63 additions & 22 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::dimension::BroadcastShape;
use num_complex::Complex;

/// Elements that can be used as direct operands in arithmetic with arrays.
Expand Down Expand Up @@ -53,24 +54,48 @@ macro_rules! impl_binary_op(
/// Perform elementwise
#[doc=$doc]
/// between `self` and `rhs`,
/// and return the result (based on `self`).
///
/// `self` must be an `Array` or `ArcArray`.
/// and return the result.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> ArrayBase<S, D>
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
{
self.$mth(&rhs)
}
}

/// Perform elementwise
#[doc=$doc]
/// between reference `self` and `rhs`,
/// and return the result as a new `Array`.
///
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
{
self.$mth(&rhs)
}
Expand All @@ -79,27 +104,34 @@ where
/// Perform elementwise
#[doc=$doc]
/// between `self` and reference `rhs`,
/// and return the result (based on `self`).
/// and return the result.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, rhs: &ArrayBase<S2, E>) -> ArrayBase<S, D>
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
{
self.zip_mut_with(rhs, |x, y| {
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
let mut self_ = if shape.slice() == self.dim.slice() {
self.into_owned().into_dimensionality::<<D as BroadcastShape<E>>::BroadcastOutput>().unwrap()
} else {
self.broadcast(shape).unwrap().to_owned()
};
self_.zip_mut_with(rhs, |x, y| {
*x = x.clone() $operator y.clone();
});
self
self_
}
}

Expand All @@ -108,7 +140,8 @@ where
/// between references `self` and `rhs`,
/// and return the result as a new `Array`.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
Expand All @@ -117,13 +150,21 @@ where
B: Clone,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
type Output = Array<A, D>;
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Array<A, D> {
// FIXME: Can we co-broadcast arrays here? And how?
self.to_owned().$mth(rhs)
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
let mut self_ = if shape.slice() == self.dim.slice() {
self.to_owned().into_dimensionality::<<D as BroadcastShape<E>>::BroadcastOutput>().unwrap()
} else {
self.broadcast(shape).unwrap().to_owned()
};
self_.zip_mut_with(rhs, |x, y| {
*x = x.clone() $operator y.clone();
});
self_
}
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ use std::marker::PhantomData;
use alloc::sync::Arc;

pub use crate::dimension::dim::*;
pub use crate::dimension::BroadcastShape;
pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis};

pub use crate::dimension::IxDynImpl;
Expand Down
5 changes: 3 additions & 2 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::ops::{Add, Div, Mul};

use crate::imp_prelude::*;
use crate::itertools::enumerate;
use crate::numeric_util;
use crate::{numeric_util, BroadcastShape};

/// # Numerical Methods for Arrays
impl<A, S, D> ArrayBase<S, D>
Expand Down Expand Up @@ -283,10 +283,11 @@ where
/// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
/// );
/// ```
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, <D::Smaller as BroadcastShape<Ix0>>::BroadcastOutput>>
where
A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
D: RemoveAxis,
D::Smaller: BroadcastShape<Ix0>,
{
let axis_length = self.len_of(axis);
if axis_length == 0 {
Expand Down
Loading

0 comments on commit 979d6df

Please sign in to comment.