Skip to content

Commit

Permalink
Auto merge of rust-lang#13681 - lowr:fix/extract-function-tail-expr, …
Browse files Browse the repository at this point in the history
…r=Veykril

fix: check tail expressions more precisely in `extract_function`

Fixes rust-lang#13620

When extracting expressions with control flows into a function, we can avoid wrapping tail expressions in `Option` or `Result` when they are also tail expressions of the container we're extracting from (see rust-lang#7840, rust-lang#9773). This is controlled by `ContainerInfo::is_in_tail`, but we've been computing it by checking if the tail expression of the range to extract is contained in the container's syntactically last expression, which may be a block that contains both tail and non-tail expressions (e.g. in rust-lang#13620, the range to be extracted is not a tail expression but we set the flag to true).

This PR tries to compute the flag as precise as possible by utilizing `for_each_tail_expr()` (and also moves the flag to `Function` struct as it's more of a property of the function to be extracted than of the container).
  • Loading branch information
bors committed Nov 27, 2022
2 parents 34e2bc6 + 8e03f18 commit 6d61be8
Showing 1 changed file with 186 additions and 24 deletions.
210 changes: 186 additions & 24 deletions crates/ide-assists/src/handlers/extract_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use ide_db::{
helpers::mod_path_to_ast,
imports::insert_use::{insert_use, ImportScope},
search::{FileReference, ReferenceCategory, SearchScope},
syntax_helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr},
syntax_helpers::node_ext::{
for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr,
},
FxIndexSet, RootDatabase,
};
use itertools::Itertools;
Expand Down Expand Up @@ -78,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
};

let body = extraction_target(&node, range)?;
let container_info = body.analyze_container(&ctx.sema)?;
let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema)?;

let (locals_used, self_param) = body.analyze(&ctx.sema);

Expand Down Expand Up @@ -119,6 +121,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
ret_ty,
body,
outliving_locals,
contains_tail_expr,
mods: container_info,
};

Expand Down Expand Up @@ -245,6 +248,8 @@ struct Function {
ret_ty: RetType,
body: FunctionBody,
outliving_locals: Vec<OutlivedLocal>,
/// Whether at least one of the container's tail expr is contained in the range we're extracting.
contains_tail_expr: bool,
mods: ContainerInfo,
}

Expand All @@ -265,7 +270,7 @@ enum ParamKind {
MutRef,
}

#[derive(Debug, Eq, PartialEq)]
#[derive(Debug)]
enum FunType {
Unit,
Single(hir::Type),
Expand Down Expand Up @@ -294,7 +299,6 @@ struct ControlFlow {
#[derive(Clone, Debug)]
struct ContainerInfo {
is_const: bool,
is_in_tail: bool,
parent_loop: Option<SyntaxNode>,
/// The function's return type, const's type etc.
ret_type: Option<hir::Type>,
Expand Down Expand Up @@ -743,7 +747,10 @@ impl FunctionBody {
(res, self_param)
}

fn analyze_container(&self, sema: &Semantics<'_, RootDatabase>) -> Option<ContainerInfo> {
fn analyze_container(
&self,
sema: &Semantics<'_, RootDatabase>,
) -> Option<(ContainerInfo, bool)> {
let mut ancestors = self.parent()?.ancestors();
let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted);
let mut parent_loop = None;
Expand Down Expand Up @@ -815,28 +822,36 @@ impl FunctionBody {
}
};
};
let container_tail = match expr? {
ast::Expr::BlockExpr(block) => block.tail_expr(),
expr => Some(expr),
};
let is_in_tail =
container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())

let expr = expr?;
let contains_tail_expr = if let Some(body_tail) = self.tail_expr() {
let mut contains_tail_expr = false;
let tail_expr_range = body_tail.syntax().text_range();
for_each_tail_expr(&expr, &mut |e| {
if tail_expr_range.contains_range(e.syntax().text_range()) {
contains_tail_expr = true;
}
});
contains_tail_expr
} else {
false
};

let parent = self.parent()?;
let parents = generic_parents(&parent);
let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect();

Some(ContainerInfo {
is_in_tail,
is_const,
parent_loop,
ret_type: ty,
generic_param_lists,
where_clauses,
})
Some((
ContainerInfo {
is_const,
parent_loop,
ret_type: ty,
generic_param_lists,
where_clauses,
},
contains_tail_expr,
))
}

fn return_ty(&self, ctx: &AssistContext<'_>) -> Option<RetType> {
Expand Down Expand Up @@ -1368,7 +1383,7 @@ impl FlowHandler {
None => FlowHandler::None,
Some(flow_kind) => {
let action = flow_kind.clone();
if *ret_ty == FunType::Unit {
if let FunType::Unit = ret_ty {
match flow_kind {
FlowKind::Return(None)
| FlowKind::Break(_, None)
Expand Down Expand Up @@ -1633,7 +1648,7 @@ impl Function {

fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> {
let fun_ty = self.return_type(ctx);
let handler = if self.mods.is_in_tail {
let handler = if self.contains_tail_expr {
FlowHandler::None
} else {
FlowHandler::from_ret_ty(self, &fun_ty)
Expand Down Expand Up @@ -1707,7 +1722,7 @@ fn make_body(
fun: &Function,
) -> ast::BlockExpr {
let ret_ty = fun.return_type(ctx);
let handler = if fun.mods.is_in_tail {
let handler = if fun.contains_tail_expr {
FlowHandler::None
} else {
FlowHandler::from_ret_ty(fun, &ret_ty)
Expand Down Expand Up @@ -1946,7 +1961,7 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) {
if nested_scope.is_none() {
if let Some(expr) = ast::Expr::cast(e.clone()) {
match expr {
ast::Expr::ReturnExpr(return_expr) if nested_scope.is_none() => {
ast::Expr::ReturnExpr(return_expr) => {
let expr = return_expr.expr();
if let Some(replacement) = make_rewritten_flow(handler, expr) {
ted::replace(return_expr.syntax(), replacement.syntax())
Expand Down Expand Up @@ -5582,6 +5597,153 @@ impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
t.into() + v.into()
}
"#,
);
}

#[test]
fn non_tail_expr_of_tail_expr_loop() {
check_assist(
extract_function,
r#"
pub fn f() {
loop {
$0if true {
continue;
}$0
if false {
break;
}
}
}
"#,
r#"
pub fn f() {
loop {
if let ControlFlow::Break(_) = fun_name() {
continue;
}
if false {
break;
}
}
}
fn $0fun_name() -> ControlFlow<()> {
if true {
return ControlFlow::Break(());
}
ControlFlow::Continue(())
}
"#,
);
}

#[test]
fn non_tail_expr_of_tail_if_block() {
// FIXME: double semicolon
check_assist(
extract_function,
r#"
//- minicore: option, try
impl<T> core::ops::Try for Option<T> {
type Output = T;
type Residual = Option<!>;
}
impl<T> core::ops::FromResidual for Option<T> {}
fn f() -> Option<()> {
if true {
let a = $0if true {
Some(())?
} else {
()
}$0;
Some(a)
} else {
None
}
}
"#,
r#"
impl<T> core::ops::Try for Option<T> {
type Output = T;
type Residual = Option<!>;
}
impl<T> core::ops::FromResidual for Option<T> {}
fn f() -> Option<()> {
if true {
let a = fun_name()?;;
Some(a)
} else {
None
}
}
fn $0fun_name() -> Option<()> {
Some(if true {
Some(())?
} else {
()
})
}
"#,
);
}

#[test]
fn tail_expr_of_tail_block_nested() {
check_assist(
extract_function,
r#"
//- minicore: option, try
impl<T> core::ops::Try for Option<T> {
type Output = T;
type Residual = Option<!>;
}
impl<T> core::ops::FromResidual for Option<T> {}
fn f() -> Option<()> {
if true {
$0{
let a = if true {
Some(())?
} else {
()
};
Some(a)
}$0
} else {
None
}
}
"#,
r#"
impl<T> core::ops::Try for Option<T> {
type Output = T;
type Residual = Option<!>;
}
impl<T> core::ops::FromResidual for Option<T> {}
fn f() -> Option<()> {
if true {
fun_name()?
} else {
None
}
}
fn $0fun_name() -> Option<()> {
let a = if true {
Some(())?
} else {
()
};
Some(a)
}
"#,
);
}
Expand Down

0 comments on commit 6d61be8

Please sign in to comment.