Simplify ast_transform

This commit is contained in:
Aleksey Kladov 2020-10-02 20:52:48 +02:00
parent 673e1ddb9a
commit 3290bb4112

View file

@ -5,12 +5,13 @@ use hir::{HirDisplay, PathResolution, SemanticsScope};
use syntax::{
algo::SyntaxRewriter,
ast::{self, AstNode},
SyntaxNode,
};
pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: N) -> N {
SyntaxRewriter::from_fn(|element| match element {
syntax::SyntaxElement::Node(n) => {
let replacement = transformer.get_substitution(&n)?;
let replacement = transformer.get_substitution(&n, transformer)?;
Some(replacement.into())
}
_ => None,
@ -47,32 +48,35 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: N) -> N {
/// We'd want to somehow express this concept simpler, but so far nobody got to
/// simplifying this!
pub trait AstTransform<'a> {
fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode>;
fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode>;
fn chain_before(self, other: Box<dyn AstTransform<'a> + 'a>) -> Box<dyn AstTransform<'a> + 'a>;
fn or<T: AstTransform<'a> + 'a>(self, other: T) -> Box<dyn AstTransform<'a> + 'a>
where
Self: Sized + 'a,
{
self.chain_before(Box::new(other))
Box::new(Or(Box::new(self), Box::new(other)))
}
}
struct NullTransformer;
struct Or<'a>(Box<dyn AstTransform<'a> + 'a>, Box<dyn AstTransform<'a> + 'a>);
impl<'a> AstTransform<'a> for NullTransformer {
fn get_substitution(&self, _node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> {
None
}
fn chain_before(self, other: Box<dyn AstTransform<'a> + 'a>) -> Box<dyn AstTransform<'a> + 'a> {
other
impl<'a> AstTransform<'a> for Or<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
self.0.get_substitution(node, recur).or_else(|| self.1.get_substitution(node, recur))
}
}
pub struct SubstituteTypeParams<'a> {
source_scope: &'a SemanticsScope<'a>,
substs: FxHashMap<hir::TypeParam, ast::Type>,
previous: Box<dyn AstTransform<'a> + 'a>,
}
impl<'a> SubstituteTypeParams<'a> {
@ -111,11 +115,7 @@ impl<'a> SubstituteTypeParams<'a> {
}
})
.collect();
return SubstituteTypeParams {
source_scope,
substs: substs_by_param,
previous: Box::new(NullTransformer),
};
return SubstituteTypeParams { source_scope, substs: substs_by_param };
// FIXME: It would probably be nicer if we could get this via HIR (i.e. get the
// trait ref, and then go from the types in the substs back to the syntax).
@ -140,7 +140,14 @@ impl<'a> SubstituteTypeParams<'a> {
Some(result)
}
}
fn get_substitution_inner(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> {
}
impl<'a> AstTransform<'a> for SubstituteTypeParams<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
_recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
let type_ref = ast::Type::cast(node.clone())?;
let path = match &type_ref {
ast::Type::PathType(path_type) => path_type.path()?,
@ -154,27 +161,23 @@ impl<'a> SubstituteTypeParams<'a> {
}
}
impl<'a> AstTransform<'a> for SubstituteTypeParams<'a> {
fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> {
self.get_substitution_inner(node).or_else(|| self.previous.get_substitution(node))
}
fn chain_before(self, other: Box<dyn AstTransform<'a> + 'a>) -> Box<dyn AstTransform<'a> + 'a> {
Box::new(SubstituteTypeParams { previous: other, ..self })
}
}
pub struct QualifyPaths<'a> {
target_scope: &'a SemanticsScope<'a>,
source_scope: &'a SemanticsScope<'a>,
previous: Box<dyn AstTransform<'a> + 'a>,
}
impl<'a> QualifyPaths<'a> {
pub fn new(target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>) -> Self {
Self { target_scope, source_scope, previous: Box::new(NullTransformer) }
Self { target_scope, source_scope }
}
}
fn get_substitution_inner(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> {
impl<'a> AstTransform<'a> for QualifyPaths<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
// FIXME handle value ns?
let from = self.target_scope.module()?;
let p = ast::Path::cast(node.clone())?;
@ -191,7 +194,7 @@ impl<'a> QualifyPaths<'a> {
let type_args = p
.segment()
.and_then(|s| s.generic_arg_list())
.map(|arg_list| apply(self, arg_list));
.map(|arg_list| apply(recur, arg_list));
if let Some(type_args) = type_args {
let last_segment = path.segment().unwrap();
path = path.with_segment(last_segment.with_generic_args(type_args))
@ -208,15 +211,6 @@ impl<'a> QualifyPaths<'a> {
}
}
impl<'a> AstTransform<'a> for QualifyPaths<'a> {
fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> {
self.get_substitution_inner(node).or_else(|| self.previous.get_substitution(node))
}
fn chain_before(self, other: Box<dyn AstTransform<'a> + 'a>) -> Box<dyn AstTransform<'a> + 'a> {
Box::new(QualifyPaths { previous: other, ..self })
}
}
pub(crate) fn path_to_ast(path: hir::ModPath) -> ast::Path {
let parse = ast::SourceFile::parse(&path.to_string());
parse