Skip to content

Commit

Permalink
Respect local subclasses in flake8-type-checking
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Nov 19, 2023
1 parent 94178a0 commit a1eba7f
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from collections.abc import Sequence # TCH003
from pandas import DataFrame
from pydantic import BaseModel


class MyBaseClass:
pass
class Parent(BaseModel):
...


class Foo(MyBaseClass):
foo: Sequence
class Child(Parent):
baz: DataFrame
32 changes: 16 additions & 16 deletions crates/ruff_linter/src/rules/flake8_pyi/rules/simple_defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use rustc_hash::FxHashSet;
use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::call_path::CallPath;
use ruff_python_ast::helpers::map_subscript;
use ruff_python_ast::{
self as ast, Arguments, Expr, Operator, ParameterWithDefault, Parameters, Stmt, UnaryOp,
self as ast, Expr, Operator, ParameterWithDefault, Parameters, Stmt, UnaryOp,
};
use ruff_python_semantic::{BindingId, ScopeKind, SemanticModel};
use ruff_source_file::Locator;
Expand Down Expand Up @@ -476,26 +477,25 @@ fn is_enum(class_def: &ast::StmtClassDef, semantic: &SemanticModel) -> bool {
semantic: &SemanticModel,
seen: &mut FxHashSet<BindingId>,
) -> bool {
let Some(Arguments { args: bases, .. }) = class_def.arguments.as_deref() else {
return false;
};

bases.iter().any(|expr| {
class_def.bases().iter().any(|expr| {
// If the base class is `enum.Enum`, `enum.Flag`, etc., then this is an enum.
if semantic.resolve_call_path(expr).is_some_and(|call_path| {
matches!(
call_path.as_slice(),
[
"enum",
"Enum" | "Flag" | "IntEnum" | "IntFlag" | "StrEnum" | "ReprEnum"
]
)
}) {
if semantic
.resolve_call_path(map_subscript(expr))
.is_some_and(|call_path| {
matches!(
call_path.as_slice(),
[
"enum",
"Enum" | "Flag" | "IntEnum" | "IntFlag" | "StrEnum" | "ReprEnum"
]
)
})
{
return true;
}

// If the base class extends `enum.Enum`, `enum.Flag`, etc., then this is an enum.
if let Some(id) = semantic.lookup_attribute(expr) {
if let Some(id) = semantic.lookup_attribute(map_subscript(expr)) {
if seen.insert(id) {
let binding = semantic.binding(id);
if let Some(base_class) = binding
Expand Down
63 changes: 50 additions & 13 deletions crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use ruff_python_ast::call_path::from_qualified_name;
use ruff_python_ast::helpers::{map_callable, map_subscript};
use ruff_python_semantic::{Binding, BindingKind, ScopeKind, SemanticModel};
use ruff_python_ast::{self as ast};
use ruff_python_semantic::{Binding, BindingId, BindingKind, ScopeKind, SemanticModel};
use rustc_hash::FxHashSet;

pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticModel) -> bool {
if matches!(
Expand Down Expand Up @@ -35,19 +37,54 @@ pub(crate) fn runtime_evaluated(
}

fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool {
let ScopeKind::Class(class_def) = &semantic.current_scope().kind else {
return false;
};
fn inner(
class_def: &ast::StmtClassDef,
base_classes: &[String],
semantic: &SemanticModel,
seen: &mut FxHashSet<BindingId>,
) -> bool {
class_def.bases().iter().any(|expr| {
// If the base class is itself runtime-evaluated, then this is too.
// Ex) `class Foo(BaseModel): ...`
if semantic
.resolve_call_path(map_subscript(expr))
.is_some_and(|call_path| {
base_classes
.iter()
.any(|base_class| from_qualified_name(base_class) == call_path)
})
{
return true;
}

class_def.bases().iter().any(|base| {
semantic
.resolve_call_path(map_subscript(base))
.is_some_and(|call_path| {
base_classes
.iter()
.any(|base_class| from_qualified_name(base_class) == call_path)
})
})
// If the base class extends a runtime-evaluated class, then this does too.
// Ex) `class Bar(BaseModel): ...; class Foo(Bar): ...`
if let Some(id) = semantic.lookup_attribute(map_subscript(expr)) {
if seen.insert(id) {
let binding = semantic.binding(id);
if let Some(base_class) = binding
.kind
.as_class_definition()
.map(|id| &semantic.scopes[*id])
.and_then(|scope| scope.kind.as_class())
{
if inner(base_class, base_classes, semantic, seen) {
return true;
}
}
}
}
false
})
}

semantic
.current_scope()
.kind
.as_class()
.is_some_and(|class_def| {
inner(class_def, base_classes, semantic, &mut FxHashSet::default())
})
}

fn runtime_evaluated_decorators(decorators: &[String], semantic: &SemanticModel) -> bool {
Expand Down
4 changes: 4 additions & 0 deletions crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ mod tests {
Rule::TypingOnlyStandardLibraryImport,
Path::new("runtime_evaluated_base_classes_4.py")
)]
#[test_case(
Rule::TypingOnlyThirdPartyImport,
Path::new("runtime_evaluated_base_classes_5.py")
)]
fn runtime_evaluated_base_classes(rule_code: Rule, path: &Path) -> Result<()> {
let snapshot = format!("{}_{}", rule_code.as_ref(), path.to_string_lossy());
let diagnostics = test_path(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
---

0 comments on commit a1eba7f

Please sign in to comment.