Skip to content

Commit

Permalink
chore: add constrain formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
kek kek kek committed Oct 24, 2023
1 parent 247b7ce commit 290f121
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 32 deletions.
9 changes: 8 additions & 1 deletion compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,14 @@ pub enum LValue {
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ConstrainStatement(pub Expression, pub Option<String>);
pub struct ConstrainStatement(pub Expression, pub Option<String>, pub ConstrainKind);

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ConstrainKind {
Assert,
AssertEq,
Constrain,
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Pattern {
Expand Down
24 changes: 16 additions & 8 deletions compiler/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ use crate::lexer::Lexer;
use crate::parser::{force, ignore_then_commit, statement_recovery};
use crate::token::{Attribute, Attributes, Keyword, SecondaryAttribute, Token, TokenKind};
use crate::{
BinaryOp, BinaryOpKind, BlockExpression, ConstrainStatement, Distinctness, FunctionDefinition,
FunctionReturnType, Ident, IfExpression, InfixExpression, LValue, Lambda, Literal,
NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, PathKind, Pattern,
Recoverable, Statement, TraitBound, TraitImplItem, TraitItem, TypeImpl, UnaryOp,
BinaryOp, BinaryOpKind, BlockExpression, ConstrainKind, ConstrainStatement, Distinctness,
FunctionDefinition, FunctionReturnType, Ident, IfExpression, InfixExpression, LValue, Lambda,
Literal, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, PathKind,
Pattern, Recoverable, Statement, TraitBound, TraitImplItem, TraitItem, TypeImpl, UnaryOp,
UnresolvedTraitConstraint, UnresolvedTypeExpression, UseTree, UseTreeKind, Visibility,
};

Expand Down Expand Up @@ -800,7 +800,7 @@ where
keyword(Keyword::Constrain).labelled(ParsingRuleLabel::Statement),
expr_parser,
)
.map(|expr| StatementKind::Constrain(ConstrainStatement(expr, None)))
.map(|expr| StatementKind::Constrain(ConstrainStatement(expr, None, ConstrainKind::Constrain)))
.validate(|expr, span, emit| {
emit(ParserError::with_reason(ParserErrorReason::ConstrainDeprecated, span));
expr
Expand Down Expand Up @@ -828,7 +828,11 @@ where
}
}

StatementKind::Constrain(ConstrainStatement(condition, message_str))
StatementKind::Constrain(ConstrainStatement(
condition,
message_str,
ConstrainKind::Assert,
))
})
}

Expand Down Expand Up @@ -859,7 +863,11 @@ where
emit(ParserError::with_reason(ParserErrorReason::AssertMessageNotString, span));
}
}
StatementKind::Constrain(ConstrainStatement(predicate, message_str))
StatementKind::Constrain(ConstrainStatement(
predicate,
message_str,
ConstrainKind::AssertEq,
))
})
}

Expand Down Expand Up @@ -2017,7 +2025,7 @@ mod test {
match parse_with(assertion_eq(expression()), "assert_eq(x, y, \"assertion message\")")
.unwrap()
{
StatementKind::Constrain(ConstrainStatement(_, message)) => {
StatementKind::Constrain(ConstrainStatement(_, message, _)) => {
assert_eq!(message, Some("assertion message".to_owned()));
}
_ => unreachable!(),
Expand Down
4 changes: 2 additions & 2 deletions tooling/nargo_fmt/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ impl Item for Expression {
}

fn format(self, visitor: &FmtVisitor) -> String {
visitor.format_subexpr(self)
visitor.format_sub_expr(self)
}
}

Expand All @@ -232,7 +232,7 @@ impl Item for (Ident, Expression) {
let (name, expr) = self;

let name = name.0.contents;
let expr = visitor.format_subexpr(expr);
let expr = visitor.format_sub_expr(expr);

if name == expr {
name
Expand Down
37 changes: 19 additions & 18 deletions tooling/nargo_fmt/src/visitor/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl FmtVisitor<'_> {
self.last_position = span.end();
}

pub(crate) fn format_subexpr(&self, expression: Expression) -> String {
pub(crate) fn format_sub_expr(&self, expression: Expression) -> String {
self.format_expr(expression, ExpressionType::SubExpression)
}

Expand Down Expand Up @@ -50,24 +50,24 @@ impl FmtVisitor<'_> {
}
};

format!("{op}{}", self.format_subexpr(prefix.rhs))
format!("{op}{}", self.format_sub_expr(prefix.rhs))
}
ExpressionKind::Cast(cast) => {
format!("{} as {}", self.format_subexpr(cast.lhs), cast.r#type)
format!("{} as {}", self.format_sub_expr(cast.lhs), cast.r#type)
}
ExpressionKind::Infix(infix) => {
format!(
"{} {} {}",
self.format_subexpr(infix.lhs),
self.format_sub_expr(infix.lhs),
infix.operator.contents.as_string(),
self.format_subexpr(infix.rhs)
self.format_sub_expr(infix.rhs)
)
}
ExpressionKind::Call(call_expr) => {
let args_span =
self.span_before(call_expr.func.span.end()..span.end(), Token::LeftParen);

let callee = self.format_subexpr(*call_expr.func);
let callee = self.format_sub_expr(*call_expr.func);
let args = format_parens(self.fork(), false, call_expr.arguments, args_span);

format!("{callee}{args}")
Expand All @@ -78,21 +78,21 @@ impl FmtVisitor<'_> {
Token::LeftParen,
);

let object = self.format_subexpr(method_call_expr.object);
let object = self.format_sub_expr(method_call_expr.object);
let method = method_call_expr.method_name.to_string();
let args = format_parens(self.fork(), false, method_call_expr.arguments, args_span);

format!("{object}.{method}{args}")
}
ExpressionKind::MemberAccess(member_access_expr) => {
let lhs_str = self.format_subexpr(member_access_expr.lhs);
let lhs_str = self.format_sub_expr(member_access_expr.lhs);
format!("{}.{}", lhs_str, member_access_expr.rhs)
}
ExpressionKind::Index(index_expr) => {
let index_span = self
.span_before(index_expr.collection.span.end()..span.end(), Token::LeftBracket);

let collection = self.format_subexpr(index_expr.collection);
let collection = self.format_sub_expr(index_expr.collection);
let index = format_brackets(self.fork(), false, vec![index_expr.index], index_span);

format!("{collection}{index}")
Expand All @@ -105,8 +105,8 @@ impl FmtVisitor<'_> {
self.slice(span).to_string()
}
Literal::Array(ArrayLiteral::Repeated { repeated_element, length }) => {
let repeated = self.format_subexpr(*repeated_element);
let length = self.format_subexpr(*length);
let repeated = self.format_sub_expr(*repeated_element);
let length = self.format_sub_expr(*length);

format!("[{repeated}; {length}]")
}
Expand Down Expand Up @@ -140,7 +140,7 @@ impl FmtVisitor<'_> {
}

if !leading.contains("//") && !trailing.contains("//") {
let sub_expr = self.format_subexpr(*sub_expr);
let sub_expr = self.format_sub_expr(*sub_expr);
format!("({leading}{sub_expr}{trailing})")
} else {
let mut visitor = self.fork();
Expand All @@ -149,7 +149,7 @@ impl FmtVisitor<'_> {
visitor.indent.block_indent(self.config);
let nested_indent = visitor.indent.to_string_with_newline();

let sub_expr = visitor.format_subexpr(*sub_expr);
let sub_expr = visitor.format_sub_expr(*sub_expr);

let mut result = String::new();
result.push('(');
Expand Down Expand Up @@ -193,13 +193,14 @@ impl FmtVisitor<'_> {

self.format_if(*if_expr)
}
_ => self.slice(span).to_string(),
ExpressionKind::Variable(_) | ExpressionKind::Lambda(_) => self.slice(span).to_string(),
ExpressionKind::Error => unreachable!(),
}
}

fn format_if(&self, if_expr: IfExpression) -> String {
let condition_str = self.format_subexpr(if_expr.condition);
let consequence_str = self.format_subexpr(if_expr.consequence);
let condition_str = self.format_sub_expr(if_expr.condition);
let consequence_str = self.format_sub_expr(if_expr.consequence);

let mut result = format!("if {condition_str} {consequence_str}");

Expand All @@ -220,8 +221,8 @@ impl FmtVisitor<'_> {
}

fn format_if_single_line(&self, if_expr: IfExpression) -> Option<String> {
let condition_str = self.format_subexpr(if_expr.condition);
let consequence_str = self.format_subexpr(extract_simple_expr(if_expr.consequence)?);
let condition_str = self.format_sub_expr(if_expr.condition);
let consequence_str = self.format_sub_expr(extract_simple_expr(if_expr.consequence)?);

let if_str = if let Some(alternative) = if_expr.alternative {
let alternative_str = if let Some(ExpressionKind::If(_)) =
Expand Down
32 changes: 30 additions & 2 deletions tooling/nargo_fmt/src/visitor/stmt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::iter::zip;

use noirc_frontend::{Statement, StatementKind};
use noirc_frontend::{ConstrainKind, ConstrainStatement, ExpressionKind, Statement, StatementKind};

use super::ExpressionType;

Expand Down Expand Up @@ -28,8 +28,36 @@ impl super::FmtVisitor<'_> {

self.push_rewrite(format!("{let_str} {expr_str};"), span);
}
StatementKind::Constrain(constrain) => {
let ConstrainStatement(expr, message, kind) = constrain;
let message =
message.map_or(String::new(), |message| format!(", \"{message}\""));
let constrain = match kind {
ConstrainKind::Assert => {
let assertion = self.format_sub_expr(expr);

format!("assert({assertion}{message});")
}
ConstrainKind::AssertEq => {
if let ExpressionKind::Infix(infix) = expr.kind {
let lhs = self.format_sub_expr(infix.lhs);
let rhs = self.format_sub_expr(infix.rhs);

format!("assert_eq({lhs}, {rhs}{message});")
} else {
unreachable!()
}
}
ConstrainKind::Constrain => {
let expr = self.format_sub_expr(expr);
format!("constrain {expr};")
}
};

self.push_rewrite(constrain, span);
}
StatementKind::Assign(_) | StatementKind::For(_) => self.format_missing(span.end()),
StatementKind::Error => unreachable!(),
_ => self.format_missing(span.end()),
}

self.last_position = span.end();
Expand Down
2 changes: 1 addition & 1 deletion tooling/nargo_fmt/tests/expected/add.nr
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ fn main(mut x: u32, y: u32, z: u32) {
assert(x == z);

x *= 8;
assert(x>9);
assert(x > 9);
}
10 changes: 10 additions & 0 deletions tooling/nargo_fmt/tests/expected/call.nr
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,14 @@ fn foo() {
my_function(some_function(10, "arg1", another_function()), another_func(20, some_function(), 30));

outer_function(some_function(), another_function(some_function(), some_value));

assert_eq(x, y);

assert_eq(x, y, "message");

assert(x);

assert(x, "message");

assert(x == y);
}
10 changes: 10 additions & 0 deletions tooling/nargo_fmt/tests/input/call.nr
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,14 @@ fn foo() {
another_function(
some_function(), some_value)
);

assert_eq( x, y );

assert_eq( x, y, "message" );

assert( x );

assert( x, "message" );

assert( x == y );
}

0 comments on commit 290f121

Please sign in to comment.