From a6f7100b55fae748a7f1185ae3f768744c49722d Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Mon, 29 Jan 2024 09:39:01 -0800 Subject: [PATCH] [`pycodestyle`] Allow `dtype` comparisons in `type-comparison` (#9676) ## Summary Per https://github.com/astral-sh/ruff/issues/9570: > `dtype` are a bit of a strange beast, but definitely best thought of as instances, not classes, and they are meant to be comparable not just to their own class, but also to the corresponding scalar types (e.g., `x.dtype == np.float32`) and strings (e.g., `x.dtype == ['i1,i4']`; basically, `__eq__` always tries to do `dtype(other)`. This PR thus allows comparisons to `dtype` in preview. Closes https://github.com/astral-sh/ruff/issues/9570. ## Test Plan `cargo test` --- .../test/fixtures/pycodestyle/E721.py | 12 +++++++++ .../pycodestyle/rules/type_comparison.rs | 27 +++++++++++++++++++ ...destyle__tests__preview__E721_E721.py.snap | 7 +++++ 3 files changed, 46 insertions(+) diff --git a/crates/ruff_linter/resources/test/fixtures/pycodestyle/E721.py b/crates/ruff_linter/resources/test/fixtures/pycodestyle/E721.py index 872fa7042aded..9ce183b6adad8 100644 --- a/crates/ruff_linter/resources/test/fixtures/pycodestyle/E721.py +++ b/crates/ruff_linter/resources/test/fixtures/pycodestyle/E721.py @@ -126,3 +126,15 @@ def type(): # Okay if type(value) is str: ... + + +import numpy as np + +#: Okay +x.dtype == float + +#: Okay +np.dtype(int) == float + +#: E721 +dtype == float diff --git a/crates/ruff_linter/src/rules/pycodestyle/rules/type_comparison.rs b/crates/ruff_linter/src/rules/pycodestyle/rules/type_comparison.rs index 598b9a9c11972..57de972c23177 100644 --- a/crates/ruff_linter/src/rules/pycodestyle/rules/type_comparison.rs +++ b/crates/ruff_linter/src/rules/pycodestyle/rules/type_comparison.rs @@ -162,7 +162,14 @@ pub(crate) fn preview_type_comparison(checker: &mut Checker, compare: &ast::Expr .filter(|(_, op)| matches!(op, CmpOp::Eq | CmpOp::NotEq)) .map(|((left, right), _)| (left, right)) { + // If either expression is a type... if is_type(left, checker.semantic()) || is_type(right, checker.semantic()) { + // And neither is a `dtype`... + if is_dtype(left, checker.semantic()) || is_dtype(right, checker.semantic()) { + continue; + } + + // Disallow the comparison. checker.diagnostics.push(Diagnostic::new( TypeComparison { preview: PreviewMode::Enabled, @@ -295,3 +302,23 @@ fn is_type(expr: &Expr, semantic: &SemanticModel) -> bool { _ => false, } } + +/// Returns `true` if the [`Expr`] appears to be a reference to a NumPy dtype, since: +/// > `dtype` are a bit of a strange beast, but definitely best thought of as instances, not +/// > classes, and they are meant to be comparable not just to their own class, but also to the +/// corresponding scalar types (e.g., `x.dtype == np.float32`) and strings (e.g., +/// `x.dtype == ['i1,i4']`; basically, __eq__ always tries to do `dtype(other)`). +fn is_dtype(expr: &Expr, semantic: &SemanticModel) -> bool { + match expr { + // Ex) `np.dtype(obj)` + Expr::Call(ast::ExprCall { func, .. }) => semantic + .resolve_call_path(func) + .is_some_and(|call_path| matches!(call_path.as_slice(), ["numpy", "dtype"])), + // Ex) `obj.dtype` + Expr::Attribute(ast::ExprAttribute { attr, .. }) => { + // Ex) `obj.dtype` + attr.as_str() == "dtype" + } + _ => false, + } +} diff --git a/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__preview__E721_E721.py.snap b/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__preview__E721_E721.py.snap index 8971d3f3ccf06..fc1d13521fdae 100644 --- a/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__preview__E721_E721.py.snap +++ b/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__preview__E721_E721.py.snap @@ -129,4 +129,11 @@ E721.py:59:4: E721 Use `is` and `is not` for type comparisons, or `isinstance()` 61 | #: Okay | +E721.py:140:1: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks + | +139 | #: E721 +140 | dtype == float + | ^^^^^^^^^^^^^^ E721 + | +