SSR: Allow function calls to match method calls

This differs from how this used to work before I removed it in that:
a) It's only one direction. Function calls in the pattern can match
method calls in the code, but not the other way around.
b) We now check that the function call in the pattern resolves to the
same function as the method call in the code.

The lack of (b) was the reason I felt the need to remove the feature
before.
This commit is contained in:
David Lattimore 2020-07-24 20:53:48 +10:00
parent 8d09ab86ed
commit 3dac31fe80
6 changed files with 169 additions and 25 deletions

View file

@ -21,6 +21,9 @@ use ra_ssr::{MatchFinder, SsrError, SsrRule};
// replacement occurs. For example if our replacement template is `foo::Bar` and we match some
// code in the `foo` module, we'll insert just `Bar`.
//
// Method calls should generally be written in UFCS form. e.g. `foo::Bar::baz($s, $a)` will match
// `$s.baz($a)`, provided the method call `baz` resolves to the method `foo::Bar::baz`.
//
// Placeholders may be given constraints by writing them as `${<name>:<constraint1>:<constraint2>...}`.
//
// Supported constraints:

View file

@ -202,8 +202,12 @@ impl<'db> MatchFinder<'db> {
// For now we ignore rules that have a different kind than our node, otherwise
// we get lots of noise. If at some point we add support for restricting rules
// to a particular kind of thing (e.g. only match type references), then we can
// relax this.
if rule.pattern.node.kind() != node.kind() {
// relax this. We special-case expressions, since function calls can match
// method calls.
if rule.pattern.node.kind() != node.kind()
&& !(ast::Expr::can_cast(rule.pattern.node.kind())
&& ast::Expr::can_cast(node.kind()))
{
continue;
}
out.push(MatchDebugInfo {

View file

@ -189,10 +189,17 @@ impl<'db, 'sema> Matcher<'db, 'sema> {
}
return Ok(());
}
// Non-placeholders.
// We allow a UFCS call to match a method call, provided they resolve to the same function.
if let Some(pattern_function) = self.rule.pattern.ufcs_function_calls.get(pattern) {
if let (Some(pattern), Some(code)) =
(ast::CallExpr::cast(pattern.clone()), ast::MethodCallExpr::cast(code.clone()))
{
return self.attempt_match_ufcs(phase, &pattern, &code, *pattern_function);
}
}
if pattern.kind() != code.kind() {
fail_match!(
"Pattern had a `{}` ({:?}), code had `{}` ({:?})",
"Pattern had `{}` ({:?}), code had `{}` ({:?})",
pattern.text(),
pattern.kind(),
code.text(),
@ -514,6 +521,37 @@ impl<'db, 'sema> Matcher<'db, 'sema> {
Ok(())
}
fn attempt_match_ufcs(
&self,
phase: &mut Phase,
pattern: &ast::CallExpr,
code: &ast::MethodCallExpr,
pattern_function: hir::Function,
) -> Result<(), MatchFailed> {
use ast::ArgListOwner;
let code_resolved_function = self
.sema
.resolve_method_call(code)
.ok_or_else(|| match_error!("Failed to resolve method call"))?;
if pattern_function != code_resolved_function {
fail_match!("Method call resolved to a different function");
}
// Check arguments.
let mut pattern_args = pattern
.arg_list()
.ok_or_else(|| match_error!("Pattern function call has no args"))?
.args();
self.attempt_match_opt(phase, pattern_args.next(), code.expr())?;
let mut code_args =
code.arg_list().ok_or_else(|| match_error!("Code method call has no args"))?.args();
loop {
match (pattern_args.next(), code_args.next()) {
(None, None) => return Ok(()),
(p, c) => self.attempt_match_opt(phase, p, c)?,
}
}
}
fn get_placeholder(&self, element: &SyntaxElement) -> Option<&Placeholder> {
only_ident(element.clone()).and_then(|ident| self.rule.get_placeholder(&ident))
}

View file

@ -18,10 +18,12 @@ pub(crate) struct ResolvedPattern {
pub(crate) node: SyntaxNode,
// Paths in `node` that we've resolved.
pub(crate) resolved_paths: FxHashMap<SyntaxNode, ResolvedPath>,
pub(crate) ufcs_function_calls: FxHashMap<SyntaxNode, hir::Function>,
}
pub(crate) struct ResolvedPath {
pub(crate) resolution: hir::PathResolution,
/// The depth of the ast::Path that was resolved within the pattern.
pub(crate) depth: u32,
}
@ -64,10 +66,26 @@ impl Resolver<'_, '_> {
fn resolve_pattern_tree(&self, pattern: SyntaxNode) -> Result<ResolvedPattern, SsrError> {
let mut resolved_paths = FxHashMap::default();
self.resolve(pattern.clone(), 0, &mut resolved_paths)?;
let ufcs_function_calls = resolved_paths
.iter()
.filter_map(|(path_node, resolved)| {
if let Some(grandparent) = path_node.parent().and_then(|parent| parent.parent()) {
if grandparent.kind() == SyntaxKind::CALL_EXPR {
if let hir::PathResolution::AssocItem(hir::AssocItem::Function(function)) =
&resolved.resolution
{
return Some((grandparent, *function));
}
}
}
None
})
.collect();
Ok(ResolvedPattern {
node: pattern,
resolved_paths,
placeholders_by_stand_in: self.placeholders_by_stand_in.clone(),
ufcs_function_calls,
})
}

View file

@ -46,35 +46,58 @@ impl<'db> MatchFinder<'db> {
usage_cache: &mut UsageCache,
matches_out: &mut Vec<Match>,
) {
if let Some(first_path) = pick_path_for_usages(pattern) {
let definition: Definition = first_path.resolution.clone().into();
if let Some(resolved_path) = pick_path_for_usages(pattern) {
let definition: Definition = resolved_path.resolution.clone().into();
for reference in self.find_usages(usage_cache, definition) {
let file = self.sema.parse(reference.file_range.file_id);
if let Some(path) = self.sema.find_node_at_offset_with_descend::<ast::Path>(
file.syntax(),
reference.file_range.range.start(),
) {
if let Some(node_to_match) = self
.sema
.ancestors_with_macros(path.syntax().clone())
.skip(first_path.depth as usize)
.next()
if let Some(node_to_match) = self.find_node_to_match(resolved_path, reference) {
if !is_search_permitted_ancestors(&node_to_match) {
mark::hit!(use_declaration_with_braces);
continue;
}
if let Ok(m) =
matching::get_match(false, rule, &node_to_match, &None, &self.sema)
{
if !is_search_permitted_ancestors(&node_to_match) {
mark::hit!(use_declaration_with_braces);
continue;
}
if let Ok(m) =
matching::get_match(false, rule, &node_to_match, &None, &self.sema)
{
matches_out.push(m);
}
matches_out.push(m);
}
}
}
}
}
fn find_node_to_match(
&self,
resolved_path: &ResolvedPath,
reference: &Reference,
) -> Option<SyntaxNode> {
let file = self.sema.parse(reference.file_range.file_id);
let depth = resolved_path.depth as usize;
let offset = reference.file_range.range.start();
if let Some(path) =
self.sema.find_node_at_offset_with_descend::<ast::Path>(file.syntax(), offset)
{
self.sema.ancestors_with_macros(path.syntax().clone()).skip(depth).next()
} else if let Some(path) =
self.sema.find_node_at_offset_with_descend::<ast::MethodCallExpr>(file.syntax(), offset)
{
// If the pattern contained a path and we found a reference to that path that wasn't
// itself a path, but was a method call, then we need to adjust how far up to try
// matching by how deep the path was within a CallExpr. The structure would have been
// CallExpr, PathExpr, Path - i.e. a depth offset of 2. We don't need to check if the
// path was part of a CallExpr because if it wasn't then all that will happen is we'll
// fail to match, which is the desired behavior.
const PATH_DEPTH_IN_CALL_EXPR: usize = 2;
if depth < PATH_DEPTH_IN_CALL_EXPR {
return None;
}
self.sema
.ancestors_with_macros(path.syntax().clone())
.skip(depth - PATH_DEPTH_IN_CALL_EXPR)
.next()
} else {
None
}
}
fn find_usages<'a>(
&self,
usage_cache: &'a mut UsageCache,

View file

@ -827,3 +827,61 @@ fn use_declaration_with_braces() {
"]],
)
}
#[test]
fn ufcs_matches_method_call() {
let code = r#"
struct Foo {}
impl Foo {
fn new(_: i32) -> Foo { Foo {} }
fn do_stuff(&self, _: i32) {}
}
struct Bar {}
impl Bar {
fn new(_: i32) -> Bar { Bar {} }
fn do_stuff(&self, v: i32) {}
}
fn main() {
let b = Bar {};
let f = Foo {};
b.do_stuff(1);
f.do_stuff(2);
Foo::new(4).do_stuff(3);
// Too many / too few args - should never match
f.do_stuff(2, 10);
f.do_stuff();
}
"#;
assert_matches("Foo::do_stuff($a, $b)", code, &["f.do_stuff(2)", "Foo::new(4).do_stuff(3)"]);
// The arguments needs special handling in the case of a function call matching a method call
// and the first argument is different.
assert_matches("Foo::do_stuff($a, 2)", code, &["f.do_stuff(2)"]);
assert_matches("Foo::do_stuff(Foo::new(4), $b)", code, &["Foo::new(4).do_stuff(3)"]);
assert_ssr_transform(
"Foo::do_stuff(Foo::new($a), $b) ==>> Bar::new($b).do_stuff($a)",
code,
expect![[r#"
struct Foo {}
impl Foo {
fn new(_: i32) -> Foo { Foo {} }
fn do_stuff(&self, _: i32) {}
}
struct Bar {}
impl Bar {
fn new(_: i32) -> Bar { Bar {} }
fn do_stuff(&self, v: i32) {}
}
fn main() {
let b = Bar {};
let f = Foo {};
b.do_stuff(1);
f.do_stuff(2);
Bar::new(3).do_stuff(4);
// Too many / too few args - should never match
f.do_stuff(2, 10);
f.do_stuff();
}
"#]],
);
}