Skip to content

Commit

Permalink
Rename Subgrid::optimize_static_nodes to optimize_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
cschwan committed Oct 30, 2024
1 parent 23628b3 commit 54a59f3
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 35 deletions.
4 changes: 2 additions & 2 deletions pineappl/src/empty_subgrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl Subgrid for EmptySubgridV1 {
}
}

fn optimize_static_nodes(&mut self) {}
fn optimize_nodes(&mut self) {}
}

#[cfg(test)]
Expand All @@ -64,7 +64,7 @@ mod tests {
subgrid.merge(&EmptySubgridV1.into(), None);
subgrid.scale(2.0);
subgrid.symmetrize(1, 2);
subgrid.optimize_static_nodes();
subgrid.optimize_nodes();
assert_eq!(
subgrid.stats(),
Stats {
Expand Down
26 changes: 15 additions & 11 deletions pineappl/src/grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,11 @@ bitflags! {
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct GridOptFlags: u32 {
/// Change the [`Subgrid`] type to optimize storage efficiency.
const OPTIMIZE_SUBGRID_TYPE = 0b1;
/// Recognize whether a subgrid was filled with events with a static scale and if this is
/// the case, optimize it by undoing the interpolation in the scale. This flag requires
/// [`Self::OPTIMIZE_SUBGRID_TYPE`] to be active.
const STATIC_SCALE_DETECTION = 0b10;
/// the case, optimize it by undoing the interpolation in the scale.
const OPTIMIZE_NODES = 0b1;
/// Change the [`Subgrid`] type to optimize storage efficiency.
const OPTIMIZE_SUBGRID_TYPE = 0b10;
/// If two channels differ by transposition of the two initial states and the functions
/// this grid is convolved with are the same for both initial states, this will merge one
/// channel into the other, with the correct transpositions.
Expand Down Expand Up @@ -832,9 +831,11 @@ impl Grid {
/// Optimizes the internal datastructures for space efficiency. The parameter `flags`
/// determines which optimizations are applied, see [`GridOptFlags`].
pub fn optimize_using(&mut self, flags: GridOptFlags) {
if flags.contains(GridOptFlags::OPTIMIZE_NODES) {
self.optimize_nodes();
}
if flags.contains(GridOptFlags::OPTIMIZE_SUBGRID_TYPE) {
let ssd = flags.contains(GridOptFlags::STATIC_SCALE_DETECTION);
self.optimize_subgrid_type(ssd);
self.optimize_subgrid_type();
}
if flags.contains(GridOptFlags::SYMMETRIZE_CHANNELS) {
self.symmetrize_channels();
Expand All @@ -850,17 +851,20 @@ impl Grid {
}
}

fn optimize_subgrid_type(&mut self, optimize_static_nodes: bool) {
fn optimize_nodes(&mut self) {
for subgrid in &mut self.subgrids {
subgrid.optimize_nodes();
}
}

fn optimize_subgrid_type(&mut self) {
for subgrid in &mut self.subgrids {
match subgrid {
// replace empty subgrids of any type with `EmptySubgridV1`
_ if subgrid.is_empty() => {
*subgrid = EmptySubgridV1.into();
}
_ => {
if optimize_static_nodes {
subgrid.optimize_static_nodes();
}
// TODO: check if we should remove this
*subgrid = ImportSubgridV1::from(&*subgrid).into();
}
Expand Down
2 changes: 1 addition & 1 deletion pineappl/src/import_subgrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Subgrid for ImportSubgridV1 {
}
}

fn optimize_static_nodes(&mut self) {}
fn optimize_nodes(&mut self) {}
}

impl From<&SubgridEnum> for ImportSubgridV1 {
Expand Down
54 changes: 44 additions & 10 deletions pineappl/src/interp_subgrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::interpolation::{self, Interp};
use super::packed_array::PackedArray;
use super::subgrid::{Stats, Subgrid, SubgridEnum, SubgridIndexedIter};
use float_cmp::approx_eq;
use itertools::izip;
use serde::{Deserialize, Serialize};
use std::mem;

Expand Down Expand Up @@ -110,39 +111,64 @@ impl Subgrid for InterpSubgridV1 {
}
}

fn optimize_static_nodes(&mut self) {
fn optimize_nodes(&mut self) {
// find the optimal ranges in which the nodes are used
let ranges: Vec<_> = self.array.indexed_iter().fold(
self.node_values()
.iter()
.map(|values| values.len()..0)
.collect(),
|mut prev, (indices, _)| {
for (i, index) in indices.iter().enumerate() {
prev[i].start = prev[i].start.min(*index);
prev[i].end = prev[i].end.max(*index + 1);
}
prev
},
);

let mut new_array = PackedArray::new(
self.array
.shape()
ranges
.iter()
.zip(&self.static_nodes)
.map(|(&dim, static_node)| if static_node.is_some() { 1 } else { dim })
.map(|(range, static_node)| {
if static_node.is_some() {
1
} else {
range.clone().count()
}
})
.collect(),
);

for (mut index, value) in self.array.indexed_iter() {
for (idx, static_node) in index.iter_mut().zip(&self.static_nodes) {
for (idx, range, static_node) in izip!(&mut index, &ranges, &self.static_nodes) {
if static_node.is_some() {
*idx = 0;
} else {
*idx -= range.start;
}
}
new_array[index.as_slice()] += value;
}

self.array = new_array;

for (static_node, interp) in self.static_nodes.iter_mut().zip(&mut self.interps) {
if let &mut Some(value) = static_node {
*interp = Interp::new(
for (interp, static_node, range) in izip!(&mut self.interps, &mut self.static_nodes, ranges)
{
*interp = if let &mut Some(value) = static_node {
Interp::new(
value,
value,
1,
0,
interp.reweight_meth(),
interp.map(),
interp.interp_meth(),
);
}
)
} else {
interp.sub_interp(range)
};
}
}
}
Expand Down Expand Up @@ -228,5 +254,13 @@ mod tests {
bytes_per_value: mem::size_of::<f64>()
}
);

subgrid.optimize_nodes();

let node_values = subgrid.node_values();

assert_eq!(node_values[0].len(), 23);
assert_eq!(node_values[1].len(), 1);
assert_eq!(node_values[2].len(), 1);
}
}
17 changes: 15 additions & 2 deletions pineappl/src/interpolation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::packed_array::PackedArray;
use arrayvec::ArrayVec;
use serde::{Deserialize, Serialize};
use std::mem;
use std::ops::Range;

const MAX_INTERP_ORDER_PLUS_ONE: usize = 8;
const MAX_DIMENSIONS: usize = 8;
Expand Down Expand Up @@ -265,6 +266,20 @@ impl Interp {
pub const fn reweight_meth(&self) -> ReweightMeth {
self.reweight
}

/// TODO
#[must_use]
pub fn sub_interp(&self, range: Range<usize>) -> Self {
Self {
min: self.gety(range.start),
max: self.gety(range.end - 1),
nodes: range.clone().count(),
order: self.order,
reweight: self.reweight,
map: self.map,
interp_meth: self.interp_meth,
}
}
}

/// TODO
Expand Down Expand Up @@ -296,8 +311,6 @@ pub fn interpolate(
return false;
};

// TODO: add static value detection

let weight = weight
/ interps
.iter()
Expand Down
2 changes: 1 addition & 1 deletion pineappl/src/subgrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub trait Subgrid {
fn stats(&self) -> Stats;

/// TODO
fn optimize_static_nodes(&mut self);
fn optimize_nodes(&mut self);
}

/// Type to iterate over the non-zero contents of a subgrid. The tuple contains the indices of the
Expand Down
15 changes: 7 additions & 8 deletions pineappl/tests/drell_yan_lo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,13 @@ fn perform_grid_tests(
.collect();

for (result, reference_after_ssd) in bins.iter().zip(reference_after_ssd.iter()) {
assert_approx_eq!(f64, *result, *reference_after_ssd, ulps = 24);
assert_approx_eq!(f64, *result, *reference_after_ssd, ulps = 32);
}

let bins = grid.convolve(&mut convolution_cache, &[], &[], &[], &[(1.0, 1.0, 1.0)]);

for (result, reference_after_ssd) in bins.iter().zip(reference_after_ssd.iter()) {
assert_approx_eq!(f64, *result, *reference_after_ssd, ulps = 24);
assert_approx_eq!(f64, *result, *reference_after_ssd, ulps = 32);
}

// TEST 9: `set_remapper`
Expand All @@ -451,7 +451,7 @@ fn perform_grid_tests(
grid.merge_bins(0..1)?;

for (result, reference_after_ssd) in bins.iter().zip(reference_after_ssd.iter()) {
assert_approx_eq!(f64, *result, *reference_after_ssd, ulps = 24);
assert_approx_eq!(f64, *result, *reference_after_ssd, ulps = 32);
}

// merge two bins with each other
Expand Down Expand Up @@ -501,7 +501,7 @@ fn perform_grid_tests(
.skip(2)
.take(6),
) {
assert_approx_eq!(f64, *result, reference_after_ssd, ulps = 16);
assert_approx_eq!(f64, *result, reference_after_ssd, ulps = 32);
}

Ok(())
Expand Down Expand Up @@ -709,14 +709,13 @@ fn grid_optimize() {
assert_eq!(node_values[1].len(), 6);
assert_eq!(node_values[2].len(), 6);

grid.optimize_using(GridOptFlags::OPTIMIZE_SUBGRID_TYPE | GridOptFlags::STATIC_SCALE_DETECTION);
grid.optimize_using(GridOptFlags::OPTIMIZE_NODES);

assert!(matches!(
grid.subgrids()[[0, 0, 0]],
SubgridEnum::ImportSubgridV1 { .. }
SubgridEnum::InterpSubgridV1 { .. }
));
// if `STATIC_SCALE_DETECTION` is present the scale dimension is better optimized
let node_values = grid.subgrids()[[0, 0, 0]].node_values();
let node_values = dbg!(grid.subgrids()[[0, 0, 0]].node_values());
assert_eq!(node_values[0].len(), 1);
assert_eq!(node_values[1].len(), 6);
assert_eq!(node_values[2].len(), 6);
Expand Down

0 comments on commit 54a59f3

Please sign in to comment.