From 05611d6ae735607a7c35b88ad1044340511da798 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Mon, 6 Mar 2023 00:29:07 +0000 Subject: [PATCH] refactor: fix-verifier-sol in rust (#152) --- Cargo.lock | 1 + Cargo.toml | 1 + fix_verifier_sol.py | 198 ------------------------------ src/eth.rs | 288 ++++++++++++++++++++++++++++++++++++++++++++ src/execute.rs | 16 +-- 5 files changed, 295 insertions(+), 209 deletions(-) delete mode 100644 fix_verifier_sol.py diff --git a/Cargo.lock b/Cargo.lock index cc36ae4a5..66b0a0375 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1761,6 +1761,7 @@ dependencies = [ "mnist", "plotters", "rand", + "regex", "reqwest", "seq-macro", "serde", diff --git a/Cargo.toml b/Cargo.toml index aa468a4a8..9c6dca53c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ snark-verifier = { git = "https://github.com/privacy-scaling-explorations/snark- ethers = "1.0.2" ethers-solc = "1.0.2" tokio = { version = "1.22.0", features = ["macros"] } +regex = "1" [dev-dependencies] criterion = {version = "0.3", features = ["html_reports"]} diff --git a/fix_verifier_sol.py b/fix_verifier_sol.py deleted file mode 100644 index 784b072d9..000000000 --- a/fix_verifier_sol.py +++ /dev/null @@ -1,198 +0,0 @@ -#!/usr/bin/env python3 - -import sys -import re - -if __name__ == "__main__": - if len(sys.argv) < 2: - print("Usage: fix_verifier_sol.py ") - sys.exit(1) - - input_file = sys.argv[1] - lines = open(input_file).readlines() - - transcript_addrs = list() - modified_lines = list() - - num_pubinputs = 0 - - # convert calldataload 0x0 to 0x40 to read from pubInputs, and the rest - # from proof - calldata_pattern = r"^.*(calldataload\((0x[a-f0-9]+)\)).*$" - mstore_pattern = r"^\s*(mstore\(0x([0-9a-fA-F]+)+),.+\)" - mstore8_pattern = r"^\s*(mstore8\((\d+)+),.+\)" - mstoren_pattern = r"^\s*(mstore\((\d+)+),.+\)" - mload_pattern = r"(mload\((0x[0-9a-fA-F]+))\)" - keccak_pattern = r"(keccak256\((0x[0-9a-fA-F]+))" - modexp_pattern = r"(staticcall\(gas\(\), 0x5, (0x[0-9a-fA-F]+), 0xc0, (0x[0-9a-fA-F]+), 0x20)" - ecmul_pattern = r"(staticcall\(gas\(\), 0x7, (0x[0-9a-fA-F]+), 0x60, (0x[0-9a-fA-F]+), 0x40)" - ecadd_pattern = r"(staticcall\(gas\(\), 0x6, (0x[0-9a-fA-F]+), 0x80, (0x[0-9a-fA-F]+), 0x40)" - ecpairing_pattern = r"(staticcall\(gas\(\), 0x8, (0x[0-9a-fA-F]+), 0x180, (0x[0-9a-fA-F]+), 0x20)" - bool_pattern = r":bool" - - # Count the number of pub inputs - start = None - end = None - i = 0 - for line in lines: - if line.strip().startswith("mstore(0x20"): - start = i - - if line.strip().startswith("mstore(0x0"): - end = i - break - i += 1 - - if start is None: - num_pubinputs = 0 - else: - num_pubinputs = end - start - - max_pubinputs_addr = 0 - if num_pubinputs > 0: - max_pubinputs_addr = num_pubinputs * 32 - 32 - - for line in lines: - m = re.search(bool_pattern, line) - if m: - line = line.replace(":bool", "") - - m = re.search(calldata_pattern, line) - if m: - calldata_and_addr = m.group(1) - addr = m.group(2) - addr_as_num = int(addr, 16) - - if addr_as_num <= max_pubinputs_addr: - pub_addr = hex(addr_as_num + 32) - line = line.replace(calldata_and_addr, "mload(add(pubInputs, " + pub_addr + "))") - else: - proof_addr = hex(addr_as_num - max_pubinputs_addr) - line = line.replace(calldata_and_addr, "mload(add(proof, " + proof_addr + "))") - - m = re.search(mstore8_pattern, line) - if m: - mstore = m.group(1) - addr = m.group(2) - addr_as_num = int(addr) - transcript_addr = hex(addr_as_num) - transcript_addrs.append(addr_as_num) - line = line.replace(mstore, "mstore8(add(transcript, " + transcript_addr + ")") - - m = re.search(mstoren_pattern, line) - if m: - mstore = m.group(1) - addr = m.group(2) - addr_as_num = int(addr) - transcript_addr = hex(addr_as_num) - transcript_addrs.append(addr_as_num) - line = line.replace(mstore, "mstore(add(transcript, " + transcript_addr + ")") - - m = re.search(modexp_pattern, line) - if m: - modexp = m.group(1) - start_addr = m.group(2) - result_addr = m.group(3) - start_addr_as_num = int(start_addr, 16) - result_addr_as_num = int(result_addr, 16) - - transcript_addr = hex(start_addr_as_num) - transcript_addrs.append(addr_as_num) - result_addr = hex(result_addr_as_num) - line = line.replace(modexp, "staticcall(gas(), 0x5, add(transcript, " + transcript_addr + "), 0xc0, add(transcript, " + result_addr + "), 0x20") - - m = re.search(ecmul_pattern, line) - if m: - ecmul = m.group(1) - start_addr = m.group(2) - result_addr = m.group(3) - start_addr_as_num = int(start_addr, 16) - result_addr_as_num = int(result_addr, 16) - - transcript_addr = hex(start_addr_as_num) - result_addr = hex(result_addr_as_num) - transcript_addrs.append(start_addr_as_num) - transcript_addrs.append(result_addr_as_num) - line = line.replace(ecmul, "staticcall(gas(), 0x7, add(transcript, " + transcript_addr + "), 0x60, add(transcript, " + result_addr + "), 0x40") - - m = re.search(ecadd_pattern, line) - if m: - ecadd = m.group(1) - start_addr = m.group(2) - result_addr = m.group(3) - start_addr_as_num = int(start_addr, 16) - result_addr_as_num = int(result_addr, 16) - - transcript_addr = hex(start_addr_as_num) - result_addr = hex(result_addr_as_num) - transcript_addrs.append(start_addr_as_num) - transcript_addrs.append(result_addr_as_num) - line = line.replace(ecadd, "staticcall(gas(), 0x6, add(transcript, " + transcript_addr + "), 0x80, add(transcript, " + result_addr + "), 0x40") - - m = re.search(ecpairing_pattern, line) - if m: - ecpairing = m.group(1) - start_addr = m.group(2) - result_addr = m.group(3) - start_addr_as_num = int(start_addr, 16) - result_addr_as_num = int(result_addr, 16) - - transcript_addr = hex(start_addr_as_num) - result_addr = hex(result_addr_as_num) - transcript_addrs.append(start_addr_as_num) - transcript_addrs.append(result_addr_as_num) - line = line.replace(ecpairing, "staticcall(gas(), 0x8, add(transcript, " + transcript_addr + "), 0x180, add(transcript, " + result_addr + "), 0x20") - - m = re.search(mstore_pattern, line) - if m: - mstore = m.group(1) - addr = m.group(2) - addr_as_num = int(addr, 16) - transcript_addr = hex(addr_as_num) - transcript_addrs.append(addr_as_num) - line = line.replace(mstore, "mstore(add(transcript, " + transcript_addr + ")") - - m = re.search(keccak_pattern, line) - if m: - keccak = m.group(1) - addr = m.group(2) - addr_as_num = int(addr, 16) - transcript_addr = hex(addr_as_num) - transcript_addrs.append(addr_as_num) - line = line.replace(keccak, "keccak256(add(transcript, " + transcript_addr + ")") - - # mload can show up multiple times per line - while True: - m = re.search(mload_pattern, line) - if not m: - break - mload = m.group(1) - addr = m.group(2) - addr_as_num = int(addr, 16) - transcript_addr = hex(addr_as_num) - transcript_addrs.append(addr_as_num) - line = line.replace(mload, "mload(add(transcript, " + transcript_addr + ")") - - # print(line, end="") - modified_lines.append(line) - - # get the max transcript addr - max_transcript_addr = int(max(transcript_addrs) / 32) - print("""// SPDX-License-Identifier: MIT -pragma solidity ^0.8.17; - -contract Verifier {{ - function verify( - uint256[] memory pubInputs, - bytes memory proof - ) public view returns (bool) {{ - bool success = true; - bytes32[{}] memory transcript; - assembly {{ - """.strip().format(max_transcript_addr)) - for line in modified_lines[16:-7]: - print(line, end="") - print("""} - return success; - } -}""") diff --git a/src/eth.rs b/src/eth.rs index 43ee3519b..578ad5f27 100644 --- a/src/eth.rs +++ b/src/eth.rs @@ -28,6 +28,7 @@ use halo2curves::group::ff::PrimeField; use log::{debug, info}; use snark_verifier::loader::evm::encode_calldata; use std::error::Error; +use std::fmt::Write; use std::fs::read_to_string; use std::path::PathBuf; use std::{convert::TryFrom, sync::Arc, time::Duration}; @@ -297,3 +298,290 @@ pub async fn send_proof( Ok(()) } + +use regex::Regex; +use std::fs::File; +use std::io::{BufRead, BufReader}; + +/// Reads in raw bytes code and generates equivalent .sol file +pub fn fix_verifier_sol(input_file: PathBuf) -> Result> { + let file = File::open(input_file.clone())?; + let reader = BufReader::new(file); + + let mut transcript_addrs: Vec = Vec::new(); + let mut modified_lines: Vec = Vec::new(); + + // convert calldataload 0x0 to 0x40 to read from pubInputs, and the rest + // from proof + let calldata_pattern = Regex::new(r"^.*(calldataload\((0x[a-f0-9]+)\)).*$")?; + let mstore_pattern = Regex::new(r"^\s*(mstore\(0x([0-9a-fA-F]+)+),.+\)")?; + let mstore8_pattern = Regex::new(r"^\s*(mstore8\((\d+)+),.+\)")?; + let mstoren_pattern = Regex::new(r"^\s*(mstore\((\d+)+),.+\)")?; + let mload_pattern = Regex::new(r"(mload\((0x[0-9a-fA-F]+))\)")?; + let keccak_pattern = Regex::new(r"(keccak256\((0x[0-9a-fA-F]+))")?; + let modexp_pattern = + Regex::new(r"(staticcall\(gas\(\), 0x5, (0x[0-9a-fA-F]+), 0xc0, (0x[0-9a-fA-F]+), 0x20)")?; + let ecmul_pattern = + Regex::new(r"(staticcall\(gas\(\), 0x7, (0x[0-9a-fA-F]+), 0x60, (0x[0-9a-fA-F]+), 0x40)")?; + let ecadd_pattern = + Regex::new(r"(staticcall\(gas\(\), 0x6, (0x[0-9a-fA-F]+), 0x80, (0x[0-9a-fA-F]+), 0x40)")?; + let ecpairing_pattern = + Regex::new(r"(staticcall\(gas\(\), 0x8, (0x[0-9a-fA-F]+), 0x180, (0x[0-9a-fA-F]+), 0x20)")?; + let bool_pattern = Regex::new(r":bool")?; + + // Count the number of pub inputs + let mut start = None; + let mut end = None; + let mut i = 0; + for line in reader.lines() { + let line = line?; + if line.trim().starts_with("mstore(0x20") { + start = Some(i); + } + + if line.trim().starts_with("mstore(0x0") { + end = Some(i); + break; + } + i += 1; + } + + let num_pubinputs = if start.is_none() { + 0 + } else { + end.unwrap() - start.unwrap() + }; + + let mut max_pubinputs_addr = 0; + if num_pubinputs > 0 { + max_pubinputs_addr = num_pubinputs * 32 - 32; + } + + let file = File::open(input_file)?; + let reader = BufReader::new(file); + + for line in reader.lines() { + let mut line = line?; + let m = bool_pattern.captures(&line); + if m.is_some() { + line = line.replace(":bool", ""); + } + + let m = calldata_pattern.captures(&line); + if m.is_some() { + let calldata_and_addr = m.as_ref().unwrap().get(1).unwrap().as_str(); + let addr = m.unwrap().get(2).unwrap().as_str(); + let addr_as_num = u32::from_str_radix(addr.strip_prefix("0x").unwrap(), 16)?; + + if addr_as_num <= max_pubinputs_addr { + let pub_addr = format!("{:#x}", addr_as_num + 32); + line = line.replace( + calldata_and_addr, + &format!("mload(add(pubInputs, {}))", pub_addr), + ); + } else { + let proof_addr = format!("{:#x}", addr_as_num - max_pubinputs_addr); + line = line.replace( + calldata_and_addr, + &format!("mload(add(proof, {}))", proof_addr), + ); + } + } + + let m = mstore8_pattern.captures(&line); + if m.is_some() { + let mstore = m.as_ref().unwrap().get(1).unwrap().as_str(); + let addr = m.unwrap().get(2).unwrap().as_str(); + let addr_as_num = u32::from_str_radix(addr, 10)?; + let transcript_addr = format!("{:#x}", addr_as_num); + transcript_addrs.push(addr_as_num); + line = line.replace( + mstore, + &format!("mstore8(add(transcript, {})", transcript_addr), + ); + } + + let m = mstoren_pattern.captures(&line); + if m.is_some() { + let mstore = m.as_ref().unwrap().get(1).unwrap().as_str(); + let addr = m.unwrap().get(2).unwrap().as_str(); + let addr_as_num = u32::from_str_radix(addr, 10)?; + let transcript_addr = format!("{:#x}", addr_as_num); + transcript_addrs.push(addr_as_num); + line = line.replace( + mstore, + &format!("mstore(add(transcript, {})", transcript_addr), + ); + } + + let m = modexp_pattern.captures(&line); + if m.is_some() { + let modexp = m.as_ref().unwrap().get(1).unwrap().as_str(); + let start_addr = m.as_ref().unwrap().get(2).unwrap().as_str(); + let result_addr = m.unwrap().get(3).unwrap().as_str(); + let start_addr_as_num = + u32::from_str_radix(start_addr.strip_prefix("0x").unwrap(), 16)?; + let result_addr_as_num = + u32::from_str_radix(result_addr.strip_prefix("0x").unwrap(), 16)?; + + let transcript_addr = format!("{:#x}", start_addr_as_num); + transcript_addrs.push(start_addr_as_num); + let result_addr = format!("{:#x}", result_addr_as_num); + line = line.replace( + modexp, + &format!( + "staticcall(gas(), 0x5, add(transcript, {}), 0xc0, add(transcript, {}), 0x20", + transcript_addr, result_addr + ), + ); + } + + let m = ecmul_pattern.captures(&line); + if m.is_some() { + let ecmul = m.as_ref().unwrap().get(1).unwrap().as_str(); + let start_addr = m.as_ref().as_ref().unwrap().get(2).unwrap().as_str(); + let result_addr = m.unwrap().get(3).unwrap().as_str(); + let start_addr_as_num = + u32::from_str_radix(start_addr.strip_prefix("0x").unwrap(), 16)?; + let result_addr_as_num = + u32::from_str_radix(result_addr.strip_prefix("0x").unwrap(), 16)?; + + let transcript_addr = format!("{:#x}", start_addr_as_num); + let result_addr = format!("{:#x}", result_addr_as_num); + transcript_addrs.push(start_addr_as_num); + transcript_addrs.push(result_addr_as_num); + line = line.replace( + ecmul, + &format!( + "staticcall(gas(), 0x7, add(transcript, {}), 0x60, add(transcript, {}), 0x40", + transcript_addr, result_addr + ), + ); + } + + let m = ecadd_pattern.captures(&line); + if m.is_some() { + let ecadd = m.as_ref().unwrap().get(1).unwrap().as_str(); + let start_addr = m.as_ref().unwrap().get(2).unwrap().as_str(); + let result_addr = m.unwrap().get(3).unwrap().as_str(); + let start_addr_as_num = + u32::from_str_radix(start_addr.strip_prefix("0x").unwrap(), 16)?; + let result_addr_as_num = + u32::from_str_radix(result_addr.strip_prefix("0x").unwrap(), 16)?; + + let transcript_addr = format!("{:#x}", start_addr_as_num); + let result_addr = format!("{:#x}", result_addr_as_num); + transcript_addrs.push(start_addr_as_num); + transcript_addrs.push(result_addr_as_num); + line = line.replace( + ecadd, + &format!( + "staticcall(gas(), 0x6, add(transcript, {}), 0x80, add(transcript, {}), 0x40", + transcript_addr, result_addr + ), + ); + } + + let m = ecpairing_pattern.captures(&line); + if m.is_some() { + let ecpairing = m.as_ref().unwrap().get(1).unwrap().as_str(); + let start_addr = m.as_ref().unwrap().get(2).unwrap().as_str(); + let result_addr = m.unwrap().get(3).unwrap().as_str(); + let start_addr_as_num = + u32::from_str_radix(start_addr.strip_prefix("0x").unwrap(), 16)?; + let result_addr_as_num = + u32::from_str_radix(result_addr.strip_prefix("0x").unwrap(), 16)?; + + let transcript_addr = format!("{:#x}", start_addr_as_num); + let result_addr = format!("{:#x}", result_addr_as_num); + transcript_addrs.push(start_addr_as_num); + transcript_addrs.push(result_addr_as_num); + line = line.replace( + ecpairing, + &format!( + "staticcall(gas(), 0x8, add(transcript, {}), 0x180, add(transcript, {}), 0x20", + transcript_addr, result_addr + ), + ); + } + + let m = mstore_pattern.captures(&line); + if m.is_some() { + let mstore = m.as_ref().unwrap().get(1).unwrap().as_str(); + println!("{:?}", m); + let addr = m.as_ref().unwrap().get(2).unwrap().as_str(); + println!("{}", addr); + let addr_as_num = u32::from_str_radix(addr, 16)?; + let transcript_addr = format!("{:#x}", addr_as_num); + transcript_addrs.push(addr_as_num); + line = line.replace( + mstore, + &format!("mstore(add(transcript, {})", transcript_addr), + ); + } + + let m = keccak_pattern.captures(&line); + if m.is_some() { + let keccak = m.as_ref().unwrap().get(1).unwrap().as_str(); + let addr = m.as_ref().unwrap().get(2).unwrap().as_str(); + let addr_as_num = u32::from_str_radix(addr.strip_prefix("0x").unwrap(), 16)?; + let transcript_addr = format!("{:#x}", addr_as_num); + transcript_addrs.push(addr_as_num); + line = line.replace( + keccak, + &format!("keccak256(add(transcript, {})", transcript_addr), + ); + } + + // mload can show up multiple times per line + loop { + let m = mload_pattern.captures(&line); + if m.is_none() { + break; + } + let mload = m.as_ref().unwrap().get(1).unwrap().as_str(); + let addr = m.as_ref().unwrap().get(2).unwrap().as_str(); + + println!("{}", addr); + let addr_as_num = u32::from_str_radix(addr.strip_prefix("0x").unwrap(), 16)?; + let transcript_addr = format!("{:#x}", addr_as_num); + transcript_addrs.push(addr_as_num); + line = line.replace( + mload, + &format!("mload(add(transcript, {})", transcript_addr), + ); + } + + modified_lines.push(line); + } + + // get the max transcript addr + let max_transcript_addr = transcript_addrs.iter().max().unwrap() / 32; + let mut contract = format!( + "// SPDX-License-Identifier: MIT + pragma solidity ^0.8.17; + + contract Verifier {{ + function verify( + uint256[] memory pubInputs, + bytes memory proof + ) public view returns (bool) {{ + bool success = true; + bytes32[{}] memory transcript; + assembly {{ + ", + max_transcript_addr + ) + .trim() + .to_string(); + + // using a boxed Write trait object here to show it works for any Struct impl'ing Write + // you may also use a std::fs::File here + let write: Box<&mut dyn Write> = Box::new(&mut contract); + + for line in modified_lines[16..modified_lines.len() - 7].iter() { + write!(write, "{}", line).unwrap(); + } + writeln!(write, "}} return success; }} }}")?; + return Ok(contract); +} diff --git a/src/execute.rs b/src/execute.rs index c250a46ec..33433c6c9 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -1,6 +1,6 @@ use super::eth::verify_proof_via_solidity; use crate::commands::{Cli, Commands, StrategyType, TranscriptType}; -use crate::eth::{deploy_verifier, send_proof}; +use crate::eth::{deploy_verifier, fix_verifier_sol, send_proof}; use crate::graph::{vector_to_quantized, Model, ModelCircuit}; use crate::pfsys::evm::aggregation::{ gen_aggregation_evm_verifier, AggregationCircuit, PoseidonTranscript, @@ -35,7 +35,6 @@ use snark_verifier::system::halo2::transcript::evm::EvmTranscript; use std::error::Error; use std::fs::File; use std::io::Write; -use std::process::Command; use std::time::Instant; use tabled::Table; use thiserror::Error; @@ -250,18 +249,13 @@ pub async fn run(cli: Cli) -> Result<(), Box> { deployment_code.save(deployment_code_path.as_ref().unwrap())?; if sol_code_path.is_some() { - let mut f = File::create(sol_code_path.as_ref().unwrap()).unwrap(); + let mut f = File::create(sol_code_path.as_ref().unwrap())?; let _ = f.write(yul_code.as_bytes()); - let cmd = Command::new("python3") - .arg("fix_verifier_sol.py") - .arg(sol_code_path.as_ref().unwrap()) - .output() - .unwrap(); - let output = cmd.stdout; + let output = fix_verifier_sol(sol_code_path.as_ref().unwrap().clone())?; - let mut f = File::create(sol_code_path.as_ref().unwrap()).unwrap(); - let _ = f.write(output.as_slice()); + let mut f = File::create(sol_code_path.as_ref().unwrap())?; + let _ = f.write(output.as_bytes()); } } Commands::CreateEVMVerifierAggr {