diff --git a/crates/assists/src/handlers/inline_function.rs b/crates/assists/src/handlers/inline_function.rs new file mode 100644 index 00000000000..2edf252b7f3 --- /dev/null +++ b/crates/assists/src/handlers/inline_function.rs @@ -0,0 +1,202 @@ +use ast::make; +use hir::{HasSource, PathResolution}; +use syntax::{ + ast::{self, edit::AstNodeEdit, ArgListOwner}, + AstNode, +}; +use test_utils::mark; + +use crate::{ + assist_context::{AssistContext, Assists}, + AssistId, AssistKind, +}; + +// Assist: inline_function +// +// Inlines a function body. +// +// ``` +// fn add(a: u32, b: u32) -> u32 { a + b } +// fn main() { +// let x = add<|>(1, 2); +// } +// ``` +// -> +// ``` +// fn add(a: u32, b: u32) -> u32 { a + b } +// fn main() { +// let x = { +// let a = 1; +// let b = 2; +// a + b +// }; +// } +// ``` +pub(crate) fn inline_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let path_expr: ast::PathExpr = ctx.find_node_at_offset()?; + let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?; + let path = path_expr.path()?; + + let function = match ctx.sema.resolve_path(&path)? { + PathResolution::Def(hir::ModuleDef::Function(f)) => f, + _ => return None, + }; + + let function_source = function.source(ctx.db())?; + let arguments: Vec<_> = call.arg_list()?.args().collect(); + let parameters = function_parameter_patterns(&function_source.value)?; + + if arguments.len() != parameters.len() { + // Can't inline the function because they've passed the wrong number of + // arguments to this function + mark::hit!(inline_function_incorrect_number_of_arguments); + return None; + } + + let new_bindings = parameters.into_iter().zip(arguments); + + let body = function_source.value.body()?; + + acc.add( + AssistId("inline_function", AssistKind::RefactorInline), + format!("Inline `{}`", path), + call.syntax().text_range(), + |builder| { + let mut statements: Vec = Vec::new(); + + for (pattern, value) in new_bindings { + statements.push(make::let_stmt(pattern, Some(value)).into()); + } + + statements.extend(body.statements()); + + let original_indentation = call.indent_level(); + let replacement = make::block_expr(statements, body.expr()) + .reset_indent() + .indent(original_indentation); + + builder.replace_ast(ast::Expr::CallExpr(call), ast::Expr::BlockExpr(replacement)); + }, + ) +} + +fn function_parameter_patterns(value: &ast::Fn) -> Option> { + let mut patterns = Vec::new(); + + for param in value.param_list()?.params() { + let pattern = param.pat()?; + patterns.push(pattern); + } + + Some(patterns) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn no_args_or_return_value_gets_inlined_without_block() { + check_assist( + inline_function, + r#" +fn foo() { println!("Hello, World!"); } +fn main() { + fo<|>o(); +} +"#, + r#" +fn foo() { println!("Hello, World!"); } +fn main() { + { + println!("Hello, World!"); + }; +} +"#, + ); + } + + #[test] + fn args_with_side_effects() { + check_assist( + inline_function, + r#" +fn foo(name: String) { println!("Hello, {}!", name); } +fn main() { + foo<|>(String::from("Michael")); +} +"#, + r#" +fn foo(name: String) { println!("Hello, {}!", name); } +fn main() { + { + let name = String::from("Michael"); + println!("Hello, {}!", name); + }; +} +"#, + ); + } + + #[test] + fn method_inlining_isnt_supported() { + check_assist_not_applicable( + inline_function, + r" +struct Foo; +impl Foo { fn bar(&self) {} } + +fn main() { Foo.bar<|>(); } +", + ); + } + + #[test] + fn not_applicable_when_incorrect_number_of_parameters_are_provided() { + mark::check!(inline_function_incorrect_number_of_arguments); + check_assist_not_applicable( + inline_function, + r#" +fn add(a: u32, b: u32) -> u32 { a + b } +fn main() { let x = add<|>(42); } +"#, + ); + } + + #[test] + fn function_with_multiple_statements() { + check_assist( + inline_function, + r#" +fn foo(a: u32, b: u32) -> u32 { + let x = a + b; + let y = x - b; + x * y +} + +fn main() { + let x = foo<|>(1, 2); +} +"#, + r#" +fn foo(a: u32, b: u32) -> u32 { + let x = a + b; + let y = x - b; + x * y +} + +fn main() { + let x = { + let a = 1; + let b = 2; + let x = a + b; + let y = x - b; + x * y + }; +} +"#, + ); + } +} diff --git a/crates/assists/src/lib.rs b/crates/assists/src/lib.rs index 01baa65fe02..9c2a9573574 100644 --- a/crates/assists/src/lib.rs +++ b/crates/assists/src/lib.rs @@ -131,6 +131,7 @@ mod handlers { mod generate_impl; mod generate_new; mod infer_function_return_type; + mod inline_function; mod inline_local_variable; mod introduce_named_lifetime; mod invert_if; @@ -183,6 +184,7 @@ mod handlers { generate_impl::generate_impl, generate_new::generate_new, infer_function_return_type::infer_function_return_type, + inline_function::inline_function, inline_local_variable::inline_local_variable, introduce_named_lifetime::introduce_named_lifetime, invert_if::invert_if, diff --git a/crates/assists/src/tests/generated.rs b/crates/assists/src/tests/generated.rs index 85e3c674297..b15352cf32c 100644 --- a/crates/assists/src/tests/generated.rs +++ b/crates/assists/src/tests/generated.rs @@ -530,6 +530,29 @@ fn foo() -> i32 { 42i32 } ) } +#[test] +fn doctest_inline_function() { + check_doc_test( + "inline_function", + r#####" +fn add(a: u32, b: u32) -> u32 { a + b } +fn main() { + let x = add<|>(1, 2); +} +"#####, + r#####" +fn add(a: u32, b: u32) -> u32 { a + b } +fn main() { + let x = { + let a = 1; + let b = 2; + a + b + }; +} +"#####, + ) +} + #[test] fn doctest_inline_local_variable() { check_doc_test(