diff --git a/crates/ruff/resources/test/fixtures/flake8_pytest_style/PT018.py b/crates/ruff/resources/test/fixtures/flake8_pytest_style/PT018.py index 4b70bae72e36ad..0eafac016e6817 100644 --- a/crates/ruff/resources/test/fixtures/flake8_pytest_style/PT018.py +++ b/crates/ruff/resources/test/fixtures/flake8_pytest_style/PT018.py @@ -52,3 +52,21 @@ def test_multiline(): x = 1; \ assert something and something_else + + +# Regression test for: https://github.com/astral-sh/ruff/issues/7143 +def test_parenthesized_not(): + assert not ( + self.find_graph_output(node.output[0]) + or self.find_graph_input(node.input[0]) + or self.find_graph_output(node.input[0]) + ) + + assert (not ( + self.find_graph_output(node.output[0]) + or self.find_graph_input(node.input[0]) + or self.find_graph_output(node.input[0]) + )) + + assert (not self.find_graph_output(node.output[0]) or + self.find_graph_input(node.input[0])) diff --git a/crates/ruff/src/rules/flake8_pytest_style/rules/assertion.rs b/crates/ruff/src/rules/flake8_pytest_style/rules/assertion.rs index 816348f18774c4..498f55e244fccb 100644 --- a/crates/ruff/src/rules/flake8_pytest_style/rules/assertion.rs +++ b/crates/ruff/src/rules/flake8_pytest_style/rules/assertion.rs @@ -598,7 +598,7 @@ fn negate<'a>(expression: &Expression<'a>) -> Expression<'a> { /// assert (b == /// "") /// ``` -fn parenthesize<'a>(expression: Expression<'a>, parent: &Expression<'a>) -> Expression<'a> { +fn parenthesize<'a>(expression: &Expression<'a>, parent: &Expression<'a>) -> Expression<'a> { if matches!( expression, Expression::Comparison(_) @@ -626,10 +626,10 @@ fn parenthesize<'a>(expression: Expression<'a>, parent: &Expression<'a>) -> Expr | Expression::NamedExpr(_) ) { if let (Some(left), Some(right)) = (parent.lpar().first(), parent.rpar().first()) { - return expression.with_parens(left.clone(), right.clone()); + return expression.clone().with_parens(left.clone(), right.clone()); } } - expression + expression.clone() } /// Replace composite condition `assert a == "hello" and b == "world"` with two statements @@ -685,10 +685,16 @@ fn fix_composite_condition(stmt: &Stmt, locator: &Locator, stylist: &Stylist) -> match &assert_statement.test { Expression::UnaryOperation(op) => { if matches!(op.operator, libcst_native::UnaryOp::Not { .. }) { - if let Expression::BooleanOperation(op) = &*op.expression { - if matches!(op.operator, BooleanOp::Or { .. }) { - conditions.push(parenthesize(negate(&op.left), &assert_statement.test)); - conditions.push(parenthesize(negate(&op.right), &assert_statement.test)); + if let Expression::BooleanOperation(boolean_operation) = &*op.expression { + if matches!(boolean_operation.operator, BooleanOp::Or { .. }) { + conditions.push(negate(&parenthesize( + &boolean_operation.left, + &op.expression, + ))); + conditions.push(negate(&parenthesize( + &*boolean_operation.right, + &op.expression, + ))); } else { bail!("Expected assert statement to be a composite condition"); } @@ -699,8 +705,8 @@ fn fix_composite_condition(stmt: &Stmt, locator: &Locator, stylist: &Stylist) -> } Expression::BooleanOperation(op) => { if matches!(op.operator, BooleanOp::And { .. }) { - conditions.push(parenthesize(*op.left.clone(), &assert_statement.test)); - conditions.push(parenthesize(*op.right.clone(), &assert_statement.test)); + conditions.push(parenthesize(&op.left, &assert_statement.test)); + conditions.push(parenthesize(&op.right, &assert_statement.test)); } else { bail!("Expected assert statement to be a composite condition"); } diff --git a/crates/ruff/src/rules/flake8_pytest_style/snapshots/ruff__rules__flake8_pytest_style__tests__PT018.snap b/crates/ruff/src/rules/flake8_pytest_style/snapshots/ruff__rules__flake8_pytest_style__tests__PT018.snap index d9c9771a7d5d8f..1d24f27b3a6be9 100644 --- a/crates/ruff/src/rules/flake8_pytest_style/snapshots/ruff__rules__flake8_pytest_style__tests__PT018.snap +++ b/crates/ruff/src/rules/flake8_pytest_style/snapshots/ruff__rules__flake8_pytest_style__tests__PT018.snap @@ -148,7 +148,7 @@ PT018.py:20:5: PT018 [*] Assertion should be broken down into multiple parts 18 18 | assert not something and something_else 19 19 | assert not (something or something_else) 20 |- assert not (something or something_else or something_third) - 20 |+ assert not something or something_else + 20 |+ assert not (something or something_else) 21 |+ assert not something_third 21 22 | assert something and something_else == """error 22 23 | message @@ -351,4 +351,67 @@ PT018.py:54:9: PT018 Assertion should be broken down into multiple parts | = help: Break down assertion into multiple parts +PT018.py:59:5: PT018 [*] Assertion should be broken down into multiple parts + | +57 | # Regression test for: https://github.com/astral-sh/ruff/issues/7143 +58 | def test_parenthesized_not(): +59 | assert not ( + | _____^ +60 | | self.find_graph_output(node.output[0]) +61 | | or self.find_graph_input(node.input[0]) +62 | | or self.find_graph_output(node.input[0]) +63 | | ) + | |_____^ PT018 +64 | +65 | assert (not ( + | + = help: Break down assertion into multiple parts + +ℹ Suggested fix +59 59 | assert not ( +60 60 | self.find_graph_output(node.output[0]) +61 61 | or self.find_graph_input(node.input[0]) +62 |- or self.find_graph_output(node.input[0]) +63 62 | ) + 63 |+ assert not ( + 64 |+ self.find_graph_output(node.input[0]) + 65 |+ ) +64 66 | +65 67 | assert (not ( +66 68 | self.find_graph_output(node.output[0]) + +PT018.py:65:5: PT018 [*] Assertion should be broken down into multiple parts + | +63 | ) +64 | +65 | assert (not ( + | _____^ +66 | | self.find_graph_output(node.output[0]) +67 | | or self.find_graph_input(node.input[0]) +68 | | or self.find_graph_output(node.input[0]) +69 | | )) + | |______^ PT018 +70 | +71 | assert (not self.find_graph_output(node.output[0]) or + | + = help: Break down assertion into multiple parts + +ℹ Suggested fix +62 62 | or self.find_graph_output(node.input[0]) +63 63 | ) +64 64 | +65 |- assert (not ( + 65 |+ assert not ( +66 66 | self.find_graph_output(node.output[0]) +67 67 | or self.find_graph_input(node.input[0]) +68 |- or self.find_graph_output(node.input[0]) +69 |- )) + 68 |+ ) + 69 |+ assert not ( + 70 |+ self.find_graph_output(node.input[0]) + 71 |+ ) +70 72 | +71 73 | assert (not self.find_graph_output(node.output[0]) or +72 74 | self.find_graph_input(node.input[0])) +