diff --git a/crates/ide_assists/src/handlers/fill_match_arms.rs b/crates/ide_assists/src/handlers/fill_match_arms.rs index 58b00105057..97435f02113 100644 --- a/crates/ide_assists/src/handlers/fill_match_arms.rs +++ b/crates/ide_assists/src/handlers/fill_match_arms.rs @@ -1,4 +1,4 @@ -use std::iter; +use std::iter::{self, Peekable}; use either::Either; use hir::{Adt, HasSource, ModuleDef, Semantics}; @@ -63,50 +63,61 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option< let module = ctx.sema.scope(expr.syntax()).module()?; - let missing_arms: Vec = if let Some(enum_def) = resolve_enum_def(&ctx.sema, &expr) { + let mut missing_pats: Peekable>> = if let Some(enum_def) = + resolve_enum_def(&ctx.sema, &expr) + { let variants = enum_def.variants(ctx.db()); - let mut variants = variants + let missing_pats = variants .into_iter() .filter_map(|variant| build_pat(ctx.db(), module, variant)) - .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat)) - .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block())) - .map(|it| it.clone_for_update()) - .collect::>(); - if Some(enum_def) - == FamousDefs(&ctx.sema, Some(module.krate())) - .core_option_Option() - .map(|x| lift_enum(x)) + .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat)); + + let missing_pats: Box> = if Some(enum_def) + == FamousDefs(&ctx.sema, Some(module.krate())).core_option_Option().map(lift_enum) { // Match `Some` variant first. cov_mark::hit!(option_order); - variants.reverse() - } - variants + Box::new(missing_pats.rev()) + } else { + Box::new(missing_pats) + }; + missing_pats.peekable() } else if let Some(enum_defs) = resolve_tuple_of_enum_def(&ctx.sema, &expr) { + let mut n_arms = 1; + let variants_of_enums: Vec> = enum_defs + .into_iter() + .map(|enum_def| enum_def.variants(ctx.db())) + .inspect(|variants| n_arms *= variants.len()) + .collect(); + // When calculating the match arms for a tuple of enums, we want // to create a match arm for each possible combination of enum // values. The `multi_cartesian_product` method transforms // Vec> into Vec<(EnumVariant, .., EnumVariant)> // where each tuple represents a proposed match arm. - enum_defs + + // A number of arms grows very fast on even a small tuple of large enums. + // We skip the assist beyond an arbitrary threshold. + if n_arms > 256 { + return None; + } + let missing_pats = variants_of_enums .into_iter() - .map(|enum_def| enum_def.variants(ctx.db())) .multi_cartesian_product() + .inspect(|_| cov_mark::hit!(fill_match_arms_lazy_computation)) .map(|variants| { let patterns = variants.into_iter().filter_map(|variant| build_pat(ctx.db(), module, variant)); ast::Pat::from(make::tuple_pat(patterns)) }) - .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat)) - .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block())) - .map(|it| it.clone_for_update()) - .collect() + .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat)); + (Box::new(missing_pats) as Box>).peekable() } else { return None; }; - if missing_arms.is_empty() { + if missing_pats.peek().is_none() { return None; } @@ -117,6 +128,9 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option< target, |builder| { let new_match_arm_list = match_arm_list.clone_for_update(); + let missing_arms = missing_pats + .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block())) + .map(|it| it.clone_for_update()); let catch_all_arm = new_match_arm_list .arms() @@ -167,13 +181,13 @@ fn does_pat_match_variant(pat: &Pat, var: &Pat) -> bool { } } -#[derive(Eq, PartialEq, Clone)] +#[derive(Eq, PartialEq, Clone, Copy)] enum ExtendedEnum { Bool, Enum(hir::Enum), } -#[derive(Eq, PartialEq, Clone)] +#[derive(Eq, PartialEq, Clone, Copy)] enum ExtendedVariant { True, False, @@ -185,7 +199,7 @@ fn lift_enum(e: hir::Enum) -> ExtendedEnum { } impl ExtendedEnum { - fn variants(&self, db: &RootDatabase) -> Vec { + fn variants(self, db: &RootDatabase) -> Vec { match self { ExtendedEnum::Enum(e) => { e.variants(db).into_iter().map(|x| ExtendedVariant::Variant(x)).collect::>() @@ -266,7 +280,9 @@ fn build_pat(db: &RootDatabase, module: hir::Module, var: ExtendedVariant) -> Op mod tests { use ide_db::helpers::FamousDefs; - use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target}; + use crate::tests::{ + check_assist, check_assist_not_applicable, check_assist_target, check_assist_unresolved, + }; use super::fill_match_arms; @@ -1045,4 +1061,19 @@ fn foo(t: Test) { }"#, ); } + + #[test] + fn lazy_computation() { + // Computing a single missing arm is enough to determine applicability of the assist. + cov_mark::check_count!(fill_match_arms_lazy_computation, 1); + check_assist_unresolved( + fill_match_arms, + r#" +enum A { One, Two, } +fn foo(tuple: (A, A)) { + match $0tuple {}; +} +"#, + ); + } } diff --git a/crates/ide_assists/src/tests.rs b/crates/ide_assists/src/tests.rs index 1739302bf09..6a9231e0787 100644 --- a/crates/ide_assists/src/tests.rs +++ b/crates/ide_assists/src/tests.rs @@ -65,6 +65,12 @@ pub(crate) fn check_assist_not_applicable(assist: Handler, ra_fixture: &str) { check(assist, ra_fixture, ExpectedResult::NotApplicable, None); } +/// Check assist in unresolved state. Useful to check assists for lazy computation. +#[track_caller] +pub(crate) fn check_assist_unresolved(assist: Handler, ra_fixture: &str) { + check(assist, ra_fixture, ExpectedResult::Unresolved, None); +} + #[track_caller] fn check_doc_test(assist_id: &str, before: &str, after: &str) { let after = trim_indent(after); @@ -101,6 +107,7 @@ fn check_doc_test(assist_id: &str, before: &str, after: &str) { enum ExpectedResult<'a> { NotApplicable, + Unresolved, After(&'a str), Target(&'a str), } @@ -115,7 +122,11 @@ fn check(handler: Handler, before: &str, expected: ExpectedResult, assist_label: let sema = Semantics::new(&db); let config = TEST_CONFIG; let ctx = AssistContext::new(sema, &config, frange); - let mut acc = Assists::new(&ctx, AssistResolveStrategy::All); + let resolve = match expected { + ExpectedResult::Unresolved => AssistResolveStrategy::None, + _ => AssistResolveStrategy::All, + }; + let mut acc = Assists::new(&ctx, resolve); handler(&mut acc, &ctx); let mut res = acc.finish(); @@ -163,8 +174,14 @@ fn check(handler: Handler, before: &str, expected: ExpectedResult, assist_label: let range = assist.target; assert_eq_text!(&text_without_caret[range], target); } + (Some(assist), ExpectedResult::Unresolved) => assert!( + assist.source_change.is_none(), + "unresolved assist should not contain source changes" + ), (Some(_), ExpectedResult::NotApplicable) => panic!("assist should not be applicable!"), - (None, ExpectedResult::After(_)) | (None, ExpectedResult::Target(_)) => { + (None, ExpectedResult::After(_)) + | (None, ExpectedResult::Target(_)) + | (None, ExpectedResult::Unresolved) => { panic!("code action is not applicable") } (None, ExpectedResult::NotApplicable) => (),