Skip to content

Commit

Permalink
Add autofix for flake8-type-checking
Browse files Browse the repository at this point in the history
This is actually sort of starting to work

This is actually sort of starting to work

Another hard case
  • Loading branch information
charliermarsh committed May 31, 2023
1 parent 9efd5b2 commit 9a84037
Show file tree
Hide file tree
Showing 26 changed files with 1,137 additions and 41 deletions.
85 changes: 85 additions & 0 deletions crates/ruff/src/autofix/codemods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,88 @@ pub(crate) fn remove_imports<'a>(

Ok(Some(state.to_string()))
}

/// Given an import statement, remove any imports that are not specified in the `imports` slice.
///
/// Returns the modified import statement.
pub(crate) fn retain_imports(
imports: &[&str],
stmt: &Stmt,
locator: &Locator,
stylist: &Stylist,
) -> Result<String> {
let module_text = locator.slice(stmt.range());
let mut tree = match_statement(module_text)?;

let Statement::Simple(body) = &mut tree else {
bail!("Expected Statement::Simple");
};

let (aliases, import_module) = match body.body.first_mut() {
Some(SmallStatement::Import(import_body)) => (&mut import_body.names, None),
Some(SmallStatement::ImportFrom(import_body)) => {
if let ImportNames::Aliases(names) = &mut import_body.names {
(
names,
Some((&import_body.relative, import_body.module.as_ref())),
)
} else {
bail!("Expected: ImportNames::Aliases");
}
}
_ => bail!("Expected: SmallStatement::ImportFrom | SmallStatement::Import"),
};

// Preserve the trailing comma (or not) from the last entry.
let trailing_comma = aliases.last().and_then(|alias| alias.comma.clone());

aliases.retain(|alias| {
imports.iter().any(|import| {
let full_name = match import_module {
Some((relative, module)) => {
let module = module.map(compose_module_path);
let member = compose_module_path(&alias.name);
let mut full_name = String::with_capacity(
relative.len() + module.as_ref().map_or(0, String::len) + member.len() + 1,
);
for _ in 0..relative.len() {
full_name.push('.');
}
if let Some(module) = module {
full_name.push_str(&module);
full_name.push('.');
}
full_name.push_str(&member);
full_name
}
None => compose_module_path(&alias.name),
};
full_name == *import
})
});

// But avoid destroying any trailing comments.
if let Some(alias) = aliases.last_mut() {
let has_comment = if let Some(comma) = &alias.comma {
match &comma.whitespace_after {
ParenthesizableWhitespace::SimpleWhitespace(_) => false,
ParenthesizableWhitespace::ParenthesizedWhitespace(whitespace) => {
whitespace.first_line.comment.is_some()
}
}
} else {
false
};
if !has_comment {
alias.comma = trailing_comma;
}
}

let mut state = CodegenState {
default_newline: &stylist.line_ending(),
default_indent: stylist.indentation(),
..CodegenState::default()
};
tree.codegen(&mut state);
Ok(state.to_string())
}
2 changes: 1 addition & 1 deletion crates/ruff/src/autofix/edits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub(crate) fn delete_stmt(
}
}

/// Generate a `Fix` to remove any unused imports from an `import` statement.
/// Generate a `Fix` to remove the specified imports from an `import` statement.
pub(crate) fn remove_unused_imports<'a>(
unused_imports: impl Iterator<Item = &'a str>,
stmt: &Stmt,
Expand Down
2 changes: 0 additions & 2 deletions crates/ruff/src/importer/insertion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
//! Insert statements into Python code.
#![allow(dead_code)]

use ruff_text_size::TextSize;
use rustpython_parser::ast::{Ranged, Stmt};
use rustpython_parser::{lexer, Mode, Tok};
Expand Down
103 changes: 95 additions & 8 deletions crates/ruff/src/importer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use libcst_native::{Codegen, CodegenState, ImportAlias, Name, NameOrAttribute};
use ruff_text_size::TextSize;
use rustpython_parser::ast::{self, Ranged, Stmt, Suite};

use crate::autofix;
use ruff_diagnostics::Edit;
use ruff_python_ast::imports::{AnyImport, Import, ImportFrom};
use ruff_python_ast::source_code::{Locator, Stylist};
use ruff_python_semantic::model::SemanticModel;
use ruff_textwrap::indent;

use crate::cst::matchers::{match_aliases, match_import_from, match_statement};
use crate::importer::insertion::Insertion;
Expand Down Expand Up @@ -73,6 +75,52 @@ impl<'a> Importer<'a> {
}
}

/// Move an existing import into a `TYPE_CHECKING` block.
///
/// If there are no existing `TYPE_CHECKING` blocks, a new one will be added at the top
/// of the file. Otherwise, it will be added after the most recent top-level
/// `TYPE_CHECKING` block.
pub(crate) fn to_typing_import(
&self,
import: &StmtImport,
at: TextSize,
semantic_model: &SemanticModel,
) -> Result<(Edit, Edit)> {
// Generate the modified import statement.
let content = autofix::codemods::retain_imports(
&[import.full_name],
import.stmt,
self.locator,
self.stylist,
)?;

// Import the `TYPE_CHECKING` symbol from the typing module.
let (type_checking_edit, type_checking) = self.get_or_import_symbol(
&ImportRequest::import_from("typing", "TYPE_CHECKING"),
at,
semantic_model,
)?;

// Add the import to a `TYPE_CHECKING` block.
let add_import_edit = if let Some(block) = self.preceding_type_checking_block(at) {
// Add the import to the `TYPE_CHECKING` block.
self.add_to_type_checking_block(&content, block.start())
} else {
// Add the import to a new `TYPE_CHECKING` block.
self.add_type_checking_block(
&format!(
"{}if {type_checking}:{}{}",
self.stylist.line_ending().as_str(),
self.stylist.line_ending().as_str(),
indent(&content, self.stylist.indentation())
),
at,
)?
};

Ok((type_checking_edit, add_import_edit))
}

/// Generate an [`Edit`] to reference the given symbol. Returns the [`Edit`] necessary to make
/// the symbol available in the current scope along with the bound name of the symbol.
///
Expand Down Expand Up @@ -204,14 +252,6 @@ impl<'a> Importer<'a> {
}
}

/// Return the import statement that precedes the given position, if any.
fn preceding_import(&self, at: TextSize) -> Option<&Stmt> {
self.runtime_imports
.partition_point(|stmt| stmt.start() < at)
.checked_sub(1)
.map(|idx| self.runtime_imports[idx])
}

/// Return the top-level [`Stmt`] that imports the given module using `Stmt::ImportFrom`
/// preceding the given position, if any.
fn find_import_from(&self, module: &str, at: TextSize) -> Option<&Stmt> {
Expand Down Expand Up @@ -258,6 +298,45 @@ impl<'a> Importer<'a> {
statement.codegen(&mut state);
Ok(Edit::range_replacement(state.to_string(), stmt.range()))
}

/// Add a `TYPE_CHECKING` block to the given module.
fn add_type_checking_block(&self, content: &str, at: TextSize) -> Result<Edit> {
let insertion = if let Some(stmt) = self.preceding_import(at) {
// Insert after the last top-level import.
Insertion::end_of_statement(stmt, self.locator, self.stylist)
} else {
// Insert at the start of the file.
Insertion::start_of_file(self.python_ast, self.locator, self.stylist)
};
if insertion.is_inline() {
Err(anyhow::anyhow!(
"Cannot insert `TYPE_CHECKING` block inline"
))
} else {
Ok(insertion.into_edit(content))
}
}

/// Add an import statement to an existing `TYPE_CHECKING` block.
fn add_to_type_checking_block(&self, content: &str, at: TextSize) -> Edit {
Insertion::start_of_block(at, self.locator, self.stylist).into_edit(content)
}

/// Return the import statement that precedes the given position, if any.
fn preceding_import(&self, at: TextSize) -> Option<&'a Stmt> {
self.runtime_imports
.partition_point(|stmt| stmt.start() < at)
.checked_sub(1)
.map(|idx| self.runtime_imports[idx])
}

/// Return the import statement that precedes the given position, if any.
fn preceding_type_checking_block(&self, at: TextSize) -> Option<&'a Stmt> {
self.type_checking_blocks
.partition_point(|stmt| stmt.start() < at)
.checked_sub(1)
.map(|idx| self.type_checking_blocks[idx])
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -301,6 +380,14 @@ impl<'a> ImportRequest<'a> {
}
}

/// An existing module or member import, located within an import statement.
pub(crate) struct StmtImport<'a> {
/// The import statement.
pub(crate) stmt: &'a Stmt,
/// The "full name" of the imported module or member.
pub(crate) full_name: &'a str,
}

/// The result of an [`Importer::get_or_import_symbol`] call.
#[derive(Debug)]
pub(crate) enum ResolutionError {
Expand Down
Loading

0 comments on commit 9a84037

Please sign in to comment.