From f1cee22274c7ad2c3d1cd7e8b17503a8c93e2959 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Fri, 29 Dec 2023 18:49:37 +0000 Subject: [PATCH 1/3] [flake8-pyi] Implement PYI058 --- .../test/fixtures/flake8_pyi/PYI058.py | 82 ++++++++ .../test/fixtures/flake8_pyi/PYI058.pyi | 58 ++++++ .../src/checkers/ast/analyze/statement.rs | 11 + crates/ruff_linter/src/codes.rs | 1 + .../ruff_linter/src/rules/flake8_pyi/mod.rs | 2 + .../rules/bad_generator_return_type.rs | 191 ++++++++++++++++++ .../src/rules/flake8_pyi/rules/mod.rs | 2 + ...__flake8_pyi__tests__PYI058_PYI058.py.snap | 57 ++++++ ..._flake8_pyi__tests__PYI058_PYI058.pyi.snap | 58 ++++++ ruff.schema.json | 1 + 10 files changed, 463 insertions(+) create mode 100644 crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.py create mode 100644 crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.pyi create mode 100644 crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs create mode 100644 crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.py.snap create mode 100644 crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.pyi.snap diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.py b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.py new file mode 100644 index 0000000000000..810a084c62d00 --- /dev/null +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.py @@ -0,0 +1,82 @@ +import collections.abc +import typing +from collections.abc import AsyncGenerator, Generator +from typing import Any + +class IteratorReturningSimpleGenerator1: + def __iter__(self) -> Generator: # PYI058 (use `Iterator`) + return (x for x in range(42)) + +class IteratorReturningSimpleGenerator2: + def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: # PYI058 (use `Iterator`) + """Fully documented, because I'm a runtime function!""" + yield from "abcdefg" + return None + +class IteratorReturningSimpleGenerator3: + def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: # PYI058 (use `Iterator`) + yield "a" + yield "b" + yield "c" + return + +class AsyncIteratorReturningSimpleAsyncGenerator1: + def __aiter__(self) -> typing.AsyncGenerator: pass # PYI058 (Use `AsyncIterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator2: + def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator3: + def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: pass # PYI058 (Use `AsyncIterator`) + +class CorrectIterator: + def __iter__(self) -> Iterator[str]: ... # OK + +class CorrectAsyncIterator: + def __aiter__(self) -> collections.abc.AsyncIterator[int]: ... # OK + +class Fine: + def __iter__(self) -> typing.Self: ... # OK + +class StrangeButWeWontComplainHere: + def __aiter__(self) -> list[bytes]: ... # OK + +def __iter__(self) -> Generator: ... # OK (not in class scope) +def __aiter__(self) -> AsyncGenerator: ... # OK (not in class scope) + +class IteratorReturningComplexGenerator: + def __iter__(self) -> Generator[str, int, bytes]: ... # OK + +class AsyncIteratorReturningComplexAsyncGenerator: + def __aiter__(self) -> AsyncGenerator[str, int]: ... # OK + +class ClassWithInvalidAsyncAiterMethod: + async def __aiter__(self) -> AsyncGenerator: ... # OK + +class IteratorWithUnusualParameters1: + def __iter__(self, foo) -> Generator: ... # OK + +class IteratorWithUnusualParameters2: + def __iter__(self, *, bar) -> Generator: ... # OK + +class IteratorWithUnusualParameters3: + def __iter__(self, *args) -> Generator: ... # OK + +class IteratorWithUnusualParameters4: + def __iter__(self, **kwargs) -> Generator: ... # OK + +class IteratorWithIterMethodThatReturnsThings: + def __iter__(self) -> Generator: # OK + yield + return 42 + +class IteratorWithIterMethodThatReceivesThingsFromSend: + def __iter__(self) -> Generator: # OK + x = yield 42 + +class IteratorWithNonTrivialIterBody: + def __iter__(self) -> Generator: # OK + foo, bar, baz = (1, 2, 3) + yield foo + yield bar + yield baz diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.pyi b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.pyi new file mode 100644 index 0000000000000..ce6e78a68b311 --- /dev/null +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.pyi @@ -0,0 +1,58 @@ +import collections.abc +import typing +from collections.abc import AsyncGenerator, Generator +from typing import Any + +class IteratorReturningSimpleGenerator1: + def __iter__(self) -> Generator: ... # PYI058 (use `Iterator`) + +class IteratorReturningSimpleGenerator2: + def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: ... # PYI058 (use `Iterator`) + +class IteratorReturningSimpleGenerator3: + def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: ... # PYI058 (use `Iterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator1: + def __aiter__(self) -> typing.AsyncGenerator: ... # PYI058 (Use `AsyncIterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator2: + def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator3: + def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: ... # PYI058 (Use `AsyncIterator`) + +class CorrectIterator: + def __iter__(self) -> Iterator[str]: ... # OK + +class CorrectAsyncIterator: + def __aiter__(self) -> collections.abc.AsyncIterator[int]: ... # OK + +class Fine: + def __iter__(self) -> typing.Self: ... # OK + +class StrangeButWeWontComplainHere: + def __aiter__(self) -> list[bytes]: ... # OK + +def __iter__(self) -> Generator: ... # OK (not in class scope) +def __aiter__(self) -> AsyncGenerator: ... # OK (not in class scope) + +class IteratorReturningComplexGenerator: + def __iter__(self) -> Generator[str, int, bytes]: ... # OK + +class AsyncIteratorReturningComplexAsyncGenerator: + def __aiter__(self) -> AsyncGenerator[str, int]: ... # OK + +class ClassWithInvalidAsyncAiterMethod: + async def __aiter__(self) -> AsyncGenerator: ... # OK + +class IteratorWithUnusualParameters1: + def __iter__(self, foo) -> Generator: ... # OK + +class IteratorWithUnusualParameters2: + def __iter__(self, *, bar) -> Generator: ... # OK + +class IteratorWithUnusualParameters3: + def __iter__(self, *args) -> Generator: ... # OK + +class IteratorWithUnusualParameters4: + def __iter__(self, **kwargs) -> Generator: ... # OK diff --git a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs index c803562bd5bf5..b344048fa727a 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs @@ -154,6 +154,17 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { parameters, ); } + if checker.enabled(Rule::GeneratorReturnFromIterMethod) { + flake8_pyi::rules::bad_generator_return_type( + checker, + stmt, + *is_async, + name, + returns.as_ref().map(AsRef::as_ref), + parameters, + body, + ); + } if checker.enabled(Rule::CustomTypeVarReturnType) { flake8_pyi::rules::custom_type_var_return_type( checker, diff --git a/crates/ruff_linter/src/codes.rs b/crates/ruff_linter/src/codes.rs index 8614dcda65ad0..0a2dbe3f2ae5c 100644 --- a/crates/ruff_linter/src/codes.rs +++ b/crates/ruff_linter/src/codes.rs @@ -747,6 +747,7 @@ pub fn code_to_rule(linter: Linter, code: &str) -> Option<(RuleGroup, Rule)> { (Flake8Pyi, "053") => (RuleGroup::Stable, rules::flake8_pyi::rules::StringOrBytesTooLong), (Flake8Pyi, "055") => (RuleGroup::Stable, rules::flake8_pyi::rules::UnnecessaryTypeUnion), (Flake8Pyi, "056") => (RuleGroup::Stable, rules::flake8_pyi::rules::UnsupportedMethodCallOnAll), + (Flake8Pyi, "058") => (RuleGroup::Preview, rules::flake8_pyi::rules::GeneratorReturnFromIterMethod), // flake8-pytest-style (Flake8PytestStyle, "001") => (RuleGroup::Stable, rules::flake8_pytest_style::rules::PytestFixtureIncorrectParenthesesStyle), diff --git a/crates/ruff_linter/src/rules/flake8_pyi/mod.rs b/crates/ruff_linter/src/rules/flake8_pyi/mod.rs index 2fa1fc356747c..79b5e05afe5d3 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/mod.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/mod.rs @@ -39,6 +39,8 @@ mod tests { #[test_case(Rule::EllipsisInNonEmptyClassBody, Path::new("PYI013.pyi"))] #[test_case(Rule::FutureAnnotationsInStub, Path::new("PYI044.py"))] #[test_case(Rule::FutureAnnotationsInStub, Path::new("PYI044.pyi"))] + #[test_case(Rule::GeneratorReturnFromIterMethod, Path::new("PYI058.py"))] + #[test_case(Rule::GeneratorReturnFromIterMethod, Path::new("PYI058.pyi"))] #[test_case(Rule::IterMethodReturnIterable, Path::new("PYI045.py"))] #[test_case(Rule::IterMethodReturnIterable, Path::new("PYI045.pyi"))] #[test_case(Rule::NoReturnArgumentAnnotationInStub, Path::new("PYI050.py"))] diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs new file mode 100644 index 0000000000000..d541e5d559851 --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs @@ -0,0 +1,191 @@ +use ruff_diagnostics::{Diagnostic, Violation}; +use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast as ast; +use ruff_python_ast::helpers::map_subscript; +use ruff_python_ast::identifier::Identifier; +use ruff_python_semantic::{ScopeKind, SemanticModel}; + +use crate::checkers::ast::Checker; + +/// ## What it does +/// Checks for simple `__iter__` methods that return `Generator`, and for +/// simple `__aiter__` methods that return `AsyncGenerator`. +/// +/// ## Why is this bad? +/// Using `(Async)Iterator` for these methods is simpler and more elegant. More +/// importantly, it also reflects the fact that the precise kind of iterator +/// returned from an `__iter__` method is usually an implementation detail that +/// could change at any time. Type annotations help define a contract for a +/// function; implementation details should not leak into that contract. +/// +/// For example: +/// ```python +/// from collections.abc import AsyncGenerator, Generator +/// from typing import Any +/// +/// +/// class CustomIterator: +/// def __iter__(self) -> Generator: +/// yield from range(42) +/// +/// +/// class CustomIterator2: +/// def __iter__(self) -> Generator[str, Any, None]: +/// yield from "abcdefg" +/// ``` +/// +/// Use instead: +/// ```python +/// from collections.abc import Iterator +/// +/// +/// class CustomIterator: +/// def __iter__(self) -> Iterator: +/// yield from range(42) +/// +/// +/// class CustomIterator2: +/// def __iter__(self) -> Iterator[str]: +/// yield from "abdefg" +/// ``` +#[violation] +pub struct GeneratorReturnFromIterMethod { + better_return_type: String, + method_name: String, +} + +impl Violation for GeneratorReturnFromIterMethod { + #[derive_message_formats] + fn message(&self) -> String { + let GeneratorReturnFromIterMethod { + better_return_type, + method_name, + } = self; + format!("Use `{better_return_type}` as the return value for simple `{method_name}` methods") + } +} + +/// PYI058 +pub(crate) fn bad_generator_return_type( + checker: &mut Checker, + stmt: &ast::Stmt, + is_async: bool, + name: &str, + returns: Option<&ast::Expr>, + parameters: &ast::Parameters, + body: &Vec, +) { + if is_async { + return; + } + + let better_return_type = match name { + "__iter__" => "Iterator", + "__aiter__" => "AsyncIterator", + _ => return, + }; + + let semantic = checker.semantic(); + + if !matches!(checker.semantic().current_scope().kind, ScopeKind::Class(_)) { + return; + } + + if !parameters.kwonlyargs.is_empty() + || parameters.kwarg.is_some() + || parameters.vararg.is_some() + { + return; + } + + if (parameters.args.len() + parameters.posonlyargs.len()) != 1 { + return; + } + + let Some(returns) = returns else { + return; + }; + + if !semantic + .resolve_call_path(map_subscript(returns)) + .is_some_and(|call_path| { + matches!( + (name, call_path.as_slice()), + ( + "__iter__", + ["typing" | "typing_extensions", "Generator"] + | ["collections", "abc", "Generator"] + ) | ( + "__aiter__", + ["typing" | "typing_extensions", "AsyncGenerator"] + | ["collections", "abc", "AsyncGenerator"] + ) + ) + }) + { + return; + }; + + // `Generator`` allows three type parameters; `AsyncGenerator`` allows two. + // If type parameters are present, + // Check that all parameters except the first one are either `typing.Any` or `None`; + // if not, don't emit the diagnostic + if let ast::Expr::Subscript(ast::ExprSubscript { slice, .. }) = returns { + let ast::Expr::Tuple(ast::ExprTuple { elts, .. }) = slice.as_ref() else { + return; + }; + if matches!( + (name, &elts[..]), + ("__iter__", [_, _, _]) | ("__aiter__", [_, _]) + ) { + if !&elts[1..].iter().all(|elt| is_any_or_none(semantic, elt)) { + return; + } + } else { + return; + } + }; + + // For .py files (runtime Python!), + // only emit the lint if it's a simple __(a)iter__ implementation + // -- for more complex function bodies, + // it's more likely we'll be emitting a false positive here + if !checker.source_type.is_stub() { + let mut yield_encountered = false; + for stmt in body { + match stmt { + ast::Stmt::Pass(_) => continue, + ast::Stmt::Return(ast::StmtReturn { value, .. }) => { + if let Some(ret_val) = value { + if yield_encountered + && !matches!(ret_val.as_ref(), ast::Expr::NoneLiteral(_)) + { + return; + } + } + } + ast::Stmt::Expr(ast::StmtExpr { value, .. }) => match value.as_ref() { + ast::Expr::StringLiteral(_) | ast::Expr::EllipsisLiteral(_) => continue, + ast::Expr::Yield(_) | ast::Expr::YieldFrom(_) => { + yield_encountered = true; + continue; + } + _ => return, + }, + _ => return, + } + } + }; + + checker.diagnostics.push(Diagnostic::new( + GeneratorReturnFromIterMethod { + better_return_type: better_return_type.to_string(), + method_name: name.to_string(), + }, + stmt.identifier(), + )); +} + +fn is_any_or_none(semantic: &SemanticModel, expr: &ast::Expr) -> bool { + semantic.match_typing_expr(expr, "Any") || matches!(expr, ast::Expr::NoneLiteral(_)) +} diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs index 3fae99e33d7b2..851b0660d113e 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs @@ -1,4 +1,5 @@ pub(crate) use any_eq_ne_annotation::*; +pub(crate) use bad_generator_return_type::*; pub(crate) use bad_version_info_comparison::*; pub(crate) use collections_named_tuple::*; pub(crate) use complex_assignment_in_stub::*; @@ -36,6 +37,7 @@ pub(crate) use unsupported_method_call_on_all::*; pub(crate) use unused_private_type_definition::*; mod any_eq_ne_annotation; +mod bad_generator_return_type; mod bad_version_info_comparison; mod collections_named_tuple; mod complex_assignment_in_stub; diff --git a/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.py.snap b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.py.snap new file mode 100644 index 0000000000000..41b5c92fca6da --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.py.snap @@ -0,0 +1,57 @@ +--- +source: crates/ruff_linter/src/rules/flake8_pyi/mod.rs +--- +PYI058.py:7:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +6 | class IteratorReturningSimpleGenerator1: +7 | def __iter__(self) -> Generator: # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +8 | return (x for x in range(42)) + | + +PYI058.py:11:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +10 | class IteratorReturningSimpleGenerator2: +11 | def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +12 | """Fully documented, because I'm a runtime function!""" +13 | yield from "abcdefg" + | + +PYI058.py:17:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +16 | class IteratorReturningSimpleGenerator3: +17 | def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +18 | yield "a" +19 | yield "b" + | + +PYI058.py:24:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +23 | class AsyncIteratorReturningSimpleAsyncGenerator1: +24 | def __aiter__(self) -> typing.AsyncGenerator: pass # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +25 | +26 | class AsyncIteratorReturningSimpleAsyncGenerator2: + | + +PYI058.py:27:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +26 | class AsyncIteratorReturningSimpleAsyncGenerator2: +27 | def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +28 | +29 | class AsyncIteratorReturningSimpleAsyncGenerator3: + | + +PYI058.py:30:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +29 | class AsyncIteratorReturningSimpleAsyncGenerator3: +30 | def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: pass # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +31 | +32 | class CorrectIterator: + | + + diff --git a/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.pyi.snap b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.pyi.snap new file mode 100644 index 0000000000000..444916e1af727 --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.pyi.snap @@ -0,0 +1,58 @@ +--- +source: crates/ruff_linter/src/rules/flake8_pyi/mod.rs +--- +PYI058.pyi:7:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +6 | class IteratorReturningSimpleGenerator1: +7 | def __iter__(self) -> Generator: ... # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +8 | +9 | class IteratorReturningSimpleGenerator2: + | + +PYI058.pyi:10:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | + 9 | class IteratorReturningSimpleGenerator2: +10 | def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: ... # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +11 | +12 | class IteratorReturningSimpleGenerator3: + | + +PYI058.pyi:13:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +12 | class IteratorReturningSimpleGenerator3: +13 | def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: ... # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +14 | +15 | class AsyncIteratorReturningSimpleAsyncGenerator1: + | + +PYI058.pyi:16:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +15 | class AsyncIteratorReturningSimpleAsyncGenerator1: +16 | def __aiter__(self) -> typing.AsyncGenerator: ... # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +17 | +18 | class AsyncIteratorReturningSimpleAsyncGenerator2: + | + +PYI058.pyi:19:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +18 | class AsyncIteratorReturningSimpleAsyncGenerator2: +19 | def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +20 | +21 | class AsyncIteratorReturningSimpleAsyncGenerator3: + | + +PYI058.pyi:22:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +21 | class AsyncIteratorReturningSimpleAsyncGenerator3: +22 | def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: ... # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +23 | +24 | class CorrectIterator: + | + + diff --git a/ruff.schema.json b/ruff.schema.json index fb7e2af47f587..e8a986c720daf 100644 --- a/ruff.schema.json +++ b/ruff.schema.json @@ -3386,6 +3386,7 @@ "PYI054", "PYI055", "PYI056", + "PYI058", "Q", "Q0", "Q00", From 77dc447c760eaee6f534c59edb431b87388cbcc4 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sun, 31 Dec 2023 23:21:18 +0000 Subject: [PATCH 2/3] Tackle most PR comments --- .../flake8_pyi/rules/bad_generator_return_type.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs index d541e5d559851..e53acb842019e 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs @@ -3,7 +3,7 @@ use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast as ast; use ruff_python_ast::helpers::map_subscript; use ruff_python_ast::identifier::Identifier; -use ruff_python_semantic::{ScopeKind, SemanticModel}; +use ruff_python_semantic::SemanticModel; use crate::checkers::ast::Checker; @@ -73,7 +73,7 @@ pub(crate) fn bad_generator_return_type( name: &str, returns: Option<&ast::Expr>, parameters: &ast::Parameters, - body: &Vec, + body: &[ast::Stmt], ) { if is_async { return; @@ -87,7 +87,7 @@ pub(crate) fn bad_generator_return_type( let semantic = checker.semantic(); - if !matches!(checker.semantic().current_scope().kind, ScopeKind::Class(_)) { + if !semantic.current_scope().kind.is_class() { return; } @@ -126,7 +126,7 @@ pub(crate) fn bad_generator_return_type( return; }; - // `Generator`` allows three type parameters; `AsyncGenerator`` allows two. + // `Generator` allows three type parameters; `AsyncGenerator` allows two. // If type parameters are present, // Check that all parameters except the first one are either `typing.Any` or `None`; // if not, don't emit the diagnostic @@ -138,7 +138,7 @@ pub(crate) fn bad_generator_return_type( (name, &elts[..]), ("__iter__", [_, _, _]) | ("__aiter__", [_, _]) ) { - if !&elts[1..].iter().all(|elt| is_any_or_none(semantic, elt)) { + if !&elts.iter().skip(1).all(|elt| is_any_or_none(elt, semantic)) { return; } } else { @@ -186,6 +186,6 @@ pub(crate) fn bad_generator_return_type( )); } -fn is_any_or_none(semantic: &SemanticModel, expr: &ast::Expr) -> bool { +fn is_any_or_none(expr: &ast::Expr, semantic: &SemanticModel) -> bool { semantic.match_typing_expr(expr, "Any") || matches!(expr, ast::Expr::NoneLiteral(_)) } From 2eb11e078ac81ac2780d71ef3987894a7fcd6376 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sun, 31 Dec 2023 23:51:47 +0000 Subject: [PATCH 3/3] Only pass in `function_def` --- .../src/checkers/ast/analyze/statement.rs | 10 +-------- .../rules/bad_generator_return_type.rs | 22 +++++++++---------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs index b344048fa727a..20ed78c8350f5 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs @@ -155,15 +155,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { ); } if checker.enabled(Rule::GeneratorReturnFromIterMethod) { - flake8_pyi::rules::bad_generator_return_type( - checker, - stmt, - *is_async, - name, - returns.as_ref().map(AsRef::as_ref), - parameters, - body, - ); + flake8_pyi::rules::bad_generator_return_type(function_def, checker); } if checker.enabled(Rule::CustomTypeVarReturnType) { flake8_pyi::rules::custom_type_var_return_type( diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs index e53acb842019e..5dc0de179227f 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs @@ -67,18 +67,15 @@ impl Violation for GeneratorReturnFromIterMethod { /// PYI058 pub(crate) fn bad_generator_return_type( + function_def: &ast::StmtFunctionDef, checker: &mut Checker, - stmt: &ast::Stmt, - is_async: bool, - name: &str, - returns: Option<&ast::Expr>, - parameters: &ast::Parameters, - body: &[ast::Stmt], ) { - if is_async { + if function_def.is_async { return; } + let name = function_def.name.as_str(); + let better_return_type = match name { "__iter__" => "Iterator", "__aiter__" => "AsyncIterator", @@ -91,6 +88,8 @@ pub(crate) fn bad_generator_return_type( return; } + let parameters = &function_def.parameters; + if !parameters.kwonlyargs.is_empty() || parameters.kwarg.is_some() || parameters.vararg.is_some() @@ -102,8 +101,9 @@ pub(crate) fn bad_generator_return_type( return; } - let Some(returns) = returns else { - return; + let returns = match &function_def.returns { + Some(returns) => returns.as_ref(), + _ => return, }; if !semantic @@ -152,7 +152,7 @@ pub(crate) fn bad_generator_return_type( // it's more likely we'll be emitting a false positive here if !checker.source_type.is_stub() { let mut yield_encountered = false; - for stmt in body { + for stmt in &function_def.body { match stmt { ast::Stmt::Pass(_) => continue, ast::Stmt::Return(ast::StmtReturn { value, .. }) => { @@ -182,7 +182,7 @@ pub(crate) fn bad_generator_return_type( better_return_type: better_return_type.to_string(), method_name: name.to_string(), }, - stmt.identifier(), + function_def.identifier(), )); }