From 05611d6ae735607a7c35b88ad1044340511da798 Mon Sep 17 00:00:00 2001
From: dante <45801863+alexander-camuto@users.noreply.github.com>
Date: Mon, 6 Mar 2023 00:29:07 +0000
Subject: [PATCH] refactor: fix-verifier-sol in rust (#152)
---
Cargo.lock | 1 +
Cargo.toml | 1 +
fix_verifier_sol.py | 198 ------------------------------
src/eth.rs | 288 ++++++++++++++++++++++++++++++++++++++++++++
src/execute.rs | 16 +--
5 files changed, 295 insertions(+), 209 deletions(-)
delete mode 100644 fix_verifier_sol.py
diff --git a/Cargo.lock b/Cargo.lock
index cc36ae4a5..66b0a0375 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1761,6 +1761,7 @@ dependencies = [
"mnist",
"plotters",
"rand",
+ "regex",
"reqwest",
"seq-macro",
"serde",
diff --git a/Cargo.toml b/Cargo.toml
index aa468a4a8..9c6dca53c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -30,6 +30,7 @@ snark-verifier = { git = "https://github.com/privacy-scaling-explorations/snark-
ethers = "1.0.2"
ethers-solc = "1.0.2"
tokio = { version = "1.22.0", features = ["macros"] }
+regex = "1"
[dev-dependencies]
criterion = {version = "0.3", features = ["html_reports"]}
diff --git a/fix_verifier_sol.py b/fix_verifier_sol.py
deleted file mode 100644
index 784b072d9..000000000
--- a/fix_verifier_sol.py
+++ /dev/null
@@ -1,198 +0,0 @@
-#!/usr/bin/env python3
-
-import sys
-import re
-
-if __name__ == "__main__":
- if len(sys.argv) < 2:
- print("Usage: fix_verifier_sol.py ")
- sys.exit(1)
-
- input_file = sys.argv[1]
- lines = open(input_file).readlines()
-
- transcript_addrs = list()
- modified_lines = list()
-
- num_pubinputs = 0
-
- # convert calldataload 0x0 to 0x40 to read from pubInputs, and the rest
- # from proof
- calldata_pattern = r"^.*(calldataload\((0x[a-f0-9]+)\)).*$"
- mstore_pattern = r"^\s*(mstore\(0x([0-9a-fA-F]+)+),.+\)"
- mstore8_pattern = r"^\s*(mstore8\((\d+)+),.+\)"
- mstoren_pattern = r"^\s*(mstore\((\d+)+),.+\)"
- mload_pattern = r"(mload\((0x[0-9a-fA-F]+))\)"
- keccak_pattern = r"(keccak256\((0x[0-9a-fA-F]+))"
- modexp_pattern = r"(staticcall\(gas\(\), 0x5, (0x[0-9a-fA-F]+), 0xc0, (0x[0-9a-fA-F]+), 0x20)"
- ecmul_pattern = r"(staticcall\(gas\(\), 0x7, (0x[0-9a-fA-F]+), 0x60, (0x[0-9a-fA-F]+), 0x40)"
- ecadd_pattern = r"(staticcall\(gas\(\), 0x6, (0x[0-9a-fA-F]+), 0x80, (0x[0-9a-fA-F]+), 0x40)"
- ecpairing_pattern = r"(staticcall\(gas\(\), 0x8, (0x[0-9a-fA-F]+), 0x180, (0x[0-9a-fA-F]+), 0x20)"
- bool_pattern = r":bool"
-
- # Count the number of pub inputs
- start = None
- end = None
- i = 0
- for line in lines:
- if line.strip().startswith("mstore(0x20"):
- start = i
-
- if line.strip().startswith("mstore(0x0"):
- end = i
- break
- i += 1
-
- if start is None:
- num_pubinputs = 0
- else:
- num_pubinputs = end - start
-
- max_pubinputs_addr = 0
- if num_pubinputs > 0:
- max_pubinputs_addr = num_pubinputs * 32 - 32
-
- for line in lines:
- m = re.search(bool_pattern, line)
- if m:
- line = line.replace(":bool", "")
-
- m = re.search(calldata_pattern, line)
- if m:
- calldata_and_addr = m.group(1)
- addr = m.group(2)
- addr_as_num = int(addr, 16)
-
- if addr_as_num <= max_pubinputs_addr:
- pub_addr = hex(addr_as_num + 32)
- line = line.replace(calldata_and_addr, "mload(add(pubInputs, " + pub_addr + "))")
- else:
- proof_addr = hex(addr_as_num - max_pubinputs_addr)
- line = line.replace(calldata_and_addr, "mload(add(proof, " + proof_addr + "))")
-
- m = re.search(mstore8_pattern, line)
- if m:
- mstore = m.group(1)
- addr = m.group(2)
- addr_as_num = int(addr)
- transcript_addr = hex(addr_as_num)
- transcript_addrs.append(addr_as_num)
- line = line.replace(mstore, "mstore8(add(transcript, " + transcript_addr + ")")
-
- m = re.search(mstoren_pattern, line)
- if m:
- mstore = m.group(1)
- addr = m.group(2)
- addr_as_num = int(addr)
- transcript_addr = hex(addr_as_num)
- transcript_addrs.append(addr_as_num)
- line = line.replace(mstore, "mstore(add(transcript, " + transcript_addr + ")")
-
- m = re.search(modexp_pattern, line)
- if m:
- modexp = m.group(1)
- start_addr = m.group(2)
- result_addr = m.group(3)
- start_addr_as_num = int(start_addr, 16)
- result_addr_as_num = int(result_addr, 16)
-
- transcript_addr = hex(start_addr_as_num)
- transcript_addrs.append(addr_as_num)
- result_addr = hex(result_addr_as_num)
- line = line.replace(modexp, "staticcall(gas(), 0x5, add(transcript, " + transcript_addr + "), 0xc0, add(transcript, " + result_addr + "), 0x20")
-
- m = re.search(ecmul_pattern, line)
- if m:
- ecmul = m.group(1)
- start_addr = m.group(2)
- result_addr = m.group(3)
- start_addr_as_num = int(start_addr, 16)
- result_addr_as_num = int(result_addr, 16)
-
- transcript_addr = hex(start_addr_as_num)
- result_addr = hex(result_addr_as_num)
- transcript_addrs.append(start_addr_as_num)
- transcript_addrs.append(result_addr_as_num)
- line = line.replace(ecmul, "staticcall(gas(), 0x7, add(transcript, " + transcript_addr + "), 0x60, add(transcript, " + result_addr + "), 0x40")
-
- m = re.search(ecadd_pattern, line)
- if m:
- ecadd = m.group(1)
- start_addr = m.group(2)
- result_addr = m.group(3)
- start_addr_as_num = int(start_addr, 16)
- result_addr_as_num = int(result_addr, 16)
-
- transcript_addr = hex(start_addr_as_num)
- result_addr = hex(result_addr_as_num)
- transcript_addrs.append(start_addr_as_num)
- transcript_addrs.append(result_addr_as_num)
- line = line.replace(ecadd, "staticcall(gas(), 0x6, add(transcript, " + transcript_addr + "), 0x80, add(transcript, " + result_addr + "), 0x40")
-
- m = re.search(ecpairing_pattern, line)
- if m:
- ecpairing = m.group(1)
- start_addr = m.group(2)
- result_addr = m.group(3)
- start_addr_as_num = int(start_addr, 16)
- result_addr_as_num = int(result_addr, 16)
-
- transcript_addr = hex(start_addr_as_num)
- result_addr = hex(result_addr_as_num)
- transcript_addrs.append(start_addr_as_num)
- transcript_addrs.append(result_addr_as_num)
- line = line.replace(ecpairing, "staticcall(gas(), 0x8, add(transcript, " + transcript_addr + "), 0x180, add(transcript, " + result_addr + "), 0x20")
-
- m = re.search(mstore_pattern, line)
- if m:
- mstore = m.group(1)
- addr = m.group(2)
- addr_as_num = int(addr, 16)
- transcript_addr = hex(addr_as_num)
- transcript_addrs.append(addr_as_num)
- line = line.replace(mstore, "mstore(add(transcript, " + transcript_addr + ")")
-
- m = re.search(keccak_pattern, line)
- if m:
- keccak = m.group(1)
- addr = m.group(2)
- addr_as_num = int(addr, 16)
- transcript_addr = hex(addr_as_num)
- transcript_addrs.append(addr_as_num)
- line = line.replace(keccak, "keccak256(add(transcript, " + transcript_addr + ")")
-
- # mload can show up multiple times per line
- while True:
- m = re.search(mload_pattern, line)
- if not m:
- break
- mload = m.group(1)
- addr = m.group(2)
- addr_as_num = int(addr, 16)
- transcript_addr = hex(addr_as_num)
- transcript_addrs.append(addr_as_num)
- line = line.replace(mload, "mload(add(transcript, " + transcript_addr + ")")
-
- # print(line, end="")
- modified_lines.append(line)
-
- # get the max transcript addr
- max_transcript_addr = int(max(transcript_addrs) / 32)
- print("""// SPDX-License-Identifier: MIT
-pragma solidity ^0.8.17;
-
-contract Verifier {{
- function verify(
- uint256[] memory pubInputs,
- bytes memory proof
- ) public view returns (bool) {{
- bool success = true;
- bytes32[{}] memory transcript;
- assembly {{
- """.strip().format(max_transcript_addr))
- for line in modified_lines[16:-7]:
- print(line, end="")
- print("""}
- return success;
- }
-}""")
diff --git a/src/eth.rs b/src/eth.rs
index 43ee3519b..578ad5f27 100644
--- a/src/eth.rs
+++ b/src/eth.rs
@@ -28,6 +28,7 @@ use halo2curves::group::ff::PrimeField;
use log::{debug, info};
use snark_verifier::loader::evm::encode_calldata;
use std::error::Error;
+use std::fmt::Write;
use std::fs::read_to_string;
use std::path::PathBuf;
use std::{convert::TryFrom, sync::Arc, time::Duration};
@@ -297,3 +298,290 @@ pub async fn send_proof(
Ok(())
}
+
+use regex::Regex;
+use std::fs::File;
+use std::io::{BufRead, BufReader};
+
+/// Reads in raw bytes code and generates equivalent .sol file
+pub fn fix_verifier_sol(input_file: PathBuf) -> Result> {
+ let file = File::open(input_file.clone())?;
+ let reader = BufReader::new(file);
+
+ let mut transcript_addrs: Vec = Vec::new();
+ let mut modified_lines: Vec = Vec::new();
+
+ // convert calldataload 0x0 to 0x40 to read from pubInputs, and the rest
+ // from proof
+ let calldata_pattern = Regex::new(r"^.*(calldataload\((0x[a-f0-9]+)\)).*$")?;
+ let mstore_pattern = Regex::new(r"^\s*(mstore\(0x([0-9a-fA-F]+)+),.+\)")?;
+ let mstore8_pattern = Regex::new(r"^\s*(mstore8\((\d+)+),.+\)")?;
+ let mstoren_pattern = Regex::new(r"^\s*(mstore\((\d+)+),.+\)")?;
+ let mload_pattern = Regex::new(r"(mload\((0x[0-9a-fA-F]+))\)")?;
+ let keccak_pattern = Regex::new(r"(keccak256\((0x[0-9a-fA-F]+))")?;
+ let modexp_pattern =
+ Regex::new(r"(staticcall\(gas\(\), 0x5, (0x[0-9a-fA-F]+), 0xc0, (0x[0-9a-fA-F]+), 0x20)")?;
+ let ecmul_pattern =
+ Regex::new(r"(staticcall\(gas\(\), 0x7, (0x[0-9a-fA-F]+), 0x60, (0x[0-9a-fA-F]+), 0x40)")?;
+ let ecadd_pattern =
+ Regex::new(r"(staticcall\(gas\(\), 0x6, (0x[0-9a-fA-F]+), 0x80, (0x[0-9a-fA-F]+), 0x40)")?;
+ let ecpairing_pattern =
+ Regex::new(r"(staticcall\(gas\(\), 0x8, (0x[0-9a-fA-F]+), 0x180, (0x[0-9a-fA-F]+), 0x20)")?;
+ let bool_pattern = Regex::new(r":bool")?;
+
+ // Count the number of pub inputs
+ let mut start = None;
+ let mut end = None;
+ let mut i = 0;
+ for line in reader.lines() {
+ let line = line?;
+ if line.trim().starts_with("mstore(0x20") {
+ start = Some(i);
+ }
+
+ if line.trim().starts_with("mstore(0x0") {
+ end = Some(i);
+ break;
+ }
+ i += 1;
+ }
+
+ let num_pubinputs = if start.is_none() {
+ 0
+ } else {
+ end.unwrap() - start.unwrap()
+ };
+
+ let mut max_pubinputs_addr = 0;
+ if num_pubinputs > 0 {
+ max_pubinputs_addr = num_pubinputs * 32 - 32;
+ }
+
+ let file = File::open(input_file)?;
+ let reader = BufReader::new(file);
+
+ for line in reader.lines() {
+ let mut line = line?;
+ let m = bool_pattern.captures(&line);
+ if m.is_some() {
+ line = line.replace(":bool", "");
+ }
+
+ let m = calldata_pattern.captures(&line);
+ if m.is_some() {
+ let calldata_and_addr = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let addr = m.unwrap().get(2).unwrap().as_str();
+ let addr_as_num = u32::from_str_radix(addr.strip_prefix("0x").unwrap(), 16)?;
+
+ if addr_as_num <= max_pubinputs_addr {
+ let pub_addr = format!("{:#x}", addr_as_num + 32);
+ line = line.replace(
+ calldata_and_addr,
+ &format!("mload(add(pubInputs, {}))", pub_addr),
+ );
+ } else {
+ let proof_addr = format!("{:#x}", addr_as_num - max_pubinputs_addr);
+ line = line.replace(
+ calldata_and_addr,
+ &format!("mload(add(proof, {}))", proof_addr),
+ );
+ }
+ }
+
+ let m = mstore8_pattern.captures(&line);
+ if m.is_some() {
+ let mstore = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let addr = m.unwrap().get(2).unwrap().as_str();
+ let addr_as_num = u32::from_str_radix(addr, 10)?;
+ let transcript_addr = format!("{:#x}", addr_as_num);
+ transcript_addrs.push(addr_as_num);
+ line = line.replace(
+ mstore,
+ &format!("mstore8(add(transcript, {})", transcript_addr),
+ );
+ }
+
+ let m = mstoren_pattern.captures(&line);
+ if m.is_some() {
+ let mstore = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let addr = m.unwrap().get(2).unwrap().as_str();
+ let addr_as_num = u32::from_str_radix(addr, 10)?;
+ let transcript_addr = format!("{:#x}", addr_as_num);
+ transcript_addrs.push(addr_as_num);
+ line = line.replace(
+ mstore,
+ &format!("mstore(add(transcript, {})", transcript_addr),
+ );
+ }
+
+ let m = modexp_pattern.captures(&line);
+ if m.is_some() {
+ let modexp = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let start_addr = m.as_ref().unwrap().get(2).unwrap().as_str();
+ let result_addr = m.unwrap().get(3).unwrap().as_str();
+ let start_addr_as_num =
+ u32::from_str_radix(start_addr.strip_prefix("0x").unwrap(), 16)?;
+ let result_addr_as_num =
+ u32::from_str_radix(result_addr.strip_prefix("0x").unwrap(), 16)?;
+
+ let transcript_addr = format!("{:#x}", start_addr_as_num);
+ transcript_addrs.push(start_addr_as_num);
+ let result_addr = format!("{:#x}", result_addr_as_num);
+ line = line.replace(
+ modexp,
+ &format!(
+ "staticcall(gas(), 0x5, add(transcript, {}), 0xc0, add(transcript, {}), 0x20",
+ transcript_addr, result_addr
+ ),
+ );
+ }
+
+ let m = ecmul_pattern.captures(&line);
+ if m.is_some() {
+ let ecmul = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let start_addr = m.as_ref().as_ref().unwrap().get(2).unwrap().as_str();
+ let result_addr = m.unwrap().get(3).unwrap().as_str();
+ let start_addr_as_num =
+ u32::from_str_radix(start_addr.strip_prefix("0x").unwrap(), 16)?;
+ let result_addr_as_num =
+ u32::from_str_radix(result_addr.strip_prefix("0x").unwrap(), 16)?;
+
+ let transcript_addr = format!("{:#x}", start_addr_as_num);
+ let result_addr = format!("{:#x}", result_addr_as_num);
+ transcript_addrs.push(start_addr_as_num);
+ transcript_addrs.push(result_addr_as_num);
+ line = line.replace(
+ ecmul,
+ &format!(
+ "staticcall(gas(), 0x7, add(transcript, {}), 0x60, add(transcript, {}), 0x40",
+ transcript_addr, result_addr
+ ),
+ );
+ }
+
+ let m = ecadd_pattern.captures(&line);
+ if m.is_some() {
+ let ecadd = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let start_addr = m.as_ref().unwrap().get(2).unwrap().as_str();
+ let result_addr = m.unwrap().get(3).unwrap().as_str();
+ let start_addr_as_num =
+ u32::from_str_radix(start_addr.strip_prefix("0x").unwrap(), 16)?;
+ let result_addr_as_num =
+ u32::from_str_radix(result_addr.strip_prefix("0x").unwrap(), 16)?;
+
+ let transcript_addr = format!("{:#x}", start_addr_as_num);
+ let result_addr = format!("{:#x}", result_addr_as_num);
+ transcript_addrs.push(start_addr_as_num);
+ transcript_addrs.push(result_addr_as_num);
+ line = line.replace(
+ ecadd,
+ &format!(
+ "staticcall(gas(), 0x6, add(transcript, {}), 0x80, add(transcript, {}), 0x40",
+ transcript_addr, result_addr
+ ),
+ );
+ }
+
+ let m = ecpairing_pattern.captures(&line);
+ if m.is_some() {
+ let ecpairing = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let start_addr = m.as_ref().unwrap().get(2).unwrap().as_str();
+ let result_addr = m.unwrap().get(3).unwrap().as_str();
+ let start_addr_as_num =
+ u32::from_str_radix(start_addr.strip_prefix("0x").unwrap(), 16)?;
+ let result_addr_as_num =
+ u32::from_str_radix(result_addr.strip_prefix("0x").unwrap(), 16)?;
+
+ let transcript_addr = format!("{:#x}", start_addr_as_num);
+ let result_addr = format!("{:#x}", result_addr_as_num);
+ transcript_addrs.push(start_addr_as_num);
+ transcript_addrs.push(result_addr_as_num);
+ line = line.replace(
+ ecpairing,
+ &format!(
+ "staticcall(gas(), 0x8, add(transcript, {}), 0x180, add(transcript, {}), 0x20",
+ transcript_addr, result_addr
+ ),
+ );
+ }
+
+ let m = mstore_pattern.captures(&line);
+ if m.is_some() {
+ let mstore = m.as_ref().unwrap().get(1).unwrap().as_str();
+ println!("{:?}", m);
+ let addr = m.as_ref().unwrap().get(2).unwrap().as_str();
+ println!("{}", addr);
+ let addr_as_num = u32::from_str_radix(addr, 16)?;
+ let transcript_addr = format!("{:#x}", addr_as_num);
+ transcript_addrs.push(addr_as_num);
+ line = line.replace(
+ mstore,
+ &format!("mstore(add(transcript, {})", transcript_addr),
+ );
+ }
+
+ let m = keccak_pattern.captures(&line);
+ if m.is_some() {
+ let keccak = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let addr = m.as_ref().unwrap().get(2).unwrap().as_str();
+ let addr_as_num = u32::from_str_radix(addr.strip_prefix("0x").unwrap(), 16)?;
+ let transcript_addr = format!("{:#x}", addr_as_num);
+ transcript_addrs.push(addr_as_num);
+ line = line.replace(
+ keccak,
+ &format!("keccak256(add(transcript, {})", transcript_addr),
+ );
+ }
+
+ // mload can show up multiple times per line
+ loop {
+ let m = mload_pattern.captures(&line);
+ if m.is_none() {
+ break;
+ }
+ let mload = m.as_ref().unwrap().get(1).unwrap().as_str();
+ let addr = m.as_ref().unwrap().get(2).unwrap().as_str();
+
+ println!("{}", addr);
+ let addr_as_num = u32::from_str_radix(addr.strip_prefix("0x").unwrap(), 16)?;
+ let transcript_addr = format!("{:#x}", addr_as_num);
+ transcript_addrs.push(addr_as_num);
+ line = line.replace(
+ mload,
+ &format!("mload(add(transcript, {})", transcript_addr),
+ );
+ }
+
+ modified_lines.push(line);
+ }
+
+ // get the max transcript addr
+ let max_transcript_addr = transcript_addrs.iter().max().unwrap() / 32;
+ let mut contract = format!(
+ "// SPDX-License-Identifier: MIT
+ pragma solidity ^0.8.17;
+
+ contract Verifier {{
+ function verify(
+ uint256[] memory pubInputs,
+ bytes memory proof
+ ) public view returns (bool) {{
+ bool success = true;
+ bytes32[{}] memory transcript;
+ assembly {{
+ ",
+ max_transcript_addr
+ )
+ .trim()
+ .to_string();
+
+ // using a boxed Write trait object here to show it works for any Struct impl'ing Write
+ // you may also use a std::fs::File here
+ let write: Box<&mut dyn Write> = Box::new(&mut contract);
+
+ for line in modified_lines[16..modified_lines.len() - 7].iter() {
+ write!(write, "{}", line).unwrap();
+ }
+ writeln!(write, "}} return success; }} }}")?;
+ return Ok(contract);
+}
diff --git a/src/execute.rs b/src/execute.rs
index c250a46ec..33433c6c9 100644
--- a/src/execute.rs
+++ b/src/execute.rs
@@ -1,6 +1,6 @@
use super::eth::verify_proof_via_solidity;
use crate::commands::{Cli, Commands, StrategyType, TranscriptType};
-use crate::eth::{deploy_verifier, send_proof};
+use crate::eth::{deploy_verifier, fix_verifier_sol, send_proof};
use crate::graph::{vector_to_quantized, Model, ModelCircuit};
use crate::pfsys::evm::aggregation::{
gen_aggregation_evm_verifier, AggregationCircuit, PoseidonTranscript,
@@ -35,7 +35,6 @@ use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::error::Error;
use std::fs::File;
use std::io::Write;
-use std::process::Command;
use std::time::Instant;
use tabled::Table;
use thiserror::Error;
@@ -250,18 +249,13 @@ pub async fn run(cli: Cli) -> Result<(), Box> {
deployment_code.save(deployment_code_path.as_ref().unwrap())?;
if sol_code_path.is_some() {
- let mut f = File::create(sol_code_path.as_ref().unwrap()).unwrap();
+ let mut f = File::create(sol_code_path.as_ref().unwrap())?;
let _ = f.write(yul_code.as_bytes());
- let cmd = Command::new("python3")
- .arg("fix_verifier_sol.py")
- .arg(sol_code_path.as_ref().unwrap())
- .output()
- .unwrap();
- let output = cmd.stdout;
+ let output = fix_verifier_sol(sol_code_path.as_ref().unwrap().clone())?;
- let mut f = File::create(sol_code_path.as_ref().unwrap()).unwrap();
- let _ = f.write(output.as_slice());
+ let mut f = File::create(sol_code_path.as_ref().unwrap())?;
+ let _ = f.write(output.as_bytes());
}
}
Commands::CreateEVMVerifierAggr {