Skip to content

Commit

Permalink
feat: generic circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
brech1 committed May 5, 2024
1 parent 73441ff commit 98cd0b9
Show file tree
Hide file tree
Showing 5 changed files with 369 additions and 0 deletions.
2 changes: 2 additions & 0 deletions crates/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"mpz-common",
"mpz-fields",
"mpz-circuits",
"mpz-circuits-generic",
"mpz-circuits-macros",
"mpz-cointoss",
"mpz-cointoss-core",
Expand All @@ -29,6 +30,7 @@ mpz-core = { path = "mpz-core" }
mpz-common = { path = "mpz-common" }
mpz-fields = { path = "mpz-fields" }
mpz-circuits = { path = "mpz-circuits" }
mpz-circuits-generic = { path = "mpz-circuits-generic" }
mpz-circuits-macros = { path = "mpz-circuits-macros" }
mpz-cointoss = { path = "mpz-cointoss" }
mpz-cointoss-core = { path = "mpz-cointoss-core" }
Expand Down
13 changes: 13 additions & 0 deletions crates/mpz-circuits-generic/cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "mpz-circuit-generic"
version = "0.1.0"
edition = "2021"

[lib]
name = "mpz_circuit_generic"

[lints]
workspace = true

[dependencies]
thiserror = "1.0.59"
219 changes: 219 additions & 0 deletions crates/mpz-circuits-generic/src/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
//! Binary Module
//!
//! Test module to display an example of a binary circuit representation.
use crate::circuit::{CircuitError, Evaluate, RepresentedValue};

#[derive(Debug, Copy, Clone, PartialEq)]
/// Binary gate value.
pub enum BinaryValue {
/// Binary zero.
Zero,
/// Binary one.
One,
}

// Each gate can be performing the same operation multiple times,
// One for each bit of the represented value.
pub type BinaryGateValue = Vec<BinaryValue>;

/// Binary gates.
pub enum BinaryOperation {
/// AND Operation.
AND,
/// NOT Operation.
NOT,
/// XOR Operation.
XOR,
}

impl BinaryOperation {
pub fn input_count(&self) -> usize {
match self {
Self::AND | Self::XOR => 2,
Self::NOT => 1,
}
}

pub fn evaluate(&self, inputs: &[&BinaryGateValue]) -> Result<BinaryGateValue, CircuitError> {
match self {
Self::AND => Ok(inputs[0]
.iter()
.zip(inputs[1])
.map(|(&a, &b)| {
if a == BinaryValue::One && b == BinaryValue::One {
BinaryValue::One
} else {
BinaryValue::Zero
}
})
.collect()),
Self::NOT => Ok(inputs[0]
.iter()
.map(|&x| {
if x == BinaryValue::Zero {
BinaryValue::One
} else {
BinaryValue::Zero
}
})
.collect()),
Self::XOR => Ok(inputs[0]
.iter()
.zip(inputs[1])
.map(|(&a, &b)| {
if a != b {
BinaryValue::One
} else {
BinaryValue::Zero
}
})
.collect()),
}
}
}

/// Binary circuit representation value.
/// Used as interface for the circuit
#[derive(Debug, PartialEq)]
pub enum BinaryCircuitReprValue {
/// Bool value,
Bool(bool),
/// u8 value.
U8(u8),
}

// Implement RepresentedValue for BinaryCircuitReprValue.
impl RepresentedValue<BinaryGateValue> for BinaryCircuitReprValue {
fn from_value(value: &BinaryGateValue) -> Result<Self, CircuitError> {
match value.len() {
1 => {
let bit = value[0];
Ok(BinaryCircuitReprValue::Bool(bit == BinaryValue::One))
}
8 => {
let byte = value.iter().fold(0, |acc, &bit| {
(acc << 1) | (if bit == BinaryValue::One { 1 } else { 0 })
});
Ok(BinaryCircuitReprValue::U8(byte as u8))
}
_ => Err(CircuitError::ConversionError),
}
}

fn to_value(&self) -> Result<BinaryGateValue, CircuitError> {
match *self {
BinaryCircuitReprValue::Bool(b) => Ok(vec![if b {
BinaryValue::One
} else {
BinaryValue::Zero
}]),
BinaryCircuitReprValue::U8(byte) => {
let bits = (0..8)
.rev()
.map(|i| {
if byte & (1 << i) != 0 {
BinaryValue::One
} else {
BinaryValue::Zero
}
})
.collect();
Ok(bits)
}
}
}
}

/// Binary gate.
pub struct BinaryGate {
/// Gate inputs. Each input is a usize that represents the index of the input gate.
inputs: Vec<usize>,
/// Gate output. A usize that represents the index of the output gate.
output: usize,
/// Gate operation.
op: BinaryOperation,
}

impl Evaluate<BinaryGateValue> for BinaryGate {
fn evaluate(&self, feeds: &mut Vec<Option<BinaryGateValue>>) -> Result<(), CircuitError> {
let input_values: Vec<_> = self
.inputs
.iter()
.map(|&idx| {
feeds
.get(idx)
.and_then(|v| v.as_ref())
.ok_or(CircuitError::MissingNodeValue(idx))
})
.collect::<Result<_, _>>()?;

if input_values.len() != self.op.input_count() {
return Err(CircuitError::InvalidGateInputCount(
self.op.input_count(),
input_values.len(),
));
}

let result = self.op.evaluate(&input_values)?;

// Resize the feeds vector if the output index is out of bounds
// This is the only reason that evaluate receives a vec instead of a slice.
if feeds.get_mut(self.output).is_none() {
feeds.resize(self.output + 1, None);
}

if let Some(output) = feeds.get_mut(self.output) {
*output = Some(result);
} else {
return Err(CircuitError::OutputIndexOutOfRange(self.output));
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::circuit::Circuit;

#[test]
fn test_circuit() {
// Initialize the circuit
let mut circuit = Circuit::<BinaryCircuitReprValue, BinaryGate, BinaryGateValue>::new();

// Add gates
let gate = BinaryGate {
inputs: vec![0, 1],
output: 2,
op: BinaryOperation::AND,
};
circuit.add_gate(gate);

// Prepare inputs
let input_a: u8 = 0b10101010;
let input_b: u8 = 0b00001111;

let repr_input_a = BinaryCircuitReprValue::U8(input_a);
let repr_input_b = BinaryCircuitReprValue::U8(input_b);

// Add inputs to the circuit
circuit.add_input(repr_input_a);
circuit.add_input(repr_input_b);

// Define output index
circuit.add_output(2);

// Expected output
let expected_output: u8 = 0b00001010;
let repr_expected_output = BinaryCircuitReprValue::U8(expected_output);

// Run the circuit and verify the outputs
let output_values = circuit.run().unwrap();

// Check if the number of outputs and their values are as expected
assert_eq!(output_values.len(), 1);
assert_eq!(output_values[0], repr_expected_output);
}
}
133 changes: 133 additions & 0 deletions crates/mpz-circuits-generic/src/circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use std::marker::PhantomData;
use thiserror::Error;

#[derive(Default)]
/// Represents the circuit interface, generic over represented values.
pub struct CircuitInterface<T> {
/// Circuit inputs.
inputs: Vec<T>,
/// Circuit outputs indices.
outputs: Vec<usize>,
}

impl<T> CircuitInterface<T> {
/// Creates a new circuit interface.
pub fn new() -> Self {
CircuitInterface {
inputs: Vec::new(),
outputs: Vec::new(),
}
}
}

/// Generic circuit implementation.
///
/// T: Circuit interface type. This is the type the input and output values use.
/// U: Generic gate type. Must implement the Evaluate<V> trait.
/// V: Gates value type. This is the type the gates perform operations on.
pub struct Circuit<T, U, V>
where
T: RepresentedValue<V>,
U: Evaluate<V>,
{
interface: CircuitInterface<T>,
gates: Vec<U>,
_phantom: PhantomData<V>,
}

impl<T, U, V> Circuit<T, U, V>
where
T: RepresentedValue<V>,
U: Evaluate<V>,
{
/// Creates a new circuit.
pub fn new() -> Self {
Circuit {
interface: CircuitInterface::<T>::new(),
gates: Vec::new(),
_phantom: PhantomData,
}
}

/// Adds a gate to the circuit.
pub fn add_gate(&mut self, gate: U) {
self.gates.push(gate);
}

/// Adds an input to the circuit.
pub fn add_input(&mut self, input: T) {
self.interface.inputs.push(input);
}

/// Adds an output to the circuit.
pub fn add_output(&mut self, output: usize) {
self.interface.outputs.push(output);
}

/// Runs the circuit and returns the output values.
pub fn run(&mut self) -> Result<Vec<T>, CircuitError> {
// Initialize node values.
let mut nodes: Vec<Option<V>> = Vec::new();
for input in &self.interface.inputs {
nodes.push(Some(input.to_value()?));
}

// Evaluate gates
for gate in &self.gates {
gate.evaluate(&mut nodes)?;
}

// Collect and convert the outputs
let mut outputs: Vec<T> = Vec::new();
for &output_index in &self.interface.outputs {
if let Some(node_value) = nodes.get(output_index).and_then(|v| v.as_ref()) {
match T::from_value(node_value) {
Ok(repr) => outputs.push(repr),
Err(e) => return Err(e),
}
} else {
return Err(CircuitError::MissingNodeValue(output_index));
}
}

Ok(outputs)
}
}

/// A trait that circuit gates should implement to perform an evaluation.
pub trait Evaluate<T> {
/// Performs an evaluation. Receives a mutable slice of optional values that represent the circuit nodes.
fn evaluate(&self, nodes: &mut Vec<Option<T>>) -> Result<(), CircuitError>;
}

/// Represented value trait.
///
/// This trait has to be implemented on the interface value type to allow its conversion to the gate value type.
pub trait RepresentedValue<T> {
/// Converts a gate value back to the represented interface value type.
fn from_value(value: &T) -> Result<Self, CircuitError>
where
Self: Sized;

/// Converts the interface value to the gate value.
fn to_value(&self) -> Result<T, CircuitError>;
}

/// Circuit errors.
#[derive(Debug, Error)]
pub enum CircuitError {
#[error("Invalid number of circuit inputs: expected {0}, got {1}")]
InvalidInputCount(usize, usize),
#[error("Invalid number of gate inputs: expected {0}, got {1}")]
InvalidGateInputCount(usize, usize),
#[error("Failed to convert external representation to internal gate value")]
ConversionError,
#[error("Output index out of range: {0}")]
OutputIndexOutOfRange(usize),
#[error("Missing node value at index {0}")]
MissingNodeValue(usize),
#[error("Gate evaluation failed: {0}")]
GateEvaluationError(String),
#[error("Generic circuit error: {0}")]
GenericCircuitError(String),
}
2 changes: 2 additions & 0 deletions crates/mpz-circuits-generic/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod binary;
pub mod circuit;

0 comments on commit 98cd0b9

Please sign in to comment.