diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index b563bd1c1d..92527e5dd3 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -5,10 +5,10 @@ use powdr_ast::{ analyzed::{Analyzed, Expression, FunctionValueDefinition, PolynomialReference, Reference}, parsed::{ display::quote, - types::{ArrayType, FunctionType, Type, TypeScheme}, + types::{ArrayType, FunctionType, TupleType, Type, TypeScheme}, ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, - IndexAccess, LambdaExpression, MatchArm, MatchExpression, Number, Pattern, - StatementInsideBlock, UnaryOperation, + IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, Number, + Pattern, StatementInsideBlock, UnaryOperation, }, }; use powdr_number::{BigInt, BigUint, FieldElement, LargeInt}; @@ -261,7 +261,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { .join(", ") ) } - Expression::String(_, s) => quote(s), + Expression::String(_, s) => format!("{}.to_string()", quote(s)), Expression::Tuple(_, items) => format!( "({})", items @@ -312,9 +312,27 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { - Err(format!( - "Compiling statements inside blocks is not yet implemented: {s}" - )) + Ok(match s { + StatementInsideBlock::LetStatement(LetStatementInsideBlock { pattern, ty, value }) => { + let Some(value) = value else { + return Err(format!( + "Column creating 'let'-statements not yet supported: {s}" + )); + }; + let value = self.format_expr(value)?; + let var_name = "scrutinee__"; + let ty = ty + .as_ref() + .map(|ty| format!(": {}", map_type(ty))) + .unwrap_or_default(); + + let (vars, code) = check_pattern(var_name, pattern)?; + // TODO if we want to explicitly specify the type, we need to exchange the non-captured + // parts by `()`. + format!("let {vars} = (|{var_name}{ty}| {code})({value}).unwrap();",) + } + StatementInsideBlock::Expression(e) => format!("{};", self.format_expr(e)?), + }) } /// Returns a string expression evaluating to the value of the symbol. @@ -456,7 +474,7 @@ fn map_type(ty: &Type) -> String { Type::String => "String".to_string(), Type::Expr => "Expr".to_string(), Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), - Type::Tuple(_) => todo!(), + Type::Tuple(TupleType { items }) => format!("({})", items.iter().map(map_type).join(", ")), Type::Function(ft) => format!( "fn({}) -> {}", ft.params.iter().map(map_type).join(", "), diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 114b33fb5b..691da1c4b9 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -204,3 +204,34 @@ fn match_array() { assert_eq!(f.call(5), 6); assert_eq!(f.call(6), 7); } + +#[test] +fn let_simple() { + let f = compile( + r#"let f: int -> int = |x| { + let a = 1; + let b = a + 9; + b - 9 + x + };"#, + "f", + ); + + assert_eq!(f.call(0), 1); + assert_eq!(f.call(1), 2); + assert_eq!(f.call(2), 3); + assert_eq!(f.call(3), 4); +} + +#[test] +fn let_complex() { + let f = compile( + r#"let f: int -> int = |x| { + let (a, b, (_, d)) = (1, 2, ("abc", [x, 5])); + a + b + d[0] + d[1] + };"#, + "f", + ); + + assert_eq!(f.call(0), 8); + assert_eq!(f.call(1), 9); +}