diff --git a/crates/stark-backend/src/air_builders/symbolic/dag.rs b/crates/stark-backend/src/air_builders/symbolic/dag.rs new file mode 100644 index 0000000000..a069f4ca2a --- /dev/null +++ b/crates/stark-backend/src/air_builders/symbolic/dag.rs @@ -0,0 +1,290 @@ +use std::sync::Arc; + +use p3_field::Field; +use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; + +use crate::air_builders::symbolic::{ + symbolic_expression::SymbolicExpression, symbolic_variable::SymbolicVariable, +}; + +/// A node in symbolic expression DAG. +/// Basically replace `Arc`s in `SymbolicExpression` with node IDs. +/// Intended to be serializable and deserializable. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(bound = "F: Field")] +#[repr(C)] +pub enum SymbolicExpressionNode { + Variable(SymbolicVariable), + IsFirstRow, + IsLastRow, + IsTransition, + Constant(F), + Add { + left_idx: usize, + right_idx: usize, + degree_multiple: usize, + }, + Sub { + left_idx: usize, + right_idx: usize, + degree_multiple: usize, + }, + Neg { + idx: usize, + degree_multiple: usize, + }, + Mul { + left_idx: usize, + right_idx: usize, + degree_multiple: usize, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(bound = "F: Field")] +pub struct SymbolicExpressionDag { + /// Nodes in **topological** order. + pub(crate) nodes: Vec>, + /// Node indices of expressions to assert equal zero. + pub(crate) constraint_idx: Vec, +} + +pub(crate) fn build_symbolic_expr_dag( + exprs: &[SymbolicExpression], +) -> SymbolicExpressionDag { + let mut expr_to_idx = FxHashMap::default(); + let mut nodes = Vec::new(); + let constraint_idx = exprs + .iter() + .map(|expr| topological_sort_symbolic_expr(expr, &mut expr_to_idx, &mut nodes)) + .collect(); + SymbolicExpressionDag { + nodes, + constraint_idx, + } +} + +/// `expr_to_idx` is a cache so that the `Arc<_>` references within symbolic expressions get +/// mapped to the same node ID if their underlying references are the same. +fn topological_sort_symbolic_expr<'a, F: Field>( + expr: &'a SymbolicExpression, + expr_to_idx: &mut FxHashMap<&'a SymbolicExpression, usize>, + nodes: &mut Vec>, +) -> usize { + if let Some(&idx) = expr_to_idx.get(expr) { + return idx; + } + let node = match expr { + SymbolicExpression::Variable(var) => SymbolicExpressionNode::Variable(*var), + SymbolicExpression::IsFirstRow => SymbolicExpressionNode::IsFirstRow, + SymbolicExpression::IsLastRow => SymbolicExpressionNode::IsLastRow, + SymbolicExpression::IsTransition => SymbolicExpressionNode::IsTransition, + SymbolicExpression::Constant(cons) => SymbolicExpressionNode::Constant(*cons), + SymbolicExpression::Add { + x, + y, + degree_multiple, + } => { + let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes); + let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes); + SymbolicExpressionNode::Add { + left_idx, + right_idx, + degree_multiple: *degree_multiple, + } + } + SymbolicExpression::Sub { + x, + y, + degree_multiple, + } => { + let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes); + let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes); + SymbolicExpressionNode::Sub { + left_idx, + right_idx, + degree_multiple: *degree_multiple, + } + } + SymbolicExpression::Neg { x, degree_multiple } => { + let idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes); + SymbolicExpressionNode::Neg { + idx, + degree_multiple: *degree_multiple, + } + } + SymbolicExpression::Mul { + x, + y, + degree_multiple, + } => { + // An important case to remember: square will have Arc::as_ptr(&x) == Arc::as_ptr(&y) + // The `expr_to_id` will ensure only one topological sort is done to prevent exponential + // behavior. + let left_idx = topological_sort_symbolic_expr(x.as_ref(), expr_to_idx, nodes); + let right_idx = topological_sort_symbolic_expr(y.as_ref(), expr_to_idx, nodes); + SymbolicExpressionNode::Mul { + left_idx, + right_idx, + degree_multiple: *degree_multiple, + } + } + }; + + let idx = nodes.len(); + nodes.push(node); + expr_to_idx.insert(expr, idx); + idx +} + +impl SymbolicExpressionDag { + /// Returns symbolic expressions for each constraint + pub fn to_symbolic_expressions(&self) -> Vec> { + let mut exprs: Vec>> = Vec::with_capacity(self.nodes.len()); + for node in &self.nodes { + let expr = match *node { + SymbolicExpressionNode::Variable(var) => SymbolicExpression::Variable(var), + SymbolicExpressionNode::IsFirstRow => SymbolicExpression::IsFirstRow, + SymbolicExpressionNode::IsLastRow => SymbolicExpression::IsLastRow, + SymbolicExpressionNode::IsTransition => SymbolicExpression::IsTransition, + SymbolicExpressionNode::Constant(f) => SymbolicExpression::Constant(f), + SymbolicExpressionNode::Add { + left_idx, + right_idx, + degree_multiple, + } => SymbolicExpression::Add { + x: exprs[left_idx].clone(), + y: exprs[right_idx].clone(), + degree_multiple, + }, + SymbolicExpressionNode::Sub { + left_idx, + right_idx, + degree_multiple, + } => SymbolicExpression::Sub { + x: exprs[left_idx].clone(), + y: exprs[right_idx].clone(), + degree_multiple, + }, + SymbolicExpressionNode::Neg { + idx, + degree_multiple, + } => SymbolicExpression::Neg { + x: exprs[idx].clone(), + degree_multiple, + }, + SymbolicExpressionNode::Mul { + left_idx, + right_idx, + degree_multiple, + } => SymbolicExpression::Mul { + x: exprs[left_idx].clone(), + y: exprs[right_idx].clone(), + degree_multiple, + }, + }; + exprs.push(Arc::new(expr)); + } + self.constraint_idx + .iter() + .map(|&idx| exprs[idx].as_ref().clone()) + .collect() + } +} + +#[cfg(test)] +mod tests { + use p3_baby_bear::BabyBear; + use p3_field::AbstractField; + + use crate::air_builders::symbolic::{ + dag::{build_symbolic_expr_dag, SymbolicExpressionDag, SymbolicExpressionNode}, + symbolic_expression::SymbolicExpression, + symbolic_variable::{Entry, SymbolicVariable}, + SymbolicConstraints, + }; + + type F = BabyBear; + + #[test] + fn test_symbolic_expressions_dag() { + let expr = SymbolicExpression::Constant(F::ONE) + * SymbolicVariable::new( + Entry::Main { + part_index: 1, + offset: 2, + }, + 3, + ); + let exprs = vec![ + SymbolicExpression::IsFirstRow * SymbolicExpression::IsLastRow + + SymbolicExpression::Constant(F::ONE) + + SymbolicExpression::IsFirstRow * SymbolicExpression::IsLastRow + + expr.clone(), + expr.clone() * expr.clone(), + ]; + let expr_list = build_symbolic_expr_dag(&exprs); + assert_eq!( + expr_list, + SymbolicExpressionDag:: { + nodes: vec![ + SymbolicExpressionNode::IsFirstRow, + SymbolicExpressionNode::IsLastRow, + SymbolicExpressionNode::Mul { + left_idx: 0, + right_idx: 1, + degree_multiple: 2 + }, + SymbolicExpressionNode::Constant(F::ONE), + SymbolicExpressionNode::Add { + left_idx: 2, + right_idx: 3, + degree_multiple: 2 + }, + // Currently topological sort does not detect all subgraph isomorphisms. For example each IsFirstRow and IsLastRow is a new reference so ptr::hash is distinct. + SymbolicExpressionNode::Mul { + left_idx: 0, + right_idx: 1, + degree_multiple: 2 + }, + SymbolicExpressionNode::Add { + left_idx: 4, + right_idx: 5, + degree_multiple: 2 + }, + SymbolicExpressionNode::Variable(SymbolicVariable::new( + Entry::Main { + part_index: 1, + offset: 2 + }, + 3 + )), + SymbolicExpressionNode::Mul { + left_idx: 3, + right_idx: 7, + degree_multiple: 1 + }, + SymbolicExpressionNode::Add { + left_idx: 6, + right_idx: 8, + degree_multiple: 2 + }, + SymbolicExpressionNode::Mul { + left_idx: 8, + right_idx: 8, + degree_multiple: 2 + } + ], + constraint_idx: vec![9, 10], + } + ); + let sc = SymbolicConstraints { + constraints: exprs, + interactions: vec![], + }; + let ser_str = serde_json::to_string(&sc).unwrap(); + let new_sc: SymbolicConstraints<_> = serde_json::from_str(&ser_str).unwrap(); + assert_eq!(sc.constraints, new_sc.constraints); + } +} diff --git a/crates/stark-backend/src/air_builders/symbolic/mod.rs b/crates/stark-backend/src/air_builders/symbolic/mod.rs index 3dc01bf46f..7b1fd4fe3f 100644 --- a/crates/stark-backend/src/air_builders/symbolic/mod.rs +++ b/crates/stark-backend/src/air_builders/symbolic/mod.rs @@ -7,7 +7,7 @@ use p3_air::{ use p3_field::Field; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use p3_util::log2_ceil_usize; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tracing::instrument; use self::{ @@ -16,6 +16,7 @@ use self::{ }; use super::PartitionedAirBuilder; use crate::{ + air_builders::symbolic::dag::{build_symbolic_expr_dag, SymbolicExpressionDag}, interaction::{ rap::InteractionPhaseAirBuilder, Interaction, InteractionBuilder, InteractionType, RapPhaseSeqKind, SymbolicInteraction, @@ -24,6 +25,7 @@ use crate::{ rap::{BaseAirWithPublicValues, PermutationAirBuilderWithExposedValues, Rap}, }; +pub mod dag; pub mod symbolic_expression; pub mod symbolic_variable; @@ -465,3 +467,27 @@ fn gen_main_trace( .collect_vec(); RowMajorMatrix::new(mat_values, width) } + +#[allow(dead_code)] +fn serialize_symbolic_exprs( + data: &[SymbolicExpression], + serializer: S, +) -> Result +where + S: Serializer, +{ + // Convert the number to a hex string before serializing + let dag = build_symbolic_expr_dag(data); + dag.serialize(serializer) +} + +#[allow(dead_code)] +fn deserialize_symbolic_exprs<'de, F: Field, D>( + deserializer: D, +) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let dag = SymbolicExpressionDag::deserialize(deserializer)?; + Ok(dag.to_symbolic_expressions()) +} diff --git a/crates/stark-backend/src/air_builders/symbolic/symbolic_expression.rs b/crates/stark-backend/src/air_builders/symbolic/symbolic_expression.rs index c7b09259d8..95a1090b40 100644 --- a/crates/stark-backend/src/air_builders/symbolic/symbolic_expression.rs +++ b/crates/stark-backend/src/air_builders/symbolic/symbolic_expression.rs @@ -5,16 +5,20 @@ use core::{ iter::{Product, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -use std::{hash::Hash, sync::Arc}; +use std::{ + hash::{Hash, Hasher}, + ptr, + sync::Arc, +}; use p3_field::{AbstractField, Field}; -use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::symbolic_variable::SymbolicVariable; /// An expression over `SymbolicVariable`s. -#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)] +// Note: avoid deriving Hash because it will hash the entire sub-tree +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] #[serde(bound = "F: Field")] pub enum SymbolicExpression { Variable(SymbolicVariable), @@ -43,6 +47,36 @@ pub enum SymbolicExpression { }, } +impl Hash for SymbolicExpression { + fn hash(&self, state: &mut H) { + // First hash the discriminant of the enum + std::mem::discriminant(self).hash(state); + // Degree multiple is not necessary + match self { + Self::Variable(v) => v.hash(state), + Self::IsFirstRow => {} // discriminant is enough + Self::IsLastRow => {} // discriminant is enough + Self::IsTransition => {} // discriminant is enough + Self::Constant(f) => f.hash(state), + Self::Add { x, y, .. } => { + ptr::hash(&**x, state); + ptr::hash(&**y, state); + } + Self::Sub { x, y, .. } => { + ptr::hash(&**x, state); + ptr::hash(&**y, state); + } + Self::Neg { x, .. } => { + ptr::hash(&**x, state); + } + Self::Mul { x, y, .. } => { + ptr::hash(&**x, state); + ptr::hash(&**y, state); + } + } + } +} + impl SymbolicExpression { /// Returns the multiple of `n` (the trace length) in this expression's degree. pub const fn degree_multiple(&self) -> usize { @@ -307,41 +341,18 @@ impl Product for SymbolicExpression { } } -pub trait SymbolicEvaluator> -where - SymbolicVariable: Hash + PartialEq + Eq, -{ +pub trait SymbolicEvaluator> { fn eval_var(&self, symbolic_var: SymbolicVariable) -> E; - #[allow(clippy::needless_option_as_deref)] - fn eval_expr( - &self, - symbolic_expr: &SymbolicExpression, - mut cache: Option<&mut FxHashMap, E>>, - ) -> E { - if let Some(ref mut cache) = cache { - if let Some(e) = cache.get(symbolic_expr) { - return e.clone(); - } - } - let e = match symbolic_expr { + fn eval_expr(&self, symbolic_expr: &SymbolicExpression) -> E { + match symbolic_expr { SymbolicExpression::Variable(var) => self.eval_var(*var), SymbolicExpression::Constant(c) => (*c).into(), - SymbolicExpression::Add { x, y, .. } => { - self.eval_expr(x, cache.as_deref_mut()) + self.eval_expr(y, cache.as_deref_mut()) - } - SymbolicExpression::Sub { x, y, .. } => { - self.eval_expr(x, cache.as_deref_mut()) - self.eval_expr(y, cache.as_deref_mut()) - } - SymbolicExpression::Neg { x, .. } => -self.eval_expr(x, cache.as_deref_mut()), - SymbolicExpression::Mul { x, y, .. } => { - self.eval_expr(x, cache.as_deref_mut()) * self.eval_expr(y, cache.as_deref_mut()) - } + SymbolicExpression::Add { x, y, .. } => self.eval_expr(x) + self.eval_expr(y), + SymbolicExpression::Sub { x, y, .. } => self.eval_expr(x) - self.eval_expr(y), + SymbolicExpression::Neg { x, .. } => -self.eval_expr(x), + SymbolicExpression::Mul { x, y, .. } => self.eval_expr(x) * self.eval_expr(y), _ => unreachable!("Expression cannot be evaluated"), - }; - if let Some(ref mut cache) = cache { - cache.insert(symbolic_expr.clone(), e.clone()); } - e } } diff --git a/crates/stark-backend/src/air_builders/symbolic/symbolic_variable.rs b/crates/stark-backend/src/air_builders/symbolic/symbolic_variable.rs index 2463a922f4..a64c8adc4e 100644 --- a/crates/stark-backend/src/air_builders/symbolic/symbolic_variable.rs +++ b/crates/stark-backend/src/air_builders/symbolic/symbolic_variable.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; use super::symbolic_expression::SymbolicExpression; #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[repr(C)] pub enum Entry { Preprocessed { offset: usize, diff --git a/crates/stark-backend/src/air_builders/verifier.rs b/crates/stark-backend/src/air_builders/verifier.rs index e5b299d52f..469d5852eb 100644 --- a/crates/stark-backend/src/air_builders/verifier.rs +++ b/crates/stark-backend/src/air_builders/verifier.rs @@ -5,10 +5,10 @@ use std::{ use p3_field::{AbstractField, ExtensionField, Field}; use p3_matrix::Matrix; -use rustc_hash::FxHashMap; use super::{ symbolic::{ + dag::{build_symbolic_expr_dag, SymbolicExpressionNode}, symbolic_expression::{SymbolicEvaluator, SymbolicExpression}, symbolic_variable::{Entry, SymbolicVariable}, }, @@ -52,10 +52,39 @@ where PubVar: Into + Copy + Send + Sync, { pub fn eval_constraints(&mut self, constraints: &[SymbolicExpression]) { - let mut cache = FxHashMap::default(); - for constraint in constraints { - let x = self.eval_expr(constraint, Some(&mut cache)); - self.assert_zero(x); + let dag = build_symbolic_expr_dag(constraints); + // node_idx -> evaluation + // We do a simple serial evaluation in topological order. + // This can be parallelized if necessary. + let mut exprs: Vec = Vec::with_capacity(dag.nodes.len()); + for node in &dag.nodes { + let expr = match *node { + SymbolicExpressionNode::Variable(var) => self.eval_var(var), + SymbolicExpressionNode::Constant(f) => Expr::from(f), + SymbolicExpressionNode::Add { + left_idx, + right_idx, + .. + } => exprs[left_idx].clone() + exprs[right_idx].clone(), + SymbolicExpressionNode::Sub { + left_idx, + right_idx, + .. + } => exprs[left_idx].clone() - exprs[right_idx].clone(), + SymbolicExpressionNode::Neg { idx, .. } => -exprs[idx].clone(), + SymbolicExpressionNode::Mul { + left_idx, + right_idx, + .. + } => exprs[left_idx].clone() * exprs[right_idx].clone(), + SymbolicExpressionNode::IsFirstRow => self.is_first_row.into(), + SymbolicExpressionNode::IsLastRow => self.is_last_row.into(), + SymbolicExpressionNode::IsTransition => self.is_transition.into(), + }; + exprs.push(expr); + } + for idx in dag.constraint_idx { + self.assert_zero(exprs[idx].clone()); } } @@ -101,37 +130,6 @@ where .into(), } } - #[allow(clippy::needless_option_as_deref)] - fn eval_expr( - &self, - symbolic_expr: &SymbolicExpression, - mut cache: Option<&mut FxHashMap, Expr>>, - ) -> Expr { - if let Some(ref mut cache) = cache { - if let Some(e) = cache.get(symbolic_expr) { - return e.clone(); - } - } - let e = match symbolic_expr { - SymbolicExpression::Variable(var) => self.eval_var(*var), - SymbolicExpression::Constant(c) => (*c).into(), - SymbolicExpression::Add { x, y, .. } => { - self.eval_expr(x, cache.as_deref_mut()) + self.eval_expr(y, cache.as_deref_mut()) - } - SymbolicExpression::Sub { x, y, .. } => { - self.eval_expr(x, cache.as_deref_mut()) - self.eval_expr(y, cache.as_deref_mut()) - } - SymbolicExpression::Neg { x, .. } => -self.eval_expr(x, cache.as_deref_mut()), - SymbolicExpression::Mul { x, y, .. } => { - self.eval_expr(x, cache.as_deref_mut()) * self.eval_expr(y, cache.as_deref_mut()) - } - SymbolicExpression::IsFirstRow => self.is_first_row.into(), - SymbolicExpression::IsLastRow => self.is_last_row.into(), - SymbolicExpression::IsTransition => self.is_transition.into(), - }; - if let Some(ref mut cache) = cache { - cache.insert(symbolic_expr.clone(), e.clone()); - } - e - } + // NOTE: do not use the eval_expr function as it can have exponential complexity! + // Instead use the `SymbolicExpressionDag` } diff --git a/crates/stark-backend/src/interaction/debug.rs b/crates/stark-backend/src/interaction/debug.rs index b42dd75526..8bdf27fca1 100644 --- a/crates/stark-backend/src/interaction/debug.rs +++ b/crates/stark-backend/src/interaction/debug.rs @@ -42,9 +42,9 @@ pub fn generate_logical_interactions( let fields = interaction .fields .iter() - .map(|expr| evaluator.eval_expr(expr, None)) + .map(|expr| evaluator.eval_expr(expr)) .collect_vec(); - let count = evaluator.eval_expr(&interaction.count, None); + let count = evaluator.eval_expr(&interaction.count); if count.is_zero() { continue; } diff --git a/crates/stark-backend/src/interaction/stark_log_up.rs b/crates/stark-backend/src/interaction/stark_log_up.rs index d044e4e71e..e841aa16b7 100644 --- a/crates/stark-backend/src/interaction/stark_log_up.rs +++ b/crates/stark-backend/src/interaction/stark_log_up.rs @@ -342,12 +342,10 @@ where debug_assert!(interaction.fields.len() <= betas.len()); let mut fields = interaction.fields.iter(); *denom = alpha - + evaluator.eval_expr( - fields.next().expect("fields should not be empty"), - None, - ); + + evaluator + .eval_expr(fields.next().expect("fields should not be empty")); for (expr, &beta) in fields.zip(betas.iter().skip(1)) { - *denom += beta * evaluator.eval_expr(expr, None); + *denom += beta * evaluator.eval_expr(expr); } } } @@ -384,7 +382,7 @@ where izip!(reciprocal_chunk, interaction_chunk) { let mut interaction_val = - *reciprocal * evaluator.eval_expr(&interaction.count, None); + *reciprocal * evaluator.eval_expr(&interaction.count); if interaction.interaction_type == InteractionType::Receive { interaction_val = -interaction_val; }