From a1eba7f3c07e18e5f54a1f9573cf352e31db92b4 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Sun, 19 Nov 2023 09:38:56 -0500 Subject: [PATCH] Respect local subclasses in flake8-type-checking --- .../runtime_evaluated_base_classes_5.py | 11 ++-- .../rules/flake8_pyi/rules/simple_defaults.rs | 32 +++++----- .../src/rules/flake8_type_checking/helpers.rs | 63 +++++++++++++++---- .../src/rules/flake8_type_checking/mod.rs | 4 ++ ...t_runtime_evaluated_base_classes_5.py.snap | 4 ++ 5 files changed, 80 insertions(+), 34 deletions(-) create mode 100644 crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__typing-only-third-party-import_runtime_evaluated_base_classes_5.py.snap diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/runtime_evaluated_base_classes_5.py b/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/runtime_evaluated_base_classes_5.py index 0f2ad5bdf76009..214d7e59d1712e 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/runtime_evaluated_base_classes_5.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/runtime_evaluated_base_classes_5.py @@ -1,11 +1,12 @@ from __future__ import annotations -from collections.abc import Sequence # TCH003 +from pandas import DataFrame +from pydantic import BaseModel -class MyBaseClass: - pass +class Parent(BaseModel): + ... -class Foo(MyBaseClass): - foo: Sequence +class Child(Parent): + baz: DataFrame diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/simple_defaults.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/simple_defaults.rs index a1fa2e58250a52..392a07be8faa0a 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/rules/simple_defaults.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/simple_defaults.rs @@ -3,8 +3,9 @@ use rustc_hash::FxHashSet; use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::call_path::CallPath; +use ruff_python_ast::helpers::map_subscript; use ruff_python_ast::{ - self as ast, Arguments, Expr, Operator, ParameterWithDefault, Parameters, Stmt, UnaryOp, + self as ast, Expr, Operator, ParameterWithDefault, Parameters, Stmt, UnaryOp, }; use ruff_python_semantic::{BindingId, ScopeKind, SemanticModel}; use ruff_source_file::Locator; @@ -476,26 +477,25 @@ fn is_enum(class_def: &ast::StmtClassDef, semantic: &SemanticModel) -> bool { semantic: &SemanticModel, seen: &mut FxHashSet, ) -> bool { - let Some(Arguments { args: bases, .. }) = class_def.arguments.as_deref() else { - return false; - }; - - bases.iter().any(|expr| { + class_def.bases().iter().any(|expr| { // If the base class is `enum.Enum`, `enum.Flag`, etc., then this is an enum. - if semantic.resolve_call_path(expr).is_some_and(|call_path| { - matches!( - call_path.as_slice(), - [ - "enum", - "Enum" | "Flag" | "IntEnum" | "IntFlag" | "StrEnum" | "ReprEnum" - ] - ) - }) { + if semantic + .resolve_call_path(map_subscript(expr)) + .is_some_and(|call_path| { + matches!( + call_path.as_slice(), + [ + "enum", + "Enum" | "Flag" | "IntEnum" | "IntFlag" | "StrEnum" | "ReprEnum" + ] + ) + }) + { return true; } // If the base class extends `enum.Enum`, `enum.Flag`, etc., then this is an enum. - if let Some(id) = semantic.lookup_attribute(expr) { + if let Some(id) = semantic.lookup_attribute(map_subscript(expr)) { if seen.insert(id) { let binding = semantic.binding(id); if let Some(base_class) = binding diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs b/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs index b2f0e456e38042..d07fe2d6cb1f27 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs @@ -1,6 +1,8 @@ use ruff_python_ast::call_path::from_qualified_name; use ruff_python_ast::helpers::{map_callable, map_subscript}; -use ruff_python_semantic::{Binding, BindingKind, ScopeKind, SemanticModel}; +use ruff_python_ast::{self as ast}; +use ruff_python_semantic::{Binding, BindingId, BindingKind, ScopeKind, SemanticModel}; +use rustc_hash::FxHashSet; pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticModel) -> bool { if matches!( @@ -35,19 +37,54 @@ pub(crate) fn runtime_evaluated( } fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool { - let ScopeKind::Class(class_def) = &semantic.current_scope().kind else { - return false; - }; + fn inner( + class_def: &ast::StmtClassDef, + base_classes: &[String], + semantic: &SemanticModel, + seen: &mut FxHashSet, + ) -> bool { + class_def.bases().iter().any(|expr| { + // If the base class is itself runtime-evaluated, then this is too. + // Ex) `class Foo(BaseModel): ...` + if semantic + .resolve_call_path(map_subscript(expr)) + .is_some_and(|call_path| { + base_classes + .iter() + .any(|base_class| from_qualified_name(base_class) == call_path) + }) + { + return true; + } - class_def.bases().iter().any(|base| { - semantic - .resolve_call_path(map_subscript(base)) - .is_some_and(|call_path| { - base_classes - .iter() - .any(|base_class| from_qualified_name(base_class) == call_path) - }) - }) + // If the base class extends a runtime-evaluated class, then this does too. + // Ex) `class Bar(BaseModel): ...; class Foo(Bar): ...` + if let Some(id) = semantic.lookup_attribute(map_subscript(expr)) { + if seen.insert(id) { + let binding = semantic.binding(id); + if let Some(base_class) = binding + .kind + .as_class_definition() + .map(|id| &semantic.scopes[*id]) + .and_then(|scope| scope.kind.as_class()) + { + if inner(base_class, base_classes, semantic, seen) { + return true; + } + } + } + } + false + }) + } + + semantic + .current_scope() + .kind + .as_class() + .is_some_and(|class_def| { + inner(class_def, base_classes, semantic, &mut FxHashSet::default()) + }) } fn runtime_evaluated_decorators(decorators: &[String], semantic: &SemanticModel) -> bool { diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/mod.rs b/crates/ruff_linter/src/rules/flake8_type_checking/mod.rs index ec5d504a0fb54e..97e9ad7cd4af0a 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/mod.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/mod.rs @@ -98,6 +98,10 @@ mod tests { Rule::TypingOnlyStandardLibraryImport, Path::new("runtime_evaluated_base_classes_4.py") )] + #[test_case( + Rule::TypingOnlyThirdPartyImport, + Path::new("runtime_evaluated_base_classes_5.py") + )] fn runtime_evaluated_base_classes(rule_code: Rule, path: &Path) -> Result<()> { let snapshot = format!("{}_{}", rule_code.as_ref(), path.to_string_lossy()); let diagnostics = test_path( diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__typing-only-third-party-import_runtime_evaluated_base_classes_5.py.snap b/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__typing-only-third-party-import_runtime_evaluated_base_classes_5.py.snap new file mode 100644 index 00000000000000..6c5ead27428cec --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__typing-only-third-party-import_runtime_evaluated_base_classes_5.py.snap @@ -0,0 +1,4 @@ +--- +source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs +--- +