From e4c35fe554cf4a5ef0bf4e0df96f6c49e5e59fd5 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 29 Feb 2024 15:41:35 +0100 Subject: [PATCH 01/62] Experimental pil to rust compiler. --- executor/src/constant_evaluator/compiler.rs | 251 ++++++++++++++++++++ executor/src/constant_evaluator/mod.rs | 14 +- 2 files changed, 263 insertions(+), 2 deletions(-) create mode 100644 executor/src/constant_evaluator/compiler.rs diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs new file mode 100644 index 0000000000..41a5b7d11d --- /dev/null +++ b/executor/src/constant_evaluator/compiler.rs @@ -0,0 +1,251 @@ +use std::{collections::HashMap, io::Write}; + +use itertools::Itertools; +use powdr_ast::{ + analyzed::{ + Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, + Reference, SymbolKind, + }, + parsed::{ + types::{ArrayType, FunctionType, Type, TypeScheme}, + ArrayLiteral, BinaryOperation, BinaryOperator, FunctionCall, IfExpression, IndexAccess, + LambdaExpression, Number, UnaryOperation, + }, +}; +use powdr_number::FieldElement; + +use super::VariablySizedColumn; + +const PREAMBLE: &str = r#" +#![allow(unused_parens)] +use ark_ff::{BigInt, BigInteger, Fp64, MontBackend, MontConfig, PrimeField}; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use std::io::{BufWriter, Write}; +use std::fs::File; + +#[derive(MontConfig)] +#[modulus = "18446744069414584321"] +#[generator = "7"] +pub struct GoldilocksBaseFieldConfig; +pub type FieldElement = Fp64>; +"#; + +pub fn generate_fixed_cols( + analyzed: &Analyzed, +) -> HashMap)> { + let definitions = process_definitions(analyzed); + let degree = analyzed.degree(); + // TODO also eval other cols + let main_func = format!( + " +fn main() {{ + let data = (0..{degree}) + .into_par_iter() + .map(|i| {{ + main_inv(num_bigint::BigInt::from(i)) + }}) + .collect::>(); + let mut writer = BufWriter::new(File::create(\"./constants.bin\").unwrap()); + for i in 0..{degree} {{ + writer + .write_all(&BigInt::from(data[i]).to_bytes_le()) + .unwrap(); + }} +}} +" + ); + let result = format!("{PREAMBLE}\n{definitions}\n{main_func}\n"); + // write result to a temp file + let mut file = std::fs::File::create("/tmp/te/src/main.rs").unwrap(); + file.write_all(result.as_bytes()).unwrap(); + Default::default() +} + +pub fn process_definitions(analyzed: &Analyzed) -> String { + let mut result = String::new(); + for (name, (sym, value)) in &analyzed.definitions { + if name == "std::check::panic" { + result.push_str("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }"); + } else if name == "std::field::modulus" { + result.push_str("fn std_field_modulus() -> num_bigint::BigInt { num_bigint::BigInt::from(18446744069414584321_u64) }"); + } else if name == "std::convert::fe" { + result.push_str("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}"); + } else if let Some(FunctionValueDefinition::Expression(value)) = value { + println!("Processing {name} = {}", value.e); + let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { + TypeScheme { + vars: Default::default(), + ty: Type::Function(FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Fe), + }), + } + } else { + value.type_scheme.clone().unwrap() + }; + match &type_scheme { + TypeScheme { + vars, + ty: + Type::Function(FunctionType { + params: param_types, + value: return_type, + }), + } => { + let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = + &value.e + else { + todo!("value of fun: {}", value.e) + }; + result.push_str(&format!( + "fn {}<{}>({}) -> {} {{ {} }}\n", + escape(name), + vars, + params + .iter() + .zip(param_types) + .map(|(p, t)| format!("{}: {}", p, map_type(t))) + .format(", "), + map_type(return_type), + format_expr(body) + )); + } + _ => { + result.push_str(&format!( + "const {}: {} = {};\n", + escape(name), + map_type(&value.type_scheme.as_ref().unwrap().ty), + format_expr(&value.e) + )); + } + } + } + } + + result +} + +fn format_expr(e: &Expression) -> String { + match e { + Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), + Expression::Reference( + _, + Reference::Poly(PolynomialReference { + name, + poly_id: _, + type_args: _, + }), + ) => escape(name), // TOOD use type args if needed. + Expression::Number( + _, + Number { + value, + type_: Some(type_), + }, + ) => match type_ { + Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + Type::TypeVar(t) => format!("{t}::from({value}_u64)"), + _ => unreachable!(), + }, + Expression::FunctionCall( + _, + FunctionCall { + function, + arguments, + }, + ) => { + format!( + "({})({})", + format_expr(function), + arguments + .iter() + .map(format_expr) + .map(|x| format!("{x}.clone()")) + .collect::>() + .join(", ") + ) + } + Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { + let left = format_expr(left); + let right = format_expr(right); + match op { + BinaryOperator::ShiftLeft => { + format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") + } + _ => format!("(({left}).clone() {op} ({right}).clone())"), + } + } + Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { + format!("({op} ({}).clone())", format_expr(expr)) + } + Expression::IndexAccess(_, IndexAccess { array, index }) => { + format!( + "{}[usize::try_from({}).unwrap()].clone()", + format_expr(array), + format_expr(index) + ) + } + Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { + // let params = if *params == vec!["r".to_string()] { + // // Hack because rust needs the type + // vec!["r: Vec".to_string()] + // } else { + // params.clone() + // }; + format!( + "|{}| {{ {} }}", + params.iter().format(", "), + format_expr(body) + ) + } + Expression::IfExpression( + _, + IfExpression { + condition, + body, + else_body, + }, + ) => { + format!( + "if {} {{ {} }} else {{ {} }}", + format_expr(condition), + format_expr(body), + format_expr(else_body) + ) + } + Expression::ArrayLiteral(_, ArrayLiteral { items }) => { + format!( + "vec![{}]", + items.iter().map(format_expr).collect::>().join(", ") + ) + } + Expression::String(_, s) => format!("{s:?}"), // TODO does this quote properly? + Expression::Tuple(_, items) => format!( + "({})", + items.iter().map(format_expr).collect::>().join(", ") + ), + _ => panic!("Implement {e}"), + } +} + +fn escape(s: &str) -> String { + s.replace('.', "_").replace("::", "_") +} + +fn map_type(ty: &Type) -> String { + match ty { + Type::Bottom | Type::Bool => format!("{ty}"), + Type::Int => "num_bigint::BigInt".to_string(), + Type::Fe => "FieldElement".to_string(), + Type::String => "String".to_string(), + Type::Col => unreachable!(), + Type::Expr => "Expr".to_string(), + Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), + Type::Tuple(_) => todo!(), + Type::Function(ft) => todo!("Type {ft}"), + Type::TypeVar(tv) => tv.to_string(), + Type::NamedType(_path, _type_args) => todo!(), + } +} diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 40256195ac..f1bb9a2c93 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -6,7 +6,7 @@ use std::{ pub use data_structures::{get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn}; use itertools::Itertools; use powdr_ast::{ - analyzed::{Analyzed, Expression, FunctionValueDefinition, Symbol, TypedExpression}, + analyzed::{Analyzed, Expression, FunctionValueDefinition, PolyID, Symbol, TypedExpression}, parsed::{ types::{ArrayType, Type}, IndexAccess, @@ -16,6 +16,9 @@ use powdr_number::{BigInt, BigUint, DegreeType, FieldElement}; use powdr_pil_analyzer::evaluator::{self, Definitions, SymbolLookup, Value}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +// TODO this is probabyl not the right place. +mod compiler; + mod data_structures; /// Generates the fixed column values for all fixed columns that are defined @@ -24,12 +27,19 @@ mod data_structures; /// Arrays of columns are flattened, the name of the `i`th array element /// is `name[i]`. pub fn generate(analyzed: &Analyzed) -> Vec<(String, VariablySizedColumn)> { - let mut fixed_cols = HashMap::new(); + // TODO to do this properly, we should try to compile as much as possible + // and only evaulato if it fails. Still, compilation should be done in one run. + + let mut fixed_cols: HashMap)> = + compiler::generate_fixed_cols(analyzed); for (poly, value) in analyzed.constant_polys_in_source_order() { if let Some(value) = value { // For arrays, generate values for each index, // for non-arrays, set index to None. for (index, (name, id)) in poly.array_elements().enumerate() { + if fixed_cols.contains_key(&name) { + continue; + } let index = poly.is_array().then_some(index as u64); let range = poly.degree.unwrap(); let values = range From 686bdc6bfaf6ed19376a374f0d53ee9ce13f6c61 Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 12:41:00 +0000 Subject: [PATCH 02/62] Compile and dlopen. --- executor/Cargo.toml | 2 + executor/src/constant_evaluator/compiler.rs | 475 ++++++++++++-------- 2 files changed, 292 insertions(+), 185 deletions(-) diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 0079cce9a9..3303097918 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -14,7 +14,9 @@ powdr-parser-util.workspace = true powdr-pil-analyzer.workspace = true itertools = "0.13" +libc = "0.2.0" log = { version = "0.4.17" } +mktemp = "0.5.0" rayon = "1.7.0" bit-vec = "0.6.3" num-traits = "0.2.15" diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index 41a5b7d11d..f5ee7ca46e 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -1,4 +1,12 @@ -use std::{collections::HashMap, io::Write}; +use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; +use std::{ + collections::{HashMap, HashSet}, + ffi::CString, + fs::File, + io::Write, + process::Command, + sync::Arc, +}; use itertools::Itertools; use powdr_ast::{ @@ -7,6 +15,7 @@ use powdr_ast::{ Reference, SymbolKind, }, parsed::{ + display::{format_type_args, quote}, types::{ArrayType, FunctionType, Type, TypeScheme}, ArrayLiteral, BinaryOperation, BinaryOperator, FunctionCall, IfExpression, IndexAccess, LambdaExpression, Number, UnaryOperation, @@ -19,9 +28,6 @@ use super::VariablySizedColumn; const PREAMBLE: &str = r#" #![allow(unused_parens)] use ark_ff::{BigInt, BigInteger, Fp64, MontBackend, MontConfig, PrimeField}; -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; -use std::io::{BufWriter, Write}; -use std::fs::File; #[derive(MontConfig)] #[modulus = "18446744069414584321"] @@ -30,203 +36,302 @@ pub struct GoldilocksBaseFieldConfig; pub type FieldElement = Fp64>; "#; +// TODO this is the old impl of goldilocks + +const CARGO_TOML: &str = r#" +[package] +name = "powdr_constants" +version = "0.1.0" +edition = "2021" + +[dependencies] +ark-ff = "0.4.2" +"#; + +// TODO crate type dylib? + pub fn generate_fixed_cols( analyzed: &Analyzed, ) -> HashMap)> { - let definitions = process_definitions(analyzed); - let degree = analyzed.degree(); - // TODO also eval other cols - let main_func = format!( - " -fn main() {{ - let data = (0..{degree}) - .into_par_iter() - .map(|i| {{ - main_inv(num_bigint::BigInt::from(i)) - }}) - .collect::>(); - let mut writer = BufWriter::new(File::create(\"./constants.bin\").unwrap()); - for i in 0..{degree} {{ - writer - .write_all(&BigInt::from(data[i]).to_bytes_le()) - .unwrap(); - }} -}} -" - ); - let result = format!("{PREAMBLE}\n{definitions}\n{main_func}\n"); - // write result to a temp file - let mut file = std::fs::File::create("/tmp/te/src/main.rs").unwrap(); - file.write_all(result.as_bytes()).unwrap(); - Default::default() -} + let mut compiler = Compiler::new(analyzed); + for (sym, _) in &analyzed.constant_polys_in_source_order() { + compiler.request_symbol(&sym.absolute_name); + } + let code = format!("{PREAMBLE}\n{}\n", compiler.compiled_symbols()); -pub fn process_definitions(analyzed: &Analyzed) -> String { - let mut result = String::new(); - for (name, (sym, value)) in &analyzed.definitions { - if name == "std::check::panic" { - result.push_str("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }"); - } else if name == "std::field::modulus" { - result.push_str("fn std_field_modulus() -> num_bigint::BigInt { num_bigint::BigInt::from(18446744069414584321_u64) }"); - } else if name == "std::convert::fe" { - result.push_str("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}"); - } else if let Some(FunctionValueDefinition::Expression(value)) = value { - println!("Processing {name} = {}", value.e); - let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { - TypeScheme { - vars: Default::default(), - ty: Type::Function(FunctionType { - params: vec![Type::Int], - value: Box::new(Type::Fe), - }), - } - } else { - value.type_scheme.clone().unwrap() - }; - match &type_scheme { - TypeScheme { - vars, - ty: - Type::Function(FunctionType { - params: param_types, - value: return_type, - }), - } => { - let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = - &value.e - else { - todo!("value of fun: {}", value.e) - }; - result.push_str(&format!( - "fn {}<{}>({}) -> {} {{ {} }}\n", - escape(name), - vars, - params - .iter() - .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(t))) - .format(", "), - map_type(return_type), - format_expr(body) - )); - } - _ => { - result.push_str(&format!( - "const {}: {} = {};\n", - escape(name), - map_type(&value.type_scheme.as_ref().unwrap().ty), - format_expr(&value.e) - )); - } + let dir = mktemp::Temp::new_dir().unwrap(); + std::fs::create_dir(dir.as_path().join("src")).unwrap(); + std::fs::write(dir.as_path().join("src").join("lib.rs"), code).unwrap(); + Command::new("cargo") + .arg("build") + .arg("--release") + .current_dir(dir.as_path()) + .output() + .unwrap(); + + unsafe { + let lib_path = CString::new( + dir.as_path() + .join("target") + .join("release") + .join("libpowdr_constants.so") + .to_str() + .unwrap(), + ) + .unwrap(); + let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); + if lib.is_null() { + panic!("Failed to load library: {:?}", lib_path); + } + for (sym, poly_id) in analyzed.constant_polys_in_source_order() { + let sym = escape(&sym.absolute_name); + let sym = CString::new(sym).unwrap(); + let sym = dlsym(lib, sym.as_ptr()); + if sym.is_null() { + println!("Failed to load symbol: {:?}", sym); + continue; } + println!("Loaded symbol: {:?}", sym); + // let sym = sym as *const VariablySizedColumn; + // cols.insert(sym.absolute_name.clone(), (poly_id, (*sym).clone())); } } + todo!() +} - result +struct Compiler<'a, T> { + analyzed: &'a Analyzed, + queue: Vec, + requested: HashSet, + failed: HashMap, + symbols: HashMap, } -fn format_expr(e: &Expression) -> String { - match e { - Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), - Expression::Reference( - _, - Reference::Poly(PolynomialReference { - name, - poly_id: _, - type_args: _, - }), - ) => escape(name), // TOOD use type args if needed. - Expression::Number( - _, - Number { - value, - type_: Some(type_), - }, - ) => match type_ { - Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - Type::TypeVar(t) => format!("{t}::from({value}_u64)"), - _ => unreachable!(), - }, - Expression::FunctionCall( - _, - FunctionCall { - function, - arguments, - }, - ) => { - format!( - "({})({})", - format_expr(function), - arguments - .iter() - .map(format_expr) - .map(|x| format!("{x}.clone()")) - .collect::>() - .join(", ") - ) +impl<'a, T> Compiler<'a, T> { + pub fn new(analyzed: &'a Analyzed) -> Self { + Self { + analyzed, + queue: Default::default(), + requested: Default::default(), + failed: Default::default(), + symbols: Default::default(), } - Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { - let left = format_expr(left); - let right = format_expr(right); - match op { - BinaryOperator::ShiftLeft => { - format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") - } - _ => format!("(({left}).clone() {op} ({right}).clone())"), - } + } + + pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { + if let Some(err) = self.failed.get(name) { + return Err(err.clone()); } - Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { - format!("({op} ({}).clone())", format_expr(expr)) + if self.requested.contains(name) { + return Ok(()); } - Expression::IndexAccess(_, IndexAccess { array, index }) => { - format!( - "{}[usize::try_from({}).unwrap()].clone()", - format_expr(array), - format_expr(index) - ) + self.requested.insert(name.to_string()); + match self.generate_code(name) { + Ok(code) => { + self.symbols.insert(name.to_string(), code); + Ok(()) + } + Err(err) => { + let err = format!("Failed to compile {name}: {err}"); + self.failed.insert(name.to_string(), err.clone()); + Err(err) + } } - Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { - // let params = if *params == vec!["r".to_string()] { - // // Hack because rust needs the type - // vec!["r: Vec".to_string()] - // } else { - // params.clone() - // }; - format!( - "|{}| {{ {} }}", - params.iter().format(", "), - format_expr(body) - ) + } + + pub fn compiled_symbols(self) -> String { + self.symbols + .into_iter() + .map(|(name, code)| code) + .format("\n\n") + .to_string() + } + + fn generate_code(&mut self, symbol: &str) -> Result { + if symbol == "std::check::panic" { + // TODO should this really panic? + return Ok("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }".to_string()); + } else if symbol == "std::field::modulus" { + // TODO depends on T + return Ok("fn std_field_modulus() -> num_bigint::BigInt { num_bigint::BigInt::from(18446744069414584321_u64) }" + .to_string()); + } else if symbol == "std::convert::fe" { + return Ok("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + .to_string()); } - Expression::IfExpression( - _, - IfExpression { - condition, - body, - else_body, + + let Some((sym, Some(FunctionValueDefinition::Expression(value)))) = + self.analyzed.definitions.get(symbol) + else { + return Err(format!( + "No definition for {symbol}, or not a generic symbol" + )); + }; + println!("Processing {symbol} = {}", value.e); + Ok(match &value.type_scheme.as_ref().unwrap() { + TypeScheme { + vars, + ty: + Type::Function(FunctionType { + params: param_types, + value: return_type, + }), + } => { + let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = + &value.e + else { + return Err(format!( + "Expected lambda expression for {symbol}, got {}", + value.e + )); + }; + format!( + "fn {}<{}>({}) -> {} {{ {} }}\n", + escape(symbol), + vars, + params + .iter() + .zip(param_types) + .map(|(p, t)| format!("{}: {}", p, map_type(t))) + .format(", "), + map_type(return_type), + self.format_expr(body)? + ) + } + _ => format!( + "const {}: {} = {};\n", + escape(symbol), + map_type(&value.type_scheme.as_ref().unwrap().ty), + self.format_expr(&value.e)? + ), + }) + } + + fn format_expr(&mut self, e: &Expression) -> Result { + Ok(match e { + Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), + Expression::Reference( + _, + Reference::Poly(PolynomialReference { + name, + poly_id: _, + type_args, + }), + ) => { + self.request_symbol(name)?; + format!( + "{}{}", + escape(name), + // TODO do all type args work here? + type_args + .as_ref() + .map(|ta| format!("::{}", format_type_args(&ta))) + .unwrap_or_default() + ) + } + Expression::Number( + _, + Number { + value, + type_: Some(type_), + }, + ) => match type_ { + Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + Type::TypeVar(t) => format!("{t}::from({value}_u64)"), + _ => unreachable!(), }, - ) => { - format!( - "if {} {{ {} }} else {{ {} }}", - format_expr(condition), - format_expr(body), - format_expr(else_body) - ) - } - Expression::ArrayLiteral(_, ArrayLiteral { items }) => { - format!( - "vec![{}]", - items.iter().map(format_expr).collect::>().join(", ") - ) - } - Expression::String(_, s) => format!("{s:?}"), // TODO does this quote properly? - Expression::Tuple(_, items) => format!( - "({})", - items.iter().map(format_expr).collect::>().join(", ") - ), - _ => panic!("Implement {e}"), + Expression::FunctionCall( + _, + FunctionCall { + function, + arguments, + }, + ) => { + format!( + "({})({})", + self.format_expr(function)?, + arguments + .iter() + .map(|a| self.format_expr(a)) + .collect::, _>>()? + .into_iter() + // TODO these should all be refs + .map(|x| format!("{x}.clone()")) + .collect::>() + .join(", ") + ) + } + Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { + let left = self.format_expr(left)?; + let right = self.format_expr(right)?; + match op { + BinaryOperator::ShiftLeft => { + format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") + } + _ => format!("(({left}).clone() {op} ({right}).clone())"), + } + } + Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { + format!("({op} ({}).clone())", self.format_expr(expr)?) + } + Expression::IndexAccess(_, IndexAccess { array, index }) => { + format!( + "{}[usize::try_from({}).unwrap()].clone()", + self.format_expr(array)?, + self.format_expr(index)? + ) + } + Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { + // let params = if *params == vec!["r".to_string()] { + // // Hack because rust needs the type + // vec!["r: Vec".to_string()] + // } else { + // params.clone() + // }; + format!( + "|{}| {{ {} }}", + params.iter().format(", "), + self.format_expr(body)? + ) + } + Expression::IfExpression( + _, + IfExpression { + condition, + body, + else_body, + }, + ) => { + format!( + "if {} {{ {} }} else {{ {} }}", + self.format_expr(condition)?, + self.format_expr(body)?, + self.format_expr(else_body)? + ) + } + Expression::ArrayLiteral(_, ArrayLiteral { items }) => { + format!( + "vec![{}]", + items + .iter() + .map(|i| self.format_expr(i)) + .collect::, _>>()? + .join(", ") + ) + } + Expression::String(_, s) => quote(s), + Expression::Tuple(_, items) => format!( + "({})", + items + .iter() + .map(|i| self.format_expr(i)) + .collect::, _>>()? + .join(", ") + ), + _ => return Err(format!("Implement {e}")), + }) } } From b2f8d8845bc42e9ef7d5ba6b31744e0d94dfd75c Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 13:01:51 +0000 Subject: [PATCH 03/62] oeu --- executor/src/constant_evaluator/compiler.rs | 68 +++++++++++++++++---- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index f5ee7ca46e..19369983bc 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -2,7 +2,7 @@ use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; use std::{ collections::{HashMap, HashSet}, ffi::CString, - fs::File, + fs::{self, File}, io::Write, process::Command, sync::Arc, @@ -17,8 +17,8 @@ use powdr_ast::{ parsed::{ display::{format_type_args, quote}, types::{ArrayType, FunctionType, Type, TypeScheme}, - ArrayLiteral, BinaryOperation, BinaryOperator, FunctionCall, IfExpression, IndexAccess, - LambdaExpression, Number, UnaryOperation, + ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, + IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; use powdr_number::FieldElement; @@ -44,8 +44,13 @@ name = "powdr_constants" version = "0.1.0" edition = "2021" +[lib] +crate-type = ["dylib"] + [dependencies] ark-ff = "0.4.2" +num-bigint = { version = "0.4.3", features = ["serde"] } +num-traits = "0.2.15" "#; // TODO crate type dylib? @@ -55,19 +60,28 @@ pub fn generate_fixed_cols( ) -> HashMap)> { let mut compiler = Compiler::new(analyzed); for (sym, _) in &analyzed.constant_polys_in_source_order() { - compiler.request_symbol(&sym.absolute_name); + // ignore err + if let Err(e) = compiler.request_symbol(&sym.absolute_name) { + println!("Failed to compile {}: {e}", &sym.absolute_name); + } } let code = format!("{PREAMBLE}\n{}\n", compiler.compiled_symbols()); + println!("Compiled code:\n{code}"); let dir = mktemp::Temp::new_dir().unwrap(); - std::fs::create_dir(dir.as_path().join("src")).unwrap(); - std::fs::write(dir.as_path().join("src").join("lib.rs"), code).unwrap(); - Command::new("cargo") + fs::write(dir.as_path().join("Cargo.toml"), CARGO_TOML).unwrap(); + fs::create_dir(dir.as_path().join("src")).unwrap(); + fs::write(dir.as_path().join("src").join("lib.rs"), code).unwrap(); + let out = Command::new("cargo") .arg("build") .arg("--release") .current_dir(dir.as_path()) .output() .unwrap(); + out.stderr.iter().for_each(|b| print!("{}", *b as char)); + if !out.status.success() { + panic!("Failed to compile."); + } unsafe { let lib_path = CString::new( @@ -96,12 +110,11 @@ pub fn generate_fixed_cols( // cols.insert(sym.absolute_name.clone(), (poly_id, (*sym).clone())); } } - todo!() + Default::default() } struct Compiler<'a, T> { analyzed: &'a Analyzed, - queue: Vec, requested: HashSet, failed: HashMap, symbols: HashMap, @@ -111,7 +124,6 @@ impl<'a, T> Compiler<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { analyzed, - queue: Default::default(), requested: Default::default(), failed: Default::default(), symbols: Default::default(), @@ -129,6 +141,7 @@ impl<'a, T> Compiler<'a, T> { match self.generate_code(name) { Ok(code) => { self.symbols.insert(name.to_string(), code); + println!("Generated code for {name}"); Ok(()) } Err(err) => { @@ -168,7 +181,18 @@ impl<'a, T> Compiler<'a, T> { )); }; println!("Processing {symbol} = {}", value.e); - Ok(match &value.type_scheme.as_ref().unwrap() { + let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { + TypeScheme { + vars: Default::default(), + ty: Type::Function(FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Fe), + }), + } + } else { + value.type_scheme.clone().unwrap() + }; + Ok(match type_scheme { TypeScheme { vars, ty: @@ -192,9 +216,9 @@ impl<'a, T> Compiler<'a, T> { params .iter() .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(t))) + .map(|(p, t)| format!("{}: {}", p, map_type(&t))) .format(", "), - map_type(return_type), + map_type(return_type.as_ref()), self.format_expr(body)? ) } @@ -330,9 +354,27 @@ impl<'a, T> Compiler<'a, T> { .collect::, _>>()? .join(", ") ), + Expression::BlockExpression(_, BlockExpression { statements, expr }) => { + format!( + "{{\n{}\n{}\n}}", + statements + .iter() + .map(|s| self.format_statement(s)) + .collect::, _>>()? + .join("\n"), + expr.as_ref() + .map(|e| self.format_expr(e.as_ref())) + .transpose()? + .unwrap_or_default() + ) + } _ => return Err(format!("Implement {e}")), }) } + + fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { + Err(format!("Implement {s}")) + } } fn escape(s: &str) -> String { From 392afc1e496b3afc287d0e590e2ef09f6837b219 Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 13:02:02 +0000 Subject: [PATCH 04/62] oeu --- executor/src/constant_evaluator/compiler.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index 19369983bc..3b452fadbb 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -110,6 +110,7 @@ pub fn generate_fixed_cols( // cols.insert(sym.absolute_name.clone(), (poly_id, (*sym).clone())); } } + panic!(); Default::default() } From 6b2029a9eca606e6b6c0bbca0c25d050fa36074f Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 13:25:29 +0000 Subject: [PATCH 05/62] loaded sym --- executor/src/constant_evaluator/compiler.rs | 50 ++++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index 3b452fadbb..87ec77c2d5 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -2,8 +2,9 @@ use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; use std::{ collections::{HashMap, HashSet}, ffi::CString, - fs::{self, File}, + fs::{self, create_dir, File}, io::Write, + path, process::Command, sync::Arc, }; @@ -32,8 +33,8 @@ use ark_ff::{BigInt, BigInteger, Fp64, MontBackend, MontConfig, PrimeField}; #[derive(MontConfig)] #[modulus = "18446744069414584321"] #[generator = "7"] -pub struct GoldilocksBaseFieldConfig; -pub type FieldElement = Fp64>; +struct GoldilocksBaseFieldConfig; +type FieldElement = Fp64>; "#; // TODO this is the old impl of goldilocks @@ -59,23 +60,45 @@ pub fn generate_fixed_cols( analyzed: &Analyzed, ) -> HashMap)> { let mut compiler = Compiler::new(analyzed); + let mut glue = String::new(); for (sym, _) in &analyzed.constant_polys_in_source_order() { // ignore err if let Err(e) = compiler.request_symbol(&sym.absolute_name) { println!("Failed to compile {}: {e}", &sym.absolute_name); } } - let code = format!("{PREAMBLE}\n{}\n", compiler.compiled_symbols()); + for (sym, _) in &analyzed.constant_polys_in_source_order() { + // TODO escape? + if compiler.is_compiled(&sym.absolute_name) { + // TODO it is a rust function, can we use a more complex type as well? + // TODO only works for goldilocks + glue.push_str(&format!( + r#" + #[no_mangle] + pub extern fn extern_{}(i: u64) -> u64 {{ + {}(num_bigint::BigInt::from(i)).into_bigint().0[0] + }} + "#, + escape(&sym.absolute_name), + escape(&sym.absolute_name), + )); + } + } + + let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); println!("Compiled code:\n{code}"); - let dir = mktemp::Temp::new_dir().unwrap(); - fs::write(dir.as_path().join("Cargo.toml"), CARGO_TOML).unwrap(); - fs::create_dir(dir.as_path().join("src")).unwrap(); - fs::write(dir.as_path().join("src").join("lib.rs"), code).unwrap(); + //let dir = mktemp::Temp::new_dir().unwrap(); + let _ = fs::remove_dir_all("/tmp/powdr_constants"); + fs::create_dir("/tmp/powdr_constants").unwrap(); + let dir = path::Path::new("/tmp/powdr_constants"); + fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); + fs::create_dir(dir.join("src")).unwrap(); + fs::write(dir.join("src").join("lib.rs"), code).unwrap(); let out = Command::new("cargo") .arg("build") .arg("--release") - .current_dir(dir.as_path()) + .current_dir(dir) .output() .unwrap(); out.stderr.iter().for_each(|b| print!("{}", *b as char)); @@ -85,8 +108,7 @@ pub fn generate_fixed_cols( unsafe { let lib_path = CString::new( - dir.as_path() - .join("target") + dir.join("target") .join("release") .join("libpowdr_constants.so") .to_str() @@ -98,7 +120,7 @@ pub fn generate_fixed_cols( panic!("Failed to load library: {:?}", lib_path); } for (sym, poly_id) in analyzed.constant_polys_in_source_order() { - let sym = escape(&sym.absolute_name); + let sym = format!("extern_{}", escape(&sym.absolute_name)); let sym = CString::new(sym).unwrap(); let sym = dlsym(lib, sym.as_ptr()); if sym.is_null() { @@ -153,6 +175,10 @@ impl<'a, T> Compiler<'a, T> { } } + pub fn is_compiled(&self, name: &str) -> bool { + self.symbols.contains_key(name) + } + pub fn compiled_symbols(self) -> String { self.symbols .into_iter() From 7ee9afc1824d604c560dd507996b73db0265a94f Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 13:34:35 +0000 Subject: [PATCH 06/62] log time --- executor/src/constant_evaluator/compiler.rs | 39 +++++++++++++++++---- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index 87ec77c2d5..d50977922c 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -1,4 +1,5 @@ use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::{ collections::{HashMap, HashSet}, ffi::CString, @@ -7,6 +8,7 @@ use std::{ path, process::Command, sync::Arc, + time::Instant, }; use itertools::Itertools; @@ -24,6 +26,8 @@ use powdr_ast::{ }; use powdr_number::FieldElement; +use crate::constant_evaluator::{MAX_DEGREE_LOG, MIN_DEGREE_LOG}; + use super::VariablySizedColumn; const PREAMBLE: &str = r#" @@ -106,6 +110,7 @@ pub fn generate_fixed_cols( panic!("Failed to compile."); } + let mut columns = HashMap::new(); unsafe { let lib_path = CString::new( dir.join("target") @@ -119,8 +124,9 @@ pub fn generate_fixed_cols( if lib.is_null() { panic!("Failed to load library: {:?}", lib_path); } - for (sym, poly_id) in analyzed.constant_polys_in_source_order() { - let sym = format!("extern_{}", escape(&sym.absolute_name)); + let start = Instant::now(); + for (poly, value) in analyzed.constant_polys_in_source_order() { + let sym = format!("extern_{}", escape(&poly.absolute_name)); let sym = CString::new(sym).unwrap(); let sym = dlsym(lib, sym.as_ptr()); if sym.is_null() { @@ -128,12 +134,33 @@ pub fn generate_fixed_cols( continue; } println!("Loaded symbol: {:?}", sym); - // let sym = sym as *const VariablySizedColumn; - // cols.insert(sym.absolute_name.clone(), (poly_id, (*sym).clone())); + let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); + let degrees = if let Some(degree) = poly.degree { + vec![degree] + } else { + (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) + .map(|degree_log| 1 << degree_log) + .collect::>() + }; + + let col_values = degrees + .into_iter() + .map(|degree| { + (0..degree) + .into_par_iter() + .map(|i| T::from(fun(i as u64))) + .collect::>() + }) + .collect::>() + .into(); + columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); } + log::info!( + "Fixed column generation (without compilation and loading time) took {}s", + start.elapsed().as_secs_f32() + ); } - panic!(); - Default::default() + columns } struct Compiler<'a, T> { From 44e80b8c6d6272c586c1de5971ac31661e736631 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 08:48:46 +0000 Subject: [PATCH 07/62] new crate. --- Cargo.toml | 2 + executor/Cargo.toml | 3 +- executor/src/constant_evaluator/mod.rs | 9 +- jit-compiler/Cargo.toml | 20 ++ .../src}/compiler.rs | 238 ++++++++---------- jit-compiler/src/lib.rs | 2 + jit-compiler/src/loader.rs | 130 ++++++++++ 7 files changed, 267 insertions(+), 137 deletions(-) create mode 100644 jit-compiler/Cargo.toml rename {executor/src/constant_evaluator => jit-compiler/src}/compiler.rs (69%) create mode 100644 jit-compiler/src/lib.rs create mode 100644 jit-compiler/src/loader.rs diff --git a/Cargo.toml b/Cargo.toml index 93a492e8ff..d76db81f3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "cli", "cli-rs", "executor", + "jit-compiler", "riscv", "parser-util", "pil-analyzer", @@ -48,6 +49,7 @@ powdr-analysis = { path = "./analysis", version = "0.1.0-alpha.2" } powdr-backend = { path = "./backend", version = "0.1.0-alpha.2" } powdr-executor = { path = "./executor", version = "0.1.0-alpha.2" } powdr-importer = { path = "./importer", version = "0.1.0-alpha.2" } +powdr-jit-compiler = { path = "./jit-compiler", version = "0.1.0-alpha.2" } powdr-linker = { path = "./linker", version = "0.1.0-alpha.2" } powdr-number = { path = "./number", version = "0.1.0-alpha.2" } powdr-parser = { path = "./parser", version = "0.1.0-alpha.2" } diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 3303097918..c0a89ef694 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -12,11 +12,10 @@ powdr-ast.workspace = true powdr-number.workspace = true powdr-parser-util.workspace = true powdr-pil-analyzer.workspace = true +powdr-jit-compiler.workspace = true itertools = "0.13" -libc = "0.2.0" log = { version = "0.4.17" } -mktemp = "0.5.0" rayon = "1.7.0" bit-vec = "0.6.3" num-traits = "0.2.15" diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index f1bb9a2c93..9d79f1d504 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -16,9 +16,6 @@ use powdr_number::{BigInt, BigUint, DegreeType, FieldElement}; use powdr_pil_analyzer::evaluator::{self, Definitions, SymbolLookup, Value}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; -// TODO this is probabyl not the right place. -mod compiler; - mod data_structures; /// Generates the fixed column values for all fixed columns that are defined @@ -30,16 +27,12 @@ pub fn generate(analyzed: &Analyzed) -> Vec<(String, Variabl // TODO to do this properly, we should try to compile as much as possible // and only evaulato if it fails. Still, compilation should be done in one run. - let mut fixed_cols: HashMap)> = - compiler::generate_fixed_cols(analyzed); + let mut fixed_cols = HashMap::new(); for (poly, value) in analyzed.constant_polys_in_source_order() { if let Some(value) = value { // For arrays, generate values for each index, // for non-arrays, set index to None. for (index, (name, id)) in poly.array_elements().enumerate() { - if fixed_cols.contains_key(&name) { - continue; - } let index = poly.is_array().then_some(index as u64); let range = poly.degree.unwrap(); let values = range diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml new file mode 100644 index 0000000000..8650abad45 --- /dev/null +++ b/jit-compiler/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "powdr-jit-compiler" +description = "powdr just-in-time compiler" +version = { workspace = true } +edition = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +powdr-ast.workspace = true +powdr-number.workspace = true +powdr-parser.workspace = true + +libc = "0.2.0" +mktemp = "0.5.0" +itertools = "0.13" + +[lints.clippy] +uninlined_format_args = "deny" diff --git a/executor/src/constant_evaluator/compiler.rs b/jit-compiler/src/compiler.rs similarity index 69% rename from executor/src/constant_evaluator/compiler.rs rename to jit-compiler/src/compiler.rs index d50977922c..7f6c51b79c 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,5 +1,4 @@ use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::{ collections::{HashMap, HashSet}, ffi::CString, @@ -24,28 +23,18 @@ use powdr_ast::{ IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::FieldElement; - -use crate::constant_evaluator::{MAX_DEGREE_LOG, MIN_DEGREE_LOG}; - -use super::VariablySizedColumn; +use powdr_number::{FieldElement, LargeInt}; const PREAMBLE: &str = r#" #![allow(unused_parens)] -use ark_ff::{BigInt, BigInteger, Fp64, MontBackend, MontConfig, PrimeField}; -#[derive(MontConfig)] -#[modulus = "18446744069414584321"] -#[generator = "7"] -struct GoldilocksBaseFieldConfig; -type FieldElement = Fp64>; "#; // TODO this is the old impl of goldilocks const CARGO_TOML: &str = r#" [package] -name = "powdr_constants" +name = "powdr_jit_compiled" version = "0.1.0" edition = "2021" @@ -53,115 +42,116 @@ edition = "2021" crate-type = ["dylib"] [dependencies] -ark-ff = "0.4.2" -num-bigint = { version = "0.4.3", features = ["serde"] } +// TODO version? +powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } +num-bigint = { version = "0.4.3" } num-traits = "0.2.15" "#; // TODO crate type dylib? -pub fn generate_fixed_cols( - analyzed: &Analyzed, -) -> HashMap)> { - let mut compiler = Compiler::new(analyzed); - let mut glue = String::new(); - for (sym, _) in &analyzed.constant_polys_in_source_order() { - // ignore err - if let Err(e) = compiler.request_symbol(&sym.absolute_name) { - println!("Failed to compile {}: {e}", &sym.absolute_name); - } - } - for (sym, _) in &analyzed.constant_polys_in_source_order() { - // TODO escape? - if compiler.is_compiled(&sym.absolute_name) { - // TODO it is a rust function, can we use a more complex type as well? - // TODO only works for goldilocks - glue.push_str(&format!( - r#" - #[no_mangle] - pub extern fn extern_{}(i: u64) -> u64 {{ - {}(num_bigint::BigInt::from(i)).into_bigint().0[0] - }} - "#, - escape(&sym.absolute_name), - escape(&sym.absolute_name), - )); - } - } +// pub fn generate_fixed_cols( +// analyzed: &Analyzed, +// ) -> HashMap)> { +// let mut compiler = Compiler::new(analyzed); +// let mut glue = String::new(); +// for (sym, _) in &analyzed.constant_polys_in_source_order() { +// // ignore err +// if let Err(e) = compiler.request_symbol(&sym.absolute_name) { +// println!("Failed to compile {}: {e}", &sym.absolute_name); +// } +// } +// for (sym, _) in &analyzed.constant_polys_in_source_order() { +// // TODO escape? +// if compiler.is_compiled(&sym.absolute_name) { +// // TODO it is a rust function, can we use a more complex type as well? +// // TODO only works for goldilocks +// glue.push_str(&format!( +// r#" +// #[no_mangle] +// pub extern fn extern_{}(i: u64) -> u64 {{ +// {}(num_bigint::BigInt::from(i)).into_bigint().0[0] +// }} +// "#, +// escape(&sym.absolute_name), +// escape(&sym.absolute_name), +// )); +// } +// } - let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); - println!("Compiled code:\n{code}"); +// let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); +// println!("Compiled code:\n{code}"); - //let dir = mktemp::Temp::new_dir().unwrap(); - let _ = fs::remove_dir_all("/tmp/powdr_constants"); - fs::create_dir("/tmp/powdr_constants").unwrap(); - let dir = path::Path::new("/tmp/powdr_constants"); - fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); - fs::create_dir(dir.join("src")).unwrap(); - fs::write(dir.join("src").join("lib.rs"), code).unwrap(); - let out = Command::new("cargo") - .arg("build") - .arg("--release") - .current_dir(dir) - .output() - .unwrap(); - out.stderr.iter().for_each(|b| print!("{}", *b as char)); - if !out.status.success() { - panic!("Failed to compile."); - } +// //let dir = mktemp::Temp::new_dir().unwrap(); +// let _ = fs::remove_dir_all("/tmp/powdr_constants"); +// fs::create_dir("/tmp/powdr_constants").unwrap(); +// let dir = path::Path::new("/tmp/powdr_constants"); +// fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); +// fs::create_dir(dir.join("src")).unwrap(); +// fs::write(dir.join("src").join("lib.rs"), code).unwrap(); +// let out = Command::new("cargo") +// .arg("build") +// .arg("--release") +// .current_dir(dir) +// .output() +// .unwrap(); +// out.stderr.iter().for_each(|b| print!("{}", *b as char)); +// if !out.status.success() { +// panic!("Failed to compile."); +// } - let mut columns = HashMap::new(); - unsafe { - let lib_path = CString::new( - dir.join("target") - .join("release") - .join("libpowdr_constants.so") - .to_str() - .unwrap(), - ) - .unwrap(); - let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); - if lib.is_null() { - panic!("Failed to load library: {:?}", lib_path); - } - let start = Instant::now(); - for (poly, value) in analyzed.constant_polys_in_source_order() { - let sym = format!("extern_{}", escape(&poly.absolute_name)); - let sym = CString::new(sym).unwrap(); - let sym = dlsym(lib, sym.as_ptr()); - if sym.is_null() { - println!("Failed to load symbol: {:?}", sym); - continue; - } - println!("Loaded symbol: {:?}", sym); - let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); - let degrees = if let Some(degree) = poly.degree { - vec![degree] - } else { - (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) - .map(|degree_log| 1 << degree_log) - .collect::>() - }; +// let mut columns = HashMap::new(); +// unsafe { +// let lib_path = CString::new( +// dir.join("target") +// .join("release") +// .join("libpowdr_constants.so") +// .to_str() +// .unwrap(), +// ) +// .unwrap(); +// let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); +// if lib.is_null() { +// panic!("Failed to load library: {:?}", lib_path); +// } +// let start = Instant::now(); +// for (poly, value) in analyzed.constant_polys_in_source_order() { +// let sym = format!("extern_{}", escape(&poly.absolute_name)); +// let sym = CString::new(sym).unwrap(); +// let sym = dlsym(lib, sym.as_ptr()); +// if sym.is_null() { +// println!("Failed to load symbol: {:?}", sym); +// continue; +// } +// println!("Loaded symbol: {:?}", sym); +// let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); +// let degrees = if let Some(degree) = poly.degree { +// vec![degree] +// } else { +// (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) +// .map(|degree_log| 1 << degree_log) +// .collect::>() +// }; - let col_values = degrees - .into_iter() - .map(|degree| { - (0..degree) - .into_par_iter() - .map(|i| T::from(fun(i as u64))) - .collect::>() - }) - .collect::>() - .into(); - columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); - } - log::info!( - "Fixed column generation (without compilation and loading time) took {}s", - start.elapsed().as_secs_f32() - ); - } - columns -} +// let col_values = degrees +// .into_iter() +// .map(|degree| { +// (0..degree) +// .into_par_iter() +// .map(|i| T::from(fun(i as u64))) +// .collect::>() +// }) +// .collect::>() +// .into(); +// columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); +// } +// log::info!( +// "Fixed column generation (without compilation and loading time) took {}s", +// start.elapsed().as_secs_f32() +// ); +// } +// columns +// } struct Compiler<'a, T> { analyzed: &'a Analyzed, @@ -170,7 +160,7 @@ struct Compiler<'a, T> { symbols: HashMap, } -impl<'a, T> Compiler<'a, T> { +impl<'a, T: FieldElement> Compiler<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { analyzed, @@ -219,9 +209,8 @@ impl<'a, T> Compiler<'a, T> { // TODO should this really panic? return Ok("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }".to_string()); } else if symbol == "std::field::modulus" { - // TODO depends on T - return Ok("fn std_field_modulus() -> num_bigint::BigInt { num_bigint::BigInt::from(18446744069414584321_u64) }" - .to_string()); + let modulus = T::modulus(); + return Ok(format!("fn std_field_modulus() -> num_bigint::BigInt {{ num_bigint::BigInt::from(\"{modulus}\") }}")); } else if symbol == "std::convert::fe" { return Ok("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()); @@ -235,6 +224,7 @@ impl<'a, T> Compiler<'a, T> { )); }; println!("Processing {symbol} = {}", value.e); + // TODO assert type scheme is there? let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { TypeScheme { vars: Default::default(), @@ -288,14 +278,7 @@ impl<'a, T> Compiler<'a, T> { fn format_expr(&mut self, e: &Expression) -> Result { Ok(match e { Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), - Expression::Reference( - _, - Reference::Poly(PolynomialReference { - name, - poly_id: _, - type_args, - }), - ) => { + Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { self.request_symbol(name)?; format!( "{}{}", @@ -314,6 +297,7 @@ impl<'a, T> Compiler<'a, T> { type_: Some(type_), }, ) => match type_ { + // TODO value does not need to be u64 Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), Type::Fe => format!("FieldElement::from({value}_u64)"), Type::Expr => format!("Expr::from({value}_u64)"), @@ -335,7 +319,7 @@ impl<'a, T> Compiler<'a, T> { .map(|a| self.format_expr(a)) .collect::, _>>()? .into_iter() - // TODO these should all be refs + // TODO these should all be refs -> turn all types to arc .map(|x| format!("{x}.clone()")) .collect::>() .join(", ") @@ -441,12 +425,12 @@ fn map_type(ty: &Type) -> String { Type::Int => "num_bigint::BigInt".to_string(), Type::Fe => "FieldElement".to_string(), Type::String => "String".to_string(), - Type::Col => unreachable!(), Type::Expr => "Expr".to_string(), Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), Type::Tuple(_) => todo!(), Type::Function(ft) => todo!("Type {ft}"), Type::TypeVar(tv) => tv.to_string(), Type::NamedType(_path, _type_args) => todo!(), + Type::Col | Type::Inter => unreachable!(), } } diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs new file mode 100644 index 0000000000..7dd2003c50 --- /dev/null +++ b/jit-compiler/src/lib.rs @@ -0,0 +1,2 @@ +mod compiler; +mod loader; diff --git a/jit-compiler/src/loader.rs b/jit-compiler/src/loader.rs new file mode 100644 index 0000000000..690bc5467d --- /dev/null +++ b/jit-compiler/src/loader.rs @@ -0,0 +1,130 @@ +// use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; +// use rayon::iter::{IntoParallelIterator, ParallelIterator}; +// use std::{ +// collections::{HashMap, HashSet}, +// ffi::CString, +// fs::{self, create_dir, File}, +// io::Write, +// path, +// process::Command, +// sync::Arc, +// time::Instant, +// }; + +// use itertools::Itertools; +// use powdr_ast::{ +// analyzed::{ +// Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, +// Reference, SymbolKind, +// }, +// parsed::{ +// display::{format_type_args, quote}, +// types::{ArrayType, FunctionType, Type, TypeScheme}, +// ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, +// IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, +// }, +// }; +// use powdr_number::FieldElement; + +// // pub fn generate_fixed_cols( +// // analyzed: &Analyzed, +// // ) -> HashMap)> { +// // let mut compiler = Compiler::new(analyzed); +// // let mut glue = String::new(); +// // for (sym, _) in &analyzed.constant_polys_in_source_order() { +// // // ignore err +// // if let Err(e) = compiler.request_symbol(&sym.absolute_name) { +// // println!("Failed to compile {}: {e}", &sym.absolute_name); +// // } +// // } +// // for (sym, _) in &analyzed.constant_polys_in_source_order() { +// // // TODO escape? +// // if compiler.is_compiled(&sym.absolute_name) { +// // // TODO it is a rust function, can we use a more complex type as well? +// // // TODO only works for goldilocks +// // glue.push_str(&format!( +// // r#" +// #[no_mangle] +// pub extern fn extern_{}(i: u64) -> u64 {{ +// {}(num_bigint::BigInt::from(i)).into_bigint().0[0] +// }} +// "#, +// escape(&sym.absolute_name), +// escape(&sym.absolute_name), +// )); +// } +// } + +// let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); +// println!("Compiled code:\n{code}"); + +// //let dir = mktemp::Temp::new_dir().unwrap(); +// let _ = fs::remove_dir_all("/tmp/powdr_constants"); +// fs::create_dir("/tmp/powdr_constants").unwrap(); +// let dir = path::Path::new("/tmp/powdr_constants"); +// fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); +// fs::create_dir(dir.join("src")).unwrap(); +// fs::write(dir.join("src").join("lib.rs"), code).unwrap(); +// let out = Command::new("cargo") +// .arg("build") +// .arg("--release") +// .current_dir(dir) +// .output() +// .unwrap(); +// out.stderr.iter().for_each(|b| print!("{}", *b as char)); +// if !out.status.success() { +// panic!("Failed to compile."); +// } + +// let mut columns = HashMap::new(); +// unsafe { +// let lib_path = CString::new( +// dir.join("target") +// .join("release") +// .join("libpowdr_constants.so") +// .to_str() +// .unwrap(), +// ) +// .unwrap(); +// let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); +// if lib.is_null() { +// panic!("Failed to load library: {:?}", lib_path); +// } +// let start = Instant::now(); +// for (poly, value) in analyzed.constant_polys_in_source_order() { +// let sym = format!("extern_{}", escape(&poly.absolute_name)); +// let sym = CString::new(sym).unwrap(); +// let sym = dlsym(lib, sym.as_ptr()); +// if sym.is_null() { +// println!("Failed to load symbol: {:?}", sym); +// continue; +// } +// println!("Loaded symbol: {:?}", sym); +// let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); +// let degrees = if let Some(degree) = poly.degree { +// vec![degree] +// } else { +// (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) +// .map(|degree_log| 1 << degree_log) +// .collect::>() +// }; + +// let col_values = degrees +// .into_iter() +// .map(|degree| { +// (0..degree) +// .into_par_iter() +// .map(|i| T::from(fun(i as u64))) +// .collect::>() +// }) +// .collect::>() +// .into(); +// columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); +// } +// log::info!( +// "Fixed column generation (without compilation and loading time) took {}s", +// start.elapsed().as_secs_f32() +// ); +// } +// columns +// } From 65e5f44b3ad1ae917c5f985166a8482e347e2b0b Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 08:49:47 +0000 Subject: [PATCH 08/62] fix --- executor/src/constant_evaluator/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 9d79f1d504..40256195ac 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -6,7 +6,7 @@ use std::{ pub use data_structures::{get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn}; use itertools::Itertools; use powdr_ast::{ - analyzed::{Analyzed, Expression, FunctionValueDefinition, PolyID, Symbol, TypedExpression}, + analyzed::{Analyzed, Expression, FunctionValueDefinition, Symbol, TypedExpression}, parsed::{ types::{ArrayType, Type}, IndexAccess, @@ -24,9 +24,6 @@ mod data_structures; /// Arrays of columns are flattened, the name of the `i`th array element /// is `name[i]`. pub fn generate(analyzed: &Analyzed) -> Vec<(String, VariablySizedColumn)> { - // TODO to do this properly, we should try to compile as much as possible - // and only evaulato if it fails. Still, compilation should be done in one run. - let mut fixed_cols = HashMap::new(); for (poly, value) in analyzed.constant_polys_in_source_order() { if let Some(value) = value { From ec40700d61d8a2cb9b6169e79319e6b4faf60a83 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 09:35:54 +0000 Subject: [PATCH 09/62] work --- executor/Cargo.toml | 1 - jit-compiler/Cargo.toml | 5 + jit-compiler/src/compiler.rs | 437 +--------------------------------- jit-compiler/src/lib.rs | 5 +- jit-compiler/tests/codegen.rs | 38 +++ 5 files changed, 47 insertions(+), 439 deletions(-) create mode 100644 jit-compiler/tests/codegen.rs diff --git a/executor/Cargo.toml b/executor/Cargo.toml index c0a89ef694..0079cce9a9 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -12,7 +12,6 @@ powdr-ast.workspace = true powdr-number.workspace = true powdr-parser-util.workspace = true powdr-pil-analyzer.workspace = true -powdr-jit-compiler.workspace = true itertools = "0.13" log = { version = "0.4.17" } diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index 8650abad45..94da05758b 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -16,5 +16,10 @@ libc = "0.2.0" mktemp = "0.5.0" itertools = "0.13" +[dev-dependencies] +powdr-pil-analyzer.workspace = true +pretty_assertions = "1.4.0" + + [lints.clippy] uninlined_format_args = "deny" diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 7f6c51b79c..b9f5e2def9 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,436 +1 @@ -use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; -use std::{ - collections::{HashMap, HashSet}, - ffi::CString, - fs::{self, create_dir, File}, - io::Write, - path, - process::Command, - sync::Arc, - time::Instant, -}; - -use itertools::Itertools; -use powdr_ast::{ - analyzed::{ - Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, - Reference, SymbolKind, - }, - parsed::{ - display::{format_type_args, quote}, - types::{ArrayType, FunctionType, Type, TypeScheme}, - ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, - IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, - }, -}; -use powdr_number::{FieldElement, LargeInt}; - -const PREAMBLE: &str = r#" -#![allow(unused_parens)] - -"#; - -// TODO this is the old impl of goldilocks - -const CARGO_TOML: &str = r#" -[package] -name = "powdr_jit_compiled" -version = "0.1.0" -edition = "2021" - -[lib] -crate-type = ["dylib"] - -[dependencies] -// TODO version? -powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } -num-bigint = { version = "0.4.3" } -num-traits = "0.2.15" -"#; - -// TODO crate type dylib? - -// pub fn generate_fixed_cols( -// analyzed: &Analyzed, -// ) -> HashMap)> { -// let mut compiler = Compiler::new(analyzed); -// let mut glue = String::new(); -// for (sym, _) in &analyzed.constant_polys_in_source_order() { -// // ignore err -// if let Err(e) = compiler.request_symbol(&sym.absolute_name) { -// println!("Failed to compile {}: {e}", &sym.absolute_name); -// } -// } -// for (sym, _) in &analyzed.constant_polys_in_source_order() { -// // TODO escape? -// if compiler.is_compiled(&sym.absolute_name) { -// // TODO it is a rust function, can we use a more complex type as well? -// // TODO only works for goldilocks -// glue.push_str(&format!( -// r#" -// #[no_mangle] -// pub extern fn extern_{}(i: u64) -> u64 {{ -// {}(num_bigint::BigInt::from(i)).into_bigint().0[0] -// }} -// "#, -// escape(&sym.absolute_name), -// escape(&sym.absolute_name), -// )); -// } -// } - -// let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); -// println!("Compiled code:\n{code}"); - -// //let dir = mktemp::Temp::new_dir().unwrap(); -// let _ = fs::remove_dir_all("/tmp/powdr_constants"); -// fs::create_dir("/tmp/powdr_constants").unwrap(); -// let dir = path::Path::new("/tmp/powdr_constants"); -// fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); -// fs::create_dir(dir.join("src")).unwrap(); -// fs::write(dir.join("src").join("lib.rs"), code).unwrap(); -// let out = Command::new("cargo") -// .arg("build") -// .arg("--release") -// .current_dir(dir) -// .output() -// .unwrap(); -// out.stderr.iter().for_each(|b| print!("{}", *b as char)); -// if !out.status.success() { -// panic!("Failed to compile."); -// } - -// let mut columns = HashMap::new(); -// unsafe { -// let lib_path = CString::new( -// dir.join("target") -// .join("release") -// .join("libpowdr_constants.so") -// .to_str() -// .unwrap(), -// ) -// .unwrap(); -// let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); -// if lib.is_null() { -// panic!("Failed to load library: {:?}", lib_path); -// } -// let start = Instant::now(); -// for (poly, value) in analyzed.constant_polys_in_source_order() { -// let sym = format!("extern_{}", escape(&poly.absolute_name)); -// let sym = CString::new(sym).unwrap(); -// let sym = dlsym(lib, sym.as_ptr()); -// if sym.is_null() { -// println!("Failed to load symbol: {:?}", sym); -// continue; -// } -// println!("Loaded symbol: {:?}", sym); -// let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); -// let degrees = if let Some(degree) = poly.degree { -// vec![degree] -// } else { -// (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) -// .map(|degree_log| 1 << degree_log) -// .collect::>() -// }; - -// let col_values = degrees -// .into_iter() -// .map(|degree| { -// (0..degree) -// .into_par_iter() -// .map(|i| T::from(fun(i as u64))) -// .collect::>() -// }) -// .collect::>() -// .into(); -// columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); -// } -// log::info!( -// "Fixed column generation (without compilation and loading time) took {}s", -// start.elapsed().as_secs_f32() -// ); -// } -// columns -// } - -struct Compiler<'a, T> { - analyzed: &'a Analyzed, - requested: HashSet, - failed: HashMap, - symbols: HashMap, -} - -impl<'a, T: FieldElement> Compiler<'a, T> { - pub fn new(analyzed: &'a Analyzed) -> Self { - Self { - analyzed, - requested: Default::default(), - failed: Default::default(), - symbols: Default::default(), - } - } - - pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { - if let Some(err) = self.failed.get(name) { - return Err(err.clone()); - } - if self.requested.contains(name) { - return Ok(()); - } - self.requested.insert(name.to_string()); - match self.generate_code(name) { - Ok(code) => { - self.symbols.insert(name.to_string(), code); - println!("Generated code for {name}"); - Ok(()) - } - Err(err) => { - let err = format!("Failed to compile {name}: {err}"); - self.failed.insert(name.to_string(), err.clone()); - Err(err) - } - } - } - - pub fn is_compiled(&self, name: &str) -> bool { - self.symbols.contains_key(name) - } - - pub fn compiled_symbols(self) -> String { - self.symbols - .into_iter() - .map(|(name, code)| code) - .format("\n\n") - .to_string() - } - - fn generate_code(&mut self, symbol: &str) -> Result { - if symbol == "std::check::panic" { - // TODO should this really panic? - return Ok("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }".to_string()); - } else if symbol == "std::field::modulus" { - let modulus = T::modulus(); - return Ok(format!("fn std_field_modulus() -> num_bigint::BigInt {{ num_bigint::BigInt::from(\"{modulus}\") }}")); - } else if symbol == "std::convert::fe" { - return Ok("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" - .to_string()); - } - - let Some((sym, Some(FunctionValueDefinition::Expression(value)))) = - self.analyzed.definitions.get(symbol) - else { - return Err(format!( - "No definition for {symbol}, or not a generic symbol" - )); - }; - println!("Processing {symbol} = {}", value.e); - // TODO assert type scheme is there? - let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { - TypeScheme { - vars: Default::default(), - ty: Type::Function(FunctionType { - params: vec![Type::Int], - value: Box::new(Type::Fe), - }), - } - } else { - value.type_scheme.clone().unwrap() - }; - Ok(match type_scheme { - TypeScheme { - vars, - ty: - Type::Function(FunctionType { - params: param_types, - value: return_type, - }), - } => { - let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = - &value.e - else { - return Err(format!( - "Expected lambda expression for {symbol}, got {}", - value.e - )); - }; - format!( - "fn {}<{}>({}) -> {} {{ {} }}\n", - escape(symbol), - vars, - params - .iter() - .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(&t))) - .format(", "), - map_type(return_type.as_ref()), - self.format_expr(body)? - ) - } - _ => format!( - "const {}: {} = {};\n", - escape(symbol), - map_type(&value.type_scheme.as_ref().unwrap().ty), - self.format_expr(&value.e)? - ), - }) - } - - fn format_expr(&mut self, e: &Expression) -> Result { - Ok(match e { - Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), - Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { - self.request_symbol(name)?; - format!( - "{}{}", - escape(name), - // TODO do all type args work here? - type_args - .as_ref() - .map(|ta| format!("::{}", format_type_args(&ta))) - .unwrap_or_default() - ) - } - Expression::Number( - _, - Number { - value, - type_: Some(type_), - }, - ) => match type_ { - // TODO value does not need to be u64 - Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - Type::TypeVar(t) => format!("{t}::from({value}_u64)"), - _ => unreachable!(), - }, - Expression::FunctionCall( - _, - FunctionCall { - function, - arguments, - }, - ) => { - format!( - "({})({})", - self.format_expr(function)?, - arguments - .iter() - .map(|a| self.format_expr(a)) - .collect::, _>>()? - .into_iter() - // TODO these should all be refs -> turn all types to arc - .map(|x| format!("{x}.clone()")) - .collect::>() - .join(", ") - ) - } - Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { - let left = self.format_expr(left)?; - let right = self.format_expr(right)?; - match op { - BinaryOperator::ShiftLeft => { - format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") - } - _ => format!("(({left}).clone() {op} ({right}).clone())"), - } - } - Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { - format!("({op} ({}).clone())", self.format_expr(expr)?) - } - Expression::IndexAccess(_, IndexAccess { array, index }) => { - format!( - "{}[usize::try_from({}).unwrap()].clone()", - self.format_expr(array)?, - self.format_expr(index)? - ) - } - Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { - // let params = if *params == vec!["r".to_string()] { - // // Hack because rust needs the type - // vec!["r: Vec".to_string()] - // } else { - // params.clone() - // }; - format!( - "|{}| {{ {} }}", - params.iter().format(", "), - self.format_expr(body)? - ) - } - Expression::IfExpression( - _, - IfExpression { - condition, - body, - else_body, - }, - ) => { - format!( - "if {} {{ {} }} else {{ {} }}", - self.format_expr(condition)?, - self.format_expr(body)?, - self.format_expr(else_body)? - ) - } - Expression::ArrayLiteral(_, ArrayLiteral { items }) => { - format!( - "vec![{}]", - items - .iter() - .map(|i| self.format_expr(i)) - .collect::, _>>()? - .join(", ") - ) - } - Expression::String(_, s) => quote(s), - Expression::Tuple(_, items) => format!( - "({})", - items - .iter() - .map(|i| self.format_expr(i)) - .collect::, _>>()? - .join(", ") - ), - Expression::BlockExpression(_, BlockExpression { statements, expr }) => { - format!( - "{{\n{}\n{}\n}}", - statements - .iter() - .map(|s| self.format_statement(s)) - .collect::, _>>()? - .join("\n"), - expr.as_ref() - .map(|e| self.format_expr(e.as_ref())) - .transpose()? - .unwrap_or_default() - ) - } - _ => return Err(format!("Implement {e}")), - }) - } - - fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { - Err(format!("Implement {s}")) - } -} - -fn escape(s: &str) -> String { - s.replace('.', "_").replace("::", "_") -} - -fn map_type(ty: &Type) -> String { - match ty { - Type::Bottom | Type::Bool => format!("{ty}"), - Type::Int => "num_bigint::BigInt".to_string(), - Type::Fe => "FieldElement".to_string(), - Type::String => "String".to_string(), - Type::Expr => "Expr".to_string(), - Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), - Type::Tuple(_) => todo!(), - Type::Function(ft) => todo!("Type {ft}"), - Type::TypeVar(tv) => tv.to_string(), - Type::NamedType(_path, _type_args) => todo!(), - Type::Col | Type::Inter => unreachable!(), - } -} +// TODO run cargo and stuff diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 7dd2003c50..1ca0c66546 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,2 +1,3 @@ -mod compiler; -mod loader; +pub mod codegen; +pub mod compiler; +pub mod loader; diff --git a/jit-compiler/tests/codegen.rs b/jit-compiler/tests/codegen.rs new file mode 100644 index 0000000000..9505ef5f62 --- /dev/null +++ b/jit-compiler/tests/codegen.rs @@ -0,0 +1,38 @@ +use powdr_jit_compiler::codegen::Compiler; +use powdr_number::GoldilocksField; +use powdr_pil_analyzer::analyze_string; + +use pretty_assertions::assert_eq; + +fn compile(input: &str, syms: &[&str]) -> String { + let analyzed = analyze_string::(input); + let mut compiler = Compiler::new(&analyzed); + for s in syms { + compiler.request_symbol(s).unwrap(); + } + compiler.compiled_symbols() +} + +#[test] +fn empty_code() { + let result = compile("", &[]); + assert_eq!(result, ""); +} + +#[test] +fn simple_fun() { + let result = compile("let c: int -> int = |i| i;", &["c"]); + assert_eq!( + result, + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + ); +} + +#[test] +fn constant() { + let result = compile("let c: int -> int = |i| i; let d = c(20);", &["c", "d"]); + assert_eq!( + result, + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + ); +} From 0be7ac0685e40a3942d49c42e67d96f9af723ac5 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 09:36:09 +0000 Subject: [PATCH 10/62] work --- jit-compiler/src/codegen.rs | 304 ++++++++++++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 jit-compiler/src/codegen.rs diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs new file mode 100644 index 0000000000..24727e724f --- /dev/null +++ b/jit-compiler/src/codegen.rs @@ -0,0 +1,304 @@ +use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; +use std::{ + collections::{HashMap, HashSet}, + ffi::CString, + fs::{self, create_dir, File}, + io::Write, + path, + process::Command, + sync::Arc, + time::Instant, +}; + +use itertools::Itertools; +use powdr_ast::{ + analyzed::{ + Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, + Reference, SymbolKind, + }, + parsed::{ + display::{format_type_args, quote}, + types::{ArrayType, FunctionType, Type, TypeScheme}, + ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, + IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, + }, +}; +use powdr_number::{FieldElement, LargeInt}; + +pub struct Compiler<'a, T> { + analyzed: &'a Analyzed, + requested: HashSet, + failed: HashMap, + symbols: HashMap, +} + +impl<'a, T: FieldElement> Compiler<'a, T> { + pub fn new(analyzed: &'a Analyzed) -> Self { + Self { + analyzed, + requested: Default::default(), + failed: Default::default(), + symbols: Default::default(), + } + } + + pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { + if let Some(err) = self.failed.get(name) { + return Err(err.clone()); + } + if self.requested.contains(name) { + return Ok(()); + } + self.requested.insert(name.to_string()); + match self.generate_code(name) { + Ok(code) => { + self.symbols.insert(name.to_string(), code); + Ok(()) + } + Err(err) => { + let err = format!("Failed to compile {name}: {err}"); + self.failed.insert(name.to_string(), err.clone()); + Err(err) + } + } + } + + pub fn is_compiled(&self, name: &str) -> bool { + self.symbols.contains_key(name) + } + + pub fn compiled_symbols(self) -> String { + self.symbols + .into_iter() + .map(|(_, code)| code) + .format("\n\n") + .to_string() + } + + fn generate_code(&mut self, symbol: &str) -> Result { + if let Some(code) = self.try_generate_builtin(symbol) { + return Ok(code); + } + + let Some((_, Some(FunctionValueDefinition::Expression(value)))) = + self.analyzed.definitions.get(symbol) + else { + return Err(format!( + "No definition for {symbol}, or not a generic symbol" + )); + }; + + let type_scheme = value.type_scheme.clone().unwrap(); + + Ok(match type_scheme { + TypeScheme { + vars, + ty: + Type::Function(FunctionType { + params: param_types, + value: return_type, + }), + } => { + let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = + &value.e + else { + return Err(format!( + "Expected lambda expression for {symbol}, got {}", + value.e + )); + }; + assert!(vars.is_empty()); + format!( + "fn {}({}) -> {} {{ {} }}\n", + escape_symbol(symbol), + params + .iter() + .zip(param_types) + .map(|(p, t)| format!("{}: {}", p, map_type(&t))) + .format(", "), + map_type(return_type.as_ref()), + self.format_expr(body)? + ) + } + _ => format!( + "const {}: {} = {};\n", + escape_symbol(symbol), + map_type(&value.type_scheme.as_ref().unwrap().ty), + self.format_expr(&value.e)? + ), + }) + } + + fn try_generate_builtin(&self, symbol: &str) -> Option { + let code = match symbol { + "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), + "std::field::modulus" => { + let modulus = T::modulus(); + Some(format!("() -> num_bigint::BigInt {{ num_bigint::BigInt::from(\"{modulus}\") }}")) + } + "std::convert::fe" => Some("(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + .to_string()), + _ => None, + }?; + Some(format!("fn {}{code}", escape_symbol(symbol))) + } + + fn format_expr(&mut self, e: &Expression) -> Result { + Ok(match e { + Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), + Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { + self.request_symbol(name)?; + format!( + "{}{}", + escape_symbol(name), + // TODO do all type args work here? + type_args + .as_ref() + .map(|ta| format!("::{}", format_type_args(&ta))) + .unwrap_or_default() + ) + } + Expression::Number( + _, + Number { + value, + type_: Some(type_), + }, + ) => match type_ { + // TODO value does not need to be u64 + Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + Type::TypeVar(t) => format!("{t}::from({value}_u64)"), + _ => unreachable!(), + }, + Expression::FunctionCall( + _, + FunctionCall { + function, + arguments, + }, + ) => { + format!( + "({})({})", + self.format_expr(function)?, + arguments + .iter() + .map(|a| self.format_expr(a)) + .collect::, _>>()? + .into_iter() + // TODO these should all be refs -> turn all types to arc + .map(|x| format!("{x}.clone()")) + .collect::>() + .join(", ") + ) + } + Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { + let left = self.format_expr(left)?; + let right = self.format_expr(right)?; + match op { + BinaryOperator::ShiftLeft => { + format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") + } + _ => format!("(({left}).clone() {op} ({right}).clone())"), + } + } + Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { + format!("({op} ({}).clone())", self.format_expr(expr)?) + } + Expression::IndexAccess(_, IndexAccess { array, index }) => { + format!( + "{}[usize::try_from({}).unwrap()].clone()", + self.format_expr(array)?, + self.format_expr(index)? + ) + } + Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { + // let params = if *params == vec!["r".to_string()] { + // // Hack because rust needs the type + // vec!["r: Vec".to_string()] + // } else { + // params.clone() + // }; + format!( + "|{}| {{ {} }}", + params.iter().format(", "), + self.format_expr(body)? + ) + } + Expression::IfExpression( + _, + IfExpression { + condition, + body, + else_body, + }, + ) => { + format!( + "if {} {{ {} }} else {{ {} }}", + self.format_expr(condition)?, + self.format_expr(body)?, + self.format_expr(else_body)? + ) + } + Expression::ArrayLiteral(_, ArrayLiteral { items }) => { + format!( + "vec![{}]", + items + .iter() + .map(|i| self.format_expr(i)) + .collect::, _>>()? + .join(", ") + ) + } + Expression::String(_, s) => quote(s), + Expression::Tuple(_, items) => format!( + "({})", + items + .iter() + .map(|i| self.format_expr(i)) + .collect::, _>>()? + .join(", ") + ), + Expression::BlockExpression(_, BlockExpression { statements, expr }) => { + format!( + "{{\n{}\n{}\n}}", + statements + .iter() + .map(|s| self.format_statement(s)) + .collect::, _>>()? + .join("\n"), + expr.as_ref() + .map(|e| self.format_expr(e.as_ref())) + .transpose()? + .unwrap_or_default() + ) + } + _ => return Err(format!("Implement {e}")), + }) + } + + fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { + Err(format!("Implement {s}")) + } +} + +fn escape_symbol(s: &str) -> String { + s.replace('.', "_").replace("::", "_") +} + +fn map_type(ty: &Type) -> String { + match ty { + Type::Bottom | Type::Bool => format!("{ty}"), + Type::Int => "num_bigint::BigInt".to_string(), + Type::Fe => "FieldElement".to_string(), + Type::String => "String".to_string(), + Type::Expr => "Expr".to_string(), + Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), + Type::Tuple(_) => todo!(), + Type::Function(ft) => todo!("Type {ft}"), + Type::TypeVar(tv) => tv.to_string(), + Type::NamedType(_path, _type_args) => todo!(), + Type::Col | Type::Inter => unreachable!(), + } +} From 6aa9f4509558b42a247ff72314900f7c062dced9 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 09:58:26 +0000 Subject: [PATCH 11/62] wor --- jit-compiler/src/codegen.rs | 53 ++++++++++++----------------------- jit-compiler/tests/codegen.rs | 12 ++++++-- 2 files changed, 27 insertions(+), 38 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 24727e724f..0e2ea33eba 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -1,21 +1,8 @@ -use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; -use std::{ - collections::{HashMap, HashSet}, - ffi::CString, - fs::{self, create_dir, File}, - io::Write, - path, - process::Command, - sync::Arc, - time::Instant, -}; +use std::collections::{HashMap, HashSet}; use itertools::Itertools; use powdr_ast::{ - analyzed::{ - Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, - Reference, SymbolKind, - }, + analyzed::{Analyzed, Expression, FunctionValueDefinition, PolynomialReference, Reference}, parsed::{ display::{format_type_args, quote}, types::{ArrayType, FunctionType, Type, TypeScheme}, @@ -23,7 +10,7 @@ use powdr_ast::{ IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::{FieldElement, LargeInt}; +use powdr_number::FieldElement; pub struct Compiler<'a, T> { analyzed: &'a Analyzed, @@ -70,8 +57,9 @@ impl<'a, T: FieldElement> Compiler<'a, T> { pub fn compiled_symbols(self) -> String { self.symbols .into_iter() + .sorted() .map(|(_, code)| code) - .format("\n\n") + .format("\n") .to_string() } @@ -148,13 +136,12 @@ impl<'a, T: FieldElement> Compiler<'a, T> { Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { self.request_symbol(name)?; + let ta = type_args.as_ref().unwrap(); format!( "{}{}", escape_symbol(name), - // TODO do all type args work here? - type_args - .as_ref() - .map(|ta| format!("::{}", format_type_args(&ta))) + (!ta.is_empty()) + .then(|| format!("::{}", format_type_args(ta))) .unwrap_or_default() ) } @@ -164,14 +151,15 @@ impl<'a, T: FieldElement> Compiler<'a, T> { value, type_: Some(type_), }, - ) => match type_ { - // TODO value does not need to be u64 - Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - Type::TypeVar(t) => format!("{t}::from({value}_u64)"), - _ => unreachable!(), - }, + ) => { + let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); + match type_ { + Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + _ => unreachable!(), + } + } Expression::FunctionCall( _, FunctionCall { @@ -214,12 +202,6 @@ impl<'a, T: FieldElement> Compiler<'a, T> { ) } Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { - // let params = if *params == vec!["r".to_string()] { - // // Hack because rust needs the type - // vec!["r: Vec".to_string()] - // } else { - // params.clone() - // }; format!( "|{}| {{ {} }}", params.iter().format(", "), @@ -284,6 +266,7 @@ impl<'a, T: FieldElement> Compiler<'a, T> { } fn escape_symbol(s: &str) -> String { + // TODO better escaping s.replace('.', "_").replace("::", "_") } diff --git a/jit-compiler/tests/codegen.rs b/jit-compiler/tests/codegen.rs index 9505ef5f62..8187ef102f 100644 --- a/jit-compiler/tests/codegen.rs +++ b/jit-compiler/tests/codegen.rs @@ -29,10 +29,16 @@ fn simple_fun() { } #[test] -fn constant() { - let result = compile("let c: int -> int = |i| i; let d = c(20);", &["c", "d"]); +fn fun_calls() { + let result = compile( + "let c: int -> int = |i| i + 20; let d = |k| c(k * 20);", + &["c", "d"], + ); assert_eq!( result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { ((i).clone() + (num_bigint::BigInt::from(20_u64)).clone()) } + +fn d(k: num_bigint::BigInt) -> num_bigint::BigInt { (c)(((k).clone() * (num_bigint::BigInt::from(20_u64)).clone()).clone()) } +" ); } From 01272e89818d74c6856749c82a001c8663386521 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 12:46:19 +0000 Subject: [PATCH 12/62] loading --- jit-compiler/src/codegen.rs | 14 ++--- jit-compiler/src/compiler.rs | 112 +++++++++++++++++++++++++++++++++- jit-compiler/src/lib.rs | 2 + jit-compiler/src/loader.rs | 2 +- jit-compiler/tests/codegen.rs | 4 +- 5 files changed, 123 insertions(+), 11 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 0e2ea33eba..96f2d15beb 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -12,14 +12,14 @@ use powdr_ast::{ }; use powdr_number::FieldElement; -pub struct Compiler<'a, T> { +pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, requested: HashSet, failed: HashMap, symbols: HashMap, } -impl<'a, T: FieldElement> Compiler<'a, T> { +impl<'a, T: FieldElement> CodeGenerator<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { analyzed, @@ -122,9 +122,9 @@ impl<'a, T: FieldElement> Compiler<'a, T> { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), "std::field::modulus" => { let modulus = T::modulus(); - Some(format!("() -> num_bigint::BigInt {{ num_bigint::BigInt::from(\"{modulus}\") }}")) + Some(format!("() -> powdr_number::BigInt {{ powdr_number::BigInt::from(\"{modulus}\") }}")) } - "std::convert::fe" => Some("(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + "std::convert::fe" => Some("(n: powdr_number::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()), _ => None, }?; @@ -154,7 +154,7 @@ impl<'a, T: FieldElement> Compiler<'a, T> { ) => { let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); match type_ { - Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Int => format!("powdr_number::BigInt::from({value}_u64)"), Type::Fe => format!("FieldElement::from({value}_u64)"), Type::Expr => format!("Expr::from({value}_u64)"), _ => unreachable!(), @@ -265,7 +265,7 @@ impl<'a, T: FieldElement> Compiler<'a, T> { } } -fn escape_symbol(s: &str) -> String { +pub fn escape_symbol(s: &str) -> String { // TODO better escaping s.replace('.', "_").replace("::", "_") } @@ -273,7 +273,7 @@ fn escape_symbol(s: &str) -> String { fn map_type(ty: &Type) -> String { match ty { Type::Bottom | Type::Bool => format!("{ty}"), - Type::Int => "num_bigint::BigInt".to_string(), + Type::Int => "powdr_number::BigInt".to_string(), Type::Fe => "FieldElement".to_string(), Type::String => "String".to_string(), Type::Expr => "Expr".to_string(), diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index b9f5e2def9..f6c182683f 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1 +1,111 @@ -// TODO run cargo and stuff +use libc::{c_void, dlopen, dlsym, RTLD_NOW}; +use std::{ + collections::{HashMap}, + ffi::CString, + fs::{self}, + path, + process::Command, +}; + +use itertools::Itertools; +use powdr_ast::{ + analyzed::{ + Analyzed, + }, +}; +use powdr_number::FieldElement; + +use crate::codegen::{escape_symbol, CodeGenerator}; + +// TODO make this depend on T + +const PREAMBLE: &str = r#" +#![allow(unused_parens)] +type FieldElement = powdr_number::goldilocks::GoldilocksField; +"#; + +// TODO this is the old impl of goldilocks + +const CARGO_TOML: &str = r#" +[package] +name = "powdr_jit_compiled" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["dylib"] + +[dependencies] +powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } +"#; + +pub fn compile( + analyzed: &Analyzed, + symbols: &[&str], +) -> Result u64>, String> { + let mut codegen = CodeGenerator::new(analyzed); + let mut glue = String::new(); + for sym in symbols { + codegen.request_symbol(sym)?; + // TODO verify that the type is `int -> int`. + // TODO we should use big int instead of u64 + let name = escape_symbol(sym); + glue.push_str(&format!( + r#" + #[no_mangle] + pub extern fn extern_{name}(i: u64) -> u64 {{ + {name}(powdr_number::BigInt::from(i)).into_bigint().0[0] + }} + "# + )); + } + + let code = format!("{PREAMBLE}\n{}\n{glue}\n", codegen.compiled_symbols()); + println!("Compiled code:\n{code}"); + + // TODO for testing, keep the dir the same + //let dir = mktemp::Temp::new_dir().unwrap(); + let _ = fs::remove_dir_all("/tmp/powdr_constants"); + fs::create_dir("/tmp/powdr_constants").unwrap(); + let dir = path::Path::new("/tmp/powdr_constants"); + fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); + fs::create_dir(dir.join("src")).unwrap(); + fs::write(dir.join("src").join("lib.rs"), code).unwrap(); + let out = Command::new("cargo") + .arg("build") + .arg("--release") + .current_dir(dir) + .output() + .unwrap(); + out.stderr.iter().for_each(|b| print!("{}", *b as char)); + if !out.status.success() { + panic!("Failed to compile."); + } + + let lib_path = CString::new( + dir.join("target") + .join("release") + .join("libpowdr_constants.so") + .to_str() + .unwrap(), + ) + .unwrap(); + + let lib = unsafe { dlopen(lib_path.as_ptr(), RTLD_NOW) }; + if lib.is_null() { + panic!("Failed to load library: {lib_path:?}"); + } + let mut result = HashMap::new(); + for sym in symbols { + let sym = format!("extern_{}", escape_symbol(sym)); + let sym_cstr = CString::new(sym.clone()).unwrap(); + let fun_ptr = unsafe { dlsym(lib, sym_cstr.as_ptr()) }; + if fun_ptr.is_null() { + return Err(format!("Failed to load symbol: {fun_ptr:?}")); + } + println!("Loaded symbol: {fun_ptr:?}"); + let fun = unsafe { std::mem::transmute::<*mut c_void, fn(u64) -> u64>(fun_ptr) }; + result.insert(sym, fun); + } + Ok(result) +} diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 1ca0c66546..3c347edfa2 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,3 +1,5 @@ pub mod codegen; pub mod compiler; pub mod loader; + +//let n = num_bigint::BigUint::from_bytes_le(&n.to_le_bytes()); diff --git a/jit-compiler/src/loader.rs b/jit-compiler/src/loader.rs index 690bc5467d..ac0b0042b5 100644 --- a/jit-compiler/src/loader.rs +++ b/jit-compiler/src/loader.rs @@ -101,7 +101,7 @@ // } // println!("Loaded symbol: {:?}", sym); // let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); -// let degrees = if let Some(degree) = poly.degree { +// let degrees = if let Some(degree) = poly.degraee { // vec![degree] // } else { // (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) diff --git a/jit-compiler/tests/codegen.rs b/jit-compiler/tests/codegen.rs index 8187ef102f..2344cb07b6 100644 --- a/jit-compiler/tests/codegen.rs +++ b/jit-compiler/tests/codegen.rs @@ -1,4 +1,4 @@ -use powdr_jit_compiler::codegen::Compiler; +use powdr_jit_compiler::codegen::CodeGenerator; use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; @@ -6,7 +6,7 @@ use pretty_assertions::assert_eq; fn compile(input: &str, syms: &[&str]) -> String { let analyzed = analyze_string::(input); - let mut compiler = Compiler::new(&analyzed); + let mut compiler = CodeGenerator::new(&analyzed); for s in syms { compiler.request_symbol(s).unwrap(); } From 249f3c879b12ce9c412a87fb44612ec5a987f2a6 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 12:50:35 +0000 Subject: [PATCH 13/62] fix --- jit-compiler/src/compiler.rs | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index f6c182683f..4efc93afa5 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,6 +1,6 @@ use libc::{c_void, dlopen, dlsym, RTLD_NOW}; use std::{ - collections::{HashMap}, + collections::HashMap, ffi::CString, fs::{self}, path, @@ -8,11 +8,7 @@ use std::{ }; use itertools::Itertools; -use powdr_ast::{ - analyzed::{ - Analyzed, - }, -}; +use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; use crate::codegen::{escape_symbol, CodeGenerator}; @@ -21,7 +17,7 @@ use crate::codegen::{escape_symbol, CodeGenerator}; const PREAMBLE: &str = r#" #![allow(unused_parens)] -type FieldElement = powdr_number::goldilocks::GoldilocksField; +//type FieldElement = powdr_number::GoldilocksField; "#; // TODO this is the old impl of goldilocks @@ -54,7 +50,7 @@ pub fn compile( r#" #[no_mangle] pub extern fn extern_{name}(i: u64) -> u64 {{ - {name}(powdr_number::BigInt::from(i)).into_bigint().0[0] + u64::try_from({name}(powdr_number::BigInt::from(i))).unwrap() }} "# )); @@ -85,7 +81,7 @@ pub fn compile( let lib_path = CString::new( dir.join("target") .join("release") - .join("libpowdr_constants.so") + .join("libpowdr_jit_compiled.so") .to_str() .unwrap(), ) @@ -97,15 +93,15 @@ pub fn compile( } let mut result = HashMap::new(); for sym in symbols { - let sym = format!("extern_{}", escape_symbol(sym)); - let sym_cstr = CString::new(sym.clone()).unwrap(); + let extern_sym = format!("extern_{}", escape_symbol(sym)); + let sym_cstr = CString::new(extern_sym).unwrap(); let fun_ptr = unsafe { dlsym(lib, sym_cstr.as_ptr()) }; if fun_ptr.is_null() { return Err(format!("Failed to load symbol: {fun_ptr:?}")); } println!("Loaded symbol: {fun_ptr:?}"); let fun = unsafe { std::mem::transmute::<*mut c_void, fn(u64) -> u64>(fun_ptr) }; - result.insert(sym, fun); + result.insert(sym.to_string(), fun); } Ok(result) } From 1570fb08ef297d8486474ba2b701eababcf6260b Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 12:53:16 +0000 Subject: [PATCH 14/62] sqrt --- jit-compiler/src/compiler.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 4efc93afa5..c996b749bc 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -7,7 +7,6 @@ use std::{ process::Command, }; -use itertools::Itertools; use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; From 81dca2bf2383f37a7591e7ac7e6aad78306308a2 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 16:13:47 +0000 Subject: [PATCH 15/62] add benchmark --- jit-compiler/src/lib.rs | 3 +++ pil-analyzer/tests/types.rs | 2 +- pipeline/Cargo.toml | 1 + pipeline/benches/evaluator_benchmark.rs | 34 ++++++++++++++++++++++++- 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 3c347edfa2..8eb4560543 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,5 +1,8 @@ +// TODO make them non-pub? pub mod codegen; pub mod compiler; pub mod loader; //let n = num_bigint::BigUint::from_bytes_le(&n.to_le_bytes()); + +pub use compiler::compile; diff --git a/pil-analyzer/tests/types.rs b/pil-analyzer/tests/types.rs index 253ed4e34c..a149de289b 100644 --- a/pil-analyzer/tests/types.rs +++ b/pil-analyzer/tests/types.rs @@ -704,7 +704,7 @@ fn trait_user_defined_enum_wrong_type() { } let n: int = 7; - let r1 = Convert::convert(n); + let r1: int = Convert::convert(n); "; type_check(input, &[]); } diff --git a/pipeline/Cargo.toml b/pipeline/Cargo.toml index 2e1f3bc28e..67a89ec506 100644 --- a/pipeline/Cargo.toml +++ b/pipeline/Cargo.toml @@ -41,6 +41,7 @@ num-traits = "0.2.15" test-log = "0.2.12" env_logger = "0.10.0" criterion = { version = "0.4", features = ["html_reports"] } +powdr-jit-compiler.workspace = true [package.metadata.cargo-udeps.ignore] development = ["env_logger"] diff --git a/pipeline/benches/evaluator_benchmark.rs b/pipeline/benches/evaluator_benchmark.rs index f5d2eb483a..76d027cc23 100644 --- a/pipeline/benches/evaluator_benchmark.rs +++ b/pipeline/benches/evaluator_benchmark.rs @@ -114,5 +114,37 @@ fn evaluator_benchmark(c: &mut Criterion) { group.finish(); } -criterion_group!(benches_pil, evaluator_benchmark); +fn jit_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("jit-benchmark"); + + let sqrt_analyzed: Analyzed = { + let code = " + let sqrt: int -> int = |x| sqrt_rec(x, x); + let sqrt_rec: int, int -> int = |y, x| + if y * y <= x && (y + 1) * (y + 1) > x { + y + } else { + sqrt_rec((y + x / y) / 2, x) + }; + " + .to_string(); + let mut pipeline = Pipeline::default().from_asm_string(code, None); + pipeline.compute_analyzed_pil().unwrap().clone() + }; + + let sqrt_fun = powdr_jit_compiler::compile(&sqrt_analyzed, &["sqrt"]).unwrap()["sqrt"]; + + for x in [879882356, 1882356, 1187956, 56] { + group.bench_with_input(format!("sqrt_{x}"), &x, |b, &x| { + b.iter(|| { + let y = (x as u64) * 112655675; + sqrt_fun(y); + }); + }); + } + + group.finish(); +} + +criterion_group!(benches_pil, evaluator_benchmark, jit_benchmark); criterion_main!(benches_pil); From 641f1f2432375a718209752751eafa768e22de05 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 16:33:27 +0000 Subject: [PATCH 16/62] forgot test file --- jit-compiler/tests/execution.rs | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 jit-compiler/tests/execution.rs diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs new file mode 100644 index 0000000000..dbb1b0f4ee --- /dev/null +++ b/jit-compiler/tests/execution.rs @@ -0,0 +1,41 @@ +use powdr_jit_compiler::compiler; +use powdr_number::GoldilocksField; +use powdr_pil_analyzer::analyze_string; + +fn compile(input: &str, symbol: &str) -> fn(u64) -> u64 { + let analyzed = analyze_string::(input); + compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] +} + +#[test] +fn identity_function() { + let f = compile("let c: int -> int = |i| i;", "c"); + + assert_eq!(f(10), 10); +} + +#[test] +fn sqrt() { + let f = compile( + " + let sqrt_rec: int, int -> int = |y, x| + if y * y <= x && (y + 1) * (y + 1) > x { + y + } else { + sqrt_rec((y + x / y) / 2, x) + }; + + let sqrt: int -> int = |x| sqrt_rec(x, x);", + "sqrt", + ); + + for i in 0..100000 { + f(879882356 * 112655675); + // assert_eq!(f(9), 3); + // assert_eq!(f(100), 10); + // assert_eq!(f(8), 2); + // assert_eq!(f(101), 10); + // assert_eq!(f(99), 9); + // assert_eq!(f(0), 0); + } +} From 65c6c049bcd9c6d97b27cdf32e1e2a5ec4c32e18 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:16:00 +0000 Subject: [PATCH 17/62] clean. --- jit-compiler/src/codegen.rs | 53 ++++++++++++- jit-compiler/src/compiler.rs | 113 ++++++++++++++++----------- jit-compiler/src/lib.rs | 31 ++++++-- jit-compiler/src/loader.rs | 130 -------------------------------- jit-compiler/tests/codegen.rs | 44 ----------- jit-compiler/tests/execution.rs | 18 ++--- 6 files changed, 150 insertions(+), 239 deletions(-) delete mode 100644 jit-compiler/src/loader.rs delete mode 100644 jit-compiler/tests/codegen.rs diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 96f2d15beb..841c6c557d 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -50,10 +50,6 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } - pub fn is_compiled(&self, name: &str) -> bool { - self.symbols.contains_key(name) - } - pub fn compiled_symbols(self) -> String { self.symbols .into_iter() @@ -285,3 +281,52 @@ fn map_type(ty: &Type) -> String { Type::Col | Type::Inter => unreachable!(), } } + +#[cfg(test)] +mod test { + use powdr_number::GoldilocksField; + use powdr_pil_analyzer::analyze_string; + + use pretty_assertions::assert_eq; + + use super::CodeGenerator; + + fn compile(input: &str, syms: &[&str]) -> String { + let analyzed = analyze_string::(input); + let mut compiler = CodeGenerator::new(&analyzed); + for s in syms { + compiler.request_symbol(s).unwrap(); + } + compiler.compiled_symbols() + } + + #[test] + fn empty_code() { + let result = compile("", &[]); + assert_eq!(result, ""); + } + + #[test] + fn simple_fun() { + let result = compile("let c: int -> int = |i| i;", &["c"]); + assert_eq!( + result, + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + ); + } + + #[test] + fn fun_calls() { + let result = compile( + "let c: int -> int = |i| i + 20; let d = |k| c(k * 20);", + &["c", "d"], + ); + assert_eq!( + result, + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { ((i).clone() + (num_bigint::BigInt::from(20_u64)).clone()) } + +fn d(k: num_bigint::BigInt) -> num_bigint::BigInt { (c)(((k).clone() * (num_bigint::BigInt::from(20_u64)).clone()).clone()) } +" + ); + } +} diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index c996b749bc..60ae6e706f 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,16 +1,25 @@ use libc::{c_void, dlopen, dlsym, RTLD_NOW}; +use mktemp::Temp; use std::{ collections::HashMap, ffi::CString, fs::{self}, - path, process::Command, }; -use powdr_ast::analyzed::Analyzed; +use powdr_ast::{ + analyzed::Analyzed, + parsed::{ + display::format_type_scheme_around_name, + types::{FunctionType, Type, TypeScheme}, + }, +}; use powdr_number::FieldElement; -use crate::codegen::{escape_symbol, CodeGenerator}; +use crate::{ + codegen::{escape_symbol, CodeGenerator}, + SymbolMap, +}; // TODO make this depend on T @@ -19,88 +28,104 @@ const PREAMBLE: &str = r#" //type FieldElement = powdr_number::GoldilocksField; "#; -// TODO this is the old impl of goldilocks - -const CARGO_TOML: &str = r#" -[package] -name = "powdr_jit_compiled" -version = "0.1.0" -edition = "2021" - -[lib] -crate-type = ["dylib"] - -[dependencies] -powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } -"#; - -pub fn compile( +pub fn create_full_code( analyzed: &Analyzed, symbols: &[&str], -) -> Result u64>, String> { +) -> Result { let mut codegen = CodeGenerator::new(analyzed); let mut glue = String::new(); + let int_int_fun: TypeScheme = Type::Function(FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Int), + }) + .into(); for sym in symbols { + let ty = analyzed.type_of_symbol(sym); + if &ty != &int_int_fun { + return Err(format!( + "Only (int -> int) functions are supported, but requested {}", + format_type_scheme_around_name(sym, &Some(ty)), + )); + } codegen.request_symbol(sym)?; - // TODO verify that the type is `int -> int`. // TODO we should use big int instead of u64 let name = escape_symbol(sym); glue.push_str(&format!( r#" #[no_mangle] - pub extern fn extern_{name}(i: u64) -> u64 {{ + pub extern fn {}(i: u64) -> u64 {{ u64::try_from({name}(powdr_number::BigInt::from(i))).unwrap() }} - "# + "#, + extern_symbol_name(sym) )); } - let code = format!("{PREAMBLE}\n{}\n{glue}\n", codegen.compiled_symbols()); - println!("Compiled code:\n{code}"); + Ok(format!( + "{PREAMBLE}\n{}\n{glue}\n", + codegen.compiled_symbols() + )) +} - // TODO for testing, keep the dir the same - //let dir = mktemp::Temp::new_dir().unwrap(); - let _ = fs::remove_dir_all("/tmp/powdr_constants"); - fs::create_dir("/tmp/powdr_constants").unwrap(); - let dir = path::Path::new("/tmp/powdr_constants"); +const CARGO_TOML: &str = r#" +[package] +name = "powdr_jit_compiled" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["dylib"] + +[dependencies] +powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } +"#; + +/// Compiles the given code and returns the path to the +/// temporary directory containing the compiled library +/// and the path to the compiled library. +pub fn call_cargo(code: &str) -> (Temp, String) { + let dir = mktemp::Temp::new_dir().unwrap(); fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); fs::create_dir(dir.join("src")).unwrap(); fs::write(dir.join("src").join("lib.rs"), code).unwrap(); let out = Command::new("cargo") .arg("build") .arg("--release") - .current_dir(dir) + .current_dir(dir.clone()) .output() .unwrap(); out.stderr.iter().for_each(|b| print!("{}", *b as char)); if !out.status.success() { panic!("Failed to compile."); } + let lib_path = dir + .join("target") + .join("release") + .join("libpowdr_jit_compiled.so"); + (dir, lib_path.to_str().unwrap().to_string()) +} - let lib_path = CString::new( - dir.join("target") - .join("release") - .join("libpowdr_jit_compiled.so") - .to_str() - .unwrap(), - ) - .unwrap(); - - let lib = unsafe { dlopen(lib_path.as_ptr(), RTLD_NOW) }; +/// Loads the given library and creates funtion pointers for the given symbols. +pub fn load_library(path: &str, symbols: &[&str]) -> Result { + let c_path = CString::new(path).unwrap(); + let lib = unsafe { dlopen(c_path.as_ptr(), RTLD_NOW) }; if lib.is_null() { - panic!("Failed to load library: {lib_path:?}"); + return Err(format!("Failed to load library: {path:?}")); } let mut result = HashMap::new(); for sym in symbols { - let extern_sym = format!("extern_{}", escape_symbol(sym)); + let extern_sym = extern_symbol_name(sym); let sym_cstr = CString::new(extern_sym).unwrap(); let fun_ptr = unsafe { dlsym(lib, sym_cstr.as_ptr()) }; if fun_ptr.is_null() { return Err(format!("Failed to load symbol: {fun_ptr:?}")); } - println!("Loaded symbol: {fun_ptr:?}"); let fun = unsafe { std::mem::transmute::<*mut c_void, fn(u64) -> u64>(fun_ptr) }; result.insert(sym.to_string(), fun); } Ok(result) } + +fn extern_symbol_name(sym: &str) -> String { + format!("extern_{}", escape_symbol(sym)) +} diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 8eb4560543..cb49f31332 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,8 +1,27 @@ -// TODO make them non-pub? -pub mod codegen; -pub mod compiler; -pub mod loader; +mod codegen; +mod compiler; -//let n = num_bigint::BigUint::from_bytes_le(&n.to_le_bytes()); +use std::collections::HashMap; -pub use compiler::compile; +use compiler::{call_cargo, create_full_code, load_library}; +use powdr_ast::analyzed::Analyzed; +use powdr_number::FieldElement; + +pub type SymbolMap = HashMap u64>; + +/// Compiles the given symbols (and their dependencies) and returns them as a map +/// from symbol name to function pointer. +/// Only functions of type (int -> int) are supported for now. +pub fn compile( + analyzed: &Analyzed, + symbols: &[&str], +) -> Result { + let code = create_full_code(analyzed, symbols)?; + + let (dir, lib_path) = call_cargo(&code); + + let result = load_library(&lib_path, symbols); + + drop(dir); + result +} diff --git a/jit-compiler/src/loader.rs b/jit-compiler/src/loader.rs deleted file mode 100644 index ac0b0042b5..0000000000 --- a/jit-compiler/src/loader.rs +++ /dev/null @@ -1,130 +0,0 @@ -// use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; -// use rayon::iter::{IntoParallelIterator, ParallelIterator}; -// use std::{ -// collections::{HashMap, HashSet}, -// ffi::CString, -// fs::{self, create_dir, File}, -// io::Write, -// path, -// process::Command, -// sync::Arc, -// time::Instant, -// }; - -// use itertools::Itertools; -// use powdr_ast::{ -// analyzed::{ -// Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, -// Reference, SymbolKind, -// }, -// parsed::{ -// display::{format_type_args, quote}, -// types::{ArrayType, FunctionType, Type, TypeScheme}, -// ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, -// IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, -// }, -// }; -// use powdr_number::FieldElement; - -// // pub fn generate_fixed_cols( -// // analyzed: &Analyzed, -// // ) -> HashMap)> { -// // let mut compiler = Compiler::new(analyzed); -// // let mut glue = String::new(); -// // for (sym, _) in &analyzed.constant_polys_in_source_order() { -// // // ignore err -// // if let Err(e) = compiler.request_symbol(&sym.absolute_name) { -// // println!("Failed to compile {}: {e}", &sym.absolute_name); -// // } -// // } -// // for (sym, _) in &analyzed.constant_polys_in_source_order() { -// // // TODO escape? -// // if compiler.is_compiled(&sym.absolute_name) { -// // // TODO it is a rust function, can we use a more complex type as well? -// // // TODO only works for goldilocks -// // glue.push_str(&format!( -// // r#" -// #[no_mangle] -// pub extern fn extern_{}(i: u64) -> u64 {{ -// {}(num_bigint::BigInt::from(i)).into_bigint().0[0] -// }} -// "#, -// escape(&sym.absolute_name), -// escape(&sym.absolute_name), -// )); -// } -// } - -// let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); -// println!("Compiled code:\n{code}"); - -// //let dir = mktemp::Temp::new_dir().unwrap(); -// let _ = fs::remove_dir_all("/tmp/powdr_constants"); -// fs::create_dir("/tmp/powdr_constants").unwrap(); -// let dir = path::Path::new("/tmp/powdr_constants"); -// fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); -// fs::create_dir(dir.join("src")).unwrap(); -// fs::write(dir.join("src").join("lib.rs"), code).unwrap(); -// let out = Command::new("cargo") -// .arg("build") -// .arg("--release") -// .current_dir(dir) -// .output() -// .unwrap(); -// out.stderr.iter().for_each(|b| print!("{}", *b as char)); -// if !out.status.success() { -// panic!("Failed to compile."); -// } - -// let mut columns = HashMap::new(); -// unsafe { -// let lib_path = CString::new( -// dir.join("target") -// .join("release") -// .join("libpowdr_constants.so") -// .to_str() -// .unwrap(), -// ) -// .unwrap(); -// let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); -// if lib.is_null() { -// panic!("Failed to load library: {:?}", lib_path); -// } -// let start = Instant::now(); -// for (poly, value) in analyzed.constant_polys_in_source_order() { -// let sym = format!("extern_{}", escape(&poly.absolute_name)); -// let sym = CString::new(sym).unwrap(); -// let sym = dlsym(lib, sym.as_ptr()); -// if sym.is_null() { -// println!("Failed to load symbol: {:?}", sym); -// continue; -// } -// println!("Loaded symbol: {:?}", sym); -// let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); -// let degrees = if let Some(degree) = poly.degraee { -// vec![degree] -// } else { -// (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) -// .map(|degree_log| 1 << degree_log) -// .collect::>() -// }; - -// let col_values = degrees -// .into_iter() -// .map(|degree| { -// (0..degree) -// .into_par_iter() -// .map(|i| T::from(fun(i as u64))) -// .collect::>() -// }) -// .collect::>() -// .into(); -// columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); -// } -// log::info!( -// "Fixed column generation (without compilation and loading time) took {}s", -// start.elapsed().as_secs_f32() -// ); -// } -// columns -// } diff --git a/jit-compiler/tests/codegen.rs b/jit-compiler/tests/codegen.rs deleted file mode 100644 index 2344cb07b6..0000000000 --- a/jit-compiler/tests/codegen.rs +++ /dev/null @@ -1,44 +0,0 @@ -use powdr_jit_compiler::codegen::CodeGenerator; -use powdr_number::GoldilocksField; -use powdr_pil_analyzer::analyze_string; - -use pretty_assertions::assert_eq; - -fn compile(input: &str, syms: &[&str]) -> String { - let analyzed = analyze_string::(input); - let mut compiler = CodeGenerator::new(&analyzed); - for s in syms { - compiler.request_symbol(s).unwrap(); - } - compiler.compiled_symbols() -} - -#[test] -fn empty_code() { - let result = compile("", &[]); - assert_eq!(result, ""); -} - -#[test] -fn simple_fun() { - let result = compile("let c: int -> int = |i| i;", &["c"]); - assert_eq!( - result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" - ); -} - -#[test] -fn fun_calls() { - let result = compile( - "let c: int -> int = |i| i + 20; let d = |k| c(k * 20);", - &["c", "d"], - ); - assert_eq!( - result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { ((i).clone() + (num_bigint::BigInt::from(20_u64)).clone()) } - -fn d(k: num_bigint::BigInt) -> num_bigint::BigInt { (c)(((k).clone() * (num_bigint::BigInt::from(20_u64)).clone()).clone()) } -" - ); -} diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index dbb1b0f4ee..697cb03919 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -1,10 +1,9 @@ -use powdr_jit_compiler::compiler; use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; fn compile(input: &str, symbol: &str) -> fn(u64) -> u64 { let analyzed = analyze_string::(input); - compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] + powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] } #[test] @@ -29,13 +28,10 @@ fn sqrt() { "sqrt", ); - for i in 0..100000 { - f(879882356 * 112655675); - // assert_eq!(f(9), 3); - // assert_eq!(f(100), 10); - // assert_eq!(f(8), 2); - // assert_eq!(f(101), 10); - // assert_eq!(f(99), 9); - // assert_eq!(f(0), 0); - } + assert_eq!(f(9), 3); + assert_eq!(f(100), 10); + assert_eq!(f(8), 2); + assert_eq!(f(101), 10); + assert_eq!(f(99), 9); + assert_eq!(f(0), 0); } From 1578b1e5d32c8c064f337c814dca41ac466b6dc6 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:30:38 +0000 Subject: [PATCH 18/62] fix --- jit-compiler/src/codegen.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 841c6c557d..aaa2a533dd 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -311,7 +311,7 @@ mod test { let result = compile("let c: int -> int = |i| i;", &["c"]); assert_eq!( result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + "fn c(i: powdr_number::BigInt) -> powdr_number::BigInt { i }\n" ); } @@ -323,9 +323,9 @@ mod test { ); assert_eq!( result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { ((i).clone() + (num_bigint::BigInt::from(20_u64)).clone()) } + "fn c(i: powdr_number::BigInt) -> powdr_number::BigInt { ((i).clone() + (powdr_number::BigInt::from(20_u64)).clone()) } -fn d(k: num_bigint::BigInt) -> num_bigint::BigInt { (c)(((k).clone() * (num_bigint::BigInt::from(20_u64)).clone()).clone()) } +fn d(k: powdr_number::BigInt) -> powdr_number::BigInt { (c)(((k).clone() * (powdr_number::BigInt::from(20_u64)).clone()).clone()) } " ); } From ffa6187715458ef401ae139b35495ec50bea4354 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:40:21 +0000 Subject: [PATCH 19/62] Some logging. --- jit-compiler/Cargo.toml | 4 ++++ jit-compiler/src/lib.rs | 7 ++++++- jit-compiler/tests/execution.rs | 2 ++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index 94da05758b..f3330c1d9d 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -13,12 +13,16 @@ powdr-number.workspace = true powdr-parser.workspace = true libc = "0.2.0" +log = "0.4.18" mktemp = "0.5.0" itertools = "0.13" [dev-dependencies] powdr-pil-analyzer.workspace = true pretty_assertions = "1.4.0" +test-log = "0.2.12" +env_logger = "0.10.0" + [lints.clippy] diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index cb49f31332..2102aea549 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,7 +1,7 @@ mod codegen; mod compiler; -use std::collections::HashMap; +use std::{collections::HashMap, fs}; use compiler::{call_cargo, create_full_code, load_library}; use powdr_ast::analyzed::Analyzed; @@ -16,11 +16,16 @@ pub fn compile( analyzed: &Analyzed, symbols: &[&str], ) -> Result { + log::info!("JIT-compiling {} symbols...", symbols.len()); let code = create_full_code(analyzed, symbols)?; let (dir, lib_path) = call_cargo(&code); + let metadata = fs::metadata(&lib_path).unwrap(); + + log::info!("Loading library with size {}...", metadata.len()); let result = load_library(&lib_path, symbols); + log::info!("Done."); drop(dir); result diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 697cb03919..ea705635c7 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -1,3 +1,5 @@ +use test_log::test; + use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; From 51cae14c50f4e8da4863789528ec062738db500f Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:45:28 +0000 Subject: [PATCH 20/62] size in mb. --- jit-compiler/src/lib.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 2102aea549..b88cc02e4c 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -22,7 +22,10 @@ pub fn compile( let (dir, lib_path) = call_cargo(&code); let metadata = fs::metadata(&lib_path).unwrap(); - log::info!("Loading library with size {}...", metadata.len()); + log::info!( + "Loading library with size {} MB...", + metadata.len() as f64 / 1000000.0 + ); let result = load_library(&lib_path, symbols); log::info!("Done."); From 22de29a751afd97ed2f538fcc6d3100d424694bc Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 18:48:44 +0200 Subject: [PATCH 21/62] Update pil-analyzer/tests/types.rs --- pil-analyzer/tests/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pil-analyzer/tests/types.rs b/pil-analyzer/tests/types.rs index a149de289b..253ed4e34c 100644 --- a/pil-analyzer/tests/types.rs +++ b/pil-analyzer/tests/types.rs @@ -704,7 +704,7 @@ fn trait_user_defined_enum_wrong_type() { } let n: int = 7; - let r1: int = Convert::convert(n); + let r1 = Convert::convert(n); "; type_check(input, &[]); } From 2682d279e3cdaa9b2e5f83947ded06e133cf81e3 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:52:40 +0000 Subject: [PATCH 22/62] Use ibig. --- jit-compiler/src/codegen.rs | 14 +++++++------- jit-compiler/src/compiler.rs | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index aaa2a533dd..f05edd4140 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -118,9 +118,9 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), "std::field::modulus" => { let modulus = T::modulus(); - Some(format!("() -> powdr_number::BigInt {{ powdr_number::BigInt::from(\"{modulus}\") }}")) + Some(format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}")) } - "std::convert::fe" => Some("(n: powdr_number::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + "std::convert::fe" => Some("(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()), _ => None, }?; @@ -150,7 +150,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { ) => { let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); match type_ { - Type::Int => format!("powdr_number::BigInt::from({value}_u64)"), + Type::Int => format!("ibig::IBig::from({value}_u64)"), Type::Fe => format!("FieldElement::from({value}_u64)"), Type::Expr => format!("Expr::from({value}_u64)"), _ => unreachable!(), @@ -269,7 +269,7 @@ pub fn escape_symbol(s: &str) -> String { fn map_type(ty: &Type) -> String { match ty { Type::Bottom | Type::Bool => format!("{ty}"), - Type::Int => "powdr_number::BigInt".to_string(), + Type::Int => "ibig::IBig".to_string(), Type::Fe => "FieldElement".to_string(), Type::String => "String".to_string(), Type::Expr => "Expr".to_string(), @@ -311,7 +311,7 @@ mod test { let result = compile("let c: int -> int = |i| i;", &["c"]); assert_eq!( result, - "fn c(i: powdr_number::BigInt) -> powdr_number::BigInt { i }\n" + "fn c(i: ibig::IBig) -> ibig::IBig { i }\n" ); } @@ -323,9 +323,9 @@ mod test { ); assert_eq!( result, - "fn c(i: powdr_number::BigInt) -> powdr_number::BigInt { ((i).clone() + (powdr_number::BigInt::from(20_u64)).clone()) } + "fn c(i: ibig::IBig) -> ibig::IBig { ((i).clone() + (ibig::IBig::from(20_u64)).clone()) } -fn d(k: powdr_number::BigInt) -> powdr_number::BigInt { (c)(((k).clone() * (powdr_number::BigInt::from(20_u64)).clone()).clone()) } +fn d(k: ibig::IBig) -> ibig::IBig { (c)(((k).clone() * (ibig::IBig::from(20_u64)).clone()).clone()) } " ); } diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 60ae6e706f..321d39eb87 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -54,7 +54,7 @@ pub fn create_full_code( r#" #[no_mangle] pub extern fn {}(i: u64) -> u64 {{ - u64::try_from({name}(powdr_number::BigInt::from(i))).unwrap() + u64::try_from({name}(ibig::IBig::from(i))).unwrap() }} "#, extern_symbol_name(sym) @@ -77,7 +77,7 @@ edition = "2021" crate-type = ["dylib"] [dependencies] -powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } +ibig = { version = "0.3.6", features = [] } "#; /// Compiles the given code and returns the path to the From e9d291096af77b529a71bf97a22bb3e0d9a49b95 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 17:09:22 +0000 Subject: [PATCH 23/62] Use native cpu. --- jit-compiler/src/compiler.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 321d39eb87..fef3e5e89a 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -89,6 +89,7 @@ pub fn call_cargo(code: &str) -> (Temp, String) { fs::create_dir(dir.join("src")).unwrap(); fs::write(dir.join("src").join("lib.rs"), code).unwrap(); let out = Command::new("cargo") + .env("RUSTFLAGS", "-C target-cpu=native") .arg("build") .arg("--release") .current_dir(dir.clone()) From 541d8dc9c6fd1f24e88c3698370605f1ae3b4724 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 19 Sep 2024 13:56:55 +0000 Subject: [PATCH 24/62] clippy --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index fef3e5e89a..c987661605 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -41,7 +41,7 @@ pub fn create_full_code( .into(); for sym in symbols { let ty = analyzed.type_of_symbol(sym); - if &ty != &int_int_fun { + if ty != int_int_fun { return Err(format!( "Only (int -> int) functions are supported, but requested {}", format_type_scheme_around_name(sym, &Some(ty)), From 5aea6b2f394618c50bd70bc8944a72773c4dc0de Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 23 Sep 2024 13:51:09 +0000 Subject: [PATCH 25/62] merge fix. --- jit-compiler/src/codegen.rs | 7 ++----- jit-compiler/tests/execution.rs | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index f05edd4140..dc49df5626 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -292,7 +292,7 @@ mod test { use super::CodeGenerator; fn compile(input: &str, syms: &[&str]) -> String { - let analyzed = analyze_string::(input); + let analyzed = analyze_string::(input).unwrap(); let mut compiler = CodeGenerator::new(&analyzed); for s in syms { compiler.request_symbol(s).unwrap(); @@ -309,10 +309,7 @@ mod test { #[test] fn simple_fun() { let result = compile("let c: int -> int = |i| i;", &["c"]); - assert_eq!( - result, - "fn c(i: ibig::IBig) -> ibig::IBig { i }\n" - ); + assert_eq!(result, "fn c(i: ibig::IBig) -> ibig::IBig { i }\n"); } #[test] diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index ea705635c7..b20b00db24 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -4,7 +4,7 @@ use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; fn compile(input: &str, symbol: &str) -> fn(u64) -> u64 { - let analyzed = analyze_string::(input); + let analyzed = analyze_string::(input).unwrap(); powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] } From 1a6db99908ac40e3ec0e2258b41698a9bb587431 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 23 Sep 2024 15:05:16 +0000 Subject: [PATCH 26/62] Remove ibig features. --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index c987661605..d1f48f5fff 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -77,7 +77,7 @@ edition = "2021" crate-type = ["dylib"] [dependencies] -ibig = { version = "0.3.6", features = [] } +ibig = { version = "0.3.6", features = [], default-features = false } "#; /// Compiles the given code and returns the path to the From cdf3ba99aa80d355934989b443082087f51a60c1 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 00:33:56 +0000 Subject: [PATCH 27/62] Nicer error messages. --- jit-compiler/src/codegen.rs | 54 +++++++++++++++++++++++------------- jit-compiler/src/compiler.rs | 13 +++++---- jit-compiler/src/lib.rs | 2 +- 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index dc49df5626..cad73cc92d 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -83,26 +83,19 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { value: return_type, }), } => { - let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = - &value.e - else { - return Err(format!( - "Expected lambda expression for {symbol}, got {}", - value.e - )); - }; assert!(vars.is_empty()); - format!( - "fn {}({}) -> {} {{ {} }}\n", - escape_symbol(symbol), - params - .iter() - .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(&t))) - .format(", "), - map_type(return_type.as_ref()), - self.format_expr(body)? - ) + self.try_format_function(symbol, ¶m_types, return_type.as_ref(), &value.e)? + } + TypeScheme { + vars, + ty: Type::Col, + } => { + assert!(vars.is_empty()); + // TODO we assume it is an int -> int function. + // The type inference algorithm should store the derived type. + // Alternatively, we insert a trait conversion function and store the type + // in the trait vars. + self.try_format_function(symbol, &[Type::Int], &Type::Int, &value.e)? } _ => format!( "const {}: {} = {};\n", @@ -113,6 +106,29 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { }) } + fn try_format_function( + &mut self, + name: &str, + param_types: &[Type], + return_type: &Type, + expr: &Expression, + ) -> Result { + let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = expr else { + return Err(format!("Expected lambda expression for {name}, got {expr}",)); + }; + Ok(format!( + "fn {}({}) -> {} {{ {} }}\n", + escape_symbol(name), + params + .iter() + .zip(param_types) + .map(|(p, t)| format!("{}: {}", p, map_type(&t))) + .format(", "), + map_type(return_type), + self.format_expr(body)? + )) + } + fn try_generate_builtin(&self, symbol: &str) -> Option { let code = match symbol { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index d1f48f5fff..d9ce985502 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -5,6 +5,7 @@ use std::{ ffi::CString, fs::{self}, process::Command, + str::from_utf8, }; use powdr_ast::{ @@ -41,9 +42,9 @@ pub fn create_full_code( .into(); for sym in symbols { let ty = analyzed.type_of_symbol(sym); - if ty != int_int_fun { + if ty != int_int_fun && ty.ty != Type::Col { return Err(format!( - "Only (int -> int) functions are supported, but requested {}", + "Only (int -> int) functions and columns are supported, but requested {}", format_type_scheme_around_name(sym, &Some(ty)), )); } @@ -83,7 +84,7 @@ ibig = { version = "0.3.6", features = [], default-features = false } /// Compiles the given code and returns the path to the /// temporary directory containing the compiled library /// and the path to the compiled library. -pub fn call_cargo(code: &str) -> (Temp, String) { +pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { let dir = mktemp::Temp::new_dir().unwrap(); fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); fs::create_dir(dir.join("src")).unwrap(); @@ -95,15 +96,15 @@ pub fn call_cargo(code: &str) -> (Temp, String) { .current_dir(dir.clone()) .output() .unwrap(); - out.stderr.iter().for_each(|b| print!("{}", *b as char)); if !out.status.success() { - panic!("Failed to compile."); + let stderr = from_utf8(&out.stderr).unwrap_or("UTF-8 error in error message."); + return Err(format!("Failed to compile: {stderr}.")); } let lib_path = dir .join("target") .join("release") .join("libpowdr_jit_compiled.so"); - (dir, lib_path.to_str().unwrap().to_string()) + Ok((dir, lib_path.to_str().unwrap().to_string())) } /// Loads the given library and creates funtion pointers for the given symbols. diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index b88cc02e4c..8d107fafda 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -19,7 +19,7 @@ pub fn compile( log::info!("JIT-compiling {} symbols...", symbols.len()); let code = create_full_code(analyzed, symbols)?; - let (dir, lib_path) = call_cargo(&code); + let (dir, lib_path) = call_cargo(&code)?; let metadata = fs::metadata(&lib_path).unwrap(); log::info!( From 7b0dcebf4c64d0f7c248d61af9b8abb0c0993292 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 12:23:03 +0000 Subject: [PATCH 28/62] Partial compile. --- jit-compiler/src/codegen.rs | 19 ++++++++++++------- jit-compiler/src/compiler.rs | 12 ++++-------- jit-compiler/src/lib.rs | 33 ++++++++++++++++++++++++++------- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index cad73cc92d..e8f2066fe4 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -293,7 +293,12 @@ fn map_type(ty: &Type) -> String { Type::Tuple(_) => todo!(), Type::Function(ft) => todo!("Type {ft}"), Type::TypeVar(tv) => tv.to_string(), - Type::NamedType(_path, _type_args) => todo!(), + Type::NamedType(path, type_args) => { + if type_args.is_some() { + unimplemented!() + } + escape_symbol(&path.to_string()) + } Type::Col | Type::Inter => unreachable!(), } } @@ -335,11 +340,11 @@ mod test { &["c", "d"], ); assert_eq!( - result, - "fn c(i: ibig::IBig) -> ibig::IBig { ((i).clone() + (ibig::IBig::from(20_u64)).clone()) } - -fn d(k: ibig::IBig) -> ibig::IBig { (c)(((k).clone() * (ibig::IBig::from(20_u64)).clone()).clone()) } -" - ); + result, + "fn c(i: ibig::IBig) -> ibig::IBig { ((i).clone() + (ibig::IBig::from(20_u64)).clone()) }\n\ + \n\ + fn d(k: ibig::IBig) -> ibig::IBig { (c)(((k).clone() * (ibig::IBig::from(20_u64)).clone()).clone()) }\n\ + " + ); } } diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index d9ce985502..4ab9a30ba9 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -29,11 +29,10 @@ const PREAMBLE: &str = r#" //type FieldElement = powdr_number::GoldilocksField; "#; -pub fn create_full_code( - analyzed: &Analyzed, +pub fn generate_glue_code( symbols: &[&str], + analyzed: &Analyzed, ) -> Result { - let mut codegen = CodeGenerator::new(analyzed); let mut glue = String::new(); let int_int_fun: TypeScheme = Type::Function(FunctionType { params: vec![Type::Int], @@ -48,7 +47,7 @@ pub fn create_full_code( format_type_scheme_around_name(sym, &Some(ty)), )); } - codegen.request_symbol(sym)?; + // TODO we should use big int instead of u64 let name = escape_symbol(sym); glue.push_str(&format!( @@ -62,10 +61,7 @@ pub fn create_full_code( )); } - Ok(format!( - "{PREAMBLE}\n{}\n{glue}\n", - codegen.compiled_symbols() - )) + Ok(format!("{PREAMBLE}\n{glue}\n",)) } const CARGO_TOML: &str = r#" diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 8d107fafda..8c05ba67b8 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -3,7 +3,8 @@ mod compiler; use std::{collections::HashMap, fs}; -use compiler::{call_cargo, create_full_code, load_library}; +use codegen::CodeGenerator; +use compiler::{call_cargo, generate_glue_code, load_library}; use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; @@ -14,12 +15,30 @@ pub type SymbolMap = HashMap u64>; /// Only functions of type (int -> int) are supported for now. pub fn compile( analyzed: &Analyzed, - symbols: &[&str], + requested_symbols: &[&str], ) -> Result { - log::info!("JIT-compiling {} symbols...", symbols.len()); - let code = create_full_code(analyzed, symbols)?; - - let (dir, lib_path) = call_cargo(&code)?; + log::info!("JIT-compiling {} symbols...", requested_symbols.len()); + + let mut codegen = CodeGenerator::new(analyzed); + let successful_symbols = requested_symbols + .into_iter() + .filter_map(|&sym| { + if let Err(e) = codegen.request_symbol(sym) { + log::warn!("Unable to generate code for symbol {sym}: {e}"); + None + } else { + Some(sym) + } + }) + .collect::>(); + + if successful_symbols.is_empty() { + return Ok(Default::default()); + }; + + let glue_code = generate_glue_code(&successful_symbols, analyzed)?; + + let (dir, lib_path) = call_cargo(&format!("{glue_code}\n{}\n", codegen.compiled_symbols()))?; let metadata = fs::metadata(&lib_path).unwrap(); log::info!( @@ -27,7 +46,7 @@ pub fn compile( metadata.len() as f64 / 1000000.0 ); - let result = load_library(&lib_path, symbols); + let result = load_library(&lib_path, &successful_symbols); log::info!("Done."); drop(dir); From 76e63890536c1966ffe7bf1a280446bde23dbf63 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 12:38:20 +0000 Subject: [PATCH 29/62] clippy --- jit-compiler/Cargo.toml | 2 -- jit-compiler/src/codegen.rs | 2 +- jit-compiler/src/compiler.rs | 2 +- jit-compiler/src/lib.rs | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index f3330c1d9d..2ae4e3c8ac 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -23,7 +23,5 @@ pretty_assertions = "1.4.0" test-log = "0.2.12" env_logger = "0.10.0" - - [lints.clippy] uninlined_format_args = "deny" diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index e8f2066fe4..368ea1c26d 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -122,7 +122,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { params .iter() .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(&t))) + .map(|(p, t)| format!("{p}: {}", map_type(t))) .format(", "), map_type(return_type), self.format_expr(body)? diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 4ab9a30ba9..fd5504f627 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -18,7 +18,7 @@ use powdr_ast::{ use powdr_number::FieldElement; use crate::{ - codegen::{escape_symbol, CodeGenerator}, + codegen::{escape_symbol}, SymbolMap, }; diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 8c05ba67b8..27d7f7ad37 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -21,7 +21,7 @@ pub fn compile( let mut codegen = CodeGenerator::new(analyzed); let successful_symbols = requested_symbols - .into_iter() + .iter() .filter_map(|&sym| { if let Err(e) = codegen.request_symbol(sym) { log::warn!("Unable to generate code for symbol {sym}: {e}"); From 00b56e6ae45594304bf85724e383b0b17f84176b Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 12:44:43 +0000 Subject: [PATCH 30/62] fmt --- jit-compiler/src/compiler.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index fd5504f627..bafa3dea9a 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -17,10 +17,7 @@ use powdr_ast::{ }; use powdr_number::FieldElement; -use crate::{ - codegen::{escape_symbol}, - SymbolMap, -}; +use crate::{codegen::escape_symbol, SymbolMap}; // TODO make this depend on T From 3544d576a37dec02b572752023bb67cec82dd845 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 15:39:09 +0200 Subject: [PATCH 31/62] Update jit-compiler/src/lib.rs --- jit-compiler/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 27d7f7ad37..4656bb6f28 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -42,7 +42,7 @@ pub fn compile( let metadata = fs::metadata(&lib_path).unwrap(); log::info!( - "Loading library with size {} MB...", + "Loading library of size {} MB...", metadata.len() as f64 / 1000000.0 ); From 5cf259cc20ad77bee86682b7c54ec4e072f62ca4 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 16:19:31 +0000 Subject: [PATCH 32/62] Portability. --- jit-compiler/src/compiler.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index bafa3dea9a..1ecfd885b1 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -68,7 +68,7 @@ version = "0.1.0" edition = "2021" [lib] -crate-type = ["dylib"] +crate-type = ["cdylib"] [dependencies] ibig = { version = "0.3.6", features = [], default-features = false } @@ -93,10 +93,17 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { let stderr = from_utf8(&out.stderr).unwrap_or("UTF-8 error in error message."); return Err(format!("Failed to compile: {stderr}.")); } + let extension = if cfg!(target_os = "windows") { + "dll" + } else if cfg!(target_os = "macos") { + "dylib" + } else { + "so" + }; let lib_path = dir .join("target") .join("release") - .join("libpowdr_jit_compiled.so"); + .join(&format!("libpowdr_jit_compiled.{extension}")); Ok((dir, lib_path.to_str().unwrap().to_string())) } From e928a376cb8818425d5e459bef6eb638deccde88 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 18:21:00 +0200 Subject: [PATCH 33/62] Update jit-compiler/tests/execution.rs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gastón Zanitti --- jit-compiler/tests/execution.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index b20b00db24..89d4dcd753 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -14,7 +14,11 @@ fn identity_function() { assert_eq!(f(10), 10); } - +#[test] +#[should_panic = "Only (int -> int) functions and columns are supported, but requested c: int -> bool"] +fn invalid_function() { + let _ = compile("let c: int -> bool = |i| true;", "c"); +} #[test] fn sqrt() { let f = compile( From 86b0a2df7e3b1da550d7cd8e98e8dacf4803e685 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 16:22:54 +0000 Subject: [PATCH 34/62] fix error message. --- jit-compiler/src/compiler.rs | 2 +- jit-compiler/tests/execution.rs | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 1ecfd885b1..27127d5e05 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -40,7 +40,7 @@ pub fn generate_glue_code( let ty = analyzed.type_of_symbol(sym); if ty != int_int_fun && ty.ty != Type::Col { return Err(format!( - "Only (int -> int) functions and columns are supported, but requested {}", + "Only (int -> int) functions and columns are supported, but requested{}", format_type_scheme_around_name(sym, &Some(ty)), )); } diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 89d4dcd753..e85c4e8fe5 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -14,11 +14,7 @@ fn identity_function() { assert_eq!(f(10), 10); } -#[test] -#[should_panic = "Only (int -> int) functions and columns are supported, but requested c: int -> bool"] -fn invalid_function() { - let _ = compile("let c: int -> bool = |i| true;", "c"); -} + #[test] fn sqrt() { let f = compile( @@ -41,3 +37,9 @@ fn sqrt() { assert_eq!(f(99), 9); assert_eq!(f(0), 0); } + +#[test] +#[should_panic = "Only (int -> int) functions and columns are supported, but requested c: int -> bool"] +fn invalid_function() { + let _ = compile("let c: int -> bool = |i| true;", "c"); +} From 3327e520862b10d7d10239233e2369b51f4dd683 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 16:42:02 +0000 Subject: [PATCH 35/62] clippy --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 27127d5e05..1e9f3ebf9a 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -103,7 +103,7 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { let lib_path = dir .join("target") .join("release") - .join(&format!("libpowdr_jit_compiled.{extension}")); + .join(format!("libpowdr_jit_compiled.{extension}")); Ok((dir, lib_path.to_str().unwrap().to_string())) } From 2a456531b23a15320e34b2439527a7d0ee980d05 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 10:44:36 +0000 Subject: [PATCH 36/62] Use extern c. --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 1e9f3ebf9a..fa61efc2df 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -50,7 +50,7 @@ pub fn generate_glue_code( glue.push_str(&format!( r#" #[no_mangle] - pub extern fn {}(i: u64) -> u64 {{ + pub extern "C" fn {}(i: u64) -> u64 {{ u64::try_from({name}(ibig::IBig::from(i))).unwrap() }} "#, From 5647eb6553ca050cbcbe73c5ef053c2be8ea0af7 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 13:40:49 +0000 Subject: [PATCH 37/62] Use libloading. --- jit-compiler/Cargo.toml | 2 +- jit-compiler/src/compiler.rs | 44 +++++++++++++------------ jit-compiler/src/lib.rs | 22 ++++++++++--- jit-compiler/tests/execution.rs | 19 ++++++----- pipeline/benches/evaluator_benchmark.rs | 4 +-- 5 files changed, 54 insertions(+), 37 deletions(-) diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index 2ae4e3c8ac..5be80611ba 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -12,10 +12,10 @@ powdr-ast.workspace = true powdr-number.workspace = true powdr-parser.workspace = true -libc = "0.2.0" log = "0.4.18" mktemp = "0.5.0" itertools = "0.13" +libloading = "0.8" [dev-dependencies] powdr-pil-analyzer.workspace = true diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index fa61efc2df..51faccd461 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,11 +1,10 @@ -use libc::{c_void, dlopen, dlsym, RTLD_NOW}; use mktemp::Temp; use std::{ collections::HashMap, - ffi::CString, fs::{self}, process::Command, str::from_utf8, + sync::Arc, }; use powdr_ast::{ @@ -17,7 +16,7 @@ use powdr_ast::{ }; use powdr_number::FieldElement; -use crate::{codegen::escape_symbol, SymbolMap}; +use crate::{codegen::escape_symbol, LoadedFunction}; // TODO make this depend on T @@ -108,24 +107,27 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { } /// Loads the given library and creates funtion pointers for the given symbols. -pub fn load_library(path: &str, symbols: &[&str]) -> Result { - let c_path = CString::new(path).unwrap(); - let lib = unsafe { dlopen(c_path.as_ptr(), RTLD_NOW) }; - if lib.is_null() { - return Err(format!("Failed to load library: {path:?}")); - } - let mut result = HashMap::new(); - for sym in symbols { - let extern_sym = extern_symbol_name(sym); - let sym_cstr = CString::new(extern_sym).unwrap(); - let fun_ptr = unsafe { dlsym(lib, sym_cstr.as_ptr()) }; - if fun_ptr.is_null() { - return Err(format!("Failed to load symbol: {fun_ptr:?}")); - } - let fun = unsafe { std::mem::transmute::<*mut c_void, fn(u64) -> u64>(fun_ptr) }; - result.insert(sym.to_string(), fun); - } - Ok(result) +pub fn load_library( + path: &str, + symbols: &[&str], +) -> Result, String> { + let library = Arc::new( + unsafe { libloading::Library::new(path) } + .map_err(|e| format!("Error loading library at {path}: {e}"))?, + ); + symbols + .iter() + .map(|&sym| { + let extern_sym = extern_symbol_name(sym); + let function = *unsafe { library.get:: u64>(extern_sym.as_bytes()) } + .map_err(|e| format!("Error accessing symbol {sym}: {e}"))?; + let fun = LoadedFunction { + library: library.clone(), + function, + }; + Ok((sym.to_string(), fun)) + }) + .collect::>() } fn extern_symbol_name(sym: &str) -> String { diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 4656bb6f28..b6c12b60a9 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,22 +1,36 @@ mod codegen; mod compiler; -use std::{collections::HashMap, fs}; +use std::{collections::HashMap, fs, sync::Arc}; use codegen::CodeGenerator; use compiler::{call_cargo, generate_glue_code, load_library}; + use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; -pub type SymbolMap = HashMap u64>; +/// Wrapper around a dynamically loaded function. +/// Prevents the dynamically loaded library to be unloaded while the function is still in use. +#[derive(Clone)] +pub struct LoadedFunction { + #[allow(dead_code)] + library: Arc, + function: fn(u64) -> u64, +} + +impl LoadedFunction { + pub fn call(&self, arg: u64) -> u64 { + (self.function)(arg) + } +} /// Compiles the given symbols (and their dependencies) and returns them as a map -/// from symbol name to function pointer. +/// from symbol name to function. /// Only functions of type (int -> int) are supported for now. pub fn compile( analyzed: &Analyzed, requested_symbols: &[&str], -) -> Result { +) -> Result, String> { log::info!("JIT-compiling {} symbols...", requested_symbols.len()); let mut codegen = CodeGenerator::new(analyzed); diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index e85c4e8fe5..f6d80c7bed 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -1,18 +1,19 @@ +use powdr_jit_compiler::LoadedFunction; use test_log::test; use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; -fn compile(input: &str, symbol: &str) -> fn(u64) -> u64 { +fn compile(input: &str, symbol: &str) -> LoadedFunction { let analyzed = analyze_string::(input).unwrap(); - powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] + powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol].clone() } #[test] fn identity_function() { let f = compile("let c: int -> int = |i| i;", "c"); - assert_eq!(f(10), 10); + assert_eq!(f.call(10), 10); } #[test] @@ -30,12 +31,12 @@ fn sqrt() { "sqrt", ); - assert_eq!(f(9), 3); - assert_eq!(f(100), 10); - assert_eq!(f(8), 2); - assert_eq!(f(101), 10); - assert_eq!(f(99), 9); - assert_eq!(f(0), 0); + assert_eq!(f.call(9), 3); + assert_eq!(f.call(100), 10); + assert_eq!(f.call(8), 2); + assert_eq!(f.call(101), 10); + assert_eq!(f.call(99), 9); + assert_eq!(f.call(0), 0); } #[test] diff --git a/pipeline/benches/evaluator_benchmark.rs b/pipeline/benches/evaluator_benchmark.rs index 76d027cc23..b0d67fee16 100644 --- a/pipeline/benches/evaluator_benchmark.rs +++ b/pipeline/benches/evaluator_benchmark.rs @@ -132,13 +132,13 @@ fn jit_benchmark(c: &mut Criterion) { pipeline.compute_analyzed_pil().unwrap().clone() }; - let sqrt_fun = powdr_jit_compiler::compile(&sqrt_analyzed, &["sqrt"]).unwrap()["sqrt"]; + let sqrt_fun = &powdr_jit_compiler::compile(&sqrt_analyzed, &["sqrt"]).unwrap()["sqrt"]; for x in [879882356, 1882356, 1187956, 56] { group.bench_with_input(format!("sqrt_{x}"), &x, |b, &x| { b.iter(|| { let y = (x as u64) * 112655675; - sqrt_fun(y); + sqrt_fun.call(y); }); }); } From 540671d8889cf8183a5a88a41ae1a08aa977bc9c Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 15:43:08 +0200 Subject: [PATCH 38/62] Update jit-compiler/src/compiler.rs Co-authored-by: Georg Wiese --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 51faccd461..4c506876f4 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -106,7 +106,7 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { Ok((dir, lib_path.to_str().unwrap().to_string())) } -/// Loads the given library and creates funtion pointers for the given symbols. +/// Loads the given library and creates function pointers for the given symbols. pub fn load_library( path: &str, symbols: &[&str], From 2937860c3fe2abcd09970655ade71c9ee7d489a0 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 13:55:50 +0000 Subject: [PATCH 39/62] Extract sqrt code. --- pipeline/benches/evaluator_benchmark.rs | 34 +++++++++---------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/pipeline/benches/evaluator_benchmark.rs b/pipeline/benches/evaluator_benchmark.rs index b0d67fee16..a55522dc64 100644 --- a/pipeline/benches/evaluator_benchmark.rs +++ b/pipeline/benches/evaluator_benchmark.rs @@ -9,6 +9,16 @@ use powdr_pipeline::test_util::{evaluate_function, evaluate_integer_function, st use criterion::{criterion_group, criterion_main, Criterion}; +const SQRT_CODE: &str = " + let sqrt: int -> int = |x| sqrt_rec(x, x); + let sqrt_rec: int, int -> int = |y, x| + if y * y <= x && (y + 1) * (y + 1) > x { + y + } else { + sqrt_rec((y + x / y) / 2, x) + }; +"; + fn evaluator_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("evaluator-benchmark"); @@ -67,17 +77,7 @@ fn evaluator_benchmark(c: &mut Criterion) { }); let sqrt_analyzed: Analyzed = { - let code = " - let sqrt: int -> int = |x| sqrt_rec(x, x); - let sqrt_rec: int, int -> int = |y, x| - if y * y <= x && (y + 1) * (y + 1) > x { - y - } else { - sqrt_rec((y + x / y) / 2, x) - }; - " - .to_string(); - let mut pipeline = Pipeline::default().from_asm_string(code, None); + let mut pipeline = Pipeline::default().from_asm_string(SQRT_CODE.to_string(), None); pipeline.compute_analyzed_pil().unwrap().clone() }; @@ -118,17 +118,7 @@ fn jit_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("jit-benchmark"); let sqrt_analyzed: Analyzed = { - let code = " - let sqrt: int -> int = |x| sqrt_rec(x, x); - let sqrt_rec: int, int -> int = |y, x| - if y * y <= x && (y + 1) * (y + 1) > x { - y - } else { - sqrt_rec((y + x / y) / 2, x) - }; - " - .to_string(); - let mut pipeline = Pipeline::default().from_asm_string(code, None); + let mut pipeline = Pipeline::default().from_asm_string(SQRT_CODE.to_string(), None); pipeline.compute_analyzed_pil().unwrap().clone() }; From 681937aa094104e5190cac281139d67e62034d96 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 15:03:51 +0000 Subject: [PATCH 40/62] Remove drop. --- jit-compiler/src/lib.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index b6c12b60a9..77216bb84b 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -62,7 +62,5 @@ pub fn compile( let result = load_library(&lib_path, &successful_symbols); log::info!("Done."); - - drop(dir); result } From 6a96b72831d6c6588a40922f8067039519e3e019 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 15:05:11 +0000 Subject: [PATCH 41/62] Add release - we need the variable. --- jit-compiler/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 77216bb84b..d530345bab 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -62,5 +62,7 @@ pub fn compile( let result = load_library(&lib_path, &successful_symbols); log::info!("Done."); + + dir.release(); result } From 8db0eba442031659f1b6353605f25097d7e909bb Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 15:08:55 +0000 Subject: [PATCH 42/62] Encapsulate temp dir in struct. --- jit-compiler/src/compiler.rs | 14 ++++++++++++-- jit-compiler/src/lib.rs | 8 +++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 4c506876f4..c864d3ba37 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -73,10 +73,17 @@ crate-type = ["cdylib"] ibig = { version = "0.3.6", features = [], default-features = false } "#; +pub struct PathInTempDir { + #[allow(dead_code)] + dir: Temp, + /// The absolute path + pub path: String, +} + /// Compiles the given code and returns the path to the /// temporary directory containing the compiled library /// and the path to the compiled library. -pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { +pub fn call_cargo(code: &str) -> Result { let dir = mktemp::Temp::new_dir().unwrap(); fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); fs::create_dir(dir.join("src")).unwrap(); @@ -103,7 +110,10 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { .join("target") .join("release") .join(format!("libpowdr_jit_compiled.{extension}")); - Ok((dir, lib_path.to_str().unwrap().to_string())) + Ok(PathInTempDir { + dir, + path: lib_path.to_str().unwrap().to_string(), + }) } /// Loads the given library and creates function pointers for the given symbols. diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index d530345bab..a8679a962c 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -52,17 +52,15 @@ pub fn compile( let glue_code = generate_glue_code(&successful_symbols, analyzed)?; - let (dir, lib_path) = call_cargo(&format!("{glue_code}\n{}\n", codegen.compiled_symbols()))?; - let metadata = fs::metadata(&lib_path).unwrap(); + let lib_file = call_cargo(&format!("{glue_code}\n{}\n", codegen.compiled_symbols()))?; + let metadata = fs::metadata(&lib_file.path).unwrap(); log::info!( "Loading library of size {} MB...", metadata.len() as f64 / 1000000.0 ); - let result = load_library(&lib_path, &successful_symbols); + let result = load_library(&lib_file.path, &successful_symbols); log::info!("Done."); - - dir.release(); result } From b47df417bff1a16faa105191b8e7489b1b7184de Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 16:03:25 +0000 Subject: [PATCH 43/62] Simplify compiler state. --- jit-compiler/src/codegen.rs | 55 ++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 368ea1c26d..df0164d5b0 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use itertools::Itertools; use powdr_ast::{ @@ -14,45 +14,56 @@ use powdr_number::FieldElement; pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, - requested: HashSet, - failed: HashMap, - symbols: HashMap, + /// Symbols mapping to either their code or an error message explaining + /// why they could not be compiled. + /// While the code is still being generated, this contains `None`. + symbols: HashMap, String>>, } impl<'a, T: FieldElement> CodeGenerator<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { analyzed, - requested: Default::default(), - failed: Default::default(), symbols: Default::default(), } } + /// Request a symbol to be compiled. The code can later be retrieved + /// via `compiled_symbols`. + /// In the error case, `self` can still be used to compile other symbols. pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { - if let Some(err) = self.failed.get(name) { - return Err(err.clone()); - } - if self.requested.contains(name) { - return Ok(()); - } - self.requested.insert(name.to_string()); - match self.generate_code(name) { - Ok(code) => { - self.symbols.insert(name.to_string(), code); - Ok(()) - } - Err(err) => { - let err = format!("Failed to compile {name}: {err}"); - self.failed.insert(name.to_string(), err.clone()); - Err(err) + match self.symbols.get(name) { + Some(Ok(_)) => Ok(()), + Some(Err(e)) => Err(e.clone()), + None => { + let name = name.to_string(); + self.symbols.insert(name.clone(), Ok(None)); + let to_insert; + let to_return; + match self.generate_code(&name) { + Ok(code) => { + to_insert = Ok(Some(code)); + to_return = Ok(()); + } + Err(err) => { + to_insert = Err(err.clone()); + to_return = Err(err); + } + } + self.symbols.insert(name, to_insert); + to_return } } } + /// Returns the concatenation of all successfully compiled symbols. pub fn compiled_symbols(self) -> String { self.symbols .into_iter() + .filter_map(|(s, r)| match r { + Ok(Some(code)) => Some((s, code)), + _ => None, + }) .sorted() .map(|(_, code)| code) .format("\n") From bdc2d38b83e973bbf513369f73a7662347c91eb9 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 16:04:22 +0000 Subject: [PATCH 44/62] Error message. --- jit-compiler/src/codegen.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index df0164d5b0..99b9367e76 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -284,7 +284,9 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { - Err(format!("Implement {s}")) + Err(format!( + "Compiling statements inside blocks is not yet implemented: {s}" + )) } } From 8c23119a0ee6626726ce12fe93309ac83de4c638 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 19:32:42 +0000 Subject: [PATCH 45/62] use unsafe extern C fn --- jit-compiler/src/compiler.rs | 5 +++-- jit-compiler/src/lib.rs | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index c864d3ba37..a92476b61f 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -129,8 +129,9 @@ pub fn load_library( .iter() .map(|&sym| { let extern_sym = extern_symbol_name(sym); - let function = *unsafe { library.get:: u64>(extern_sym.as_bytes()) } - .map_err(|e| format!("Error accessing symbol {sym}: {e}"))?; + let function = + *unsafe { library.get:: u64>(extern_sym.as_bytes()) } + .map_err(|e| format!("Error accessing symbol {sym}: {e}"))?; let fun = LoadedFunction { library: library.clone(), function, diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index a8679a962c..465663e46b 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -15,12 +15,12 @@ use powdr_number::FieldElement; pub struct LoadedFunction { #[allow(dead_code)] library: Arc, - function: fn(u64) -> u64, + function: unsafe extern "C" fn(u64) -> u64, } impl LoadedFunction { pub fn call(&self, arg: u64) -> u64 { - (self.function)(arg) + unsafe { (self.function)(arg) } } } From 9501ef9b9525765eae70911c23738de89779602e Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 19:33:34 +0000 Subject: [PATCH 46/62] use mebibytes. --- jit-compiler/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 465663e46b..fc72d3789e 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -57,7 +57,7 @@ pub fn compile( log::info!( "Loading library of size {} MB...", - metadata.len() as f64 / 1000000.0 + metadata.len() as f64 / (1024.0 * 1024.0) ); let result = load_library(&lib_file.path, &successful_symbols); From 1b44c1c83d4eeac41a8d945447df5ca8fdeedf40 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 16:39:52 +0000 Subject: [PATCH 47/62] Match expressions. --- jit-compiler/src/codegen.rs | 52 ++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 99b9367e76..66b6c050fd 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -7,7 +7,8 @@ use powdr_ast::{ display::{format_type_args, quote}, types::{ArrayType, FunctionType, Type, TypeScheme}, ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, - IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, + IndexAccess, LambdaExpression, MatchArm, MatchExpression, Number, Pattern, + StatementInsideBlock, UnaryOperation, }, }; use powdr_number::FieldElement; @@ -279,6 +280,22 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { .unwrap_or_default() ) } + Expression::MatchExpression(_, MatchExpression { scrutinee, arms }) => { + format!( + "match {} {{\n{}\n}}", + self.format_expr(scrutinee)?, + arms.iter() + .map(|MatchArm { pattern, value }| { + Ok(format!( + "{} => {},", + format_pattern(pattern), + self.format_expr(value)?, + )) + }) + .collect::, String>>()? + .join("\n") + ) + } _ => return Err(format!("Implement {e}")), }) } @@ -290,6 +307,31 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } +fn format_pattern(pattern: &Pattern) -> String { + match pattern { + Pattern::CatchAll(_) => "_".to_string(), + Pattern::Ellipsis(_) => "..".to_string(), + Pattern::Number(_, n) => { + // TODO this should probably fail if the number is too large. + n.to_string() + } + Pattern::String(_, s) => quote(s), + Pattern::Tuple(_, items) => { + format!("({})", items.iter().map(format_pattern).join(", ")) + } + Pattern::Array(_, items) => { + format!("[{}]", items.iter().map(format_pattern).join(", ")) + } + Pattern::Variable(_, var) => var.clone(), + Pattern::Enum(_, name, None) => escape_symbol(&name.to_string()), + Pattern::Enum(_, name, Some(fields)) => format!( + "{}({})", + escape_symbol(&name.to_string()), + fields.iter().map(format_pattern).join(", ") + ), + } +} + pub fn escape_symbol(s: &str) -> String { // TODO better escaping s.replace('.', "_").replace("::", "_") @@ -307,10 +349,12 @@ fn map_type(ty: &Type) -> String { Type::Function(ft) => todo!("Type {ft}"), Type::TypeVar(tv) => tv.to_string(), Type::NamedType(path, type_args) => { - if type_args.is_some() { - unimplemented!() + let name = escape_symbol(&path.to_string()); + if let Some(type_args) = type_args { + format!("{name}::<{}>", type_args.iter().map(map_type).join(", ")) + } else { + name } - escape_symbol(&path.to_string()) } Type::Col | Type::Inter => unreachable!(), } From 3cf303ba0f5ce746931c7a24150e56b3c87beea7 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 09:41:58 +0000 Subject: [PATCH 48/62] test. --- jit-compiler/tests/execution.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index f6d80c7bed..964aa52185 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -39,6 +39,24 @@ fn sqrt() { assert_eq!(f.call(0), 0); } +#[test] +fn match_expr() { + let f = compile( + "let f: int -> int = |x| match (x, [1, x, 3]) { + (0, [.., 3]) => 1, + (1, [1, ..]) => 2, + (2, [1, 2, 3]) => 3, + (_, [1, _, 3]) => 0, + };", + "f", + ); + + assert_eq!(f.call(0), 1); + assert_eq!(f.call(1), 2); + assert_eq!(f.call(2), 3); + assert_eq!(f.call(3), 0); +} + #[test] #[should_panic = "Only (int -> int) functions and columns are supported, but requested c: int -> bool"] fn invalid_function() { From 0d7df728099da86339bae06b954bd3c8b8318936 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 11:28:26 +0000 Subject: [PATCH 49/62] better match exprs --- jit-compiler/src/codegen.rs | 111 +++++++++++++++++++++++++------- jit-compiler/tests/execution.rs | 20 ++++-- 2 files changed, 101 insertions(+), 30 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 66b6c050fd..59010c780e 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use itertools::Itertools; +use itertools::{multiunzip, Itertools}; use powdr_ast::{ analyzed::{Analyzed, Expression, FunctionValueDefinition, PolynomialReference, Reference}, parsed::{ @@ -281,19 +281,21 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { ) } Expression::MatchExpression(_, MatchExpression { scrutinee, arms }) => { + // TODO try to find a solution where we do not introduce a variable + // or at least make it unique. format!( - "match {} {{\n{}\n}}", + "{{\nlet scrutinee__ = {};\n{}\n}}\n", self.format_expr(scrutinee)?, arms.iter() .map(|MatchArm { pattern, value }| { + let (vars, code) = check_pattern(pattern)?; Ok(format!( - "{} => {},", - format_pattern(pattern), + "if let Some({vars}) = ({code})(scrutinee__.clone()) {{\n{}\n}}", self.format_expr(value)?, )) }) .collect::, String>>()? - .join("\n") + .join(" else ") ) } _ => return Err(format!("Implement {e}")), @@ -307,29 +309,55 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } -fn format_pattern(pattern: &Pattern) -> String { - match pattern { - Pattern::CatchAll(_) => "_".to_string(), - Pattern::Ellipsis(_) => "..".to_string(), +/// Returns string of tuples with var names (capturing) and code. +/// TODO +/// the ellipsis represents code that tries to match the given pattern. +/// This function is used when generating code for match expressions. +fn check_pattern(pattern: &Pattern) -> Result<(String, String), String> { + Ok(match pattern { + Pattern::CatchAll(_) => ("()".to_string(), "|_| Some(())".to_string()), Pattern::Number(_, n) => { - // TODO this should probably fail if the number is too large. - n.to_string() + // TODO format large n properly. + ( + "_".to_string(), + format!("|s| (s == ibig::IBig::from({n})).then_some(())"), + ) } - Pattern::String(_, s) => quote(s), + Pattern::String(_, s) => ( + "_".to_string(), + format!("|s| (&s == {}).then_some(())", quote(s)), + ), Pattern::Tuple(_, items) => { - format!("({})", items.iter().map(format_pattern).join(", ")) + let mut vars = vec![]; + let inner_code = items + .iter() + .enumerate() + .map(|(i, item)| { + let (v, code) = check_pattern(item)?; + vars.push(v.clone()); + Ok(format!("let r_{i} = ({code})(s.clone())?;")) + }) + .collect::, String>>()? + .join("\n"); + let code = format!( + "|s| {{\n{inner_code}\nSome(({}))\n}}", + items + .iter() + .enumerate() + .map(|(i, _)| format!("r_{i}")) + .format(", ") + ); + (format!("({})", vars.join(", ")), code) } - Pattern::Array(_, items) => { - format!("[{}]", items.iter().map(format_pattern).join(", ")) + Pattern::Array(..) => { + return Err(format!("Arrays as patterns not yet implemented: {pattern}")); } - Pattern::Variable(_, var) => var.clone(), - Pattern::Enum(_, name, None) => escape_symbol(&name.to_string()), - Pattern::Enum(_, name, Some(fields)) => format!( - "{}({})", - escape_symbol(&name.to_string()), - fields.iter().map(format_pattern).join(", ") - ), - } + Pattern::Variable(_, var) => (format!("{var}"), "|s| Some(s)".to_string()), + Pattern::Enum(..) => { + return Err(format!("Enums as patterns not yet implemented: {pattern}")); + } + Pattern::Ellipsis(_) => unreachable!(), + }) } pub fn escape_symbol(s: &str) -> String { @@ -404,4 +432,41 @@ mod test { " ); } + + #[test] + fn match_exprs() { + let result = compile( + r#"let c: int -> int = |i| match (i, "abc") { (_, "") => 1, (8, v) => 2, (x, _) => x, _ => 5 };"#, + &["c"], + ); + assert_eq!( + result, + r#"fn c(i: ibig::IBig) -> ibig::IBig { { +let scrutinee__ = (i, "abc"); +if let Some(((), _)) = (|s| { +let r_0 = (|_| Some(()))(s.clone())?; +let r_1 = (|s| (&s == "").then_some(()))(s.clone())?; +Some((r_0, r_1)) +})(scrutinee__.clone()) { +ibig::IBig::from(1_u64) +} else if let Some((_, v)) = (|s| { +let r_0 = (|s| s == ibig::IBig::from(8).then_some(()))(s.clone())?; +let r_1 = (|s| Some(s))(s.clone())?; +Some((r_0, r_1)) +})(scrutinee__.clone()) { +ibig::IBig::from(2_u64) +} else if let Some((x, ())) = (|s| { +let r_0 = (|s| Some(s))(s.clone())?; +let r_1 = (|_| Some(()))(s.clone())?; +Some((r_0, r_1)) +})(scrutinee__.clone()) { +x +} else if let Some(()) = (|_| Some(()))(scrutinee__.clone()) { +ibig::IBig::from(5_u64) +} +} + } +"# + ); + } } diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 964aa52185..d340506105 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -6,7 +6,13 @@ use powdr_pil_analyzer::analyze_string; fn compile(input: &str, symbol: &str) -> LoadedFunction { let analyzed = analyze_string::(input).unwrap(); - powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol].clone() + powdr_jit_compiler::compile(&analyzed, &[symbol]) + .map_err(|e| { + eprintln!("{e}"); + e + }) + .unwrap()[symbol] + .clone() } #[test] @@ -42,12 +48,12 @@ fn sqrt() { #[test] fn match_expr() { let f = compile( - "let f: int -> int = |x| match (x, [1, x, 3]) { - (0, [.., 3]) => 1, - (1, [1, ..]) => 2, - (2, [1, 2, 3]) => 3, - (_, [1, _, 3]) => 0, - };", + r#"let f: int -> int = |x| match (x, ("abc", x + 3)) { + (0, _) => 1, + (1, ("ab", _)) => 2, + (1, ("abc", t)) => t, + (a, (_, b)) => a + b, + };"#, "f", ); From f22c2bbed791fe3a2223c00805d8037f95330b02 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 11:56:42 +0000 Subject: [PATCH 50/62] comment --- jit-compiler/src/codegen.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 59010c780e..32bd800176 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -329,6 +329,7 @@ fn check_pattern(pattern: &Pattern) -> Result<(String, String), String> { ), Pattern::Tuple(_, items) => { let mut vars = vec![]; + // TODO we need to de-structure s! let inner_code = items .iter() .enumerate() From ed0c9859a5741dc9f57e0006d1771d21dd30d5fd Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 13:09:18 +0000 Subject: [PATCH 51/62] simple match test. --- jit-compiler/src/codegen.rs | 1 + jit-compiler/tests/execution.rs | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 32bd800176..1651a954ac 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -294,6 +294,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { self.format_expr(value)?, )) }) + .chain(std::iter::once(Ok("{ panic!(\"No match\"); }".to_string()))) .collect::, String>>()? .join(" else ") ) diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index d340506105..a8e9076914 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -45,6 +45,24 @@ fn sqrt() { assert_eq!(f.call(0), 0); } +#[test] +fn simple_match() { + let f = compile( + r#"let f: int -> int = |x| match x { + 0 => 1, + 1 => 2, + 2 => 3, + _ => 0, + };"#, + "f", + ); + + assert_eq!(f.call(0), 1); + assert_eq!(f.call(1), 2); + assert_eq!(f.call(2), 3); + assert_eq!(f.call(3), 0); +} + #[test] fn match_expr() { let f = compile( From 01dcb4b03b0af5f6aa2ff8ed800d85d58e360520 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 13:30:41 +0000 Subject: [PATCH 52/62] match expressions for tuples. --- jit-compiler/src/codegen.rs | 41 +++++++++++++++------------------ jit-compiler/src/compiler.rs | 5 ++++ jit-compiler/tests/execution.rs | 23 ++++++++++++++---- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 1651a954ac..b730a66f85 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -262,8 +262,8 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { "({})", items .iter() - .map(|i| self.format_expr(i)) - .collect::, _>>()? + .map(|i| Ok(format!("{}.clone()", self.format_expr(i)?))) + .collect::, String>>()? .join(", ") ), Expression::BlockExpression(_, BlockExpression { statements, expr }) => { @@ -283,14 +283,15 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { Expression::MatchExpression(_, MatchExpression { scrutinee, arms }) => { // TODO try to find a solution where we do not introduce a variable // or at least make it unique. + let var_name = "scrutinee__"; format!( - "{{\nlet scrutinee__ = {};\n{}\n}}\n", + "{{\nlet {var_name} = {};\n{}\n}}\n", self.format_expr(scrutinee)?, arms.iter() .map(|MatchArm { pattern, value }| { - let (vars, code) = check_pattern(pattern)?; + let (vars, code) = check_pattern(var_name, pattern)?; Ok(format!( - "if let Some({vars}) = ({code})(scrutinee__.clone()) {{\n{}\n}}", + "if let Some({vars}) = ({code}) {{\n{}\n}}", self.format_expr(value)?, )) }) @@ -314,47 +315,41 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { /// TODO /// the ellipsis represents code that tries to match the given pattern. /// This function is used when generating code for match expressions. -fn check_pattern(pattern: &Pattern) -> Result<(String, String), String> { +fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String), String> { Ok(match pattern { - Pattern::CatchAll(_) => ("()".to_string(), "|_| Some(())".to_string()), + Pattern::CatchAll(_) => ("()".to_string(), "Some(())".to_string()), Pattern::Number(_, n) => { // TODO format large n properly. ( "_".to_string(), - format!("|s| (s == ibig::IBig::from({n})).then_some(())"), + format!("({value_name} == ibig::IBig::from({n})).then_some(())"), ) } Pattern::String(_, s) => ( "_".to_string(), - format!("|s| (&s == {}).then_some(())", quote(s)), + format!("({value_name} == {}).then_some(())", quote(s)), ), Pattern::Tuple(_, items) => { let mut vars = vec![]; - // TODO we need to de-structure s! let inner_code = items .iter() .enumerate() .map(|(i, item)| { - let (v, code) = check_pattern(item)?; + let (v, code) = check_pattern(&format!("{value_name}.{i}"), item)?; vars.push(v.clone()); - Ok(format!("let r_{i} = ({code})(s.clone())?;")) + Ok(format!("({code})?")) }) .collect::, String>>()? - .join("\n"); - let code = format!( - "|s| {{\n{inner_code}\nSome(({}))\n}}", - items - .iter() - .enumerate() - .map(|(i, _)| format!("r_{i}")) - .format(", ") - ); - (format!("({})", vars.join(", ")), code) + .join(", "); + ( + format!("({})", vars.join(", ")), + format!("(|| Some(({inner_code})))()"), + ) } Pattern::Array(..) => { return Err(format!("Arrays as patterns not yet implemented: {pattern}")); } - Pattern::Variable(_, var) => (format!("{var}"), "|s| Some(s)".to_string()), + Pattern::Variable(_, var) => (format!("{var}"), format!("Some({value_name}.clone())")), Pattern::Enum(..) => { return Err(format!("Enums as patterns not yet implemented: {pattern}")); } diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index a92476b61f..2cf5aa11fd 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -88,6 +88,11 @@ pub fn call_cargo(code: &str) -> Result { fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); fs::create_dir(dir.join("src")).unwrap(); fs::write(dir.join("src").join("lib.rs"), code).unwrap(); + Command::new("cargo") + .arg("fmt") + .current_dir(dir.clone()) + .output() + .unwrap(); let out = Command::new("cargo") .env("RUSTFLAGS", "-C target-cpu=native") .arg("build") diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index a8e9076914..3c8b8dcad7 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -46,7 +46,7 @@ fn sqrt() { } #[test] -fn simple_match() { +fn match_number() { let f = compile( r#"let f: int -> int = |x| match x { 0 => 1, @@ -63,6 +63,21 @@ fn simple_match() { assert_eq!(f.call(3), 0); } +#[test] +fn match_string() { + let f = compile( + r#"let f: int -> int = |x| match "abc" { + "ab" => 1, + "abc" => 2, + _ => 0, + };"#, + "f", + ); + + assert_eq!(f.call(0), 2); + assert_eq!(f.call(1), 2); +} + #[test] fn match_expr() { let f = compile( @@ -76,9 +91,9 @@ fn match_expr() { ); assert_eq!(f.call(0), 1); - assert_eq!(f.call(1), 2); - assert_eq!(f.call(2), 3); - assert_eq!(f.call(3), 0); + assert_eq!(f.call(1), 4); + assert_eq!(f.call(2), 7); + assert_eq!(f.call(3), 9); } #[test] From 48458fe700182f07800b849dbe78d3975ad7cb49 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 14:18:02 +0000 Subject: [PATCH 53/62] proper numbers --- jit-compiler/src/codegen.rs | 55 +++++++++++++++++++++++---------- jit-compiler/tests/execution.rs | 29 ++++++++++++++++- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index b730a66f85..568228f18d 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use itertools::{multiunzip, Itertools}; +use itertools::Itertools; use powdr_ast::{ analyzed::{Analyzed, Expression, FunctionValueDefinition, PolynomialReference, Reference}, parsed::{ @@ -11,7 +11,7 @@ use powdr_ast::{ StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::FieldElement; +use powdr_number::{BigInt, BigUint, FieldElement, LargeInt}; pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, @@ -145,8 +145,8 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { let code = match symbol { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), "std::field::modulus" => { - let modulus = T::modulus(); - Some(format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}")) + let modulus = T::modulus().to_arbitrary_integer(); + Some(format!("() -> ibig::IBig {{ {} }}", format_number(&modulus))) } "std::convert::fe" => Some("(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()), @@ -176,12 +176,15 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { type_: Some(type_), }, ) => { - let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); - match type_ { - Type::Int => format!("ibig::IBig::from({value}_u64)"), - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - _ => unreachable!(), + if type_ == &Type::Int { + format_number(&value) + } else { + let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); + match type_ { + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + _ => unreachable!(), + } } } Expression::FunctionCall( @@ -318,13 +321,10 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String), String> { Ok(match pattern { Pattern::CatchAll(_) => ("()".to_string(), "Some(())".to_string()), - Pattern::Number(_, n) => { - // TODO format large n properly. - ( - "_".to_string(), - format!("({value_name} == ibig::IBig::from({n})).then_some(())"), - ) - } + Pattern::Number(_, n) => ( + "_".to_string(), + format!("({value_name} == {})).then_some(())", format_number(&n)), + ), Pattern::String(_, s) => ( "_".to_string(), format!("({value_name} == {}).then_some(())", quote(s)), @@ -362,6 +362,27 @@ pub fn escape_symbol(s: &str) -> String { s.replace('.', "_").replace("::", "_") } +fn format_number(n: &BigUint) -> String { + if let Ok(n) = u64::try_from(n) { + format!("ibig::IBig::from({n}_u64)") + } else { + let bytes = n + .to_le_bytes() + .iter() + .map(|b| format!("{b}_u8")) + .format(", "); + format!("ibig::IBig::from_le_bytes(&[{bytes}])") + } +} + +fn format_signed_number(n: &BigInt) -> String { + if let Ok(n) = BigUint::try_from(n) { + format_number(&n) + } else { + format!("-{}", format_signed_number(n)) + } +} + fn map_type(ty: &Type) -> String { match ty { Type::Bottom | Type::Bool => format!("{ty}"), diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 3c8b8dcad7..f48bfb8cc1 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -79,7 +79,7 @@ fn match_string() { } #[test] -fn match_expr() { +fn match_tuples() { let f = compile( r#"let f: int -> int = |x| match (x, ("abc", x + 3)) { (0, _) => 1, @@ -96,6 +96,33 @@ fn match_expr() { assert_eq!(f.call(3), 9); } +#[test] +fn match_array() { + let f = compile( + r#"let f: int -> int = |y| match (y, [1, 3, 3, 4]) { + (0, _) => 1, + (1, [.., 2, 4]) => 20, + (1, [.., x, 4]) => x - 1, + (2, [x, .., 0]) => 22, + (2, [x, .., 4]) => x + 2, + (3, [1, 3, 3, 4, ..]) => 4, + (4, [1, 3, 3, 4]) => 5, + (5, [..]) => 6, + _ => 7 + };"#, + "f", + ); + + assert_eq!(f.call(0), 1); + assert_eq!(f.call(1), 2); + assert_eq!(f.call(2), 3); + assert_eq!(f.call(3), 4); + assert_eq!(f.call(4), 5); + assert_eq!(f.call(5), 6); + assert_eq!(f.call(6), 7); + assert_eq!(f.call(7), 8); +} + #[test] #[should_panic = "Only (int -> int) functions and columns are supported, but requested c: int -> bool"] fn invalid_function() { From 7f88008836583f68297c214e5189f2cb2951cc96 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 14:45:20 +0000 Subject: [PATCH 54/62] more pat --- jit-compiler/src/codegen.rs | 58 ++++++++++++++++++++++++++------- jit-compiler/tests/execution.rs | 1 + 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 568228f18d..ff251794f3 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -323,11 +323,14 @@ fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String) Pattern::CatchAll(_) => ("()".to_string(), "Some(())".to_string()), Pattern::Number(_, n) => ( "_".to_string(), - format!("({value_name} == {})).then_some(())", format_number(&n)), + format!( + "({value_name}.clone() == {}).then_some(())", + format_signed_number(n) + ), ), Pattern::String(_, s) => ( "_".to_string(), - format!("({value_name} == {}).then_some(())", quote(s)), + format!("({value_name}.clone() == {}).then_some(())", quote(s)), ), Pattern::Tuple(_, items) => { let mut vars = vec![]; @@ -336,7 +339,7 @@ fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String) .enumerate() .map(|(i, item)| { let (v, code) = check_pattern(&format!("{value_name}.{i}"), item)?; - vars.push(v.clone()); + vars.push(v); Ok(format!("({code})?")) }) .collect::, String>>()? @@ -346,8 +349,40 @@ fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String) format!("(|| Some(({inner_code})))()"), ) } - Pattern::Array(..) => { - return Err(format!("Arrays as patterns not yet implemented: {pattern}")); + Pattern::Array(_, items) => { + let mut vars = vec![]; + let mut ellipsis_seen = false; + let inner_code = items + .iter() + .enumerate() + .filter_map(|(i, item)| { + if matches!(item, Pattern::Ellipsis(_)) { + ellipsis_seen = true; + return None; + } + Some(if ellipsis_seen { + let i_rev = items.len() - i; + (format!("({value_name}[{value_name}.len() - {i_rev}]"), item) + } else { + (format!("{value_name}[{i}]"), item) + }) + }) + .map(|(access_name, item)| { + let (v, code) = check_pattern(&access_name, item)?; + vars.push(v); + Ok(format!("({code})?")) + }) + .collect::, String>>()? + .join(", "); + let length_check = if ellipsis_seen { + format!("{value_name}.len() >= {}", items.len() - 1) + } else { + format!("{value_name}.len() == {}", items.len()) + }; + ( + format!("({})", items.iter().map(|_| "_").join(", ")), + format!("({length_check}).then(|| Some(({inner_code})))"), + ) } Pattern::Variable(_, var) => (format!("{var}"), format!("Some({value_name}.clone())")), Pattern::Enum(..) => { @@ -366,12 +401,13 @@ fn format_number(n: &BigUint) -> String { if let Ok(n) = u64::try_from(n) { format!("ibig::IBig::from({n}_u64)") } else { - let bytes = n - .to_le_bytes() - .iter() - .map(|b| format!("{b}_u8")) - .format(", "); - format!("ibig::IBig::from_le_bytes(&[{bytes}])") + format!( + "ibig::IBig::from_le_bytes(&[{}])", + n.to_le_bytes() + .iter() + .map(|b| format!("{b}_u8")) + .format(", ") + ) } } diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index f48bfb8cc1..dce97127c6 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -101,6 +101,7 @@ fn match_array() { let f = compile( r#"let f: int -> int = |y| match (y, [1, 3, 3, 4]) { (0, _) => 1, + (1, [1, 3]) => 20, (1, [.., 2, 4]) => 20, (1, [.., x, 4]) => x - 1, (2, [x, .., 0]) => 22, From e48c256a09214e078103967751841e316605df0d Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 14:46:11 +0000 Subject: [PATCH 55/62] more pat --- jit-compiler/src/codegen.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index ff251794f3..10155ed255 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -362,7 +362,7 @@ fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String) } Some(if ellipsis_seen { let i_rev = items.len() - i; - (format!("({value_name}[{value_name}.len() - {i_rev}]"), item) + (format!("{value_name}[{value_name}.len() - {i_rev}]"), item) } else { (format!("{value_name}[{i}]"), item) }) From 9a3242c206193dc53ad728f6731d29ae3d1fdf7d Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 14:55:47 +0000 Subject: [PATCH 56/62] arrays --- jit-compiler/src/codegen.rs | 41 ++------------------------------- jit-compiler/tests/execution.rs | 1 - 2 files changed, 2 insertions(+), 40 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 10155ed255..2cf2bbcb25 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -380,8 +380,8 @@ fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String) format!("{value_name}.len() == {}", items.len()) }; ( - format!("({})", items.iter().map(|_| "_").join(", ")), - format!("({length_check}).then(|| Some(({inner_code})))"), + format!("({})", vars.join(", ")), + format!("if {length_check} {{ (|| Some(({inner_code})))() }} else {{ None }}"), ) } Pattern::Variable(_, var) => (format!("{var}"), format!("Some({value_name}.clone())")), @@ -486,41 +486,4 @@ mod test { " ); } - - #[test] - fn match_exprs() { - let result = compile( - r#"let c: int -> int = |i| match (i, "abc") { (_, "") => 1, (8, v) => 2, (x, _) => x, _ => 5 };"#, - &["c"], - ); - assert_eq!( - result, - r#"fn c(i: ibig::IBig) -> ibig::IBig { { -let scrutinee__ = (i, "abc"); -if let Some(((), _)) = (|s| { -let r_0 = (|_| Some(()))(s.clone())?; -let r_1 = (|s| (&s == "").then_some(()))(s.clone())?; -Some((r_0, r_1)) -})(scrutinee__.clone()) { -ibig::IBig::from(1_u64) -} else if let Some((_, v)) = (|s| { -let r_0 = (|s| s == ibig::IBig::from(8).then_some(()))(s.clone())?; -let r_1 = (|s| Some(s))(s.clone())?; -Some((r_0, r_1)) -})(scrutinee__.clone()) { -ibig::IBig::from(2_u64) -} else if let Some((x, ())) = (|s| { -let r_0 = (|s| Some(s))(s.clone())?; -let r_1 = (|_| Some(()))(s.clone())?; -Some((r_0, r_1)) -})(scrutinee__.clone()) { -x -} else if let Some(()) = (|_| Some(()))(scrutinee__.clone()) { -ibig::IBig::from(5_u64) -} -} - } -"# - ); - } } diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index dce97127c6..ac11b40319 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -121,7 +121,6 @@ fn match_array() { assert_eq!(f.call(4), 5); assert_eq!(f.call(5), 6); assert_eq!(f.call(6), 7); - assert_eq!(f.call(7), 8); } #[test] From e79c6101590e97cf7f46d3f10a1b795277345cb0 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 15:02:47 +0000 Subject: [PATCH 57/62] docstring. --- jit-compiler/src/codegen.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 2cf2bbcb25..212a69b53c 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -314,10 +314,14 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } -/// Returns string of tuples with var names (capturing) and code. -/// TODO -/// the ellipsis represents code that tries to match the given pattern. -/// This function is used when generating code for match expressions. +/// Used for patterns in match and let statements: +/// `value_name` is an expression string that is to be matched against `pattern`. +/// Returns a rust pattern string (tuple of new variables, might be nested) and a code string +/// that, when executed, returns an Option with the values for the new variables if the pattern +/// matched `value_name` and `None` otherwise. +/// +/// So if `let (vars, code) = check_pattern("x", pattern)?;`, then the return value +/// can be used like this: `if let Some({vars}) = ({code}) {{ .. }}` fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String), String> { Ok(match pattern { Pattern::CatchAll(_) => ("()".to_string(), "Some(())".to_string()), From 6a8462d682835535bf665fb34c18b9b67c61bd94 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 15:10:54 +0000 Subject: [PATCH 58/62] trigger change request --- jit-compiler/tests/execution.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index ac11b40319..a2b69ca792 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -14,7 +14,7 @@ fn compile(input: &str, symbol: &str) -> LoadedFunction { .unwrap()[symbol] .clone() } - +add test #[test] fn identity_function() { let f = compile("let c: int -> int = |i| i;", "c"); From ca0ae6fd29fe9d2ad9c46e6792177b64a02dcd42 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 15:11:04 +0000 Subject: [PATCH 59/62] undo change. --- jit-compiler/tests/execution.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index a2b69ca792..ac11b40319 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -14,7 +14,7 @@ fn compile(input: &str, symbol: &str) -> LoadedFunction { .unwrap()[symbol] .clone() } -add test + #[test] fn identity_function() { let f = compile("let c: int -> int = |i| i;", "c"); From a8b140604d6bca8bfd0c7168f6f3f9c1c62b50ab Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 30 Sep 2024 14:15:18 +0000 Subject: [PATCH 60/62] fix --- jit-compiler/src/codegen.rs | 4 ++-- jit-compiler/tests/execution.rs | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 9a96d916e8..318f55a327 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -266,8 +266,8 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { "({})", items .iter() - .map(|i| self.format_expr(i)) - .collect::, _>>()? + .map(|i| Ok(format!("({}.clone())", self.format_expr(i)?))) + .collect::, String>>()? .join(", ") ), Expression::BlockExpression(_, BlockExpression { statements, expr }) => { diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 5c0a638ab9..a4c7fd7616 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -6,7 +6,13 @@ use powdr_pil_analyzer::analyze_string; fn compile(input: &str, symbol: &str) -> LoadedFunction { let analyzed = analyze_string::(input).unwrap(); - powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol].clone() + powdr_jit_compiler::compile(&analyzed, &[symbol]) + .map_err(|e| { + eprintln!("Error jit-compiling:\n{e}"); + e + }) + .unwrap()[symbol] + .clone() } #[test] From b24ee2733aec280f56f0769feb21d34857781fc0 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 30 Sep 2024 14:47:57 +0000 Subject: [PATCH 61/62] Let statements --- jit-compiler/src/codegen.rs | 34 +++++++++++++++++++++++++-------- jit-compiler/tests/execution.rs | 31 ++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 318f55a327..48c47c026f 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -5,10 +5,10 @@ use powdr_ast::{ analyzed::{Analyzed, Expression, FunctionValueDefinition, PolynomialReference, Reference}, parsed::{ display::quote, - types::{ArrayType, FunctionType, Type, TypeScheme}, + types::{ArrayType, FunctionType, TupleType, Type, TypeScheme}, ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, - IndexAccess, LambdaExpression, MatchArm, MatchExpression, Number, Pattern, - StatementInsideBlock, UnaryOperation, + IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, Number, + Pattern, StatementInsideBlock, UnaryOperation, }, }; use powdr_number::{BigInt, BigUint, FieldElement, LargeInt}; @@ -261,7 +261,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { .join(", ") ) } - Expression::String(_, s) => quote(s), + Expression::String(_, s) => format!("{}.to_string()", quote(s)), Expression::Tuple(_, items) => format!( "({})", items @@ -309,9 +309,27 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { - Err(format!( - "Compiling statements inside blocks is not yet implemented: {s}" - )) + Ok(match s { + StatementInsideBlock::LetStatement(LetStatementInsideBlock { pattern, ty, value }) => { + let Some(value) = value else { + return Err(format!( + "Column creating 'let'-statements not yet supported: {s}" + )); + }; + let value = self.format_expr(value)?; + let var_name = "scrutinee__"; + let ty = ty + .as_ref() + .map(|ty| format!(": {}", map_type(ty))) + .unwrap_or_default(); + + let (vars, code) = check_pattern(var_name, pattern)?; + // TODO if we want to explicitly specify the type, we need to exchange the non-captured + // parts by `()`. + format!("let {vars} = (|{var_name}{ty}| {code})({value}).unwrap();\n",) + } + StatementInsideBlock::Expression(e) => format!("{};\n", self.format_expr(e)?), + }) } /// Returns a string expression evaluating to the value of the symbol. @@ -448,7 +466,7 @@ fn map_type(ty: &Type) -> String { Type::String => "String".to_string(), Type::Expr => "Expr".to_string(), Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), - Type::Tuple(_) => todo!(), + Type::Tuple(TupleType { items }) => format!("({})", items.iter().map(map_type).join(", ")), Type::Function(ft) => format!( "fn({}) -> {}", ft.params.iter().map(map_type).join(", "), diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index a4c7fd7616..281bc0c0f5 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -158,3 +158,34 @@ fn match_array() { assert_eq!(f.call(5), 6); assert_eq!(f.call(6), 7); } + +#[test] +fn let_simple() { + let f = compile( + r#"let f: int -> int = |x| { + let a = 1; + let b = a + 9; + b - 9 + x + };"#, + "f", + ); + + assert_eq!(f.call(0), 1); + assert_eq!(f.call(1), 2); + assert_eq!(f.call(2), 3); + assert_eq!(f.call(3), 4); +} + +#[test] +fn let_complex() { + let f = compile( + r#"let f: int -> int = |x| { + let (a, b, (_, d)) = (1, 2, ("abc", [x, 5])); + a + b + d[0] + d[1] + };"#, + "f", + ); + + assert_eq!(f.call(0), 8); + assert_eq!(f.call(1), 9); +} From 98c2ff6da1b3acf0d23e816dfb57c8649600dfa8 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 30 Sep 2024 15:54:20 +0000 Subject: [PATCH 62/62] Formatting. --- jit-compiler/src/codegen.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 48c47c026f..44914f3de4 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -326,9 +326,9 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { let (vars, code) = check_pattern(var_name, pattern)?; // TODO if we want to explicitly specify the type, we need to exchange the non-captured // parts by `()`. - format!("let {vars} = (|{var_name}{ty}| {code})({value}).unwrap();\n",) + format!("let {vars} = (|{var_name}{ty}| {code})({value}).unwrap();",) } - StatementInsideBlock::Expression(e) => format!("{};\n", self.format_expr(e)?), + StatementInsideBlock::Expression(e) => format!("{};", self.format_expr(e)?), }) }