From 822c61f559dc522dbd28f2886d20989a55613fc0 Mon Sep 17 00:00:00 2001 From: Ryo Yoshida Date: Sat, 26 Nov 2022 23:51:57 +0900 Subject: [PATCH 1/2] refactor: remove unnecessary stuff --- crates/ide-assists/src/handlers/extract_function.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index c1e2f19ab18b2..10a3a33226bbb 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -265,7 +265,7 @@ enum ParamKind { MutRef, } -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug)] enum FunType { Unit, Single(hir::Type), @@ -1368,7 +1368,7 @@ impl FlowHandler { None => FlowHandler::None, Some(flow_kind) => { let action = flow_kind.clone(); - if *ret_ty == FunType::Unit { + if let FunType::Unit = ret_ty { match flow_kind { FlowKind::Return(None) | FlowKind::Break(_, None) @@ -1946,7 +1946,7 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) { if nested_scope.is_none() { if let Some(expr) = ast::Expr::cast(e.clone()) { match expr { - ast::Expr::ReturnExpr(return_expr) if nested_scope.is_none() => { + ast::Expr::ReturnExpr(return_expr) => { let expr = return_expr.expr(); if let Some(replacement) = make_rewritten_flow(handler, expr) { ted::replace(return_expr.syntax(), replacement.syntax()) From 8e03f18e37d2782189391955bc56d3aebead81f5 Mon Sep 17 00:00:00 2001 From: Ryo Yoshida Date: Sat, 26 Nov 2022 23:51:22 +0900 Subject: [PATCH 2/2] fix: check if range contains tail expression --- .../src/handlers/extract_function.rs | 204 ++++++++++++++++-- 1 file changed, 183 insertions(+), 21 deletions(-) diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index 10a3a33226bbb..0483cfdc64667 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -11,7 +11,9 @@ use ide_db::{ helpers::mod_path_to_ast, imports::insert_use::{insert_use, ImportScope}, search::{FileReference, ReferenceCategory, SearchScope}, - syntax_helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr}, + syntax_helpers::node_ext::{ + for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr, + }, FxIndexSet, RootDatabase, }; use itertools::Itertools; @@ -78,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op }; let body = extraction_target(&node, range)?; - let container_info = body.analyze_container(&ctx.sema)?; + let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema)?; let (locals_used, self_param) = body.analyze(&ctx.sema); @@ -119,6 +121,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op ret_ty, body, outliving_locals, + contains_tail_expr, mods: container_info, }; @@ -245,6 +248,8 @@ struct Function { ret_ty: RetType, body: FunctionBody, outliving_locals: Vec, + /// Whether at least one of the container's tail expr is contained in the range we're extracting. + contains_tail_expr: bool, mods: ContainerInfo, } @@ -294,7 +299,6 @@ struct ControlFlow { #[derive(Clone, Debug)] struct ContainerInfo { is_const: bool, - is_in_tail: bool, parent_loop: Option, /// The function's return type, const's type etc. ret_type: Option, @@ -743,7 +747,10 @@ impl FunctionBody { (res, self_param) } - fn analyze_container(&self, sema: &Semantics<'_, RootDatabase>) -> Option { + fn analyze_container( + &self, + sema: &Semantics<'_, RootDatabase>, + ) -> Option<(ContainerInfo, bool)> { let mut ancestors = self.parent()?.ancestors(); let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted); let mut parent_loop = None; @@ -815,28 +822,36 @@ impl FunctionBody { } }; }; - let container_tail = match expr? { - ast::Expr::BlockExpr(block) => block.tail_expr(), - expr => Some(expr), - }; - let is_in_tail = - container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| { - container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range()) + + let expr = expr?; + let contains_tail_expr = if let Some(body_tail) = self.tail_expr() { + let mut contains_tail_expr = false; + let tail_expr_range = body_tail.syntax().text_range(); + for_each_tail_expr(&expr, &mut |e| { + if tail_expr_range.contains_range(e.syntax().text_range()) { + contains_tail_expr = true; + } }); + contains_tail_expr + } else { + false + }; let parent = self.parent()?; let parents = generic_parents(&parent); let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect(); let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect(); - Some(ContainerInfo { - is_in_tail, - is_const, - parent_loop, - ret_type: ty, - generic_param_lists, - where_clauses, - }) + Some(( + ContainerInfo { + is_const, + parent_loop, + ret_type: ty, + generic_param_lists, + where_clauses, + }, + contains_tail_expr, + )) } fn return_ty(&self, ctx: &AssistContext<'_>) -> Option { @@ -1633,7 +1648,7 @@ impl Function { fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option { let fun_ty = self.return_type(ctx); - let handler = if self.mods.is_in_tail { + let handler = if self.contains_tail_expr { FlowHandler::None } else { FlowHandler::from_ret_ty(self, &fun_ty) @@ -1707,7 +1722,7 @@ fn make_body( fun: &Function, ) -> ast::BlockExpr { let ret_ty = fun.return_type(ctx); - let handler = if fun.mods.is_in_tail { + let handler = if fun.contains_tail_expr { FlowHandler::None } else { FlowHandler::from_ret_ty(fun, &ret_ty) @@ -5582,6 +5597,153 @@ impl Struct where T: Into + Copy, U: Debug { fn $0fun_name(t: T, v: V) -> i32 where T: Into + Copy, V: Into { t.into() + v.into() } +"#, + ); + } + + #[test] + fn non_tail_expr_of_tail_expr_loop() { + check_assist( + extract_function, + r#" +pub fn f() { + loop { + $0if true { + continue; + }$0 + + if false { + break; + } + } +} +"#, + r#" +pub fn f() { + loop { + if let ControlFlow::Break(_) = fun_name() { + continue; + } + + if false { + break; + } + } +} + +fn $0fun_name() -> ControlFlow<()> { + if true { + return ControlFlow::Break(()); + } + ControlFlow::Continue(()) +} +"#, + ); + } + + #[test] + fn non_tail_expr_of_tail_if_block() { + // FIXME: double semicolon + check_assist( + extract_function, + r#" +//- minicore: option, try +impl core::ops::Try for Option { + type Output = T; + type Residual = Option; +} +impl core::ops::FromResidual for Option {} + +fn f() -> Option<()> { + if true { + let a = $0if true { + Some(())? + } else { + () + }$0; + Some(a) + } else { + None + } +} +"#, + r#" +impl core::ops::Try for Option { + type Output = T; + type Residual = Option; +} +impl core::ops::FromResidual for Option {} + +fn f() -> Option<()> { + if true { + let a = fun_name()?;; + Some(a) + } else { + None + } +} + +fn $0fun_name() -> Option<()> { + Some(if true { + Some(())? + } else { + () + }) +} +"#, + ); + } + + #[test] + fn tail_expr_of_tail_block_nested() { + check_assist( + extract_function, + r#" +//- minicore: option, try +impl core::ops::Try for Option { + type Output = T; + type Residual = Option; +} +impl core::ops::FromResidual for Option {} + +fn f() -> Option<()> { + if true { + $0{ + let a = if true { + Some(())? + } else { + () + }; + Some(a) + }$0 + } else { + None + } +} +"#, + r#" +impl core::ops::Try for Option { + type Output = T; + type Residual = Option; +} +impl core::ops::FromResidual for Option {} + +fn f() -> Option<()> { + if true { + fun_name()? + } else { + None + } +} + +fn $0fun_name() -> Option<()> { + let a = if true { + Some(())? + } else { + () + }; + Some(a) +} "#, ); }