From 2895e7d1266c661b36bbeeea31c293bdf5fdf0e7 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Fri, 29 Dec 2023 12:46:37 -0400 Subject: [PATCH] Respect mixed `return` and `raise` cases in return-type analysis (#9310) ## Summary Given: ```python from somewhere import get_cfg def lookup_cfg(cfg_description): cfg = get_cfg(cfg_description) if cfg is not None: return cfg raise AttributeError(f"No cfg found matching {cfg_description}") ``` We were analyzing the method from last-to-first statement. So we saw the `raise`, then assumed the method _always_ raised. In reality, though, it _might_ return. This PR improves the branch analysis to respect these mixed cases. Closes https://github.com/astral-sh/ruff/issues/9269. Closes https://github.com/astral-sh/ruff/issues/9304. --- .../flake8_annotations/auto_return_type.py | 35 +++ .../src/rules/flake8_annotations/helpers.rs | 7 +- ..._annotations__tests__auto_return_type.snap | 95 +++++++ ...tations__tests__auto_return_type_py38.snap | 113 +++++++++ crates/ruff_python_ast/src/helpers.rs | 200 --------------- .../ruff_python_semantic/src/analyze/mod.rs | 1 + .../src/analyze/terminal.rs | 234 ++++++++++++++++++ 7 files changed, 482 insertions(+), 203 deletions(-) create mode 100644 crates/ruff_python_semantic/src/analyze/terminal.rs diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py b/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py index 267a4ac373f39..cc6ca87fc189a 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py @@ -229,3 +229,38 @@ def overloaded(i: "str") -> "str": def overloaded(i): return i + + +def func(x: int): + if not x: + return 1 + raise ValueError + + +def func(x: int): + if not x: + return 1 + else: + return 2 + raise ValueError + + +def func(): + try: + raise ValueError + except: + return 2 + + +def func(): + try: + return 1 + except: + pass + + +def func(x: int): + for _ in range(3): + if x > 0: + return 1 + raise ValueError diff --git a/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs b/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs index 0ef1bcf262c0f..ad1c6d615bbb6 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs +++ b/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs @@ -3,10 +3,11 @@ use rustc_hash::FxHashSet; use ruff_diagnostics::Edit; use ruff_python_ast::helpers::{ - pep_604_union, typing_optional, typing_union, ReturnStatementVisitor, Terminal, + pep_604_union, typing_optional, typing_union, ReturnStatementVisitor, }; use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{self as ast, Expr, ExprContext}; +use ruff_python_semantic::analyze::terminal::Terminal; use ruff_python_semantic::analyze::type_inference::{NumberLike, PythonType, ResolvedPythonType}; use ruff_python_semantic::analyze::visibility; use ruff_python_semantic::{Definition, SemanticModel}; @@ -61,7 +62,7 @@ pub(crate) fn auto_return_type(function: &ast::StmtFunctionDef) -> Option Option 0: // return 1 // ``` - if terminal.is_none() { + if terminal == Terminal::ConditionalReturn || terminal == Terminal::None { return_type = return_type.union(ResolvedPythonType::Atom(PythonType::None)); } diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap index d374776949219..e2574509aa676 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap @@ -579,4 +579,99 @@ auto_return_type.py:210:5: ANN201 [*] Missing return type annotation for public 212 212 | raise ValueError 213 213 | else: +auto_return_type.py:234:5: ANN201 [*] Missing return type annotation for public function `func` + | +234 | def func(x: int): + | ^^^^ ANN201 +235 | if not x: +236 | return 1 + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +231 231 | return i +232 232 | +233 233 | +234 |-def func(x: int): + 234 |+def func(x: int) -> int: +235 235 | if not x: +236 236 | return 1 +237 237 | raise ValueError + +auto_return_type.py:240:5: ANN201 [*] Missing return type annotation for public function `func` + | +240 | def func(x: int): + | ^^^^ ANN201 +241 | if not x: +242 | return 1 + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +237 237 | raise ValueError +238 238 | +239 239 | +240 |-def func(x: int): + 240 |+def func(x: int) -> int: +241 241 | if not x: +242 242 | return 1 +243 243 | else: + +auto_return_type.py:248:5: ANN201 [*] Missing return type annotation for public function `func` + | +248 | def func(): + | ^^^^ ANN201 +249 | try: +250 | raise ValueError + | + = help: Add return type annotation: `int | None` + +ℹ Unsafe fix +245 245 | raise ValueError +246 246 | +247 247 | +248 |-def func(): + 248 |+def func() -> int | None: +249 249 | try: +250 250 | raise ValueError +251 251 | except: + +auto_return_type.py:255:5: ANN201 [*] Missing return type annotation for public function `func` + | +255 | def func(): + | ^^^^ ANN201 +256 | try: +257 | return 1 + | + = help: Add return type annotation: `int | None` + +ℹ Unsafe fix +252 252 | return 2 +253 253 | +254 254 | +255 |-def func(): + 255 |+def func() -> int | None: +256 256 | try: +257 257 | return 1 +258 258 | except: + +auto_return_type.py:262:5: ANN201 [*] Missing return type annotation for public function `func` + | +262 | def func(x: int): + | ^^^^ ANN201 +263 | for _ in range(3): +264 | if x > 0: + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +259 259 | pass +260 260 | +261 261 | +262 |-def func(x: int): + 262 |+def func(x: int) -> int: +263 263 | for _ in range(3): +264 264 | if x > 0: +265 265 | return 1 + diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type_py38.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type_py38.snap index a2fb6448f7cdc..4dfd66c8d4b37 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type_py38.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type_py38.snap @@ -642,4 +642,117 @@ auto_return_type.py:210:5: ANN201 [*] Missing return type annotation for public 212 212 | raise ValueError 213 213 | else: +auto_return_type.py:234:5: ANN201 [*] Missing return type annotation for public function `func` + | +234 | def func(x: int): + | ^^^^ ANN201 +235 | if not x: +236 | return 1 + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +231 231 | return i +232 232 | +233 233 | +234 |-def func(x: int): + 234 |+def func(x: int) -> int: +235 235 | if not x: +236 236 | return 1 +237 237 | raise ValueError + +auto_return_type.py:240:5: ANN201 [*] Missing return type annotation for public function `func` + | +240 | def func(x: int): + | ^^^^ ANN201 +241 | if not x: +242 | return 1 + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +237 237 | raise ValueError +238 238 | +239 239 | +240 |-def func(x: int): + 240 |+def func(x: int) -> int: +241 241 | if not x: +242 242 | return 1 +243 243 | else: + +auto_return_type.py:248:5: ANN201 [*] Missing return type annotation for public function `func` + | +248 | def func(): + | ^^^^ ANN201 +249 | try: +250 | raise ValueError + | + = help: Add return type annotation: `Optional[int]` + +ℹ Unsafe fix +214 214 | return 1 +215 215 | +216 216 | +217 |-from typing import overload + 217 |+from typing import overload, Optional +218 218 | +219 219 | +220 220 | @overload +-------------------------------------------------------------------------------- +245 245 | raise ValueError +246 246 | +247 247 | +248 |-def func(): + 248 |+def func() -> Optional[int]: +249 249 | try: +250 250 | raise ValueError +251 251 | except: + +auto_return_type.py:255:5: ANN201 [*] Missing return type annotation for public function `func` + | +255 | def func(): + | ^^^^ ANN201 +256 | try: +257 | return 1 + | + = help: Add return type annotation: `Optional[int]` + +ℹ Unsafe fix +214 214 | return 1 +215 215 | +216 216 | +217 |-from typing import overload + 217 |+from typing import overload, Optional +218 218 | +219 219 | +220 220 | @overload +-------------------------------------------------------------------------------- +252 252 | return 2 +253 253 | +254 254 | +255 |-def func(): + 255 |+def func() -> Optional[int]: +256 256 | try: +257 257 | return 1 +258 258 | except: + +auto_return_type.py:262:5: ANN201 [*] Missing return type annotation for public function `func` + | +262 | def func(x: int): + | ^^^^ ANN201 +263 | for _ in range(3): +264 | if x > 0: + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +259 259 | pass +260 260 | +261 261 | +262 |-def func(x: int): + 262 |+def func(x: int) -> int: +263 263 | for _ in range(3): +264 264 | if x > 0: +265 265 | return 1 + diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index 94cbd009c594b..4efe25e3469a2 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -921,206 +921,6 @@ where } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Terminal { - /// Every path through the function ends with a `raise` statement. - Raise, - /// Every path through the function ends with a `return` (or `raise`) statement. - Return, -} - -impl Terminal { - /// Returns the [`Terminal`] behavior of the function, if it can be determined, or `None` if the - /// function contains at least one control flow path that does not end with a `return` or `raise` - /// statement. - pub fn from_function(function: &ast::StmtFunctionDef) -> Option { - /// Returns `true` if the body may break via a `break` statement. - fn sometimes_breaks(stmts: &[Stmt]) -> bool { - for stmt in stmts { - match stmt { - Stmt::For(ast::StmtFor { body, orelse, .. }) => { - if returns(body).is_some() { - return false; - } - if sometimes_breaks(orelse) { - return true; - } - } - Stmt::While(ast::StmtWhile { body, orelse, .. }) => { - if returns(body).is_some() { - return false; - } - if sometimes_breaks(orelse) { - return true; - } - } - Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - if std::iter::once(body) - .chain(elif_else_clauses.iter().map(|clause| &clause.body)) - .any(|body| sometimes_breaks(body)) - { - return true; - } - } - Stmt::Match(ast::StmtMatch { cases, .. }) => { - if cases.iter().any(|case| sometimes_breaks(&case.body)) { - return true; - } - } - Stmt::Try(ast::StmtTry { - body, - handlers, - orelse, - finalbody, - .. - }) => { - if sometimes_breaks(body) - || handlers.iter().any(|handler| { - let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { - body, - .. - }) = handler; - sometimes_breaks(body) - }) - || sometimes_breaks(orelse) - || sometimes_breaks(finalbody) - { - return true; - } - } - Stmt::With(ast::StmtWith { body, .. }) => { - if sometimes_breaks(body) { - return true; - } - } - Stmt::Break(_) => return true, - Stmt::Return(_) => return false, - Stmt::Raise(_) => return false, - _ => {} - } - } - false - } - - /// Returns `true` if the body may break via a `break` statement. - fn always_breaks(stmts: &[Stmt]) -> bool { - for stmt in stmts { - match stmt { - Stmt::Break(_) => return true, - Stmt::Return(_) => return false, - Stmt::Raise(_) => return false, - _ => {} - } - } - false - } - - /// Returns `true` if the body contains a branch that ends without an explicit `return` or - /// `raise` statement. - fn returns(stmts: &[Stmt]) -> Option { - for stmt in stmts.iter().rev() { - match stmt { - Stmt::For(ast::StmtFor { body, orelse, .. }) - | Stmt::While(ast::StmtWhile { body, orelse, .. }) => { - if always_breaks(body) { - return None; - } - if let Some(terminal) = returns(body) { - return Some(terminal); - } - if !sometimes_breaks(body) { - if let Some(terminal) = returns(orelse) { - return Some(terminal); - } - } - } - Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - if elif_else_clauses.iter().any(|clause| clause.test.is_none()) { - match Terminal::combine(std::iter::once(returns(body)).chain( - elif_else_clauses.iter().map(|clause| returns(&clause.body)), - )) { - Some(Terminal::Raise) => return Some(Terminal::Raise), - Some(Terminal::Return) => return Some(Terminal::Return), - _ => {} - } - } - } - Stmt::Match(ast::StmtMatch { cases, .. }) => { - // Note: we assume the `match` is exhaustive. - match Terminal::combine(cases.iter().map(|case| returns(&case.body))) { - Some(Terminal::Raise) => return Some(Terminal::Raise), - Some(Terminal::Return) => return Some(Terminal::Return), - _ => {} - } - } - Stmt::Try(ast::StmtTry { - body, - handlers, - orelse, - finalbody, - .. - }) => { - // If the `finally` block returns, the `try` block must also return. - if let Some(terminal) = returns(finalbody) { - return Some(terminal); - } - - // If the body returns, the `try` block must also return. - if returns(body) == Some(Terminal::Return) { - return Some(Terminal::Return); - } - - // If the else block and all the handlers return, the `try` block must also - // return. - if let Some(terminal) = - Terminal::combine(std::iter::once(returns(orelse)).chain( - handlers.iter().map(|handler| { - let ExceptHandler::ExceptHandler( - ast::ExceptHandlerExceptHandler { body, .. }, - ) = handler; - returns(body) - }), - )) - { - return Some(terminal); - } - } - Stmt::With(ast::StmtWith { body, .. }) => { - if let Some(terminal) = returns(body) { - return Some(terminal); - } - } - Stmt::Return(_) => return Some(Terminal::Return), - Stmt::Raise(_) => return Some(Terminal::Raise), - _ => {} - } - } - None - } - - returns(&function.body) - } - - /// Combine a series of [`Terminal`] operators. - fn combine(iter: impl Iterator>) -> Option { - iter.reduce(|acc, terminal| match (acc, terminal) { - (Some(Self::Raise), Some(Self::Raise)) => Some(Self::Raise), - (Some(_), Some(Self::Return)) => Some(Self::Return), - (Some(Self::Return), Some(_)) => Some(Self::Return), - _ => None, - }) - .flatten() - } -} - /// A [`StatementVisitor`] that collects all `raise` statements in a function or method. #[derive(Default)] pub struct RaiseStatementVisitor<'a> { diff --git a/crates/ruff_python_semantic/src/analyze/mod.rs b/crates/ruff_python_semantic/src/analyze/mod.rs index 0376f63c39f43..832da1b481755 100644 --- a/crates/ruff_python_semantic/src/analyze/mod.rs +++ b/crates/ruff_python_semantic/src/analyze/mod.rs @@ -2,6 +2,7 @@ pub mod class; pub mod function_type; pub mod imports; pub mod logging; +pub mod terminal; pub mod type_inference; pub mod typing; pub mod visibility; diff --git a/crates/ruff_python_semantic/src/analyze/terminal.rs b/crates/ruff_python_semantic/src/analyze/terminal.rs new file mode 100644 index 0000000000000..f5642876d2208 --- /dev/null +++ b/crates/ruff_python_semantic/src/analyze/terminal.rs @@ -0,0 +1,234 @@ +use ruff_python_ast::{self as ast, ExceptHandler, Stmt}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Terminal { + /// There is no known terminal (e.g., an implicit return). + None, + /// Every path through the function ends with a `raise` statement. + Raise, + /// No path through the function ends with a `return` statement. + Return, + /// Every path through the function ends with a `return` or `raise` statement. + Explicit, + /// At least one path through the function ends with a `return` statement. + ConditionalReturn, +} + +impl Terminal { + /// Returns the [`Terminal`] behavior of the function, if it can be determined. + pub fn from_function(function: &ast::StmtFunctionDef) -> Terminal { + Self::from_body(&function.body) + } + + /// Returns the [`Terminal`] behavior of the body, if it can be determined. + fn from_body(stmts: &[Stmt]) -> Terminal { + let mut terminal = Terminal::None; + + for stmt in stmts { + match stmt { + Stmt::For(ast::StmtFor { body, orelse, .. }) + | Stmt::While(ast::StmtWhile { body, orelse, .. }) => { + if always_breaks(body) { + continue; + } + + terminal = terminal.union(Self::from_body(body)); + + if !sometimes_breaks(body) { + terminal = terminal.union(Self::from_body(orelse)); + } + } + Stmt::If(ast::StmtIf { + body, + elif_else_clauses, + .. + }) => { + let branch_terminal = Terminal::combine( + std::iter::once(Self::from_body(body)).chain( + elif_else_clauses + .iter() + .map(|clause| Self::from_body(&clause.body)), + ), + ); + + // If the `if` statement is known to be exhaustive (by way of including an + // `else`)... + if elif_else_clauses.iter().any(|clause| clause.test.is_none()) { + // And all branches return, then the `if` statement returns. + terminal = terminal.union(branch_terminal); + } else if branch_terminal.has_return() { + // Otherwise, if any branch returns, we know this can't be a + // non-returning function. + terminal = terminal.union(Terminal::ConditionalReturn); + } + } + Stmt::Match(ast::StmtMatch { cases, .. }) => { + // Note: we assume the `match` is exhaustive. + terminal = terminal.union(Terminal::combine( + cases.iter().map(|case| Self::from_body(&case.body)), + )); + } + Stmt::Try(ast::StmtTry { + handlers, + orelse, + finalbody, + .. + }) => { + // If the `finally` block returns, the `try` block must also return. + terminal = terminal.union(Self::from_body(finalbody)); + + // If the else block and all the handlers return, the `try` block must also + // return. + let branch_terminal = + Terminal::combine(std::iter::once(Self::from_body(orelse)).chain( + handlers.iter().map(|handler| { + let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { + body, + .. + }) = handler; + Self::from_body(body) + }), + )); + + if orelse.is_empty() { + // If there's no `else`, we may fall through. + if branch_terminal.has_return() { + terminal = terminal.union(Terminal::ConditionalReturn); + } + } else { + // If there's an `else`, we may not fall through. + terminal = terminal.union(branch_terminal); + } + } + Stmt::With(ast::StmtWith { body, .. }) => { + terminal = terminal.union(Self::from_body(body)); + } + Stmt::Return(_) => { + terminal = terminal.union(Terminal::Explicit); + } + Stmt::Raise(_) => { + terminal = terminal.union(Terminal::Raise); + } + _ => {} + } + } + terminal + } + + /// Returns `true` if the [`Terminal`] behavior includes at least one `return` path. + fn has_return(self) -> bool { + matches!( + self, + Self::Return | Self::Explicit | Self::ConditionalReturn + ) + } + + /// Combine two [`Terminal`] operators. + fn union(self, other: Self) -> Self { + match (self, other) { + (Self::None, other) => other, + (other, Self::None) => other, + (Self::Explicit, _) => Self::Explicit, + (_, Self::Explicit) => Self::Explicit, + (Self::ConditionalReturn, Self::ConditionalReturn) => Self::ConditionalReturn, + (Self::Raise, Self::ConditionalReturn) => Self::Explicit, + (Self::ConditionalReturn, Self::Raise) => Self::Explicit, + (Self::Return, Self::ConditionalReturn) => Self::Return, + (Self::ConditionalReturn, Self::Return) => Self::Return, + (Self::Raise, Self::Raise) => Self::Raise, + (Self::Return, Self::Return) => Self::Return, + (Self::Raise, Self::Return) => Self::Explicit, + (Self::Return, Self::Raise) => Self::Explicit, + } + } + + /// Combine a series of [`Terminal`] operators. + fn combine(iter: impl Iterator) -> Terminal { + iter.fold(Terminal::None, Self::union) + } +} + +/// Returns `true` if the body may break via a `break` statement. +fn sometimes_breaks(stmts: &[Stmt]) -> bool { + for stmt in stmts { + match stmt { + Stmt::For(ast::StmtFor { body, orelse, .. }) => { + if Terminal::from_body(body).has_return() { + return false; + } + if sometimes_breaks(orelse) { + return true; + } + } + Stmt::While(ast::StmtWhile { body, orelse, .. }) => { + if Terminal::from_body(body).has_return() { + return false; + } + if sometimes_breaks(orelse) { + return true; + } + } + Stmt::If(ast::StmtIf { + body, + elif_else_clauses, + .. + }) => { + if std::iter::once(body) + .chain(elif_else_clauses.iter().map(|clause| &clause.body)) + .any(|body| sometimes_breaks(body)) + { + return true; + } + } + Stmt::Match(ast::StmtMatch { cases, .. }) => { + if cases.iter().any(|case| sometimes_breaks(&case.body)) { + return true; + } + } + Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + if sometimes_breaks(body) + || handlers.iter().any(|handler| { + let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { + body, + .. + }) = handler; + sometimes_breaks(body) + }) + || sometimes_breaks(orelse) + || sometimes_breaks(finalbody) + { + return true; + } + } + Stmt::With(ast::StmtWith { body, .. }) => { + if sometimes_breaks(body) { + return true; + } + } + Stmt::Break(_) => return true, + Stmt::Return(_) => return false, + Stmt::Raise(_) => return false, + _ => {} + } + } + false +} + +/// Returns `true` if the body may break via a `break` statement. +fn always_breaks(stmts: &[Stmt]) -> bool { + for stmt in stmts { + match stmt { + Stmt::Break(_) => return true, + Stmt::Return(_) => return false, + Stmt::Raise(_) => return false, + _ => {} + } + } + false +}