feat: add multiple getters mode in generate_getter

This commit adds two modes to generate_getter action.
First, the plain old working on single fields.
Second, working on a selected range of fields.
This commit is contained in:
feniljain 2022-10-08 00:54:57 +05:30
parent 97b357e41b
commit 5bff6c55de
8 changed files with 290 additions and 67 deletions

View file

@ -51,14 +51,14 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<'
Some(field) => {
let field_name = field.name()?;
let field_ty = field.ty()?;
(format!("{field_name}"), field_ty, field.syntax().text_range())
(field_name.to_string(), field_ty, field.syntax().text_range())
}
None => {
let field = ctx.find_node_at_offset::<ast::TupleField>()?;
let field_list = ctx.find_node_at_offset::<ast::TupleFieldList>()?;
let field_list_index = field_list.fields().position(|it| it == field)?;
let field_ty = field.ty()?;
(format!("{field_list_index}"), field_ty, field.syntax().text_range())
(field_list_index.to_string(), field_ty, field.syntax().text_range())
}
};
@ -77,13 +77,11 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<'
for method in methods {
let adt = ast::Adt::Struct(strukt.clone());
let name = method.name(ctx.db()).to_string();
let impl_def = find_struct_impl(ctx, &adt, &name).flatten();
let method_name = method.name(ctx.db());
let impl_def = find_struct_impl(ctx, &adt, &[name]).flatten();
acc.add_group(
&GroupLabel("Generate delegate methods…".to_owned()),
AssistId("generate_delegate_methods", AssistKind::Generate),
format!("Generate delegate for `{field_name}.{method_name}()`"),
format!("Generate delegate for `{}.{}()`", field_name, method.name(ctx.db())),
target,
|builder| {
// Create the function
@ -158,7 +156,7 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<'
}
None => {
let offset = strukt.syntax().text_range().end();
let snippet = format!("\n\n{impl_def}");
let snippet = format!("\n\n{}", impl_def.syntax());
builder.insert(offset, snippet);
}
}

View file

@ -52,7 +52,7 @@ pub(crate) fn generate_enum_is_method(acc: &mut Assists, ctx: &AssistContext<'_>
let fn_name = format!("is_{}", &to_lower_snake_case(&variant_name.text()));
// Return early if we've found an existing new fn
let impl_def = find_struct_impl(ctx, &parent_enum, &fn_name)?;
let impl_def = find_struct_impl(ctx, &parent_enum, &[fn_name.clone()])?;
let target = variant.syntax().text_range();
acc.add_group(

View file

@ -147,7 +147,7 @@ fn generate_enum_projection_method(
let fn_name = format!("{}_{}", fn_name_prefix, &to_lower_snake_case(&variant_name.text()));
// Return early if we've found an existing new fn
let impl_def = find_struct_impl(ctx, &parent_enum, &fn_name)?;
let impl_def = find_struct_impl(ctx, &parent_enum, &[fn_name.clone()])?;
let target = variant.syntax().text_range();
acc.add_group(

View file

@ -198,7 +198,7 @@ fn get_adt_source(
let file = ctx.sema.parse(range.file_id);
let adt_source =
ctx.sema.find_node_at_offset_with_macros(file.syntax(), range.range.start())?;
find_struct_impl(ctx, &adt_source, fn_name).map(|impl_| (impl_, range.file_id))
find_struct_impl(ctx, &adt_source, &[fn_name.to_string()]).map(|impl_| (impl_, range.file_id))
}
struct FunctionTemplate {

View file

@ -1,6 +1,9 @@
use ide_db::famous_defs::FamousDefs;
use stdx::{format_to, to_lower_snake_case};
use syntax::ast::{self, AstNode, HasName, HasVisibility};
use syntax::{
ast::{self, AstNode, HasName, HasVisibility},
TextRange,
};
use crate::{
utils::{convert_reference_type, find_impl_block_end, find_struct_impl, generate_impl_text},
@ -72,86 +75,257 @@ pub(crate) fn generate_getter_mut(acc: &mut Assists, ctx: &AssistContext<'_>) ->
generate_getter_impl(acc, ctx, true)
}
#[derive(Clone, Debug)]
struct RecordFieldInfo {
field_name: syntax::ast::Name,
field_ty: syntax::ast::Type,
fn_name: String,
target: TextRange,
}
struct GetterInfo {
impl_def: Option<ast::Impl>,
strukt: ast::Struct,
mutable: bool,
}
pub(crate) fn generate_getter_impl(
acc: &mut Assists,
ctx: &AssistContext<'_>,
mutable: bool,
) -> Option<()> {
// This if condition denotes two modes this assist can work in:
// - First is acting upon selection of record fields
// - Next is acting upon a single record field
//
// This is the only part where implementation diverges a bit,
// subsequent code is generic for both of these modes
let (strukt, info_of_record_fields, fn_names) = if !ctx.has_empty_selection() {
// Selection Mode
let node = ctx.covering_element();
let node = match node {
syntax::NodeOrToken::Node(n) => n,
syntax::NodeOrToken::Token(t) => t.parent()?,
};
let parent_struct = node.ancestors().find_map(ast::Struct::cast)?;
let (info_of_record_fields, field_names) =
extract_and_parse_record_fields(&parent_struct, ctx.selection_trimmed(), mutable)?;
(parent_struct, info_of_record_fields, field_names)
} else {
// Single Record Field mode
let strukt = ctx.find_node_at_offset::<ast::Struct>()?;
let field = ctx.find_node_at_offset::<ast::RecordField>()?;
let field_name = field.name()?;
let field_ty = field.ty()?;
let record_field_info = parse_record_field(field, mutable)?;
// Return early if we've found an existing fn
let mut fn_name = to_lower_snake_case(&field_name.to_string());
if mutable {
format_to!(fn_name, "_mut");
let fn_name = record_field_info.fn_name.clone();
(strukt, vec![record_field_info], vec![fn_name])
};
// No record fields to do work on :(
if info_of_record_fields.len() == 0 {
return None;
}
let impl_def = find_struct_impl(ctx, &ast::Adt::Struct(strukt.clone()), fn_name.as_str())?;
let impl_def = find_struct_impl(ctx, &ast::Adt::Struct(strukt.clone()), &fn_names)?;
let (id, label) = if mutable {
("generate_getter_mut", "Generate a mut getter method")
} else {
("generate_getter", "Generate a getter method")
};
let target = field.syntax().text_range();
// Computing collective text range of all record fields in selected region
let target: TextRange = info_of_record_fields
.iter()
.map(|record_field_info| record_field_info.target)
.reduce(|acc, target| acc.cover(target))?;
let getter_info = GetterInfo { impl_def, strukt, mutable };
acc.add_group(
&GroupLabel("Generate getter/setter".to_owned()),
AssistId(id, AssistKind::Generate),
label,
target,
|builder| {
let record_fields_count = info_of_record_fields.len();
let mut buf = String::with_capacity(512);
if impl_def.is_some() {
// Check if an impl exists
if let Some(impl_def) = &getter_info.impl_def {
// Check if impl is empty
if let Some(assoc_item_list) = impl_def.assoc_item_list() {
if assoc_item_list.assoc_items().next().is_some() {
// If not empty then only insert a new line
buf.push('\n');
}
}
}
let vis = strukt.visibility().map_or(String::new(), |v| format!("{v} "));
let (ty, body) = if mutable {
(format!("&mut {field_ty}"), format!("&mut self.{field_name}"))
for (i, record_field_info) in info_of_record_fields.iter().enumerate() {
// this buf inserts a newline at the end of a getter
// automatically, if one wants to add one more newline
// for separating it from other assoc items, that needs
// to be handled spearately
let mut getter_buf =
generate_getter_from_info(ctx, &getter_info, &record_field_info);
// Insert `$0` only for last getter we generate
if i == record_fields_count - 1 {
getter_buf = getter_buf.replacen("fn ", "fn $0", 1);
}
// For first element we do not merge with '\n', as
// that can be inserted by impl_def check defined
// above, for other cases which are:
//
// - impl exists but it empty, here we would ideally
// not want to keep newline between impl <struct> {
// and fn <fn-name>() { line
//
// - next if impl itself does not exist, in this
// case we ourselves generate a new impl and that
// again ends up with the same reasoning as above
// for not keeping newline
if i == 0 {
buf = buf + &getter_buf;
} else {
buf = buf + "\n" + &getter_buf;
}
// We don't insert a new line at the end of
// last getter as it will end up in the end
// of an impl where we would not like to keep
// getter and end of impl ( i.e. `}` ) with an
// extra line for no reason
if i < record_fields_count - 1 {
buf = buf + "\n";
}
}
let start_offset = getter_info
.impl_def
.as_ref()
.and_then(|impl_def| find_impl_block_end(impl_def.to_owned(), &mut buf))
.unwrap_or_else(|| {
buf = generate_impl_text(&ast::Adt::Struct(getter_info.strukt.clone()), &buf);
getter_info.strukt.syntax().text_range().end()
});
match ctx.config.snippet_cap {
Some(cap) => builder.insert_snippet(cap, start_offset, buf),
None => builder.insert(start_offset, buf),
}
},
)
}
fn generate_getter_from_info(
ctx: &AssistContext<'_>,
info: &GetterInfo,
record_field_info: &RecordFieldInfo,
) -> String {
let mut buf = String::with_capacity(512);
let vis = info.strukt.visibility().map_or(String::new(), |v| format!("{} ", v));
let (ty, body) = if info.mutable {
(
format!("&mut {}", record_field_info.field_ty),
format!("&mut self.{}", record_field_info.field_name),
)
} else {
(|| {
let krate = ctx.sema.scope(field_ty.syntax())?.krate();
let krate = ctx.sema.scope(record_field_info.field_ty.syntax())?.krate();
let famous_defs = &FamousDefs(&ctx.sema, krate);
ctx.sema
.resolve_type(&field_ty)
.resolve_type(&record_field_info.field_ty)
.and_then(|ty| convert_reference_type(ty, ctx.db(), famous_defs))
.map(|conversion| {
cov_mark::hit!(convert_reference_type);
(
conversion.convert_type(ctx.db()),
conversion.getter(field_name.to_string()),
conversion.getter(record_field_info.field_name.to_string()),
)
})
})()
.unwrap_or_else(|| (format!("&{field_ty}"), format!("&self.{field_name}")))
.unwrap_or_else(|| {
(
format!("&{}", record_field_info.field_ty),
format!("&self.{}", record_field_info.field_name),
)
})
};
let mut_ = mutable.then(|| "mut ").unwrap_or_default();
format_to!(
buf,
" {vis}fn {fn_name}(&{mut_}self) -> {ty} {{
{body}
}}"
" {}fn {}(&{}self) -> {} {{
{}
}}",
vis,
record_field_info.fn_name,
info.mutable.then(|| "mut ").unwrap_or_default(),
ty,
body,
);
let start_offset = impl_def
.and_then(|impl_def| find_impl_block_end(impl_def, &mut buf))
.unwrap_or_else(|| {
buf = generate_impl_text(&ast::Adt::Struct(strukt.clone()), &buf);
strukt.syntax().text_range().end()
});
buf
}
match ctx.config.snippet_cap {
Some(cap) => {
builder.insert_snippet(cap, start_offset, buf.replacen("fn ", "fn $0", 1))
fn extract_and_parse_record_fields(
node: &ast::Struct,
selection_range: TextRange,
mutable: bool,
) -> Option<(Vec<RecordFieldInfo>, Vec<String>)> {
let mut field_names: Vec<String> = vec![];
let field_list = node.field_list()?;
match field_list {
ast::FieldList::RecordFieldList(ele) => {
let info_of_record_fields_in_selection = ele
.fields()
.filter_map(|record_field| {
if selection_range.contains_range(record_field.syntax().text_range()) {
let record_field_info = parse_record_field(record_field, mutable)?;
field_names.push(record_field_info.fn_name.clone());
return Some(record_field_info);
}
None => builder.insert(start_offset, buf),
None
})
.collect::<Vec<RecordFieldInfo>>();
if info_of_record_fields_in_selection.len() == 0 {
return None;
}
},
)
Some((info_of_record_fields_in_selection, field_names))
}
ast::FieldList::TupleFieldList(_) => {
return None;
}
}
}
fn parse_record_field(record_field: ast::RecordField, mutable: bool) -> Option<RecordFieldInfo> {
let field_name = record_field.name()?;
let field_ty = record_field.ty()?;
let mut fn_name = to_lower_snake_case(&field_name.to_string());
if mutable {
format_to!(fn_name, "_mut");
}
let target = record_field.syntax().text_range();
Some(RecordFieldInfo { field_name, field_ty, fn_name, target })
}
#[cfg(test)]
@ -481,6 +655,55 @@ impl Context {
fn $0data(&self) -> Result<&bool, &i32> {
self.data.as_ref()
}
}
"#,
);
}
#[test]
fn test_generate_multiple_getters_from_selection() {
check_assist(
generate_getter,
r#"
struct Context {
$0data: Data,
count: usize,$0
}
"#,
r#"
struct Context {
data: Data,
count: usize,
}
impl Context {
fn data(&self) -> &Data {
&self.data
}
fn $0count(&self) -> &usize {
&self.count
}
}
"#,
);
}
#[test]
fn test_generate_multiple_getters_from_selection_one_already_exists() {
// As impl for one of the fields already exist, skip it
check_assist_not_applicable(
generate_getter,
r#"
struct Context {
$0data: Data,
count: usize,$0
}
impl Context {
fn data(&self) -> &Data {
&self.data
}
}
"#,
);

View file

@ -39,7 +39,8 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
};
// Return early if we've found an existing new fn
let impl_def = find_struct_impl(ctx, &ast::Adt::Struct(strukt.clone()), "new")?;
let impl_def =
find_struct_impl(ctx, &ast::Adt::Struct(strukt.clone()), &[String::from("new")])?;
let current_module = ctx.sema.scope(strukt.syntax())?.module();

View file

@ -36,11 +36,8 @@ pub(crate) fn generate_setter(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt
// Return early if we've found an existing fn
let fn_name = to_lower_snake_case(&field_name.to_string());
let impl_def = find_struct_impl(
ctx,
&ast::Adt::Struct(strukt.clone()),
format!("set_{fn_name}").as_str(),
)?;
let impl_def =
find_struct_impl(ctx, &ast::Adt::Struct(strukt.clone()), &[format!("set_{fn_name}")])?;
let target = field.syntax().text_range();
acc.add_group(

View file

@ -331,10 +331,14 @@ fn calc_depth(pat: &ast::Pat, depth: usize) -> usize {
// FIXME: change the new fn checking to a more semantic approach when that's more
// viable (e.g. we process proc macros, etc)
// FIXME: this partially overlaps with `find_impl_block_*`
/// `find_struct_impl` looks for impl of a struct, but this also has additional feature
/// where it takes a list of function names and check if they exist inside impl_, if
/// even one match is found, it returns None
pub(crate) fn find_struct_impl(
ctx: &AssistContext<'_>,
adt: &ast::Adt,
name: &str,
names: &[String],
) -> Option<Option<ast::Impl>> {
let db = ctx.db();
let module = adt.syntax().parent()?;
@ -362,7 +366,7 @@ pub(crate) fn find_struct_impl(
});
if let Some(ref impl_blk) = block {
if has_fn(impl_blk, name) {
if has_any_fn(impl_blk, names) {
return None;
}
}
@ -370,12 +374,12 @@ pub(crate) fn find_struct_impl(
Some(block)
}
fn has_fn(imp: &ast::Impl, rhs_name: &str) -> bool {
fn has_any_fn(imp: &ast::Impl, names: &[String]) -> bool {
if let Some(il) = imp.assoc_item_list() {
for item in il.assoc_items() {
if let ast::AssocItem::Fn(f) = item {
if let Some(name) = f.name() {
if name.text().eq_ignore_ascii_case(rhs_name) {
if names.iter().any(|n| n.eq_ignore_ascii_case(&name.text())) {
return true;
}
}