From 1031bb655025524476546ab2761fac513af989ca Mon Sep 17 00:00:00 2001 From: konsti Date: Fri, 4 Aug 2023 13:52:26 +0200 Subject: [PATCH] Formatter: Add SourceType to context to enable special formatting for stub files (#6331) **Summary** This adds the information whether we're in a .py python source file or in a .pyi stub file to enable people working on #5822 and related issues. I'm not completely happy with `Default` for something that depends on the input. **Test Plan** None, this is currently unused, i'm leaving this to first implementation of stub file specific formatting. --------- Co-authored-by: Micha Reiser --- .../pycodestyle/rules/type_comparison.rs | 8 +-- crates/ruff_benchmark/benches/formatter.rs | 5 +- crates/ruff_cli/src/lib.rs | 13 ++--- crates/ruff_dev/src/format_dev.rs | 58 +++++++++---------- crates/ruff_python_ast/src/lib.rs | 34 +++++++++++ crates/ruff_python_formatter/Cargo.toml | 6 +- crates/ruff_python_formatter/src/cli.rs | 18 +++--- crates/ruff_python_formatter/src/lib.rs | 14 ++--- crates/ruff_python_formatter/src/main.rs | 6 +- crates/ruff_python_formatter/src/options.rs | 49 +++++++++++----- .../ruff_python_formatter/tests/fixtures.rs | 7 +-- crates/ruff_wasm/src/lib.rs | 16 ++--- 12 files changed, 137 insertions(+), 97 deletions(-) diff --git a/crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs b/crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs index c13d7a533d09c..2dc68873416be 100644 --- a/crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs +++ b/crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs @@ -52,15 +52,13 @@ pub(crate) fn type_comparison(checker: &mut Checker, compare: &ast::ExprCompare) } // Left-hand side must be, e.g., `type(obj)`. - let Expr::Call(ast::ExprCall { - func, .. - }) = left else { + let Expr::Call(ast::ExprCall { func, .. }) = left else { continue; }; let Expr::Name(ast::ExprName { id, .. }) = func.as_ref() else { - continue; - }; + continue; + }; if !(id == "type" && checker.semantic().is_builtin("type")) { continue; diff --git a/crates/ruff_benchmark/benches/formatter.rs b/crates/ruff_benchmark/benches/formatter.rs index 74688bed09314..6895bd2ac1eca 100644 --- a/crates/ruff_benchmark/benches/formatter.rs +++ b/crates/ruff_benchmark/benches/formatter.rs @@ -1,6 +1,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use ruff_benchmark::{TestCase, TestCaseSpeed, TestFile, TestFileDownloadError}; use ruff_python_formatter::{format_module, PyFormatOptions}; +use std::path::Path; use std::time::Duration; #[cfg(target_os = "windows")] @@ -51,8 +52,8 @@ fn benchmark_formatter(criterion: &mut Criterion) { &case, |b, case| { b.iter(|| { - format_module(case.code(), PyFormatOptions::default()) - .expect("Formatting to succeed") + let options = PyFormatOptions::from_extension(Path::new(case.name())); + format_module(case.code(), options).expect("Formatting to succeed") }); }, ); diff --git a/crates/ruff_cli/src/lib.rs b/crates/ruff_cli/src/lib.rs index edfb43f5de3ad..0e30a4cfdb080 100644 --- a/crates/ruff_cli/src/lib.rs +++ b/crates/ruff_cli/src/lib.rs @@ -161,25 +161,20 @@ fn format(files: &[PathBuf]) -> Result { internal use only." ); - let format_code = |code: &str| { - // dummy, to check that the function was actually called - let contents = code.replace("# DEL", ""); - // real formatting that is currently a passthrough - format_module(&contents, PyFormatOptions::default()) - }; - match &files { // Check if we should read from stdin [path] if path == Path::new("-") => { let unformatted = read_from_stdin()?; - let formatted = format_code(&unformatted)?; + let options = PyFormatOptions::from_extension(Path::new("stdin.py")); + let formatted = format_module(&unformatted, options)?; stdout().lock().write_all(formatted.as_code().as_bytes())?; } _ => { for file in files { let unformatted = std::fs::read_to_string(file) .with_context(|| format!("Could not read {}: ", file.display()))?; - let formatted = format_code(&unformatted)?; + let options = PyFormatOptions::from_extension(file); + let formatted = format_module(&unformatted, options)?; std::fs::write(file, formatted.as_code().as_bytes()) .with_context(|| format!("Could not write to {}, exiting", file.display()))?; } diff --git a/crates/ruff_dev/src/format_dev.rs b/crates/ruff_dev/src/format_dev.rs index a669ed022c445..924ce7d2d3077 100644 --- a/crates/ruff_dev/src/format_dev.rs +++ b/crates/ruff_dev/src/format_dev.rs @@ -376,10 +376,10 @@ fn format_dev_project( // TODO(konstin): The assumptions between this script (one repo) and ruff (pass in a bunch of // files) mismatch. - let options = BlackOptions::from_file(&files[0])?.to_py_format_options(); + let black_options = BlackOptions::from_file(&files[0])?; debug!( parent: None, - "Options for {}: {options:?}", + "Options for {}: {black_options:?}", files[0].display() ); @@ -398,7 +398,7 @@ fn format_dev_project( paths .into_par_iter() .map(|dir_entry| { - let result = format_dir_entry(dir_entry, stability_check, write, &options); + let result = format_dir_entry(dir_entry, stability_check, write, &black_options); pb_span.pb_inc(1); result }) @@ -447,7 +447,7 @@ fn format_dir_entry( dir_entry: Result, stability_check: bool, write: bool, - options: &PyFormatOptions, + options: &BlackOptions, ) -> anyhow::Result<(Result, PathBuf), Error> { let dir_entry = match dir_entry.context("Iterating the files in the repository failed") { Ok(dir_entry) => dir_entry, @@ -460,27 +460,27 @@ fn format_dir_entry( } let file = dir_entry.path().to_path_buf(); + let options = options.to_py_format_options(&file); // Handle panics (mostly in `debug_assert!`) - let result = - match catch_unwind(|| format_dev_file(&file, stability_check, write, options.clone())) { - Ok(result) => result, - Err(panic) => { - if let Some(message) = panic.downcast_ref::() { - Err(CheckFileError::Panic { - message: message.clone(), - }) - } else if let Some(&message) = panic.downcast_ref::<&str>() { - Err(CheckFileError::Panic { - message: message.to_string(), - }) - } else { - Err(CheckFileError::Panic { - // This should not happen, but it can - message: "(Panic didn't set a string message)".to_string(), - }) - } + let result = match catch_unwind(|| format_dev_file(&file, stability_check, write, options)) { + Ok(result) => result, + Err(panic) => { + if let Some(message) = panic.downcast_ref::() { + Err(CheckFileError::Panic { + message: message.clone(), + }) + } else if let Some(&message) = panic.downcast_ref::<&str>() { + Err(CheckFileError::Panic { + message: message.to_string(), + }) + } else { + Err(CheckFileError::Panic { + // This should not happen, but it can + message: "(Panic didn't set a string message)".to_string(), + }) } - }; + } + }; Ok((result, file)) } @@ -833,9 +833,8 @@ impl BlackOptions { Self::from_toml(&fs::read_to_string(&path)?, repo) } - fn to_py_format_options(&self) -> PyFormatOptions { - let mut options = PyFormatOptions::default(); - options + fn to_py_format_options(&self, file: &Path) -> PyFormatOptions { + PyFormatOptions::from_extension(file) .with_line_width( LineWidth::try_from(self.line_length).expect("Invalid line length limit"), ) @@ -843,8 +842,7 @@ impl BlackOptions { MagicTrailingComma::Ignore } else { MagicTrailingComma::Respect - }); - options + }) } } @@ -868,7 +866,7 @@ mod tests { "}; let options = BlackOptions::from_toml(toml, Path::new("pyproject.toml")) .unwrap() - .to_py_format_options(); + .to_py_format_options(Path::new("code_inline.py")); assert_eq!(options.line_width(), LineWidth::try_from(119).unwrap()); assert!(matches!( options.magic_trailing_comma(), @@ -887,7 +885,7 @@ mod tests { "#}; let options = BlackOptions::from_toml(toml, Path::new("pyproject.toml")) .unwrap() - .to_py_format_options(); + .to_py_format_options(Path::new("code_inline.py")); assert_eq!(options.line_width(), LineWidth::try_from(130).unwrap()); assert!(matches!( options.magic_trailing_comma(), diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index 4b8b9a1eeb92e..a7489fb7a1321 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -1,4 +1,5 @@ use ruff_text_size::{TextRange, TextSize}; +use std::path::Path; pub mod all; pub mod call_path; @@ -49,3 +50,36 @@ where T::range(self) } } + +#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum PySourceType { + #[default] + Python, + Stub, + Jupyter, +} + +impl PySourceType { + pub const fn is_python(&self) -> bool { + matches!(self, PySourceType::Python) + } + + pub const fn is_stub(&self) -> bool { + matches!(self, PySourceType::Stub) + } + + pub const fn is_jupyter(&self) -> bool { + matches!(self, PySourceType::Jupyter) + } +} + +impl From<&Path> for PySourceType { + fn from(path: &Path) -> Self { + match path.extension() { + Some(ext) if ext == "pyi" => PySourceType::Stub, + Some(ext) if ext == "ipynb" => PySourceType::Jupyter, + _ => PySourceType::Python, + } + } +} diff --git a/crates/ruff_python_formatter/Cargo.toml b/crates/ruff_python_formatter/Cargo.toml index 1f722f0a79a7e..b57a90faed007 100644 --- a/crates/ruff_python_formatter/Cargo.toml +++ b/crates/ruff_python_formatter/Cargo.toml @@ -32,7 +32,7 @@ smallvec = { workspace = true } thiserror = { workspace = true } [dev-dependencies] -ruff_formatter = { path = "../ruff_formatter", features = ["serde"]} +ruff_formatter = { path = "../ruff_formatter", features = ["serde"] } insta = { workspace = true, features = ["glob"] } serde = { workspace = true } @@ -43,8 +43,8 @@ similar = { workspace = true } name = "ruff_python_formatter_fixtures" path = "tests/fixtures.rs" test = true -required-features = [ "serde" ] +required-features = ["serde"] [features] -serde = ["dep:serde", "ruff_formatter/serde", "ruff_source_file/serde"] +serde = ["dep:serde", "ruff_formatter/serde", "ruff_source_file/serde", "ruff_python_ast/serde"] default = ["serde"] diff --git a/crates/ruff_python_formatter/src/cli.rs b/crates/ruff_python_formatter/src/cli.rs index 132798486a581..fff146fdb91cd 100644 --- a/crates/ruff_python_formatter/src/cli.rs +++ b/crates/ruff_python_formatter/src/cli.rs @@ -1,14 +1,14 @@ #![allow(clippy::print_stdout)] -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use anyhow::{bail, Context, Result}; use clap::{command, Parser, ValueEnum}; -use ruff_python_parser::lexer::lex; -use ruff_python_parser::{parse_tokens, Mode}; use ruff_formatter::SourceCode; use ruff_python_index::CommentRangesBuilder; +use ruff_python_parser::lexer::lex; +use ruff_python_parser::{parse_tokens, Mode}; use crate::{format_node, PyFormatOptions}; @@ -37,7 +37,7 @@ pub struct Cli { pub print_comments: bool, } -pub fn format_and_debug_print(input: &str, cli: &Cli) -> Result { +pub fn format_and_debug_print(input: &str, cli: &Cli, source_type: &Path) -> Result { let mut tokens = Vec::new(); let mut comment_ranges = CommentRangesBuilder::default(); @@ -57,13 +57,9 @@ pub fn format_and_debug_print(input: &str, cli: &Cli) -> Result { let python_ast = parse_tokens(tokens, Mode::Module, "").context("Syntax error in input")?; - let formatted = format_node( - &python_ast, - &comment_ranges, - input, - PyFormatOptions::default(), - ) - .context("Failed to format node")?; + let options = PyFormatOptions::from_extension(source_type); + let formatted = format_node(&python_ast, &comment_ranges, input, options) + .context("Failed to format node")?; if cli.print_ir { println!("{}", formatted.document().display(SourceCode::new(input))); } diff --git a/crates/ruff_python_formatter/src/lib.rs b/crates/ruff_python_formatter/src/lib.rs index 390798bfc0dad..0ade2fad5461d 100644 --- a/crates/ruff_python_formatter/src/lib.rs +++ b/crates/ruff_python_formatter/src/lib.rs @@ -255,6 +255,7 @@ mod tests { use ruff_python_index::CommentRangesBuilder; use ruff_python_parser::lexer::lex; use ruff_python_parser::{parse_tokens, Mode}; + use std::path::Path; /// Very basic test intentionally kept very similar to the CLI #[test] @@ -321,15 +322,10 @@ with [ let comment_ranges = comment_ranges.finish(); // Parse the AST. - let python_ast = parse_tokens(tokens, Mode::Module, "").unwrap(); - - let formatted = format_node( - &python_ast, - &comment_ranges, - src, - PyFormatOptions::default(), - ) - .unwrap(); + let source_path = "code_inline.py"; + let python_ast = parse_tokens(tokens, Mode::Module, source_path).unwrap(); + let options = PyFormatOptions::from_extension(Path::new(source_path)); + let formatted = format_node(&python_ast, &comment_ranges, src, options).unwrap(); // Uncomment the `dbg` to print the IR. // Use `dbg_write!(f, []) instead of `write!(f, [])` in your formatting code to print some IR diff --git a/crates/ruff_python_formatter/src/main.rs b/crates/ruff_python_formatter/src/main.rs index 63b789888cc99..ba74dd06c3692 100644 --- a/crates/ruff_python_formatter/src/main.rs +++ b/crates/ruff_python_formatter/src/main.rs @@ -1,4 +1,5 @@ use std::io::{stdout, Read, Write}; +use std::path::Path; use std::{fs, io}; use anyhow::{bail, Context, Result}; @@ -25,7 +26,8 @@ fn main() -> Result<()> { ); } let input = read_from_stdin()?; - let formatted = format_and_debug_print(&input, &cli)?; + // It seems reasonable to give this a dummy name + let formatted = format_and_debug_print(&input, &cli, Path::new("stdin.py"))?; if cli.check { if formatted == input { return Ok(()); @@ -37,7 +39,7 @@ fn main() -> Result<()> { for file in &cli.files { let input = fs::read_to_string(file) .with_context(|| format!("Could not read {}: ", file.display()))?; - let formatted = format_and_debug_print(&input, &cli)?; + let formatted = format_and_debug_print(&input, &cli, file)?; match cli.emit { Some(Emit::Stdout) => stdout().lock().write_all(formatted.as_bytes())?, None | Some(Emit::Files) => { diff --git a/crates/ruff_python_formatter/src/options.rs b/crates/ruff_python_formatter/src/options.rs index 65bcf655527c5..e5aef3dc35501 100644 --- a/crates/ruff_python_formatter/src/options.rs +++ b/crates/ruff_python_formatter/src/options.rs @@ -1,5 +1,7 @@ use ruff_formatter::printer::{LineEnding, PrinterOptions}; use ruff_formatter::{FormatOptions, IndentStyle, LineWidth}; +use ruff_python_ast::PySourceType; +use std::path::Path; #[derive(Clone, Debug)] #[cfg_attr( @@ -8,6 +10,9 @@ use ruff_formatter::{FormatOptions, IndentStyle, LineWidth}; serde(default) )] pub struct PyFormatOptions { + /// Whether we're in a `.py` file or `.pyi` file, which have different rules + source_type: PySourceType, + /// Specifies the indent style: /// * Either a tab /// * or a specific amount of spaces @@ -28,7 +33,31 @@ fn default_line_width() -> LineWidth { LineWidth::try_from(88).unwrap() } +impl Default for PyFormatOptions { + fn default() -> Self { + Self { + source_type: PySourceType::default(), + indent_style: IndentStyle::Space(4), + line_width: LineWidth::try_from(88).unwrap(), + quote_style: QuoteStyle::default(), + magic_trailing_comma: MagicTrailingComma::default(), + } + } +} + impl PyFormatOptions { + /// Otherwise sets the defaults. Returns none if the extension is unknown + pub fn from_extension(path: &Path) -> Self { + Self::from_source_type(PySourceType::from(path)) + } + + pub fn from_source_type(source_type: PySourceType) -> Self { + Self { + source_type, + ..Self::default() + } + } + pub fn magic_trailing_comma(&self) -> MagicTrailingComma { self.magic_trailing_comma } @@ -42,17 +71,20 @@ impl PyFormatOptions { self } - pub fn with_magic_trailing_comma(&mut self, trailing_comma: MagicTrailingComma) -> &mut Self { + #[must_use] + pub fn with_magic_trailing_comma(mut self, trailing_comma: MagicTrailingComma) -> Self { self.magic_trailing_comma = trailing_comma; self } - pub fn with_indent_style(&mut self, indent_style: IndentStyle) -> &mut Self { + #[must_use] + pub fn with_indent_style(mut self, indent_style: IndentStyle) -> Self { self.indent_style = indent_style; self } - pub fn with_line_width(&mut self, line_width: LineWidth) -> &mut Self { + #[must_use] + pub fn with_line_width(mut self, line_width: LineWidth) -> Self { self.line_width = line_width; self } @@ -77,17 +109,6 @@ impl FormatOptions for PyFormatOptions { } } -impl Default for PyFormatOptions { - fn default() -> Self { - Self { - indent_style: IndentStyle::Space(4), - line_width: LineWidth::try_from(88).unwrap(), - quote_style: QuoteStyle::default(), - magic_trailing_comma: MagicTrailingComma::default(), - } - } -} - #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] #[cfg_attr( feature = "serde", diff --git a/crates/ruff_python_formatter/tests/fixtures.rs b/crates/ruff_python_formatter/tests/fixtures.rs index 5d9be66a41ab2..2ea1be082885e 100644 --- a/crates/ruff_python_formatter/tests/fixtures.rs +++ b/crates/ruff_python_formatter/tests/fixtures.rs @@ -17,7 +17,7 @@ fn black_compatibility() { let reader = BufReader::new(options_file); serde_json::from_reader(reader).expect("Options to be a valid Json file") } else { - PyFormatOptions::default() + PyFormatOptions::from_extension(input_path) }; let printed = format_module(&content, options.clone()).unwrap_or_else(|err| { @@ -106,11 +106,11 @@ fn format() { let test_file = |input_path: &Path| { let content = fs::read_to_string(input_path).unwrap(); - let options = PyFormatOptions::default(); + let options = PyFormatOptions::from_extension(input_path); let printed = format_module(&content, options.clone()).expect("Formatting to succeed"); let formatted_code = printed.as_code(); - ensure_stability_when_formatting_twice(formatted_code, options, input_path); + ensure_stability_when_formatting_twice(formatted_code, options.clone(), input_path); let mut snapshot = format!("## Input\n{}", CodeFrame::new("py", &content)); @@ -139,7 +139,6 @@ fn format() { .unwrap(); } } else { - let options = PyFormatOptions::default(); let printed = format_module(&content, options.clone()).expect("Formatting to succeed"); let formatted_code = printed.as_code(); diff --git a/crates/ruff_wasm/src/lib.rs b/crates/ruff_wasm/src/lib.rs index c3f48a9f286bd..9180a8c3983f9 100644 --- a/crates/ruff_wasm/src/lib.rs +++ b/crates/ruff_wasm/src/lib.rs @@ -21,6 +21,7 @@ use ruff::rules::{ use ruff::settings::configuration::Configuration; use ruff::settings::options::Options; use ruff::settings::{defaults, flags, Settings}; +use ruff_python_ast::PySourceType; use ruff_python_codegen::Stylist; use ruff_python_formatter::{format_module, format_node, PyFormatOptions}; use ruff_python_index::{CommentRangesBuilder, Indexer}; @@ -262,7 +263,9 @@ impl Workspace { } pub fn format(&self, contents: &str) -> Result { - let printed = format_module(contents, PyFormatOptions::default()).map_err(into_error)?; + // TODO(konstin): Add an options for py/pyi to the UI (1/2) + let options = PyFormatOptions::from_source_type(PySourceType::default()); + let printed = format_module(contents, options).map_err(into_error)?; Ok(printed.into_code()) } @@ -278,13 +281,10 @@ impl Workspace { let comment_ranges = comment_ranges.finish(); let module = parse_tokens(tokens, Mode::Module, ".").map_err(into_error)?; - let formatted = format_node( - &module, - &comment_ranges, - contents, - PyFormatOptions::default(), - ) - .map_err(into_error)?; + // TODO(konstin): Add an options for py/pyi to the UI (2/2) + let options = PyFormatOptions::from_source_type(PySourceType::default()); + let formatted = + format_node(&module, &comment_ranges, contents, options).map_err(into_error)?; Ok(format!("{formatted}")) }