Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions crates/hir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2824,6 +2824,58 @@ impl Function {
}
}

/// Returns the text ranges of `return` keywords in this function's body,
/// excluding those inside closures or async blocks.
pub fn return_points(self, db: &dyn HirDatabase) -> Vec<TextRange> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should still IMO return Vec<ast::ReturnExpr>, without upmapping them, as a general-purpose API (that maybe will be useful elsewhere as well). Do the upmapping in this assist.

let func_id: FunctionId = match self.id {
AnyFunctionId::FunctionId(id) => id,
_ => return vec![],
};
let (body, source_map) = Body::with_source_map(db, func_id.into());
let fn_file_id = func_id.loc(db).id.file_id;

fn collect_returns(
body: &Body,
source_map: &hir_def::expr_store::ExpressionStoreSourceMap,
db: &dyn HirDatabase,
fn_file_id: HirFileId,
expr_id: ExprId,
acc: &mut Vec<TextRange>,
) {
match &body[expr_id] {
Expr::Closure { .. } | Expr::Async { .. } | Expr::Const(_) => return,
Expr::Return { .. } => {
if let Ok(source) = source_map.expr_syntax(expr_id)
&& let Some(ret_expr) = source.value.cast::<ast::ReturnExpr>()
{
let root = db.parse_or_expand(source.file_id);
let node = ret_expr.to_node(&root);
if let Some(return_token) = node.return_token() {
let token_range = return_token.text_range();
if source.file_id == fn_file_id {
acc.push(token_range);
} else if let Some((file_range, _)) =
InFile::new(source.file_id, token_range)
.original_node_file_range_opt(db)
&& file_range.file_id == fn_file_id
{
acc.push(file_range.range);
}
}
}
}
_ => {}
}
body.walk_child_exprs(expr_id, |child| {
collect_returns(body, source_map, db, fn_file_id, child, acc);
});
}

let mut returns = vec![];
collect_returns(body, source_map, db, fn_file_id, body.root_expr(), &mut returns);
returns
}

pub fn as_proc_macro(self, db: &dyn HirDatabase) -> Option<Macro> {
let AnyFunctionId::FunctionId(id) = self.id else {
return None;
Expand Down
298 changes: 295 additions & 3 deletions crates/ide-assists/src/handlers/inline_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ fn inline(
} else {
fn_body.clone_for_update()
};

// Check whether the function has return expressions that need to be transformed
// into labeled breaks. Async functions don't need this because the body gets wrapped
// in `async move { ... }` which preserves return semantics.
let is_async_fn = function.is_async(sema.db);
let has_early_returns = !is_async_fn && !function.return_points(sema.db).is_empty();
let usages_for_locals = |local| {
Definition::Local(local)
.usages(sema)
Expand Down Expand Up @@ -550,7 +556,21 @@ fn inline(
}
}

let is_async_fn = function.is_async(sema.db);
// Transform return expressions into break expressions with a labeled block.
// `return_ranges` (from HIR analysis) tells us whether the function has returns
// that need transformation. The actual replacement walks the cloned body tree,
// skipping returns inside closures, async blocks, and macro call arguments.
if has_early_returns {
let label = make::label(make::lifetime("'inline")).clone_for_update();

replace_returns_with_breaks(&body);

// Insert label as a child of BlockExpr, before the StmtList
if let Some(stmt_list) = body.stmt_list() {
ted::insert(ted::Position::before(stmt_list.syntax()), label.syntax());
}
}

if is_async_fn {
cov_mark::hit!(inline_call_async_fn);
body = make::async_move_block_expr(body.statements(), body.tail_expr()).clone_for_update();
Expand All @@ -574,12 +594,13 @@ fn inline(
};
body.reindent_to(original_indentation);

let has_label = body.label().is_some();
let no_stmts = body.statements().next().is_none();
match body.tail_expr() {
Some(expr) if matches!(expr, ast::Expr::ClosureExpr(_)) && no_stmts => {
Some(expr) if matches!(expr, ast::Expr::ClosureExpr(_)) && no_stmts && !has_label => {
make::expr_paren(expr).clone_for_update().into()
}
Some(expr) if !is_async_fn && no_stmts => expr,
Some(expr) if !is_async_fn && no_stmts && !has_label => expr,
_ => match node
.syntax()
.parent()
Expand All @@ -605,6 +626,33 @@ fn is_in_type_path(self_tok: &syntax::SyntaxToken) -> bool {
.is_some()
}

/// Replaces `return` / `return expr` with `break 'inline` / `break 'inline expr`
/// for all `ReturnExpr` nodes that belong directly to the function body
/// (i.e., not nested inside closures or async blocks).
fn replace_returns_with_breaks(body: &ast::BlockExpr) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't use return_points().

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace_returns_with_breaks operates on the cloned body (a different syntax tree) so the ReturnExpr nodes from return_points() dont correspond to it. how do you want to use return_points()?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you map the returns on the original body, the ranges will stay the same. Also as I said use text replacement.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

fn walk(node: &syntax::SyntaxNode) {
for child in node.children() {
if ast::ClosureExpr::can_cast(child.kind()) {
continue;
}
if let Some(block) = ast::BlockExpr::cast(child.clone())
&& (block.async_token().is_some() || block.const_token().is_some())
{
continue;
}
if let Some(return_expr) = ast::ReturnExpr::cast(child.clone()) {
let break_expr =
make::expr_break(Some(make::lifetime("'inline")), return_expr.expr())
.clone_for_update();
ted::replace(return_expr.syntax(), break_expr.syntax());
} else {
walk(&child);
}
}
}
walk(body.syntax());
}

fn path_expr_as_record_field(usage: &PathExpr) -> Option<ast::RecordExprField> {
let path = usage.path()?;
let name_ref = path.as_single_name_ref()?;
Expand Down Expand Up @@ -1883,6 +1931,250 @@ macro_rules! bar {
fn f() {
bar!(foo$0());
}
"#,
);
}

#[test]
fn inline_call_with_early_return() {
check_assist(
inline_call,
r#"
fn early_return() -> Option<()> {
return None;
}
fn main() -> Option<()> {
if early_return$0().is_none() {
return Some(());
}
None
}
"#,
r#"
fn early_return() -> Option<()> {
return None;
}
fn main() -> Option<()> {
if 'inline: {
break 'inline None;
}.is_none() {
return Some(());
}
None
}
"#,
);
}

#[test]
fn inline_call_with_early_return_and_tail_expr() {
check_assist(
inline_call,
r#"
fn foo(x: i32) -> i32 {
if x < 0 {
return -1;
}
x * 2
}
fn main() {
let result = foo$0(5);
}
"#,
r#"
fn foo(x: i32) -> i32 {
if x < 0 {
return -1;
}
x * 2
}
fn main() {
let result = 'inline: {
let x = 5;
if x < 0 {
break 'inline -1;
}
x * 2
};
}
"#,
);
}

#[test]
fn inline_call_with_return_in_closure_not_transformed() {
check_assist(
inline_call,
r#"
fn foo() -> fn() -> i32 {
|| return 42
}
fn main() {
let f = foo$0();
}
"#,
r#"
fn foo() -> fn() -> i32 {
|| return 42
}
fn main() {
let f = (|| return 42);
}
"#,
);
}

#[test]
fn inline_call_with_multiple_early_returns() {
check_assist(
inline_call,
r#"
fn classify(x: i32) -> &'static str {
if x < 0 {
return "negative";
}
if x == 0 {
return "zero";
}
"positive"
}
fn main() {
let s = classify$0(1);
}
"#,
r#"
fn classify(x: i32) -> &'static str {
if x < 0 {
return "negative";
}
if x == 0 {
return "zero";
}
"positive"
}
fn main() {
let s = 'inline: {
let x = 1;
if x < 0 {
break 'inline "negative";
}
if x == 0 {
break 'inline "zero";
}
"positive"
};
}
"#,
);
}

#[test]
fn inline_call_with_return_in_async_block_not_transformed() {
check_assist(
inline_call,
r#"
fn foo() -> i32 {
let _ = async { return 42; };
return 0;
}
fn main() {
let x = foo$0();
}
"#,
r#"
fn foo() -> i32 {
let _ = async { return 42; };
return 0;
}
fn main() {
let x = 'inline: {
let _ = async { return 42; };
break 'inline 0;
};
}
"#,
);
}

#[test]
fn inline_call_with_return_no_expr() {
check_assist(
inline_call,
r#"
fn greet(is_loud: bool) {
if !is_loud {
return;
}
println!("HELLO!");
}
fn main() {
greet$0(true);
}
"#,
r#"
fn greet(is_loud: bool) {
if !is_loud {
return;
}
println!("HELLO!");
}
fn main() {
'inline: {
if !true {
break 'inline;
}
println!("HELLO!");
};
}
"#,
);
}

#[test]
fn inline_into_callers_with_early_return() {
check_assist(
inline_into_callers,
r#"
fn ear$0ly(x: i32) -> i32 {
if x < 0 {
return -1;
}
x * 2
}
fn main() {
let a = early(1);
}
"#,
r#"

fn main() {
let a = 'inline: {
let x = 1;
if x < 0 {
break 'inline -1;
}
x * 2
};
}
"#,
);
}

#[test]
fn inline_call_with_return_as_tail_expr() {
check_assist(
inline_call,
r#"
fn foo() -> i32 { return 42 }
fn main() {
let x = foo$0();
}
"#,
r#"
fn foo() -> i32 { return 42 }
fn main() {
let x = 'inline: { break 'inline 42 };
}
"#,
);
}
Expand Down
Loading