diff --git a/crates/ruff_python_ast/src/node.rs b/crates/ruff_python_ast/src/node.rs index 35cfb9147cd81..17725fe0d65ae 100644 --- a/crates/ruff_python_ast/src/node.rs +++ b/crates/ruff_python_ast/src/node.rs @@ -4817,7 +4817,7 @@ pub enum AnyNodeRef<'a> { ElifElseClause(&'a ast::ElifElseClause), } -impl AnyNodeRef<'_> { +impl<'a> AnyNodeRef<'a> { pub fn as_ptr(&self) -> NonNull<()> { match self { AnyNodeRef::ModModule(node) => NonNull::from(*node).cast(), @@ -5456,9 +5456,9 @@ impl AnyNodeRef<'_> { ) } - pub fn visit_preorder<'a, V>(&'a self, visitor: &mut V) + pub fn visit_preorder<'b, V>(&'b self, visitor: &mut V) where - V: PreorderVisitor<'a> + ?Sized, + V: PreorderVisitor<'b> + ?Sized, { match self { AnyNodeRef::ModModule(node) => node.visit_preorder(visitor), @@ -5544,6 +5544,66 @@ impl AnyNodeRef<'_> { AnyNodeRef::ElifElseClause(node) => node.visit_preorder(visitor), } } + + /// The last child of the last branch, if the node has multiple branches. + pub fn last_child_in_body(&self) -> Option> { + let body = match self { + AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. }) + | AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. }) + | AnyNodeRef::StmtWith(ast::StmtWith { body, .. }) + | AnyNodeRef::MatchCase(MatchCase { body, .. }) + | AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler { + body, + .. + }) + | AnyNodeRef::ElifElseClause(ast::ElifElseClause { body, .. }) => body, + AnyNodeRef::StmtIf(ast::StmtIf { + body, + elif_else_clauses, + .. + }) => elif_else_clauses.last().map_or(body, |clause| &clause.body), + + AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. }) + | AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => { + if orelse.is_empty() { + body + } else { + orelse + } + } + + AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => { + return cases.last().map(AnyNodeRef::from); + } + + AnyNodeRef::StmtTry(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + if finalbody.is_empty() { + if orelse.is_empty() { + if handlers.is_empty() { + body + } else { + return handlers.last().map(AnyNodeRef::from); + } + } else { + orelse + } + } else { + finalbody + } + } + + // Not a node that contains an indented child node. + _ => return None, + }; + + body.last().map(AnyNodeRef::from) + } } impl<'a> From<&'a ast::ModModule> for AnyNodeRef<'a> { diff --git a/crates/ruff_python_formatter/resources/test/fixtures/ruff/newlines.py b/crates/ruff_python_formatter/resources/test/fixtures/ruff/newlines.py index 0d33408ecbf4d..9e355f3bcf1fa 100644 --- a/crates/ruff_python_formatter/resources/test/fixtures/ruff/newlines.py +++ b/crates/ruff_python_formatter/resources/test/fixtures/ruff/newlines.py @@ -1,6 +1,7 @@ ### # Blank lines around functions ### +import sys x = 1 @@ -159,3 +160,96 @@ def f(): # comment x = 1 + + +def f(): + if True: + + def double(s): + return s + s + print("below function") + if True: + + class A: + x = 1 + print("below class") + if True: + + def double(s): + return s + s + # + print("below comment function") + if True: + + class A: + x = 1 + # + print("below comment class") + if True: + + def double(s): + return s + s + # + print("below comment function 2") + if True: + + def double(s): + return s + s + # + def outer(): + def inner(): + pass + print("below nested functions") + +if True: + + def double(s): + return s + s +print("below function") +if True: + + class A: + x = 1 +print("below class") +def outer(): + def inner(): + pass +print("below nested functions") + + +class Path: + if sys.version_info >= (3, 11): + def joinpath(self): ... + + # The .open method comes from pathlib.pyi and should be kept in sync. + @overload + def open(self): ... + + + + +def fakehttp(): + + class FakeHTTPConnection: + if mock_close: + def close(self): + pass + FakeHTTPConnection.fakedata = fakedata + + + + + +if True: + if False: + def x(): + def y(): + pass + #comment + print() + + +# NOTE: Please keep this the last block in this file +if True: + def nested_trailing_function(): + pass \ No newline at end of file diff --git a/crates/ruff_python_formatter/src/comments/placement.rs b/crates/ruff_python_formatter/src/comments/placement.rs index 88b72cdd89430..86d63a2c8cf0f 100644 --- a/crates/ruff_python_formatter/src/comments/placement.rs +++ b/crates/ruff_python_formatter/src/comments/placement.rs @@ -347,9 +347,9 @@ fn handle_end_of_line_comment_around_body<'a>( // ``` // The first earlier branch filters out ambiguities e.g. around try-except-finally. if let Some(preceding) = comment.preceding_node() { - if let Some(last_child) = last_child_in_body(preceding) { + if let Some(last_child) = preceding.last_child_in_body() { let innermost_child = - std::iter::successors(Some(last_child), |parent| last_child_in_body(*parent)) + std::iter::successors(Some(last_child), AnyNodeRef::last_child_in_body) .last() .unwrap_or(last_child); return CommentPlacement::trailing(innermost_child, comment); @@ -670,7 +670,7 @@ fn handle_own_line_comment_after_branch<'a>( preceding: AnyNodeRef<'a>, locator: &Locator, ) -> CommentPlacement<'a> { - let Some(last_child) = last_child_in_body(preceding) else { + let Some(last_child) = preceding.last_child_in_body() else { return CommentPlacement::Default(comment); }; @@ -734,7 +734,7 @@ fn handle_own_line_comment_after_branch<'a>( return CommentPlacement::trailing(last_child_in_parent, comment); } Ordering::Greater => { - if let Some(nested_child) = last_child_in_body(last_child_in_parent) { + if let Some(nested_child) = last_child_in_parent.last_child_in_body() { // The comment belongs to the inner block. parent = Some(last_child_in_parent); last_child_in_parent = nested_child; @@ -2176,65 +2176,6 @@ where right.is_some_and(|right| left.ptr_eq(right.into())) } -/// The last child of the last branch, if the node has multiple branches. -fn last_child_in_body(node: AnyNodeRef) -> Option { - let body = match node { - AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. }) - | AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. }) - | AnyNodeRef::StmtWith(ast::StmtWith { body, .. }) - | AnyNodeRef::MatchCase(MatchCase { body, .. }) - | AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler { - body, .. - }) - | AnyNodeRef::ElifElseClause(ast::ElifElseClause { body, .. }) => body, - AnyNodeRef::StmtIf(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => elif_else_clauses.last().map_or(body, |clause| &clause.body), - - AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. }) - | AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => { - if orelse.is_empty() { - body - } else { - orelse - } - } - - AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => { - return cases.last().map(AnyNodeRef::from); - } - - AnyNodeRef::StmtTry(ast::StmtTry { - body, - handlers, - orelse, - finalbody, - .. - }) => { - if finalbody.is_empty() { - if orelse.is_empty() { - if handlers.is_empty() { - body - } else { - return handlers.last().map(AnyNodeRef::from); - } - } else { - orelse - } - } else { - finalbody - } - } - - // Not a node that contains an indented child node. - _ => return None, - }; - - body.last().map(AnyNodeRef::from) -} - /// Returns `true` if `statement` is the first statement in an alternate `body` (e.g. the else of an if statement) fn is_first_statement_in_alternate_body(statement: AnyNodeRef, has_body: AnyNodeRef) -> bool { match has_body { diff --git a/crates/ruff_python_formatter/src/statement/suite.rs b/crates/ruff_python_formatter/src/statement/suite.rs index 9575f78d86a77..c6361c53bb3f7 100644 --- a/crates/ruff_python_formatter/src/statement/suite.rs +++ b/crates/ruff_python_formatter/src/statement/suite.rs @@ -155,13 +155,63 @@ impl FormatRule> for FormatSuite { while let Some(following) = iter.next() { let following_comments = comments.leading_dangling_trailing(following); + let needs_empty_lines = if is_class_or_function_definition(following) { + // Here we insert empty lines even if the preceding has a trailing own line comment + true + } else if preceding_comments.has_trailing_own_line() { + // If there is a comment between preceding and following the empty lines were + // inserted before the comment by preceding and there are no extra empty lines after + // the comment, which also includes nested class/function definitions. + // ```python + // class Test: + // def a(self): + // pass + // # trailing comment + // + // + // # two lines before, one line after + // + // c = 30 + false + } else { + // Find nested class or function definitions that need an empty line after them. + // + // ```python + // def f(): + // if True: + // + // def double(s): + // return s + s + // + // print("below function") + // ``` + // Again, a empty lines are inserted before comments + // ```python + // if True: + // + // def double(s): + // return s + s + // + // # + // print("below comment function") + // ``` + std::iter::successors(Some(AnyNodeRef::from(preceding)), |parent| { + parent + .last_child_in_body() + .filter(|last_child| !comments.has_trailing_own_line(*last_child)) + }) + .any(|last_child| { + matches!( + last_child, + AnyNodeRef::StmtFunctionDef(_) | AnyNodeRef::StmtClassDef(_) + ) + }) + }; + // Add empty lines before and after a function or class definition. If the preceding // node is a function or class, and contains trailing comments, then the statement // itself will add the requisite empty lines when formatting its comments. - if (is_class_or_function_definition(preceding) - && !preceding_comments.has_trailing_own_line()) - || is_class_or_function_definition(following) - { + if needs_empty_lines { if source_type.is_stub() { stub_file_empty_lines( self.kind, diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__function2.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__function2.py.snap index 2570ff4de030a..2d1ed7c6ad0d6 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__function2.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__function2.py.snap @@ -73,7 +73,7 @@ with hmm_but_this_should_get_two_preceding_newlines(): elif os.name == "nt": try: import msvcrt -@@ -54,12 +53,10 @@ +@@ -54,7 +53,6 @@ class IHopeYouAreHavingALovelyDay: def __call__(self): print("i_should_be_followed_by_only_one_newline") @@ -81,11 +81,6 @@ with hmm_but_this_should_get_two_preceding_newlines(): else: def foo(): - pass -- - - with hmm_but_this_should_get_two_preceding_newlines(): - pass ``` ## Ruff Output @@ -151,6 +146,7 @@ else: def foo(): pass + with hmm_but_this_should_get_two_preceding_newlines(): pass ``` diff --git a/crates/ruff_python_formatter/tests/snapshots/format@newlines.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@newlines.py.snap index 654d55dbb9789..a0873d299eafa 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@newlines.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@newlines.py.snap @@ -7,6 +7,7 @@ input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/newlines.p ### # Blank lines around functions ### +import sys x = 1 @@ -165,13 +166,106 @@ def f(): # comment x = 1 -``` + + +def f(): + if True: + + def double(s): + return s + s + print("below function") + if True: + + class A: + x = 1 + print("below class") + if True: + + def double(s): + return s + s + # + print("below comment function") + if True: + + class A: + x = 1 + # + print("below comment class") + if True: + + def double(s): + return s + s + # + print("below comment function 2") + if True: + + def double(s): + return s + s + # + def outer(): + def inner(): + pass + print("below nested functions") + +if True: + + def double(s): + return s + s +print("below function") +if True: + + class A: + x = 1 +print("below class") +def outer(): + def inner(): + pass +print("below nested functions") + + +class Path: + if sys.version_info >= (3, 11): + def joinpath(self): ... + + # The .open method comes from pathlib.pyi and should be kept in sync. + @overload + def open(self): ... + + + + +def fakehttp(): + + class FakeHTTPConnection: + if mock_close: + def close(self): + pass + FakeHTTPConnection.fakedata = fakedata + + + + + +if True: + if False: + def x(): + def y(): + pass + #comment + print() + + +# NOTE: Please keep this the last block in this file +if True: + def nested_trailing_function(): + pass``` ## Output ```py ### # Blank lines around functions ### +import sys x = 1 @@ -339,6 +433,117 @@ def f(): # comment x = 1 + + +def f(): + if True: + + def double(s): + return s + s + + print("below function") + if True: + + class A: + x = 1 + + print("below class") + if True: + + def double(s): + return s + s + + # + print("below comment function") + if True: + + class A: + x = 1 + + # + print("below comment class") + if True: + + def double(s): + return s + s + + # + print("below comment function 2") + if True: + + def double(s): + return s + s + # + + def outer(): + def inner(): + pass + + print("below nested functions") + + +if True: + + def double(s): + return s + s + + +print("below function") +if True: + + class A: + x = 1 + + +print("below class") + + +def outer(): + def inner(): + pass + + +print("below nested functions") + + +class Path: + if sys.version_info >= (3, 11): + + def joinpath(self): + ... + + # The .open method comes from pathlib.pyi and should be kept in sync. + @overload + def open(self): + ... + + +def fakehttp(): + class FakeHTTPConnection: + if mock_close: + + def close(self): + pass + + FakeHTTPConnection.fakedata = fakedata + + +if True: + if False: + + def x(): + def y(): + pass + + # comment + print() + + +# NOTE: Please keep this the last block in this file +if True: + + def nested_trailing_function(): + pass ``` diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__if.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__if.py.snap index 3795d7f03922f..c72aa5423766c 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@statement__if.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__if.py.snap @@ -410,6 +410,7 @@ else: pass # 3 + if True: print("a") # 1 elif True: diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__try.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__try.py.snap index e4fe04e4348cb..b3facb5640191 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@statement__try.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__try.py.snap @@ -311,6 +311,7 @@ finally: pass # d + try: pass # a except ZeroDivisionError: