diff --git a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs index 82e0970cc4b..aeecac3c74f 100644 --- a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs @@ -16,7 +16,7 @@ use rustc_hash::FxHashSet; use syntax::{ ast::{ self, edit::IndentLevel, edit_in_place::Indent, make, AstNode, HasAttrs, HasGenericParams, - HasName, HasTypeBounds, HasVisibility, + HasName, HasVisibility, }, match_ast, ted::{self, Position}, @@ -106,7 +106,26 @@ pub(crate) fn extract_struct_from_enum_variant( } let indent = enum_ast.indent_level(); - let def = create_struct_def(variant_name.clone(), &variant, &field_list, &enum_ast); + let generic_params = enum_ast + .generic_param_list() + .map(|known_generics| extract_generic_params(&known_generics, &field_list)); + let generics = + generic_params.as_ref().filter(|generics| !generics.all_empty()).map(|generics| { + make::generic_param_list( + generics + .lifetimes + .iter() + .cloned() + .map(ast::GenericParam::LifetimeParam) + .chain(generics.types.iter().cloned().map(ast::GenericParam::TypeParam)) + .chain( + generics.consts.iter().cloned().map(ast::GenericParam::ConstParam), + ), + ) + .clone_for_update() + }); + let def = + create_struct_def(variant_name.clone(), &variant, &field_list, generics, &enum_ast); def.reindent_to(indent); let start_offset = &variant.parent_enum().syntax().clone(); @@ -118,7 +137,7 @@ pub(crate) fn extract_struct_from_enum_variant( ], ); - update_variant(&variant, enum_ast.generic_param_list()); + update_variant(&variant, generic_params); }, ) } @@ -159,10 +178,95 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va .any(|(name, _)| name.to_string() == variant_name.to_string()) } +struct ExtractedGenerics { + lifetimes: Vec, + types: Vec, + consts: Vec, +} + +impl ExtractedGenerics { + fn all_empty(&self) -> bool { + self.lifetimes.is_empty() && self.types.is_empty() && self.consts.is_empty() + } +} + +fn extract_generic_params( + known_generics: &ast::GenericParamList, + field_list: &Either, +) -> ExtractedGenerics { + let mut lifetimes = known_generics.lifetime_params().map(|x| (x, false)).collect_vec(); + let mut types = known_generics.type_params().map(|x| (x, false)).collect_vec(); + let mut consts = known_generics.const_params().map(|x| (x, false)).collect_vec(); + + match field_list { + Either::Left(field_list) => field_list + .fields() + .filter_map(|f| f.ty()) + .for_each(|ty| tag_generics_in_variant(&ty, &mut lifetimes, &mut types, &mut consts)), + Either::Right(field_list) => field_list + .fields() + .filter_map(|f| f.ty()) + .for_each(|ty| tag_generics_in_variant(&ty, &mut lifetimes, &mut types, &mut consts)), + } + + let lifetimes = lifetimes.into_iter().filter_map(|(x, present)| present.then(|| x)).collect(); + let types = types.into_iter().filter_map(|(x, present)| present.then(|| x)).collect(); + let consts = consts.into_iter().filter_map(|(x, present)| present.then(|| x)).collect(); + + ExtractedGenerics { lifetimes, types, consts } +} + +fn tag_generics_in_variant( + ty: &ast::Type, + lifetimes: &mut [(ast::LifetimeParam, bool)], + types: &mut [(ast::TypeParam, bool)], + consts: &mut [(ast::ConstParam, bool)], +) { + for token in + ty.syntax().preorder_with_tokens().filter_map(|node_or_token| match node_or_token { + syntax::WalkEvent::Enter(syntax::NodeOrToken::Token(token)) => Some(token), + _ => None, + }) + { + match token.kind() { + T![lifetime_ident] => { + for (lt, present) in lifetimes.iter_mut() { + if let Some(lt) = lt.lifetime() { + if lt.text().as_str() == token.text() { + *present = true; + break; + } + } + } + } + T![ident] => { + for (ty, present) in types.iter_mut() { + if let Some(name) = ty.name() { + if name.text().as_str() == token.text() { + *present = true; + break; + } + } + } + for (cnst, present) in consts.iter_mut() { + if let Some(name) = cnst.name() { + if name.text().as_str() == token.text() { + *present = true; + break; + } + } + } + } + _ => (), + } + } +} + fn create_struct_def( variant_name: ast::Name, variant: &ast::Variant, field_list: &Either, + generics: Option, enum_: &ast::Enum, ) -> ast::Struct { let enum_vis = enum_.visibility(); @@ -204,9 +308,7 @@ fn create_struct_def( field_list.reindent_to(IndentLevel::single()); - // FIXME: This uses all the generic params of the enum, but the variant might not use all of them. - let strukt = make::struct_(enum_vis, variant_name, enum_.generic_param_list(), field_list) - .clone_for_update(); + let strukt = make::struct_(enum_vis, variant_name, generics, field_list).clone_for_update(); // FIXME: Consider making this an actual function somewhere (like in `AttrsOwnerEdit`) after some deliberation let attrs_and_docs = |node: &SyntaxNode| { @@ -243,26 +345,24 @@ fn create_struct_def( strukt } -fn update_variant(variant: &ast::Variant, generic: Option) -> Option<()> { +fn update_variant(variant: &ast::Variant, generics: Option) -> Option<()> { let name = variant.name()?; - let ty = match generic { - // FIXME: This uses all the generic params of the enum, but the variant might not use all of them. - Some(gpl) => { - let gpl = gpl.clone_for_update(); - gpl.generic_params().for_each(|gp| { - let tbl = match gp { - ast::GenericParam::LifetimeParam(it) => it.type_bound_list(), - ast::GenericParam::TypeParam(it) => it.type_bound_list(), - ast::GenericParam::ConstParam(_) => return, - }; - if let Some(tbl) = tbl { - tbl.remove(); - } - }); - make::ty(&format!("{}<{}>", name.text(), gpl.generic_params().join(", "))) - } - None => make::ty(&name.text()), - }; + let ty = generics + .filter(|generics| !generics.all_empty()) + .map(|generics| { + let generics_str = [ + generics.lifetimes.iter().filter_map(|lt| lt.lifetime()).join(", "), + generics.types.iter().filter_map(|ty| ty.name()).join(", "), + generics.consts.iter().filter_map(|cnst| cnst.name()).join(", "), + ] + .iter() + .filter(|s| !s.is_empty()) + .join(", "); + + make::ty(&format!("{}<{}>", &name.text(), &generics_str)) + }) + .unwrap_or_else(|| make::ty(&name.text())); + let tuple_field = make::tuple_field(None, ty); let replacement = make::variant( name, @@ -902,4 +1002,92 @@ enum A { $0One(u8, u32) } fn test_extract_not_applicable_no_field_named() { check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0None {} }"); } + + #[test] + fn test_extract_struct_only_copies_needed_generics() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum X<'a, 'b, 'x> { + $0A { a: &'a &'x mut () }, + B { b: &'b () }, + C { c: () }, +} +"#, + r#" +struct A<'a, 'x>{ a: &'a &'x mut () } + +enum X<'a, 'b, 'x> { + A(A<'a, 'x>), + B { b: &'b () }, + C { c: () }, +} +"#, + ); + } + + #[test] + fn test_extract_struct_with_liftime_type_const() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum X<'b, T, V, const C: usize> { + $0A { a: T, b: X<'b>, c: [u8; C] }, + D { d: V }, +} +"#, + r#" +struct A<'b, T, const C: usize>{ a: T, b: X<'b>, c: [u8; C] } + +enum X<'b, T, V, const C: usize> { + A(A<'b, T, C>), + D { d: V }, +} +"#, + ); + } + + #[test] + fn test_extract_struct_without_generics() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum X<'a, 'b> { + A { a: &'a () }, + B { b: &'b () }, + $0C { c: () }, +} +"#, + r#" +struct C{ c: () } + +enum X<'a, 'b> { + A { a: &'a () }, + B { b: &'b () }, + C(C), +} +"#, + ); + } + + #[test] + fn test_extract_struct_keeps_trait_bounds() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum En { + $0A { a: T }, + B { b: V }, +} +"#, + r#" +struct A{ a: T } + +enum En { + A(A), + B { b: V }, +} +"#, + ); + } }