Skip to content

Commit

Permalink
Implement isort custom sections and ordering (astral-sh#2419)
Browse files Browse the repository at this point in the history
  • Loading branch information
hackedd committed Apr 12, 2023
1 parent 8608414 commit 80f4d50
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 67 deletions.
7 changes: 7 additions & 0 deletions crates/ruff/resources/test/fixtures/isort/sections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations
import os
import sys
import pytz
import django.settings
from library import foo
from . import local
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ruff_python_semantic::binding::{
Binding, BindingKind, ExecutionContext, FromImportation, Importation, SubmoduleImportation,
};

use crate::rules::isort::{categorize, ImportType};
use crate::rules::isort::{categorize, ImportSection, ImportType};
use crate::settings::Settings;

/// ## What it does
Expand Down Expand Up @@ -294,25 +294,31 @@ pub fn typing_only_runtime_import(
&settings.isort.known_modules,
settings.target_version,
) {
ImportType::LocalFolder | ImportType::FirstParty => Some(Diagnostic::new(
TypingOnlyFirstPartyImport {
full_name: full_name.to_string(),
},
binding.range,
)),
ImportType::ThirdParty => Some(Diagnostic::new(
TypingOnlyThirdPartyImport {
full_name: full_name.to_string(),
},
binding.range,
)),
ImportType::StandardLibrary => Some(Diagnostic::new(
ImportSection::Known(ImportType::LocalFolder | ImportType::FirstParty) => {
Some(Diagnostic::new(
TypingOnlyFirstPartyImport {
full_name: full_name.to_string(),
},
binding.range,
))
}
ImportSection::Known(ImportType::ThirdParty) | ImportSection::UserDefined(_) => {
Some(Diagnostic::new(
TypingOnlyThirdPartyImport {
full_name: full_name.to_string(),
},
binding.range,
))
}
ImportSection::Known(ImportType::StandardLibrary) => Some(Diagnostic::new(
TypingOnlyStandardLibraryImport {
full_name: full_name.to_string(),
},
binding.range,
)),
ImportType::Future => unreachable!("`__future__` imports should be marked as used"),
ImportSection::Known(ImportType::Future) => {
unreachable!("`__future__` imports should be marked as used")
}
}
} else {
None
Expand Down
113 changes: 80 additions & 33 deletions crates/ruff/src/rules/isort/categorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf};
use std::{fs, iter};

use log::debug;
use rustc_hash::FxHashMap;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum_macros::EnumIter;
Expand Down Expand Up @@ -37,6 +38,15 @@ pub enum ImportType {
LocalFolder,
}

#[derive(
Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, CacheKey,
)]
#[serde(untagged)]
pub enum ImportSection {
Known(ImportType),
UserDefined(String),
}

#[derive(Debug)]
enum Reason<'a> {
NonZeroLevel,
Expand All @@ -49,6 +59,7 @@ enum Reason<'a> {
SamePackage,
SourceMatch(&'a Path),
NoMatch,
UserDefinedSection,
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -59,27 +70,42 @@ pub fn categorize(
package: Option<&Path>,
known_modules: &KnownModules,
target_version: PythonVersion,
) -> ImportType {
) -> ImportSection {
let module_base = module_name.split('.').next().unwrap();
let (import_type, reason) = {
if level.map_or(false, |level| level > 0) {
(ImportType::LocalFolder, Reason::NonZeroLevel)
(
ImportSection::Known(ImportType::LocalFolder),
Reason::NonZeroLevel,
)
} else if module_base == "__future__" {
(ImportType::Future, Reason::Future)
(ImportSection::Known(ImportType::Future), Reason::Future)
} else if let Some((import_type, reason)) = known_modules.categorize(module_name) {
(import_type, reason)
} else if KNOWN_STANDARD_LIBRARY
.get(&target_version.as_tuple())
.unwrap()
.contains(module_base)
{
(ImportType::StandardLibrary, Reason::KnownStandardLibrary)
(
ImportSection::Known(ImportType::StandardLibrary),
Reason::KnownStandardLibrary,
)
} else if same_package(package, module_base) {
(ImportType::FirstParty, Reason::SamePackage)
(
ImportSection::Known(ImportType::FirstParty),
Reason::SamePackage,
)
} else if let Some(src) = match_sources(src, module_base) {
(ImportType::FirstParty, Reason::SourceMatch(src))
(
ImportSection::Known(ImportType::FirstParty),
Reason::SourceMatch(src),
)
} else {
(ImportType::ThirdParty, Reason::NoMatch)
(
ImportSection::Known(ImportType::ThirdParty),
Reason::NoMatch,
)
}
};
debug!(
Expand Down Expand Up @@ -116,8 +142,8 @@ pub fn categorize_imports<'a>(
package: Option<&Path>,
known_modules: &KnownModules,
target_version: PythonVersion,
) -> BTreeMap<ImportType, ImportBlock<'a>> {
let mut block_by_type: BTreeMap<ImportType, ImportBlock> = BTreeMap::default();
) -> BTreeMap<ImportSection, ImportBlock<'a>> {
let mut block_by_type: BTreeMap<ImportSection, ImportBlock> = BTreeMap::default();
// Categorize `StmtKind::Import`.
for (alias, comments) in block.import {
let import_type = categorize(
Expand Down Expand Up @@ -195,6 +221,8 @@ pub struct KnownModules {
pub local_folder: BTreeSet<String>,
/// A set of user-provided standard library modules.
pub standard_library: BTreeSet<String>,
/// A map of additional user-provided sections.
pub sections: FxHashMap<String, BTreeSet<String>>,
/// Whether any of the known modules are submodules (e.g., `foo.bar`, as opposed to `foo`).
has_submodules: bool,
}
Expand All @@ -205,29 +233,36 @@ impl KnownModules {
third_party: Vec<String>,
local_folder: Vec<String>,
standard_library: Vec<String>,
sections: FxHashMap<String, Vec<String>>,
) -> Self {
let first_party = BTreeSet::from_iter(first_party);
let third_party = BTreeSet::from_iter(third_party);
let local_folder = BTreeSet::from_iter(local_folder);
let standard_library = BTreeSet::from_iter(standard_library);
let sections: FxHashMap<_, _> = sections
.into_iter()
.map(|(k, v)| (k, BTreeSet::from_iter(v)))
.collect();
let has_submodules = first_party
.iter()
.chain(third_party.iter())
.chain(local_folder.iter())
.chain(standard_library.iter())
.chain(sections.values().flatten())
.any(|m| m.contains('.'));
Self {
first_party,
third_party,
local_folder,
standard_library,
sections,
has_submodules,
}
}

/// Return the [`ImportType`] for a given module, if it's been categorized as a known module
/// Return the [`ImportSection`] for a given module, if it's been categorized as a known module
/// by the user.
fn categorize(&self, module_name: &str) -> Option<(ImportType, Reason)> {
fn categorize(&self, module_name: &str) -> Option<(ImportSection, Reason)> {
if self.has_submodules {
// Check all module prefixes from the longest to the shortest (e.g., given
// `foo.bar.baz`, check `foo.bar.baz`, then `foo.bar`, then `foo`, taking the first,
Expand All @@ -239,34 +274,46 @@ impl KnownModules {
.rev()
{
let submodule = &module_name[0..i];
if self.first_party.contains(submodule) {
return Some((ImportType::FirstParty, Reason::KnownFirstParty));
}
if self.third_party.contains(submodule) {
return Some((ImportType::ThirdParty, Reason::KnownThirdParty));
}
if self.local_folder.contains(submodule) {
return Some((ImportType::LocalFolder, Reason::KnownLocalFolder));
}
if self.standard_library.contains(submodule) {
return Some((ImportType::StandardLibrary, Reason::ExtraStandardLibrary));
if let Some(result) = self.categorize_submodule(submodule) {
return Some(result);
}
}
None
} else {
// Happy path: no submodules, so we can check the module base and be done.
let module_base = module_name.split('.').next().unwrap();
if self.first_party.contains(module_base) {
Some((ImportType::FirstParty, Reason::KnownFirstParty))
} else if self.third_party.contains(module_base) {
Some((ImportType::ThirdParty, Reason::KnownThirdParty))
} else if self.local_folder.contains(module_base) {
Some((ImportType::LocalFolder, Reason::KnownLocalFolder))
} else if self.standard_library.contains(module_base) {
Some((ImportType::StandardLibrary, Reason::ExtraStandardLibrary))
} else {
None
}
self.categorize_submodule(module_base)
}
}

fn categorize_submodule(&self, submodule: &str) -> Option<(ImportSection, Reason)> {
if let Some((section, _)) = self.sections.iter().find(|(_, v)| v.contains(submodule)) {
Some((
ImportSection::UserDefined(section.clone()),
Reason::UserDefinedSection,
))
} else if self.first_party.contains(submodule) {
Some((
ImportSection::Known(ImportType::FirstParty),
Reason::KnownFirstParty,
))
} else if self.third_party.contains(submodule) {
Some((
ImportSection::Known(ImportType::ThirdParty),
Reason::KnownThirdParty,
))
} else if self.local_folder.contains(submodule) {
Some((
ImportSection::Known(ImportType::LocalFolder),
Reason::KnownLocalFolder,
))
} else if self.standard_library.contains(submodule) {
Some((
ImportSection::Known(ImportType::StandardLibrary),
Reason::ExtraStandardLibrary,
))
} else {
None
}
}
}
Loading

0 comments on commit 80f4d50

Please sign in to comment.