Skip to content

Commit

Permalink
Merge pull request #3 from LukeMathWalker/refactoring-smarter-debugging
Browse files Browse the repository at this point in the history
A different approach
  • Loading branch information
Andrew authored Apr 23, 2019
2 parents b535f28 + 83c9f08 commit 72e05d7
Showing 1 changed file with 96 additions and 118 deletions.
214 changes: 96 additions & 118 deletions src/arrayformat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,141 +8,116 @@
use std::fmt;
use super::{
ArrayBase,
Axis,
Data,
Dimension,
NdProducer,
Ix
};
use crate::dimension::IntoDimension;
use crate::aliases::Ix1;

const PRINT_ELEMENTS_LIMIT: Ix = 3;

fn get_overflow_axes(shape: &[Ix], limit: usize) -> Vec<usize> {
shape.iter()
.enumerate()
.rev()
.filter(|(_, axis_size)| **axis_size > 2 * limit)
.map(|(axis, _)| axis)
.collect()
fn format_1d_array<A, S, F>(
view: &ArrayBase<S, Ix1>,
f: &mut fmt::Formatter,
mut format: F,
limit: Ix) -> fmt::Result
where
F: FnMut(&A, &mut fmt::Formatter) -> fmt::Result,
S: Data<Elem=A>,
{
let to_be_printed = to_be_printed(view.len(), limit);

let n_to_be_printed = to_be_printed.len();

write!(f, "[")?;
for (j, index) in to_be_printed.into_iter().enumerate() {
match index {
PrintableCell::ElementIndex(i) => {
format(&view[i], f)?;
if j != n_to_be_printed - 1 {
write!(f, ", ")?;
}
},
PrintableCell::Ellipses => write!(f, "..., ")?,
}
}
write!(f, "]")?;
Ok(())
}

fn get_highest_axis_to_skip(overflow_axes: &Vec<usize>,
shape: &[Ix],
index: &[Ix],
limit: &usize) -> Option<usize> {
overflow_axes.iter()
.filter(|axis| {
if **axis == shape.len() - 1 {
return false
};
let sa_idx_max = shape.iter().skip(**axis).next().unwrap();
let sa_idx_val = index.iter().skip(**axis).next().unwrap();
sa_idx_val >= limit && sa_idx_val < &(sa_idx_max - limit)
})
.min()
.map(|v| *v)
enum PrintableCell {
ElementIndex(usize),
Ellipses,
}

fn get_highest_changed_axis(index: &[Ix], prev_index: &[Ix]) -> Option<usize> {
index.iter()
.take(index.len() - 1)
.zip(prev_index.iter())
.enumerate()
.filter(|(_, (a, b))| a != b)
.map(|(i, _)| i)
.next()
// Returns what indexes should be printed for a certain axis.
// If the axis is longer than 2 * limit, a `Ellipses` is inserted
// where indexes are being omitted.
fn to_be_printed(length: usize, limit: usize) -> Vec<PrintableCell> {
if length <= 2 * limit {
(0..length).map(|x| PrintableCell::ElementIndex(x)).collect()
} else {
let mut v: Vec<PrintableCell> = (0..limit).map(|x| PrintableCell::ElementIndex(x)).collect();
v.push(PrintableCell::Ellipses);
v.extend((length-limit..length).map(|x| PrintableCell::ElementIndex(x)));
v
}
}

fn format_array<A, S, D, F>(view: &ArrayBase<S, D>,
f: &mut fmt::Formatter,
mut format: F,
limit: Ix) -> fmt::Result
where F: FnMut(&A, &mut fmt::Formatter) -> fmt::Result,
D: Dimension,
S: Data<Elem=A>,
fn format_array<A, S, D, F>(
view: &ArrayBase<S, D>,
f: &mut fmt::Formatter,
mut format: F,
limit: Ix) -> fmt::Result
where
F: FnMut(&A, &mut fmt::Formatter) -> fmt::Result + Clone,
D: Dimension,
S: Data<Elem=A>,
{
if view.shape().is_empty() {
// Handle 0-dimensional array case first
return format(view.iter().next().unwrap(), f)
// If any of the axes has 0 length, we return the same empty array representation
// e.g. [[]] for 2-d arrays
if view.shape().iter().any(|&x| x == 0) {
write!(f, "{}{}", "[".repeat(view.ndim()), "]".repeat(view.ndim()))?;
return Ok(())
}

let overflow_axes: Vec<Ix> = get_overflow_axes(view.shape(), limit);

let ndim = view.ndim();
let nth_idx_max = view.shape()[ndim-1];

// None will be an empty iter.
let mut last_index = match view.dim().into_dimension().first_index() {
None => view.dim().into_dimension().clone(),
Some(ix) => ix,
};
write!(f, "{}", "[".repeat(ndim))?;
// Shows if ellipses for horizontal split were printed.
let mut printed_ellipses_h = vec![false; ndim];
// Shows if the row was printed for the first time after horizontal split.
let mut no_rows_after_skip_yet = false;

// Simply use the indexed iterator, and take the index wraparounds
// as cues for when to add []'s and how many to add.
for (index, elt) in view.indexed_iter() {
let index = index.into_dimension();

let skip_row_for_axis = get_highest_axis_to_skip(
&overflow_axes,
view.shape(),
index.slice(),
&limit
);
if skip_row_for_axis.is_some() {
no_rows_after_skip_yet = true;
}

let max_changed_idx = get_highest_changed_axis(index.slice(), last_index.slice());
if let Some(i) = max_changed_idx {
printed_ellipses_h.iter_mut().skip(i + 1).for_each(|e| { *e = false; });

if skip_row_for_axis.is_none() {
// New row.
// # of ['s needed
let n = ndim - i - 1;
if !no_rows_after_skip_yet {
write!(f, "{}", "]".repeat(n))?;
writeln!(f, ",")?;
match view.shape() {
// If it's 0 dimensional, we just print out the scalar
[] => format(view.iter().next().unwrap(), f)?,
// We delegate 1-dimensional arrays to a specialized function
[_] => format_1d_array(&view.view().into_dimensionality::<Ix1>().unwrap(), f, format, limit)?,
// For n-dimensional arrays, we proceed recursively
shape => {
// Cast into a dynamically dimensioned view
// This is required to be able to use `index_axis`
let view = view.view().into_dyn();
// We start by checking what indexes from the first axis should be printed
// We put a `None` in the middle if we are omitting elements
let to_be_printed = to_be_printed(shape[0], limit);

let n_to_be_printed = to_be_printed.len();

write!(f, "[")?;
for (j, index) in to_be_printed.into_iter().enumerate() {
match index {
PrintableCell::ElementIndex(i) => {
// Proceed recursively with the (n-1)-dimensional slice
format_array(
&view.index_axis(Axis(0), i), f, format.clone(), limit
)?;
// We need to add a separator after each slice,
// apart from the last one
if j != n_to_be_printed - 1 {
write!(f, ",\n ")?
}
},
PrintableCell::Ellipses => write!(f, "...,\n ")?
}
no_rows_after_skip_yet = false;
write!(f, "{}", " ".repeat(ndim - n))?;
write!(f, "{}", "[".repeat(n))?;
} else if !printed_ellipses_h[skip_row_for_axis.unwrap()] {
let ax = skip_row_for_axis.unwrap();
let n = ndim - i - 1;
write!(f, "{}", "]".repeat(n))?;
writeln!(f, ",")?;
write!(f, "{}", " ".repeat(ax + 1))?;
writeln!(f, "...,")?;
printed_ellipses_h[ax] = true;
}
last_index = index.clone();
}

if skip_row_for_axis.is_none() {
let nth_idx_op = index.slice().iter().last();
if overflow_axes.contains(&(ndim - 1)) {
let nth_idx_val = nth_idx_op.unwrap();
if nth_idx_val >= &limit && nth_idx_val < &(nth_idx_max - &limit) {
if nth_idx_val == &limit {
write!(f, ", ...")?;
}
continue;
}
}

if max_changed_idx.is_none() && !index.slice().iter().all(|x| *x == 0) {
write!(f, ", ")?;
}
format(elt, f)?;
write!(f, "]")?;
}
}
write!(f, "{}", "]".repeat(ndim))?;
Ok(())
}

Expand Down Expand Up @@ -240,15 +215,17 @@ mod formatting_with_omit {
let a: Array2<u32> = arr2(&[[], []]);
let actual_output = format!("{}", a);
let expected_output = String::from("[[]]");
assert_eq!(actual_output, expected_output);
print_output_diff(&expected_output, &actual_output);
assert_eq!(expected_output, actual_output);
}

#[test]
fn zero_length_axes() {
let a = Array3::<f32>::zeros((3, 0, 4));
let actual_output = format!("{}", a);
let expected_output = String::from("[[[]]]");
assert_eq!(actual_output, expected_output);
print_output_diff(&expected_output, &actual_output);
assert_eq!(expected_output, actual_output);
}

#[test]
Expand All @@ -257,7 +234,8 @@ mod formatting_with_omit {
let a = arr0(element);
let actual_output = format!("{}", a);
let expected_output = format!("{}", element);
assert_eq!(actual_output, expected_output);
print_output_diff(&expected_output, &actual_output);
assert_eq!(expected_output, actual_output);
}

#[test]
Expand Down

0 comments on commit 72e05d7

Please sign in to comment.