diff --git a/crates/ruff_python_ast/src/comparable.rs b/crates/ruff_python_ast/src/comparable.rs index f2a9a5a02fd39..4e08e98d1b286 100644 --- a/crates/ruff_python_ast/src/comparable.rs +++ b/crates/ruff_python_ast/src/comparable.rs @@ -1480,3 +1480,44 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> { } } } + +#[derive(Debug, PartialEq, Eq, Hash)] +pub enum ComparableMod<'a> { + Module(ComparableModModule<'a>), + Expression(ComparableModExpression<'a>), +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ComparableModModule<'a> { + body: Vec>, +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ComparableModExpression<'a> { + body: Box>, +} + +impl<'a> From<&'a ast::Mod> for ComparableMod<'a> { + fn from(mod_: &'a ast::Mod) -> Self { + match mod_ { + ast::Mod::Module(module) => Self::Module(module.into()), + ast::Mod::Expression(expr) => Self::Expression(expr.into()), + } + } +} + +impl<'a> From<&'a ast::ModModule> for ComparableModModule<'a> { + fn from(module: &'a ast::ModModule) -> Self { + Self { + body: module.body.iter().map(Into::into).collect(), + } + } +} + +impl<'a> From<&'a ast::ModExpression> for ComparableModExpression<'a> { + fn from(expr: &'a ast::ModExpression) -> Self { + Self { + body: (&expr.body).into(), + } + } +} diff --git a/crates/ruff_python_ast/src/visitor.rs b/crates/ruff_python_ast/src/visitor.rs index 8084f030c8f55..f740044fb5ffc 100644 --- a/crates/ruff_python_ast/src/visitor.rs +++ b/crates/ruff_python_ast/src/visitor.rs @@ -1,6 +1,7 @@ //! AST visitor trait and walk functions. pub mod preorder; +pub mod transformer; use crate::{ self as ast, Alias, Arguments, BoolOp, CmpOp, Comprehension, Decorator, ElifElseClause, @@ -14,8 +15,10 @@ use crate::{ /// Prefer [`crate::statement_visitor::StatementVisitor`] for visitors that only need to visit /// statements. /// -/// Use the [`PreorderVisitor`](self::preorder::PreorderVisitor) if you want to visit the nodes +/// Use the [`PreorderVisitor`](preorder::PreorderVisitor) if you want to visit the nodes /// in pre-order rather than evaluation order. +/// +/// Use the [`Transformer`](transformer::Transformer) if you want to modify the nodes. pub trait Visitor<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { walk_stmt(self, stmt); diff --git a/crates/ruff_python_ast/src/visitor/transformer.rs b/crates/ruff_python_ast/src/visitor/transformer.rs new file mode 100644 index 0000000000000..b90ab0a1b61e9 --- /dev/null +++ b/crates/ruff_python_ast/src/visitor/transformer.rs @@ -0,0 +1,732 @@ +use crate::{ + self as ast, Alias, Arguments, BoolOp, CmpOp, Comprehension, Decorator, ElifElseClause, + ExceptHandler, Expr, ExprContext, Keyword, MatchCase, Operator, Parameter, Parameters, Pattern, + PatternArguments, PatternKeyword, Stmt, TypeParam, TypeParamTypeVar, TypeParams, UnaryOp, + WithItem, +}; + +/// A trait for transforming ASTs. Visits all nodes in the AST recursively in evaluation-order. +pub trait Transformer { + fn visit_stmt(&self, stmt: &mut Stmt) { + walk_stmt(self, stmt); + } + fn visit_annotation(&self, expr: &mut Expr) { + walk_annotation(self, expr); + } + fn visit_decorator(&self, decorator: &mut Decorator) { + walk_decorator(self, decorator); + } + fn visit_expr(&self, expr: &mut Expr) { + walk_expr(self, expr); + } + fn visit_expr_context(&self, expr_context: &mut ExprContext) { + walk_expr_context(self, expr_context); + } + fn visit_bool_op(&self, bool_op: &mut BoolOp) { + walk_bool_op(self, bool_op); + } + fn visit_operator(&self, operator: &mut Operator) { + walk_operator(self, operator); + } + fn visit_unary_op(&self, unary_op: &mut UnaryOp) { + walk_unary_op(self, unary_op); + } + fn visit_cmp_op(&self, cmp_op: &mut CmpOp) { + walk_cmp_op(self, cmp_op); + } + fn visit_comprehension(&self, comprehension: &mut Comprehension) { + walk_comprehension(self, comprehension); + } + fn visit_except_handler(&self, except_handler: &mut ExceptHandler) { + walk_except_handler(self, except_handler); + } + fn visit_format_spec(&self, format_spec: &mut Expr) { + walk_format_spec(self, format_spec); + } + fn visit_arguments(&self, arguments: &mut Arguments) { + walk_arguments(self, arguments); + } + fn visit_parameters(&self, parameters: &mut Parameters) { + walk_parameters(self, parameters); + } + fn visit_parameter(&self, parameter: &mut Parameter) { + walk_parameter(self, parameter); + } + fn visit_keyword(&self, keyword: &mut Keyword) { + walk_keyword(self, keyword); + } + fn visit_alias(&self, alias: &mut Alias) { + walk_alias(self, alias); + } + fn visit_with_item(&self, with_item: &mut WithItem) { + walk_with_item(self, with_item); + } + fn visit_type_params(&self, type_params: &mut TypeParams) { + walk_type_params(self, type_params); + } + fn visit_type_param(&self, type_param: &mut TypeParam) { + walk_type_param(self, type_param); + } + fn visit_match_case(&self, match_case: &mut MatchCase) { + walk_match_case(self, match_case); + } + fn visit_pattern(&self, pattern: &mut Pattern) { + walk_pattern(self, pattern); + } + fn visit_pattern_arguments(&self, pattern_arguments: &mut PatternArguments) { + walk_pattern_arguments(self, pattern_arguments); + } + fn visit_pattern_keyword(&self, pattern_keyword: &mut PatternKeyword) { + walk_pattern_keyword(self, pattern_keyword); + } + fn visit_body(&self, body: &mut [Stmt]) { + walk_body(self, body); + } + fn visit_elif_else_clause(&self, elif_else_clause: &mut ElifElseClause) { + walk_elif_else_clause(self, elif_else_clause); + } +} + +pub fn walk_body(visitor: &V, body: &mut [Stmt]) { + for stmt in body { + visitor.visit_stmt(stmt); + } +} + +pub fn walk_elif_else_clause( + visitor: &V, + elif_else_clause: &mut ElifElseClause, +) { + if let Some(test) = &mut elif_else_clause.test { + visitor.visit_expr(test); + } + visitor.visit_body(&mut elif_else_clause.body); +} + +pub fn walk_stmt(visitor: &V, stmt: &mut Stmt) { + match stmt { + Stmt::FunctionDef(ast::StmtFunctionDef { + parameters, + body, + decorator_list, + returns, + type_params, + .. + }) => { + for decorator in decorator_list { + visitor.visit_decorator(decorator); + } + if let Some(type_params) = type_params { + visitor.visit_type_params(type_params); + } + visitor.visit_parameters(parameters); + for expr in returns { + visitor.visit_annotation(expr); + } + visitor.visit_body(body); + } + Stmt::ClassDef(ast::StmtClassDef { + arguments, + body, + decorator_list, + type_params, + .. + }) => { + for decorator in decorator_list { + visitor.visit_decorator(decorator); + } + if let Some(type_params) = type_params { + visitor.visit_type_params(type_params); + } + if let Some(arguments) = arguments { + visitor.visit_arguments(arguments); + } + visitor.visit_body(body); + } + Stmt::Return(ast::StmtReturn { value, range: _ }) => { + if let Some(expr) = value { + visitor.visit_expr(expr); + } + } + Stmt::Delete(ast::StmtDelete { targets, range: _ }) => { + for expr in targets { + visitor.visit_expr(expr); + } + } + Stmt::TypeAlias(ast::StmtTypeAlias { + range: _, + name, + type_params, + value, + }) => { + visitor.visit_expr(value); + if let Some(type_params) = type_params { + visitor.visit_type_params(type_params); + } + visitor.visit_expr(name); + } + Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { + visitor.visit_expr(value); + for expr in targets { + visitor.visit_expr(expr); + } + } + Stmt::AugAssign(ast::StmtAugAssign { + target, + op, + value, + range: _, + }) => { + visitor.visit_expr(value); + visitor.visit_operator(op); + visitor.visit_expr(target); + } + Stmt::AnnAssign(ast::StmtAnnAssign { + target, + annotation, + value, + .. + }) => { + if let Some(expr) = value { + visitor.visit_expr(expr); + } + visitor.visit_annotation(annotation); + visitor.visit_expr(target); + } + Stmt::For(ast::StmtFor { + target, + iter, + body, + orelse, + .. + }) => { + visitor.visit_expr(iter); + visitor.visit_expr(target); + visitor.visit_body(body); + visitor.visit_body(orelse); + } + Stmt::While(ast::StmtWhile { + test, + body, + orelse, + range: _, + }) => { + visitor.visit_expr(test); + visitor.visit_body(body); + visitor.visit_body(orelse); + } + Stmt::If(ast::StmtIf { + test, + body, + elif_else_clauses, + range: _, + }) => { + visitor.visit_expr(test); + visitor.visit_body(body); + for clause in elif_else_clauses { + if let Some(test) = &mut clause.test { + visitor.visit_expr(test); + } + walk_elif_else_clause(visitor, clause); + } + } + Stmt::With(ast::StmtWith { items, body, .. }) => { + for with_item in items { + visitor.visit_with_item(with_item); + } + visitor.visit_body(body); + } + Stmt::Match(ast::StmtMatch { + subject, + cases, + range: _, + }) => { + visitor.visit_expr(subject); + for match_case in cases { + visitor.visit_match_case(match_case); + } + } + Stmt::Raise(ast::StmtRaise { + exc, + cause, + range: _, + }) => { + if let Some(expr) = exc { + visitor.visit_expr(expr); + }; + if let Some(expr) = cause { + visitor.visit_expr(expr); + }; + } + Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + is_star: _, + range: _, + }) => { + visitor.visit_body(body); + for except_handler in handlers { + visitor.visit_except_handler(except_handler); + } + visitor.visit_body(orelse); + visitor.visit_body(finalbody); + } + Stmt::Assert(ast::StmtAssert { + test, + msg, + range: _, + }) => { + visitor.visit_expr(test); + if let Some(expr) = msg { + visitor.visit_expr(expr); + } + } + Stmt::Import(ast::StmtImport { names, range: _ }) => { + for alias in names { + visitor.visit_alias(alias); + } + } + Stmt::ImportFrom(ast::StmtImportFrom { names, .. }) => { + for alias in names { + visitor.visit_alias(alias); + } + } + Stmt::Global(_) => {} + Stmt::Nonlocal(_) => {} + Stmt::Expr(ast::StmtExpr { value, range: _ }) => visitor.visit_expr(value), + Stmt::Pass(_) | Stmt::Break(_) | Stmt::Continue(_) | Stmt::IpyEscapeCommand(_) => {} + } +} + +pub fn walk_annotation(visitor: &V, expr: &mut Expr) { + visitor.visit_expr(expr); +} + +pub fn walk_decorator(visitor: &V, decorator: &mut Decorator) { + visitor.visit_expr(&mut decorator.expression); +} + +pub fn walk_expr(visitor: &V, expr: &mut Expr) { + match expr { + Expr::BoolOp(ast::ExprBoolOp { + op, + values, + range: _, + }) => { + visitor.visit_bool_op(op); + for expr in values { + visitor.visit_expr(expr); + } + } + Expr::NamedExpr(ast::ExprNamedExpr { + target, + value, + range: _, + }) => { + visitor.visit_expr(value); + visitor.visit_expr(target); + } + Expr::BinOp(ast::ExprBinOp { + left, + op, + right, + range: _, + }) => { + visitor.visit_expr(left); + visitor.visit_operator(op); + visitor.visit_expr(right); + } + Expr::UnaryOp(ast::ExprUnaryOp { + op, + operand, + range: _, + }) => { + visitor.visit_unary_op(op); + visitor.visit_expr(operand); + } + Expr::Lambda(ast::ExprLambda { + parameters, + body, + range: _, + }) => { + if let Some(parameters) = parameters { + visitor.visit_parameters(parameters); + } + visitor.visit_expr(body); + } + Expr::IfExp(ast::ExprIfExp { + test, + body, + orelse, + range: _, + }) => { + visitor.visit_expr(test); + visitor.visit_expr(body); + visitor.visit_expr(orelse); + } + Expr::Dict(ast::ExprDict { + keys, + values, + range: _, + }) => { + for expr in keys.iter_mut().flatten() { + visitor.visit_expr(expr); + } + for expr in values { + visitor.visit_expr(expr); + } + } + Expr::Set(ast::ExprSet { elts, range: _ }) => { + for expr in elts { + visitor.visit_expr(expr); + } + } + Expr::ListComp(ast::ExprListComp { + elt, + generators, + range: _, + }) => { + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + visitor.visit_expr(elt); + } + Expr::SetComp(ast::ExprSetComp { + elt, + generators, + range: _, + }) => { + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + visitor.visit_expr(elt); + } + Expr::DictComp(ast::ExprDictComp { + key, + value, + generators, + range: _, + }) => { + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + visitor.visit_expr(key); + visitor.visit_expr(value); + } + Expr::GeneratorExp(ast::ExprGeneratorExp { + elt, + generators, + range: _, + }) => { + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + visitor.visit_expr(elt); + } + Expr::Await(ast::ExprAwait { value, range: _ }) => visitor.visit_expr(value), + Expr::Yield(ast::ExprYield { value, range: _ }) => { + if let Some(expr) = value { + visitor.visit_expr(expr); + } + } + Expr::YieldFrom(ast::ExprYieldFrom { value, range: _ }) => visitor.visit_expr(value), + Expr::Compare(ast::ExprCompare { + left, + ops, + comparators, + range: _, + }) => { + visitor.visit_expr(left); + for cmp_op in ops { + visitor.visit_cmp_op(cmp_op); + } + for expr in comparators { + visitor.visit_expr(expr); + } + } + Expr::Call(ast::ExprCall { + func, + arguments, + range: _, + }) => { + visitor.visit_expr(func); + visitor.visit_arguments(arguments); + } + Expr::FormattedValue(ast::ExprFormattedValue { + value, format_spec, .. + }) => { + visitor.visit_expr(value); + if let Some(expr) = format_spec { + visitor.visit_format_spec(expr); + } + } + Expr::FString(ast::ExprFString { values, .. }) => { + for expr in values { + visitor.visit_expr(expr); + } + } + Expr::StringLiteral(_) + | Expr::BytesLiteral(_) + | Expr::NumberLiteral(_) + | Expr::BooleanLiteral(_) + | Expr::NoneLiteral(_) + | Expr::EllipsisLiteral(_) => {} + Expr::Attribute(ast::ExprAttribute { value, ctx, .. }) => { + visitor.visit_expr(value); + visitor.visit_expr_context(ctx); + } + Expr::Subscript(ast::ExprSubscript { + value, + slice, + ctx, + range: _, + }) => { + visitor.visit_expr(value); + visitor.visit_expr(slice); + visitor.visit_expr_context(ctx); + } + Expr::Starred(ast::ExprStarred { + value, + ctx, + range: _, + }) => { + visitor.visit_expr(value); + visitor.visit_expr_context(ctx); + } + Expr::Name(ast::ExprName { ctx, .. }) => { + visitor.visit_expr_context(ctx); + } + Expr::List(ast::ExprList { + elts, + ctx, + range: _, + }) => { + for expr in elts { + visitor.visit_expr(expr); + } + visitor.visit_expr_context(ctx); + } + Expr::Tuple(ast::ExprTuple { + elts, + ctx, + range: _, + }) => { + for expr in elts { + visitor.visit_expr(expr); + } + visitor.visit_expr_context(ctx); + } + Expr::Slice(ast::ExprSlice { + lower, + upper, + step, + range: _, + }) => { + if let Some(expr) = lower { + visitor.visit_expr(expr); + } + if let Some(expr) = upper { + visitor.visit_expr(expr); + } + if let Some(expr) = step { + visitor.visit_expr(expr); + } + } + Expr::IpyEscapeCommand(_) => {} + } +} + +pub fn walk_comprehension(visitor: &V, comprehension: &mut Comprehension) { + visitor.visit_expr(&mut comprehension.iter); + visitor.visit_expr(&mut comprehension.target); + for expr in &mut comprehension.ifs { + visitor.visit_expr(expr); + } +} + +pub fn walk_except_handler( + visitor: &V, + except_handler: &mut ExceptHandler, +) { + match except_handler { + ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { type_, body, .. }) => { + if let Some(expr) = type_ { + visitor.visit_expr(expr); + } + visitor.visit_body(body); + } + } +} + +pub fn walk_format_spec(visitor: &V, format_spec: &mut Expr) { + visitor.visit_expr(format_spec); +} + +pub fn walk_arguments(visitor: &V, arguments: &mut Arguments) { + // Note that the there might be keywords before the last arg, e.g. in + // f(*args, a=2, *args2, **kwargs)`, but we follow Python in evaluating first `args` and then + // `keywords`. See also [Arguments::arguments_source_order`]. + for arg in &mut arguments.args { + visitor.visit_expr(arg); + } + for keyword in &mut arguments.keywords { + visitor.visit_keyword(keyword); + } +} + +pub fn walk_parameters(visitor: &V, parameters: &mut Parameters) { + // Defaults are evaluated before annotations. + for arg in &mut parameters.posonlyargs { + if let Some(default) = &mut arg.default { + visitor.visit_expr(default); + } + } + for arg in &mut parameters.args { + if let Some(default) = &mut arg.default { + visitor.visit_expr(default); + } + } + for arg in &mut parameters.kwonlyargs { + if let Some(default) = &mut arg.default { + visitor.visit_expr(default); + } + } + + for arg in &mut parameters.posonlyargs { + visitor.visit_parameter(&mut arg.parameter); + } + for arg in &mut parameters.args { + visitor.visit_parameter(&mut arg.parameter); + } + if let Some(arg) = &mut parameters.vararg { + visitor.visit_parameter(arg); + } + for arg in &mut parameters.kwonlyargs { + visitor.visit_parameter(&mut arg.parameter); + } + if let Some(arg) = &mut parameters.kwarg { + visitor.visit_parameter(arg); + } +} + +pub fn walk_parameter(visitor: &V, parameter: &mut Parameter) { + if let Some(expr) = &mut parameter.annotation { + visitor.visit_annotation(expr); + } +} + +pub fn walk_keyword(visitor: &V, keyword: &mut Keyword) { + visitor.visit_expr(&mut keyword.value); +} + +pub fn walk_with_item(visitor: &V, with_item: &mut WithItem) { + visitor.visit_expr(&mut with_item.context_expr); + if let Some(expr) = &mut with_item.optional_vars { + visitor.visit_expr(expr); + } +} + +pub fn walk_type_params(visitor: &V, type_params: &mut TypeParams) { + for type_param in &mut type_params.type_params { + visitor.visit_type_param(type_param); + } +} + +pub fn walk_type_param(visitor: &V, type_param: &mut TypeParam) { + match type_param { + TypeParam::TypeVar(TypeParamTypeVar { + bound, + name: _, + range: _, + }) => { + if let Some(expr) = bound { + visitor.visit_expr(expr); + } + } + TypeParam::TypeVarTuple(_) | TypeParam::ParamSpec(_) => {} + } +} + +pub fn walk_match_case(visitor: &V, match_case: &mut MatchCase) { + visitor.visit_pattern(&mut match_case.pattern); + if let Some(expr) = &mut match_case.guard { + visitor.visit_expr(expr); + } + visitor.visit_body(&mut match_case.body); +} + +pub fn walk_pattern(visitor: &V, pattern: &mut Pattern) { + match pattern { + Pattern::MatchValue(ast::PatternMatchValue { value, .. }) => { + visitor.visit_expr(value); + } + Pattern::MatchSingleton(_) => {} + Pattern::MatchSequence(ast::PatternMatchSequence { patterns, .. }) => { + for pattern in patterns { + visitor.visit_pattern(pattern); + } + } + Pattern::MatchMapping(ast::PatternMatchMapping { keys, patterns, .. }) => { + for expr in keys { + visitor.visit_expr(expr); + } + for pattern in patterns { + visitor.visit_pattern(pattern); + } + } + Pattern::MatchClass(ast::PatternMatchClass { cls, arguments, .. }) => { + visitor.visit_expr(cls); + visitor.visit_pattern_arguments(arguments); + } + Pattern::MatchStar(_) => {} + Pattern::MatchAs(ast::PatternMatchAs { pattern, .. }) => { + if let Some(pattern) = pattern { + visitor.visit_pattern(pattern); + } + } + Pattern::MatchOr(ast::PatternMatchOr { patterns, .. }) => { + for pattern in patterns { + visitor.visit_pattern(pattern); + } + } + } +} + +pub fn walk_pattern_arguments( + visitor: &V, + pattern_arguments: &mut PatternArguments, +) { + for pattern in &mut pattern_arguments.patterns { + visitor.visit_pattern(pattern); + } + for keyword in &mut pattern_arguments.keywords { + visitor.visit_pattern_keyword(keyword); + } +} + +pub fn walk_pattern_keyword( + visitor: &V, + pattern_keyword: &mut PatternKeyword, +) { + visitor.visit_pattern(&mut pattern_keyword.pattern); +} + +#[allow(unused_variables)] +pub fn walk_expr_context(visitor: &V, expr_context: &mut ExprContext) {} + +#[allow(unused_variables)] +pub fn walk_bool_op(visitor: &V, bool_op: &mut BoolOp) {} + +#[allow(unused_variables)] +pub fn walk_operator(visitor: &V, operator: &mut Operator) {} + +#[allow(unused_variables)] +pub fn walk_unary_op(visitor: &V, unary_op: &mut UnaryOp) {} + +#[allow(unused_variables)] +pub fn walk_cmp_op(visitor: &V, cmp_op: &mut CmpOp) {} + +#[allow(unused_variables)] +pub fn walk_alias(visitor: &V, alias: &mut Alias) {} diff --git a/crates/ruff_python_formatter/tests/fixtures.rs b/crates/ruff_python_formatter/tests/fixtures.rs index 7c11a1beb7ccc..c3fd2b20707e8 100644 --- a/crates/ruff_python_formatter/tests/fixtures.rs +++ b/crates/ruff_python_formatter/tests/fixtures.rs @@ -1,11 +1,18 @@ -use ruff_formatter::FormatOptions; -use ruff_python_formatter::{format_module_source, PreviewMode, PyFormatOptions}; -use similar::TextDiff; use std::fmt::{Formatter, Write}; use std::io::BufReader; use std::path::Path; use std::{fmt, fs}; +use similar::TextDiff; + +use crate::normalizer::Normalizer; +use ruff_formatter::FormatOptions; +use ruff_python_ast::comparable::ComparableMod; +use ruff_python_formatter::{format_module_source, PreviewMode, PyFormatOptions}; +use ruff_python_parser::{parse, AsMode}; + +mod normalizer; + #[test] fn black_compatibility() { let test_file = |input_path: &Path| { @@ -33,6 +40,7 @@ fn black_compatibility() { let formatted_code = printed.as_code(); + ensure_unchanged_ast(&content, formatted_code, &options, input_path); ensure_stability_when_formatting_twice(formatted_code, options, input_path); if formatted_code == expected_output { @@ -111,6 +119,7 @@ fn format() { format_module_source(&content, options.clone()).expect("Formatting to succeed"); let formatted_code = printed.as_code(); + ensure_unchanged_ast(&content, formatted_code, &options, input_path); ensure_stability_when_formatting_twice(formatted_code, options.clone(), input_path); let mut snapshot = format!("## Input\n{}", CodeFrame::new("python", &content)); @@ -128,6 +137,7 @@ fn format() { format_module_source(&content, options.clone()).expect("Formatting to succeed"); let formatted_code = printed.as_code(); + ensure_unchanged_ast(&content, formatted_code, &options, input_path); ensure_stability_when_formatting_twice(formatted_code, options.clone(), input_path); writeln!( @@ -140,29 +150,20 @@ fn format() { .unwrap(); } } else { - let printed = - format_module_source(&content, options.clone()).expect("Formatting to succeed"); - let formatted = printed.as_code(); - - ensure_stability_when_formatting_twice(formatted, options.clone(), input_path); - // We want to capture the differences in the preview style in our fixtures let options_preview = options.with_preview(PreviewMode::Enabled); let printed_preview = format_module_source(&content, options_preview.clone()) .expect("Formatting to succeed"); let formatted_preview = printed_preview.as_code(); - ensure_stability_when_formatting_twice( - formatted_preview, - options_preview.clone(), - input_path, - ); + ensure_unchanged_ast(&content, formatted_preview, &options_preview, input_path); + ensure_stability_when_formatting_twice(formatted_preview, options_preview, input_path); - if formatted == formatted_preview { + if formatted_code == formatted_preview { writeln!( snapshot, "## Output\n{}", - CodeFrame::new("python", &formatted) + CodeFrame::new("python", &formatted_code) ) .unwrap(); } else { @@ -171,10 +172,10 @@ fn format() { writeln!( snapshot, "## Output\n{}\n## Preview changes\n{}", - CodeFrame::new("python", &formatted), + CodeFrame::new("python", &formatted_code), CodeFrame::new( "diff", - TextDiff::from_lines(formatted, formatted_preview) + TextDiff::from_lines(formatted_code, formatted_preview) .unified_diff() .header("Stable", "Preview") ) @@ -239,6 +240,57 @@ Formatted twice: } } +/// Ensure that formatting doesn't change the AST. +/// +/// Like Black, there are a few exceptions to this "invariant" which are encoded in +/// [`NormalizedMod`] and related structs. Namely, formatting can change indentation within strings, +/// and can also flatten tuples within `del` statements. +fn ensure_unchanged_ast( + unformatted_code: &str, + formatted_code: &str, + options: &PyFormatOptions, + input_path: &Path, +) { + let source_type = options.source_type(); + + // Parse the unformatted code. + let mut unformatted_ast = parse( + unformatted_code, + source_type.as_mode(), + &input_path.to_string_lossy(), + ) + .expect("Unformatted code to be valid syntax"); + Normalizer.visit_module(&mut unformatted_ast); + let unformatted_ast = ComparableMod::from(&unformatted_ast); + + // Parse the formatted code. + let mut formatted_ast = parse( + formatted_code, + source_type.as_mode(), + &input_path.to_string_lossy(), + ) + .expect("Formatted code to be valid syntax"); + Normalizer.visit_module(&mut formatted_ast); + let formatted_ast = ComparableMod::from(&formatted_ast); + + if formatted_ast != unformatted_ast { + let diff = TextDiff::from_lines( + &format!("{unformatted_ast:#?}"), + &format!("{formatted_ast:#?}"), + ) + .unified_diff() + .header("Unformatted", "Formatted") + .to_string(); + panic!( + r#"Reformatting the unformatted code of {} resulted in AST changes. +--- +{diff} +"#, + input_path.display(), + ); + } +} + struct Header<'a> { title: &'a str, } diff --git a/crates/ruff_python_formatter/tests/normalizer.rs b/crates/ruff_python_formatter/tests/normalizer.rs new file mode 100644 index 0000000000000..5aab798d69333 --- /dev/null +++ b/crates/ruff_python_formatter/tests/normalizer.rs @@ -0,0 +1,83 @@ +use itertools::Either::{Left, Right}; + +use ruff_python_ast::visitor::transformer; +use ruff_python_ast::visitor::transformer::Transformer; +use ruff_python_ast::{self as ast, Expr, Stmt}; + +/// A struct to normalize AST nodes for the purpose of comparing formatted representations for +/// semantic equivalence. +/// +/// Vis-à-vis comparing ASTs, comparing these normalized representations does the following: +/// - Ignores non-abstraction information that we've encoded into the AST, e.g., the difference +/// between `class C: ...` and `class C(): ...`, which is part of our AST but not `CPython`'s. +/// - Normalize strings. The formatter can re-indent docstrings, so we need to compare string +/// contents ignoring whitespace. (Black does the same.) +/// - Ignores nested tuples in deletions. (Black does the same.) +pub(crate) struct Normalizer; + +impl Normalizer { + /// Transform an AST module into a normalized representation. + #[allow(dead_code)] + pub(crate) fn visit_module(&self, module: &mut ast::Mod) { + match module { + ast::Mod::Module(module) => { + self.visit_body(&mut module.body); + } + ast::Mod::Expression(expression) => { + self.visit_expr(&mut expression.body); + } + } + } +} + +impl Transformer for Normalizer { + fn visit_stmt(&self, stmt: &mut Stmt) { + match stmt { + Stmt::ClassDef(class_def) => { + // Treat `class C: ...` and `class C(): ...` equivalently. + if class_def + .arguments + .as_ref() + .is_some_and(|arguments| arguments.is_empty()) + { + class_def.arguments = None; + } + } + Stmt::Delete(delete) => { + // Treat `del a, b` and `del (a, b)` equivalently. + delete.targets = delete + .targets + .clone() + .into_iter() + .flat_map(|target| { + if let Expr::Tuple(tuple) = target { + Left(tuple.elts.into_iter()) + } else { + Right(std::iter::once(target)) + } + }) + .collect(); + } + _ => {} + } + + transformer::walk_stmt(self, stmt); + } + + fn visit_expr(&self, expr: &mut Expr) { + if let Expr::StringLiteral(string_literal) = expr { + // Normalize a string by (1) stripping any leading and trailing space from each + // line, and (2) removing any blank lines from the start and end of the string. + string_literal.value = string_literal + .value + .lines() + .map(str::trim) + .collect::>() + .join("\n") + .trim() + .to_owned(); + } + + transformer::walk_expr(self, expr); + } +}