diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.py b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.py new file mode 100644 index 0000000000000..810a084c62d00 --- /dev/null +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.py @@ -0,0 +1,82 @@ +import collections.abc +import typing +from collections.abc import AsyncGenerator, Generator +from typing import Any + +class IteratorReturningSimpleGenerator1: + def __iter__(self) -> Generator: # PYI058 (use `Iterator`) + return (x for x in range(42)) + +class IteratorReturningSimpleGenerator2: + def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: # PYI058 (use `Iterator`) + """Fully documented, because I'm a runtime function!""" + yield from "abcdefg" + return None + +class IteratorReturningSimpleGenerator3: + def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: # PYI058 (use `Iterator`) + yield "a" + yield "b" + yield "c" + return + +class AsyncIteratorReturningSimpleAsyncGenerator1: + def __aiter__(self) -> typing.AsyncGenerator: pass # PYI058 (Use `AsyncIterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator2: + def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator3: + def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: pass # PYI058 (Use `AsyncIterator`) + +class CorrectIterator: + def __iter__(self) -> Iterator[str]: ... # OK + +class CorrectAsyncIterator: + def __aiter__(self) -> collections.abc.AsyncIterator[int]: ... # OK + +class Fine: + def __iter__(self) -> typing.Self: ... # OK + +class StrangeButWeWontComplainHere: + def __aiter__(self) -> list[bytes]: ... # OK + +def __iter__(self) -> Generator: ... # OK (not in class scope) +def __aiter__(self) -> AsyncGenerator: ... # OK (not in class scope) + +class IteratorReturningComplexGenerator: + def __iter__(self) -> Generator[str, int, bytes]: ... # OK + +class AsyncIteratorReturningComplexAsyncGenerator: + def __aiter__(self) -> AsyncGenerator[str, int]: ... # OK + +class ClassWithInvalidAsyncAiterMethod: + async def __aiter__(self) -> AsyncGenerator: ... # OK + +class IteratorWithUnusualParameters1: + def __iter__(self, foo) -> Generator: ... # OK + +class IteratorWithUnusualParameters2: + def __iter__(self, *, bar) -> Generator: ... # OK + +class IteratorWithUnusualParameters3: + def __iter__(self, *args) -> Generator: ... # OK + +class IteratorWithUnusualParameters4: + def __iter__(self, **kwargs) -> Generator: ... # OK + +class IteratorWithIterMethodThatReturnsThings: + def __iter__(self) -> Generator: # OK + yield + return 42 + +class IteratorWithIterMethodThatReceivesThingsFromSend: + def __iter__(self) -> Generator: # OK + x = yield 42 + +class IteratorWithNonTrivialIterBody: + def __iter__(self) -> Generator: # OK + foo, bar, baz = (1, 2, 3) + yield foo + yield bar + yield baz diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.pyi b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.pyi new file mode 100644 index 0000000000000..ce6e78a68b311 --- /dev/null +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.pyi @@ -0,0 +1,58 @@ +import collections.abc +import typing +from collections.abc import AsyncGenerator, Generator +from typing import Any + +class IteratorReturningSimpleGenerator1: + def __iter__(self) -> Generator: ... # PYI058 (use `Iterator`) + +class IteratorReturningSimpleGenerator2: + def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: ... # PYI058 (use `Iterator`) + +class IteratorReturningSimpleGenerator3: + def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: ... # PYI058 (use `Iterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator1: + def __aiter__(self) -> typing.AsyncGenerator: ... # PYI058 (Use `AsyncIterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator2: + def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`) + +class AsyncIteratorReturningSimpleAsyncGenerator3: + def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: ... # PYI058 (Use `AsyncIterator`) + +class CorrectIterator: + def __iter__(self) -> Iterator[str]: ... # OK + +class CorrectAsyncIterator: + def __aiter__(self) -> collections.abc.AsyncIterator[int]: ... # OK + +class Fine: + def __iter__(self) -> typing.Self: ... # OK + +class StrangeButWeWontComplainHere: + def __aiter__(self) -> list[bytes]: ... # OK + +def __iter__(self) -> Generator: ... # OK (not in class scope) +def __aiter__(self) -> AsyncGenerator: ... # OK (not in class scope) + +class IteratorReturningComplexGenerator: + def __iter__(self) -> Generator[str, int, bytes]: ... # OK + +class AsyncIteratorReturningComplexAsyncGenerator: + def __aiter__(self) -> AsyncGenerator[str, int]: ... # OK + +class ClassWithInvalidAsyncAiterMethod: + async def __aiter__(self) -> AsyncGenerator: ... # OK + +class IteratorWithUnusualParameters1: + def __iter__(self, foo) -> Generator: ... # OK + +class IteratorWithUnusualParameters2: + def __iter__(self, *, bar) -> Generator: ... # OK + +class IteratorWithUnusualParameters3: + def __iter__(self, *args) -> Generator: ... # OK + +class IteratorWithUnusualParameters4: + def __iter__(self, **kwargs) -> Generator: ... # OK diff --git a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs index c803562bd5bf5..20ed78c8350f5 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs @@ -154,6 +154,9 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { parameters, ); } + if checker.enabled(Rule::GeneratorReturnFromIterMethod) { + flake8_pyi::rules::bad_generator_return_type(function_def, checker); + } if checker.enabled(Rule::CustomTypeVarReturnType) { flake8_pyi::rules::custom_type_var_return_type( checker, diff --git a/crates/ruff_linter/src/codes.rs b/crates/ruff_linter/src/codes.rs index 321b485e85e0d..75f4d29673462 100644 --- a/crates/ruff_linter/src/codes.rs +++ b/crates/ruff_linter/src/codes.rs @@ -748,6 +748,7 @@ pub fn code_to_rule(linter: Linter, code: &str) -> Option<(RuleGroup, Rule)> { (Flake8Pyi, "053") => (RuleGroup::Stable, rules::flake8_pyi::rules::StringOrBytesTooLong), (Flake8Pyi, "055") => (RuleGroup::Stable, rules::flake8_pyi::rules::UnnecessaryTypeUnion), (Flake8Pyi, "056") => (RuleGroup::Stable, rules::flake8_pyi::rules::UnsupportedMethodCallOnAll), + (Flake8Pyi, "058") => (RuleGroup::Preview, rules::flake8_pyi::rules::GeneratorReturnFromIterMethod), // flake8-pytest-style (Flake8PytestStyle, "001") => (RuleGroup::Stable, rules::flake8_pytest_style::rules::PytestFixtureIncorrectParenthesesStyle), diff --git a/crates/ruff_linter/src/rules/flake8_pyi/mod.rs b/crates/ruff_linter/src/rules/flake8_pyi/mod.rs index 2fa1fc356747c..79b5e05afe5d3 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/mod.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/mod.rs @@ -39,6 +39,8 @@ mod tests { #[test_case(Rule::EllipsisInNonEmptyClassBody, Path::new("PYI013.pyi"))] #[test_case(Rule::FutureAnnotationsInStub, Path::new("PYI044.py"))] #[test_case(Rule::FutureAnnotationsInStub, Path::new("PYI044.pyi"))] + #[test_case(Rule::GeneratorReturnFromIterMethod, Path::new("PYI058.py"))] + #[test_case(Rule::GeneratorReturnFromIterMethod, Path::new("PYI058.pyi"))] #[test_case(Rule::IterMethodReturnIterable, Path::new("PYI045.py"))] #[test_case(Rule::IterMethodReturnIterable, Path::new("PYI045.pyi"))] #[test_case(Rule::NoReturnArgumentAnnotationInStub, Path::new("PYI050.py"))] diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs new file mode 100644 index 0000000000000..5dc0de179227f --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/bad_generator_return_type.rs @@ -0,0 +1,191 @@ +use ruff_diagnostics::{Diagnostic, Violation}; +use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast as ast; +use ruff_python_ast::helpers::map_subscript; +use ruff_python_ast::identifier::Identifier; +use ruff_python_semantic::SemanticModel; + +use crate::checkers::ast::Checker; + +/// ## What it does +/// Checks for simple `__iter__` methods that return `Generator`, and for +/// simple `__aiter__` methods that return `AsyncGenerator`. +/// +/// ## Why is this bad? +/// Using `(Async)Iterator` for these methods is simpler and more elegant. More +/// importantly, it also reflects the fact that the precise kind of iterator +/// returned from an `__iter__` method is usually an implementation detail that +/// could change at any time. Type annotations help define a contract for a +/// function; implementation details should not leak into that contract. +/// +/// For example: +/// ```python +/// from collections.abc import AsyncGenerator, Generator +/// from typing import Any +/// +/// +/// class CustomIterator: +/// def __iter__(self) -> Generator: +/// yield from range(42) +/// +/// +/// class CustomIterator2: +/// def __iter__(self) -> Generator[str, Any, None]: +/// yield from "abcdefg" +/// ``` +/// +/// Use instead: +/// ```python +/// from collections.abc import Iterator +/// +/// +/// class CustomIterator: +/// def __iter__(self) -> Iterator: +/// yield from range(42) +/// +/// +/// class CustomIterator2: +/// def __iter__(self) -> Iterator[str]: +/// yield from "abdefg" +/// ``` +#[violation] +pub struct GeneratorReturnFromIterMethod { + better_return_type: String, + method_name: String, +} + +impl Violation for GeneratorReturnFromIterMethod { + #[derive_message_formats] + fn message(&self) -> String { + let GeneratorReturnFromIterMethod { + better_return_type, + method_name, + } = self; + format!("Use `{better_return_type}` as the return value for simple `{method_name}` methods") + } +} + +/// PYI058 +pub(crate) fn bad_generator_return_type( + function_def: &ast::StmtFunctionDef, + checker: &mut Checker, +) { + if function_def.is_async { + return; + } + + let name = function_def.name.as_str(); + + let better_return_type = match name { + "__iter__" => "Iterator", + "__aiter__" => "AsyncIterator", + _ => return, + }; + + let semantic = checker.semantic(); + + if !semantic.current_scope().kind.is_class() { + return; + } + + let parameters = &function_def.parameters; + + if !parameters.kwonlyargs.is_empty() + || parameters.kwarg.is_some() + || parameters.vararg.is_some() + { + return; + } + + if (parameters.args.len() + parameters.posonlyargs.len()) != 1 { + return; + } + + let returns = match &function_def.returns { + Some(returns) => returns.as_ref(), + _ => return, + }; + + if !semantic + .resolve_call_path(map_subscript(returns)) + .is_some_and(|call_path| { + matches!( + (name, call_path.as_slice()), + ( + "__iter__", + ["typing" | "typing_extensions", "Generator"] + | ["collections", "abc", "Generator"] + ) | ( + "__aiter__", + ["typing" | "typing_extensions", "AsyncGenerator"] + | ["collections", "abc", "AsyncGenerator"] + ) + ) + }) + { + return; + }; + + // `Generator` allows three type parameters; `AsyncGenerator` allows two. + // If type parameters are present, + // Check that all parameters except the first one are either `typing.Any` or `None`; + // if not, don't emit the diagnostic + if let ast::Expr::Subscript(ast::ExprSubscript { slice, .. }) = returns { + let ast::Expr::Tuple(ast::ExprTuple { elts, .. }) = slice.as_ref() else { + return; + }; + if matches!( + (name, &elts[..]), + ("__iter__", [_, _, _]) | ("__aiter__", [_, _]) + ) { + if !&elts.iter().skip(1).all(|elt| is_any_or_none(elt, semantic)) { + return; + } + } else { + return; + } + }; + + // For .py files (runtime Python!), + // only emit the lint if it's a simple __(a)iter__ implementation + // -- for more complex function bodies, + // it's more likely we'll be emitting a false positive here + if !checker.source_type.is_stub() { + let mut yield_encountered = false; + for stmt in &function_def.body { + match stmt { + ast::Stmt::Pass(_) => continue, + ast::Stmt::Return(ast::StmtReturn { value, .. }) => { + if let Some(ret_val) = value { + if yield_encountered + && !matches!(ret_val.as_ref(), ast::Expr::NoneLiteral(_)) + { + return; + } + } + } + ast::Stmt::Expr(ast::StmtExpr { value, .. }) => match value.as_ref() { + ast::Expr::StringLiteral(_) | ast::Expr::EllipsisLiteral(_) => continue, + ast::Expr::Yield(_) | ast::Expr::YieldFrom(_) => { + yield_encountered = true; + continue; + } + _ => return, + }, + _ => return, + } + } + }; + + checker.diagnostics.push(Diagnostic::new( + GeneratorReturnFromIterMethod { + better_return_type: better_return_type.to_string(), + method_name: name.to_string(), + }, + function_def.identifier(), + )); +} + +fn is_any_or_none(expr: &ast::Expr, semantic: &SemanticModel) -> bool { + semantic.match_typing_expr(expr, "Any") || matches!(expr, ast::Expr::NoneLiteral(_)) +} diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs index 3fae99e33d7b2..851b0660d113e 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs @@ -1,4 +1,5 @@ pub(crate) use any_eq_ne_annotation::*; +pub(crate) use bad_generator_return_type::*; pub(crate) use bad_version_info_comparison::*; pub(crate) use collections_named_tuple::*; pub(crate) use complex_assignment_in_stub::*; @@ -36,6 +37,7 @@ pub(crate) use unsupported_method_call_on_all::*; pub(crate) use unused_private_type_definition::*; mod any_eq_ne_annotation; +mod bad_generator_return_type; mod bad_version_info_comparison; mod collections_named_tuple; mod complex_assignment_in_stub; diff --git a/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.py.snap b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.py.snap new file mode 100644 index 0000000000000..41b5c92fca6da --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.py.snap @@ -0,0 +1,57 @@ +--- +source: crates/ruff_linter/src/rules/flake8_pyi/mod.rs +--- +PYI058.py:7:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +6 | class IteratorReturningSimpleGenerator1: +7 | def __iter__(self) -> Generator: # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +8 | return (x for x in range(42)) + | + +PYI058.py:11:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +10 | class IteratorReturningSimpleGenerator2: +11 | def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +12 | """Fully documented, because I'm a runtime function!""" +13 | yield from "abcdefg" + | + +PYI058.py:17:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +16 | class IteratorReturningSimpleGenerator3: +17 | def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +18 | yield "a" +19 | yield "b" + | + +PYI058.py:24:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +23 | class AsyncIteratorReturningSimpleAsyncGenerator1: +24 | def __aiter__(self) -> typing.AsyncGenerator: pass # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +25 | +26 | class AsyncIteratorReturningSimpleAsyncGenerator2: + | + +PYI058.py:27:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +26 | class AsyncIteratorReturningSimpleAsyncGenerator2: +27 | def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +28 | +29 | class AsyncIteratorReturningSimpleAsyncGenerator3: + | + +PYI058.py:30:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +29 | class AsyncIteratorReturningSimpleAsyncGenerator3: +30 | def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: pass # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +31 | +32 | class CorrectIterator: + | + + diff --git a/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.pyi.snap b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.pyi.snap new file mode 100644 index 0000000000000..444916e1af727 --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI058_PYI058.pyi.snap @@ -0,0 +1,58 @@ +--- +source: crates/ruff_linter/src/rules/flake8_pyi/mod.rs +--- +PYI058.pyi:7:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +6 | class IteratorReturningSimpleGenerator1: +7 | def __iter__(self) -> Generator: ... # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +8 | +9 | class IteratorReturningSimpleGenerator2: + | + +PYI058.pyi:10:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | + 9 | class IteratorReturningSimpleGenerator2: +10 | def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: ... # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +11 | +12 | class IteratorReturningSimpleGenerator3: + | + +PYI058.pyi:13:9: PYI058 Use `Iterator` as the return value for simple `__iter__` methods + | +12 | class IteratorReturningSimpleGenerator3: +13 | def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: ... # PYI058 (use `Iterator`) + | ^^^^^^^^ PYI058 +14 | +15 | class AsyncIteratorReturningSimpleAsyncGenerator1: + | + +PYI058.pyi:16:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +15 | class AsyncIteratorReturningSimpleAsyncGenerator1: +16 | def __aiter__(self) -> typing.AsyncGenerator: ... # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +17 | +18 | class AsyncIteratorReturningSimpleAsyncGenerator2: + | + +PYI058.pyi:19:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +18 | class AsyncIteratorReturningSimpleAsyncGenerator2: +19 | def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +20 | +21 | class AsyncIteratorReturningSimpleAsyncGenerator3: + | + +PYI058.pyi:22:9: PYI058 Use `AsyncIterator` as the return value for simple `__aiter__` methods + | +21 | class AsyncIteratorReturningSimpleAsyncGenerator3: +22 | def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: ... # PYI058 (Use `AsyncIterator`) + | ^^^^^^^^^ PYI058 +23 | +24 | class CorrectIterator: + | + + diff --git a/ruff.schema.json b/ruff.schema.json index fdf63ac6d7613..c50d2e69df672 100644 --- a/ruff.schema.json +++ b/ruff.schema.json @@ -3388,6 +3388,7 @@ "PYI054", "PYI055", "PYI056", + "PYI058", "Q", "Q0", "Q00",