Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flake8-pyi] Implement PYI058 #9313

Merged
merged 4 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import collections.abc
import typing
from collections.abc import AsyncGenerator, Generator
from typing import Any

class IteratorReturningSimpleGenerator1:
def __iter__(self) -> Generator: # PYI058 (use `Iterator`)
return (x for x in range(42))

class IteratorReturningSimpleGenerator2:
def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: # PYI058 (use `Iterator`)
"""Fully documented, because I'm a runtime function!"""
yield from "abcdefg"
return None

class IteratorReturningSimpleGenerator3:
def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: # PYI058 (use `Iterator`)
yield "a"
yield "b"
yield "c"
return

class AsyncIteratorReturningSimpleAsyncGenerator1:
def __aiter__(self) -> typing.AsyncGenerator: pass # PYI058 (Use `AsyncIterator`)

class AsyncIteratorReturningSimpleAsyncGenerator2:
def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`)

class AsyncIteratorReturningSimpleAsyncGenerator3:
def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: pass # PYI058 (Use `AsyncIterator`)

class CorrectIterator:
def __iter__(self) -> Iterator[str]: ... # OK

class CorrectAsyncIterator:
def __aiter__(self) -> collections.abc.AsyncIterator[int]: ... # OK

class Fine:
def __iter__(self) -> typing.Self: ... # OK

class StrangeButWeWontComplainHere:
def __aiter__(self) -> list[bytes]: ... # OK

def __iter__(self) -> Generator: ... # OK (not in class scope)
def __aiter__(self) -> AsyncGenerator: ... # OK (not in class scope)

class IteratorReturningComplexGenerator:
def __iter__(self) -> Generator[str, int, bytes]: ... # OK

class AsyncIteratorReturningComplexAsyncGenerator:
def __aiter__(self) -> AsyncGenerator[str, int]: ... # OK

class ClassWithInvalidAsyncAiterMethod:
async def __aiter__(self) -> AsyncGenerator: ... # OK

class IteratorWithUnusualParameters1:
def __iter__(self, foo) -> Generator: ... # OK

class IteratorWithUnusualParameters2:
def __iter__(self, *, bar) -> Generator: ... # OK

class IteratorWithUnusualParameters3:
def __iter__(self, *args) -> Generator: ... # OK

class IteratorWithUnusualParameters4:
def __iter__(self, **kwargs) -> Generator: ... # OK

class IteratorWithIterMethodThatReturnsThings:
def __iter__(self) -> Generator: # OK
yield
return 42

class IteratorWithIterMethodThatReceivesThingsFromSend:
def __iter__(self) -> Generator: # OK
x = yield 42

class IteratorWithNonTrivialIterBody:
def __iter__(self) -> Generator: # OK
foo, bar, baz = (1, 2, 3)
yield foo
yield bar
yield baz
58 changes: 58 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI058.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import collections.abc
import typing
from collections.abc import AsyncGenerator, Generator
from typing import Any

class IteratorReturningSimpleGenerator1:
def __iter__(self) -> Generator: ... # PYI058 (use `Iterator`)

class IteratorReturningSimpleGenerator2:
def __iter__(self, /) -> collections.abc.Generator[str, Any, None]: ... # PYI058 (use `Iterator`)

class IteratorReturningSimpleGenerator3:
def __iter__(self, /) -> collections.abc.Generator[str, None, typing.Any]: ... # PYI058 (use `Iterator`)

class AsyncIteratorReturningSimpleAsyncGenerator1:
def __aiter__(self) -> typing.AsyncGenerator: ... # PYI058 (Use `AsyncIterator`)

class AsyncIteratorReturningSimpleAsyncGenerator2:
def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, Any]: ... # PYI058 (Use `AsyncIterator`)

class AsyncIteratorReturningSimpleAsyncGenerator3:
def __aiter__(self, /) -> collections.abc.AsyncGenerator[str, None]: ... # PYI058 (Use `AsyncIterator`)

class CorrectIterator:
def __iter__(self) -> Iterator[str]: ... # OK

class CorrectAsyncIterator:
def __aiter__(self) -> collections.abc.AsyncIterator[int]: ... # OK

class Fine:
def __iter__(self) -> typing.Self: ... # OK

class StrangeButWeWontComplainHere:
def __aiter__(self) -> list[bytes]: ... # OK

def __iter__(self) -> Generator: ... # OK (not in class scope)
def __aiter__(self) -> AsyncGenerator: ... # OK (not in class scope)

class IteratorReturningComplexGenerator:
def __iter__(self) -> Generator[str, int, bytes]: ... # OK

class AsyncIteratorReturningComplexAsyncGenerator:
def __aiter__(self) -> AsyncGenerator[str, int]: ... # OK

class ClassWithInvalidAsyncAiterMethod:
async def __aiter__(self) -> AsyncGenerator: ... # OK

class IteratorWithUnusualParameters1:
def __iter__(self, foo) -> Generator: ... # OK

class IteratorWithUnusualParameters2:
def __iter__(self, *, bar) -> Generator: ... # OK

class IteratorWithUnusualParameters3:
def __iter__(self, *args) -> Generator: ... # OK

class IteratorWithUnusualParameters4:
def __iter__(self, **kwargs) -> Generator: ... # OK
3 changes: 3 additions & 0 deletions crates/ruff_linter/src/checkers/ast/analyze/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
parameters,
);
}
if checker.enabled(Rule::GeneratorReturnFromIterMethod) {
flake8_pyi::rules::bad_generator_return_type(function_def, checker);
}
if checker.enabled(Rule::CustomTypeVarReturnType) {
flake8_pyi::rules::custom_type_var_return_type(
checker,
Expand Down
1 change: 1 addition & 0 deletions crates/ruff_linter/src/codes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ pub fn code_to_rule(linter: Linter, code: &str) -> Option<(RuleGroup, Rule)> {
(Flake8Pyi, "053") => (RuleGroup::Stable, rules::flake8_pyi::rules::StringOrBytesTooLong),
(Flake8Pyi, "055") => (RuleGroup::Stable, rules::flake8_pyi::rules::UnnecessaryTypeUnion),
(Flake8Pyi, "056") => (RuleGroup::Stable, rules::flake8_pyi::rules::UnsupportedMethodCallOnAll),
(Flake8Pyi, "058") => (RuleGroup::Preview, rules::flake8_pyi::rules::GeneratorReturnFromIterMethod),

// flake8-pytest-style
(Flake8PytestStyle, "001") => (RuleGroup::Stable, rules::flake8_pytest_style::rules::PytestFixtureIncorrectParenthesesStyle),
Expand Down
2 changes: 2 additions & 0 deletions crates/ruff_linter/src/rules/flake8_pyi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ mod tests {
#[test_case(Rule::EllipsisInNonEmptyClassBody, Path::new("PYI013.pyi"))]
#[test_case(Rule::FutureAnnotationsInStub, Path::new("PYI044.py"))]
#[test_case(Rule::FutureAnnotationsInStub, Path::new("PYI044.pyi"))]
#[test_case(Rule::GeneratorReturnFromIterMethod, Path::new("PYI058.py"))]
#[test_case(Rule::GeneratorReturnFromIterMethod, Path::new("PYI058.pyi"))]
#[test_case(Rule::IterMethodReturnIterable, Path::new("PYI045.py"))]
#[test_case(Rule::IterMethodReturnIterable, Path::new("PYI045.pyi"))]
#[test_case(Rule::NoReturnArgumentAnnotationInStub, Path::new("PYI050.py"))]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast as ast;
use ruff_python_ast::helpers::map_subscript;
use ruff_python_ast::identifier::Identifier;
use ruff_python_semantic::SemanticModel;

use crate::checkers::ast::Checker;

/// ## What it does
/// Checks for simple `__iter__` methods that return `Generator`, and for
/// simple `__aiter__` methods that return `AsyncGenerator`.
///
/// ## Why is this bad?
/// Using `(Async)Iterator` for these methods is simpler and more elegant. More
/// importantly, it also reflects the fact that the precise kind of iterator
/// returned from an `__iter__` method is usually an implementation detail that
/// could change at any time. Type annotations help define a contract for a
/// function; implementation details should not leak into that contract.
///
/// For example:
/// ```python
/// from collections.abc import AsyncGenerator, Generator
/// from typing import Any
///
///
/// class CustomIterator:
/// def __iter__(self) -> Generator:
/// yield from range(42)
///
///
/// class CustomIterator2:
/// def __iter__(self) -> Generator[str, Any, None]:
/// yield from "abcdefg"
/// ```
///
/// Use instead:
/// ```python
/// from collections.abc import Iterator
///
///
/// class CustomIterator:
/// def __iter__(self) -> Iterator:
/// yield from range(42)
///
///
/// class CustomIterator2:
/// def __iter__(self) -> Iterator[str]:
/// yield from "abdefg"
/// ```
#[violation]
pub struct GeneratorReturnFromIterMethod {
better_return_type: String,
method_name: String,
}

impl Violation for GeneratorReturnFromIterMethod {
#[derive_message_formats]
fn message(&self) -> String {
let GeneratorReturnFromIterMethod {
better_return_type,
method_name,
} = self;
format!("Use `{better_return_type}` as the return value for simple `{method_name}` methods")
}
}

/// PYI058
pub(crate) fn bad_generator_return_type(
function_def: &ast::StmtFunctionDef,
checker: &mut Checker,
) {
if function_def.is_async {
return;
}

let name = function_def.name.as_str();

let better_return_type = match name {
"__iter__" => "Iterator",
"__aiter__" => "AsyncIterator",
_ => return,
};

let semantic = checker.semantic();

if !semantic.current_scope().kind.is_class() {
return;
}

let parameters = &function_def.parameters;

if !parameters.kwonlyargs.is_empty()
|| parameters.kwarg.is_some()
|| parameters.vararg.is_some()
{
return;
}

if (parameters.args.len() + parameters.posonlyargs.len()) != 1 {
return;
}

let returns = match &function_def.returns {
Some(returns) => returns.as_ref(),
_ => return,
};

if !semantic
.resolve_call_path(map_subscript(returns))
.is_some_and(|call_path| {
matches!(
(name, call_path.as_slice()),
(
"__iter__",
["typing" | "typing_extensions", "Generator"]
| ["collections", "abc", "Generator"]
) | (
"__aiter__",
["typing" | "typing_extensions", "AsyncGenerator"]
| ["collections", "abc", "AsyncGenerator"]
)
)
})
{
return;
};

// `Generator` allows three type parameters; `AsyncGenerator` allows two.
// If type parameters are present,
// Check that all parameters except the first one are either `typing.Any` or `None`;
// if not, don't emit the diagnostic
if let ast::Expr::Subscript(ast::ExprSubscript { slice, .. }) = returns {
let ast::Expr::Tuple(ast::ExprTuple { elts, .. }) = slice.as_ref() else {
return;
};
if matches!(
(name, &elts[..]),
("__iter__", [_, _, _]) | ("__aiter__", [_, _])
) {
if !&elts.iter().skip(1).all(|elt| is_any_or_none(elt, semantic)) {
return;
}
} else {
return;
}
};

// For .py files (runtime Python!),
// only emit the lint if it's a simple __(a)iter__ implementation
// -- for more complex function bodies,
// it's more likely we'll be emitting a false positive here
if !checker.source_type.is_stub() {
let mut yield_encountered = false;
for stmt in &function_def.body {
match stmt {
ast::Stmt::Pass(_) => continue,
ast::Stmt::Return(ast::StmtReturn { value, .. }) => {
if let Some(ret_val) = value {
if yield_encountered
&& !matches!(ret_val.as_ref(), ast::Expr::NoneLiteral(_))
{
return;
}
}
}
ast::Stmt::Expr(ast::StmtExpr { value, .. }) => match value.as_ref() {
ast::Expr::StringLiteral(_) | ast::Expr::EllipsisLiteral(_) => continue,
ast::Expr::Yield(_) | ast::Expr::YieldFrom(_) => {
yield_encountered = true;
continue;
}
_ => return,
},
_ => return,
}
}
};

checker.diagnostics.push(Diagnostic::new(
GeneratorReturnFromIterMethod {
better_return_type: better_return_type.to_string(),
method_name: name.to_string(),
},
function_def.identifier(),
));
}

fn is_any_or_none(expr: &ast::Expr, semantic: &SemanticModel) -> bool {
semantic.match_typing_expr(expr, "Any") || matches!(expr, ast::Expr::NoneLiteral(_))
}
2 changes: 2 additions & 0 deletions crates/ruff_linter/src/rules/flake8_pyi/rules/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) use any_eq_ne_annotation::*;
pub(crate) use bad_generator_return_type::*;
pub(crate) use bad_version_info_comparison::*;
pub(crate) use collections_named_tuple::*;
pub(crate) use complex_assignment_in_stub::*;
Expand Down Expand Up @@ -36,6 +37,7 @@ pub(crate) use unsupported_method_call_on_all::*;
pub(crate) use unused_private_type_definition::*;

mod any_eq_ne_annotation;
mod bad_generator_return_type;
mod bad_version_info_comparison;
mod collections_named_tuple;
mod complex_assignment_in_stub;
Expand Down
Loading
Loading