Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate edits when quoting annotations #9140

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ def f():

def func(value: DataFrame):
...


def f():
from pandas import DataFrame, Series

def baz() -> DataFrame | Series:
...
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::borrow::Cow;

use anyhow::Result;
use itertools::Itertools;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{Diagnostic, Fix, FixAvailability, Violation};
Expand Down Expand Up @@ -262,7 +263,7 @@ pub(crate) fn runtime_import_in_type_checking_block(

/// Generate a [`Fix`] to quote runtime usages for imports in a type-checking block.
fn quote_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> Result<Fix> {
let mut quote_reference_edits = imports
let quote_reference_edits = imports
.iter()
.flat_map(|ImportBinding { binding, .. }| {
binding.references.iter().filter_map(|reference_id| {
Expand All @@ -280,14 +281,12 @@ fn quote_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding])
})
})
.collect::<Result<Vec<_>>>()?;
let quote_reference_edit = quote_reference_edits
.pop()
.expect("Expected at least one reference");
Ok(
Fix::unsafe_edits(quote_reference_edit, quote_reference_edits).isolate(Checker::isolation(
checker.semantic().parent_statement_id(node_id),
)),
)

let mut rest = quote_reference_edits.into_iter().dedup();
let head = rest.next().expect("Expected at least one reference");
Ok(Fix::unsafe_edits(head, rest).isolate(Checker::isolation(
checker.semantic().parent_statement_id(node_id),
)))
}

/// Generate a [`Fix`] to remove runtime imports from a type-checking block.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::borrow::Cow;

use anyhow::Result;
use itertools::Itertools;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{Diagnostic, DiagnosticKind, Fix, FixAvailability, Violation};
Expand Down Expand Up @@ -506,7 +507,7 @@ fn fix_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) ->
add_import_edit
.into_edits()
.into_iter()
.chain(quote_reference_edits),
.chain(quote_reference_edits.into_iter().dedup()),
)
.isolate(Checker::isolation(
checker.semantic().parent_statement_id(node_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ quote.py:64:28: TCH004 [*] Quote references to `pandas.DataFrame`. Import is in
66 |- def func(value: DataFrame):
66 |+ def func(value: "DataFrame"):
67 67 | ...
68 68 |
69 69 |


Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,60 @@ quote.py:54:24: TCH002 Move third-party import `pandas.DataFrame` into a type-ch
|
= help: Move into type-checking block

quote.py:71:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block
|
70 | def f():
71 | from pandas import DataFrame, Series
| ^^^^^^^^^ TCH002
72 |
73 | def baz() -> DataFrame | Series:
|
= help: Move into type-checking block

ℹ Unsafe fix
1 |+from typing import TYPE_CHECKING
2 |+
3 |+if TYPE_CHECKING:
4 |+ from pandas import DataFrame, Series
1 5 | def f():
2 6 | from pandas import DataFrame
3 7 |
--------------------------------------------------------------------------------
68 72 |
69 73 |
70 74 | def f():
71 |- from pandas import DataFrame, Series
72 75 |
73 |- def baz() -> DataFrame | Series:
76 |+ def baz() -> "DataFrame | Series":
74 77 | ...

quote.py:71:35: TCH002 [*] Move third-party import `pandas.Series` into a type-checking block
|
70 | def f():
71 | from pandas import DataFrame, Series
| ^^^^^^ TCH002
72 |
73 | def baz() -> DataFrame | Series:
|
= help: Move into type-checking block

ℹ Unsafe fix
1 |+from typing import TYPE_CHECKING
2 |+
3 |+if TYPE_CHECKING:
4 |+ from pandas import DataFrame, Series
1 5 | def f():
2 6 | from pandas import DataFrame
3 7 |
--------------------------------------------------------------------------------
68 72 |
69 73 |
70 74 | def f():
71 |- from pandas import DataFrame, Series
72 75 |
73 |- def baz() -> DataFrame | Series:
76 |+ def baz() -> "DataFrame | Series":
74 77 | ...


Loading