Skip to content

Commit

Permalink
sharded funcs WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-barrett committed Nov 7, 2024
1 parent 5dca8c7 commit 3970a17
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 33 deletions.
4 changes: 2 additions & 2 deletions src/lair/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ impl<F: PrimeField32> QueryRecord<F> {
let inv_func_queries = toplevel
.func_map
.iter()
.map(|(_, func)| {
if func.invertible {
.map(|(_, funcs)| {
if funcs.full_func.invertible {
Some(FxHashMap::default())
} else {
None
Expand Down
2 changes: 1 addition & 1 deletion src/lair/func_chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl<'a, F, C1: Chipset<F>, C2: Chipset<F>> FuncChip<'a, F, C1, C2> {
toplevel
.func_map
.values()
.map(|func| FuncChip::from_func(func, toplevel))
.map(|funcs| FuncChip::from_func(&funcs.full_func, toplevel))
.collect()
}

Expand Down
105 changes: 75 additions & 30 deletions src/lair/toplevel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,32 @@
use either::Either;
use p3_field::Field;
use rustc_hash::FxHashMap;
use rustc_hash::{FxHashMap, FxHashSet};

use super::{bytecode::*, chipset::Chipset, expr::*, map::Map, FxIndexMap, List, Name};

/// This struct holds the complete Lair function and the sharded functions, containing
/// only the paths belonging to the same group
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FuncStruct<F> {
pub(crate) full_func: Func<F>,
pub(crate) sharded_funcs: FxHashMap<ReturnGroup, Func<F>>,
}

impl<F: Clone> FuncStruct<F> {
pub fn from_func(full_func: Func<F>, groups: &FxHashSet<ReturnGroup>) -> Self {
let sharded_funcs = full_func.shard(groups);
Self {
full_func,
sharded_funcs,
}
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Toplevel<F, C1: Chipset<F>, C2: Chipset<F>> {
/// Lair functions reachable by the `Call` operator
pub(crate) func_map: FxIndexMap<Name, Func<F>>,
pub(crate) func_map: FxIndexMap<Name, FuncStruct<F>>,
/// Extern chips reachable by the `ExternCall` operator. The two different
/// chipset types can be used to encode native and custom chips.
pub(crate) chip_map: FxIndexMap<Name, Either<C1, C2>>,
Expand Down Expand Up @@ -42,8 +60,9 @@ impl<F: Field + Ord, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
.enumerate()
.map(|(i, func)| {
func.check(&info_map, &chip_map);
let cfunc = func.expand().compile(i, &info_map, &chip_map);
(func.name, cfunc)
let (cfunc, groups) = func.expand().compile(i, &info_map, &chip_map);
let func_struct = FuncStruct::from_func(cfunc, &groups);
(func.name, func_struct)
})
.collect();
Toplevel { func_map, chip_map }
Expand All @@ -58,17 +77,21 @@ impl<F: Field + Ord, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
impl<F, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
#[inline]
pub fn func_by_index(&self, i: usize) -> &Func<F> {
self.func_map
&self
.func_map
.get_index(i)
.unwrap_or_else(|| panic!("Func index {i} out of bounds"))
.1
.full_func
}

#[inline]
pub fn func_by_name(&self, name: &'static str) -> &Func<F> {
self.func_map
&self
.func_map
.get(&Name(name))
.unwrap_or_else(|| panic!("Func {name} not found"))
.full_func
}

#[inline]
Expand All @@ -85,7 +108,7 @@ impl<F, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
}
}

/// A map from `Var` its block identifier. Variables in this map are always bound
/// A map from `Var` to its block identifier. Variables in this map are always bound
type BindMap = FxHashMap<Var, usize>;

/// A map that tells whether a `Var`, from a certain block, has been used or not
Expand Down Expand Up @@ -171,8 +194,7 @@ struct ExpandCtx {
/// Context struct of `compile`
struct LinkCtx<'a, C1, C2> {
var_index: usize,
return_ident: usize,
return_idents: Vec<usize>,
return_groups: FxHashSet<ReturnGroup>,
link_map: LinkMap,
info_map: &'a FxIndexMap<Name, FuncInfo>,
chip_map: &'a FxIndexMap<Name, Either<C1, C2>>,
Expand Down Expand Up @@ -258,28 +280,28 @@ impl<F: Field + Ord> FuncE<F> {
func_index: usize,
info_map: &FxIndexMap<Name, FuncInfo>,
chip_map: &FxIndexMap<Name, Either<C1, C2>>,
) -> Func<F> {
let ctx = &mut LinkCtx {
) -> (Func<F>, FxHashSet<ReturnGroup>) {
let mut ctx = LinkCtx {
var_index: 0,
return_ident: 0,
return_idents: vec![],
return_groups: FxHashSet::default(),
link_map: FxHashMap::default(),
info_map,
chip_map,
};
self.input_params.iter().for_each(|var| {
link_new(var, ctx);
link_new(var, &mut ctx);
});
let body = self.body.compile(ctx);
Func {
let body = self.body.compile(&mut ctx);
let func = Func {
name: self.name,
invertible: self.invertible,
partial: self.partial,
index: func_index,
body,
input_size: self.input_params.total_size(),
output_size: self.output_size,
}
};
(func, ctx.return_groups)
}
}

Expand All @@ -304,20 +326,11 @@ impl<F: Field + Ord> BlockE<F> {
let mut ops = vec![];
self.ops.iter().for_each(|op| op.compile(&mut ops, ctx));
let ops = ops.into();
let saved_return_idents = std::mem::take(&mut ctx.return_idents);
let ctrl = self.ctrl.compile(ctx);
let block_return_idents = std::mem::take(&mut ctx.return_idents);
// sanity check
assert!(
!block_return_idents.is_empty(),
"A block must have at least one return ident"
);
ctx.return_idents = saved_return_idents;
ctx.return_idents.extend(&block_return_idents);
Block {
ops,
ctrl,
return_idents: block_return_idents.into(),
return_idents: [].into(),
}
}
}
Expand Down Expand Up @@ -533,9 +546,8 @@ impl<F: Field + Ord> CtrlE<F> {
.iter()
.flat_map(|arg| get_var(arg, ctx).to_vec())
.collect();
let ctrl = Ctrl::Return(ctx.return_ident, return_vec, *group);
ctx.return_idents.push(ctx.return_ident);
ctx.return_ident += 1;
let ctrl = Ctrl::Return(0, return_vec, *group);
ctx.return_groups.insert(*group);
ctrl
}
CtrlE::Choose(v, cases) => {
Expand Down Expand Up @@ -879,3 +891,36 @@ impl<F: Field + Ord> OpE<F> {
}
}
}

impl<F: Clone> Func<F> {
fn shard(&self, groups: &FxHashSet<ReturnGroup>) -> FxHashMap<ReturnGroup, Func<F>> {
let mut map = FxHashMap::default();
for group in groups.iter() {
let body = self
.body
.shard(*group, &mut 0)
.expect("Group {group} does not exist");
let func = Func { body, ..*self };
map.insert(*group, func);
}
map
}
}

impl<F: Clone> Block<F> {
fn shard(&self, group: ReturnGroup, return_ident: &mut usize) -> Option<Self> {
let (ctrl, return_idents) = self.ctrl.shard(group, return_ident)?;
let block = Block {
ctrl,
ops: self.ops.clone(),
return_idents,
};
Some(block)
}
}

impl<F: Clone> Ctrl<F> {
fn shard(&self, group: ReturnGroup, return_ident: &mut usize) -> Option<(Self, List<usize>)> {
todo!()
}
}

0 comments on commit 3970a17

Please sign in to comment.