Skip to content

Commit

Permalink
[flake8-comprehensions] Set comprehensions not a violation for `sum…
Browse files Browse the repository at this point in the history
…` in `unnecessary-comprehension-in-call` (`C419`) (astral-sh#12691)

## Summary

Removes set comprehension as a violation for `sum` when checking `C419`,
because set comprehension may de-duplicate entries in a generator,
thereby modifying the value of the sum.

Closes astral-sh#12690.
  • Loading branch information
dylwil3 committed Aug 7, 2024
1 parent 7637edd commit 7ff3db9
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,17 @@ async def f() -> bool:
i.bit_count() for i in range(5) # rbracket comment
] # rpar comment
)

## Set comprehensions should only be linted
## when function is invariant under duplication of inputs

# should be linted...
any({x.id for x in bar})
all({x.id for x in bar})

# should be linted in preview...
min({x.id for x in bar})
max({x.id for x in bar})

# should not be linted...
sum({x.id for x in bar})
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use ruff_python_ast::{self as ast, Expr, Keyword};

use ruff_diagnostics::{Diagnostic, FixAvailability};
use ruff_diagnostics::{Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::any_over_expr;
use ruff_python_ast::{self as ast, Expr, Keyword};
use ruff_text_size::{Ranged, TextSize};

use crate::checkers::ast::Checker;

use crate::rules::flake8_comprehensions::fixes;

/// ## What it does
/// Checks for unnecessary list comprehensions passed to builtin functions that take an iterable.
/// Checks for unnecessary list or set comprehensions passed to builtin functions that take an iterable.
///
/// Set comprehensions are only a violation in the case where the builtin function does not care about
/// duplication of elements in the passed iterable.
///
/// ## Why is this bad?
/// Many builtin functions (this rule currently covers `any` and `all` in stable, along with `min`,
Expand Down Expand Up @@ -65,18 +66,23 @@ use crate::rules::flake8_comprehensions::fixes;
///
/// [preview]: https://docs.astral.sh/ruff/preview/
#[violation]
pub struct UnnecessaryComprehensionInCall;
pub struct UnnecessaryComprehensionInCall {
comprehension_kind: ComprehensionKind,
}

impl Violation for UnnecessaryComprehensionInCall {
const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes;

#[derive_message_formats]
fn message(&self) -> String {
format!("Unnecessary list comprehension")
match self.comprehension_kind {
ComprehensionKind::List => format!("Unnecessary list comprehension"),
ComprehensionKind::Set => format!("Unnecessary set comprehension"),
}
}

fn fix_title(&self) -> Option<String> {
Some("Remove unnecessary list comprehension".to_string())
Some("Remove unnecessary comprehension".to_string())
}
}

Expand All @@ -102,18 +108,42 @@ pub(crate) fn unnecessary_comprehension_in_call(
if contains_await(elt) {
return;
}
let Some(builtin_function) = checker.semantic().resolve_builtin_symbol(func) else {
let Some(Ok(builtin_function)) = checker
.semantic()
.resolve_builtin_symbol(func)
.map(SupportedBuiltins::try_from)
else {
return;
};
if !(matches!(builtin_function, "any" | "all")
|| (checker.settings.preview.is_enabled()
&& matches!(builtin_function, "sum" | "min" | "max")))
if !(matches!(
builtin_function,
SupportedBuiltins::Any | SupportedBuiltins::All
) || (checker.settings.preview.is_enabled()
&& matches!(
builtin_function,
SupportedBuiltins::Sum | SupportedBuiltins::Min | SupportedBuiltins::Max
)))
{
return;
}

let mut diagnostic = Diagnostic::new(UnnecessaryComprehensionInCall, arg.range());

let mut diagnostic = match (arg, builtin_function.duplication_variance()) {
(Expr::ListComp(_), _) => Diagnostic::new(
UnnecessaryComprehensionInCall {
comprehension_kind: ComprehensionKind::List,
},
arg.range(),
),
(Expr::SetComp(_), DuplicationVariance::Invariant) => Diagnostic::new(
UnnecessaryComprehensionInCall {
comprehension_kind: ComprehensionKind::Set,
},
arg.range(),
),
_ => {
return;
}
};
if args.len() == 1 {
// If there's only one argument, remove the list or set brackets.
diagnostic.try_set_fix(|| {
Expand Down Expand Up @@ -144,3 +174,51 @@ pub(crate) fn unnecessary_comprehension_in_call(
fn contains_await(expr: &Expr) -> bool {
any_over_expr(expr, &Expr::is_await_expr)
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum DuplicationVariance {
Invariant,
Variant,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum ComprehensionKind {
List,
Set,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum SupportedBuiltins {
All,
Any,
Sum,
Min,
Max,
}

impl TryFrom<&str> for SupportedBuiltins {
type Error = &'static str;

fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"all" => Ok(Self::All),
"any" => Ok(Self::Any),
"sum" => Ok(Self::Sum),
"min" => Ok(Self::Min),
"max" => Ok(Self::Max),
_ => Err("Unsupported builtin for `unnecessary-comprehension-in-call`"),
}
}
}

impl SupportedBuiltins {
fn duplication_variance(self) -> DuplicationVariance {
match self {
SupportedBuiltins::All
| SupportedBuiltins::Any
| SupportedBuiltins::Min
| SupportedBuiltins::Max => DuplicationVariance::Invariant,
SupportedBuiltins::Sum => DuplicationVariance::Variant,
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ C419.py:1:5: C419 [*] Unnecessary list comprehension
2 | all([x.id for x in bar])
3 | any( # first comment
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
1 |-any([x.id for x in bar])
Expand All @@ -25,7 +25,7 @@ C419.py:2:5: C419 [*] Unnecessary list comprehension
3 | any( # first comment
4 | [x.id for x in bar], # second comment
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
1 1 | any([x.id for x in bar])
Expand All @@ -44,7 +44,7 @@ C419.py:4:5: C419 [*] Unnecessary list comprehension
5 | ) # third comment
6 | all( # first comment
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
1 1 | any([x.id for x in bar])
Expand All @@ -65,7 +65,7 @@ C419.py:7:5: C419 [*] Unnecessary list comprehension
8 | ) # third comment
9 | any({x.id for x in bar})
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
4 4 | [x.id for x in bar], # second comment
Expand All @@ -77,7 +77,7 @@ C419.py:7:5: C419 [*] Unnecessary list comprehension
9 9 | any({x.id for x in bar})
10 10 |

C419.py:9:5: C419 [*] Unnecessary list comprehension
C419.py:9:5: C419 [*] Unnecessary set comprehension
|
7 | [x.id for x in bar], # second comment
8 | ) # third comment
Expand All @@ -86,7 +86,7 @@ C419.py:9:5: C419 [*] Unnecessary list comprehension
10 |
11 | # OK
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
6 6 | all( # first comment
Expand All @@ -113,7 +113,7 @@ C419.py:28:5: C419 [*] Unnecessary list comprehension
34 | # trailing comment
35 | )
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
25 25 |
Expand Down Expand Up @@ -145,7 +145,7 @@ C419.py:39:5: C419 [*] Unnecessary list comprehension
| |_____^ C419
43 | )
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

ℹ Unsafe fix
36 36 |
Expand All @@ -160,3 +160,45 @@ C419.py:39:5: C419 [*] Unnecessary list comprehension
41 |+# second line comment
42 |+i.bit_count() for i in range(5) # rbracket comment # rpar comment
43 43 | )
44 44 |
45 45 | ## Set comprehensions should only be linted

C419.py:49:5: C419 [*] Unnecessary set comprehension
|
48 | # should be linted...
49 | any({x.id for x in bar})
| ^^^^^^^^^^^^^^^^^^^ C419
50 | all({x.id for x in bar})
|
= help: Remove unnecessary comprehension

ℹ Unsafe fix
46 46 | ## when function is invariant under duplication of inputs
47 47 |
48 48 | # should be linted...
49 |-any({x.id for x in bar})
49 |+any(x.id for x in bar)
50 50 | all({x.id for x in bar})
51 51 |
52 52 | # should be linted in preview...

C419.py:50:5: C419 [*] Unnecessary set comprehension
|
48 | # should be linted...
49 | any({x.id for x in bar})
50 | all({x.id for x in bar})
| ^^^^^^^^^^^^^^^^^^^ C419
51 |
52 | # should be linted in preview...
|
= help: Remove unnecessary comprehension

ℹ Unsafe fix
47 47 |
48 48 | # should be linted...
49 49 | any({x.id for x in bar})
50 |-all({x.id for x in bar})
50 |+all(x.id for x in bar)
51 51 |
52 52 | # should be linted in preview...
53 53 | min({x.id for x in bar})
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ C419_1.py:1:5: C419 [*] Unnecessary list comprehension
2 | min([x.val for x in bar])
3 | max([x.val for x in bar])
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
1 |-sum([x.val for x in bar])
Expand All @@ -25,7 +25,7 @@ C419_1.py:2:5: C419 [*] Unnecessary list comprehension
3 | max([x.val for x in bar])
4 | sum([x.val for x in bar], 0)
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
1 1 | sum([x.val for x in bar])
Expand All @@ -43,7 +43,7 @@ C419_1.py:3:5: C419 [*] Unnecessary list comprehension
| ^^^^^^^^^^^^^^^^^^^^ C419
4 | sum([x.val for x in bar], 0)
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
1 1 | sum([x.val for x in bar])
Expand All @@ -63,7 +63,7 @@ C419_1.py:4:5: C419 [*] Unnecessary list comprehension
5 |
6 | # OK
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
1 1 | sum([x.val for x in bar])
Expand All @@ -89,7 +89,7 @@ C419_1.py:14:5: C419 [*] Unnecessary list comprehension
19 | dt.timedelta(),
20 | )
|
= help: Remove unnecessary list comprehension
= help: Remove unnecessary comprehension

Unsafe fix
11 11 |
Expand Down

0 comments on commit 7ff3db9

Please sign in to comment.