diff --git a/crates/ruff/resources/test/fixtures/pycodestyle/E721.py b/crates/ruff/resources/test/fixtures/pycodestyle/E721.py index acb92fab29c9f..a220f58c3d8e4 100644 --- a/crates/ruff/resources/test/fixtures/pycodestyle/E721.py +++ b/crates/ruff/resources/test/fixtures/pycodestyle/E721.py @@ -54,3 +54,7 @@ pass assert type(res) == type(None) + +types = StrEnum +if x == types.X: + pass diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index 72b5d101bcbf8..5c4bff9c642fb 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -3186,11 +3186,7 @@ where } if self.settings.rules.enabled(Rule::TypeComparison) { - self.diagnostics.extend(pycodestyle::rules::type_comparison( - ops, - comparators, - Range::from(expr), - )); + pycodestyle::rules::type_comparison(self, expr, ops, comparators); } if self.settings.rules.enabled(Rule::SysVersionCmpStr3) diff --git a/crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs b/crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs index fde3b00735cdf..3109617e4f3f9 100644 --- a/crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs +++ b/crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs @@ -1,6 +1,7 @@ use itertools::izip; use rustpython_parser::ast::{Cmpop, Constant, Expr, ExprKind}; +use crate::checkers::ast::Checker; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::types::Range; @@ -34,9 +35,7 @@ impl Violation for TypeComparison { } /// E721 -pub fn type_comparison(ops: &[Cmpop], comparators: &[Expr], location: Range) -> Vec { - let mut diagnostics: Vec = vec![]; - +pub fn type_comparison(checker: &mut Checker, expr: &Expr, ops: &[Cmpop], comparators: &[Expr]) { for (op, right) in izip!(ops, comparators) { if !matches!(op, Cmpop::Is | Cmpop::IsNot | Cmpop::Eq | Cmpop::NotEq) { continue; @@ -44,8 +43,8 @@ pub fn type_comparison(ops: &[Cmpop], comparators: &[Expr], location: Range) -> match &right.node { ExprKind::Call { func, args, .. } => { if let ExprKind::Name { id, .. } = &func.node { - // Ex) type(False) - if id == "type" { + // Ex) `type(False)` + if id == "type" && checker.ctx.is_builtin("type") { if let Some(arg) = args.first() { // Allow comparison for types which are not obvious. if !matches!( @@ -56,7 +55,9 @@ pub fn type_comparison(ops: &[Cmpop], comparators: &[Expr], location: Range) -> kind: None } ) { - diagnostics.push(Diagnostic::new(TypeComparison, location)); + checker + .diagnostics + .push(Diagnostic::new(TypeComparison, Range::from(expr))); } } } @@ -64,15 +65,22 @@ pub fn type_comparison(ops: &[Cmpop], comparators: &[Expr], location: Range) -> } ExprKind::Attribute { value, .. } => { if let ExprKind::Name { id, .. } = &value.node { - // Ex) types.IntType - if id == "types" { - diagnostics.push(Diagnostic::new(TypeComparison, location)); + // Ex) `types.NoneType` + if id == "types" + && checker + .ctx + .resolve_call_path(value) + .map_or(false, |call_path| { + call_path.first().map_or(false, |module| *module == "types") + }) + { + checker + .diagnostics + .push(Diagnostic::new(TypeComparison, Range::from(expr))); } } } _ => {} } } - - diagnostics }