init partialord

This commit is contained in:
Yoshua Wuyts 2021-08-16 10:07:19 +02:00
parent 1cca1fa5bf
commit bc6aee51b0

View file

@ -21,6 +21,7 @@ pub(crate) fn gen_trait_fn_body(
"Default" => gen_default_impl(adt, func),
"Hash" => gen_hash_impl(adt, func),
"PartialEq" => gen_partial_eq(adt, func),
"PartialOrd" => gen_partial_ord(adt, func),
_ => None,
}
}
@ -572,6 +573,171 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
Some(())
}
fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
match expr {
Some(expr) => Some(make::expr_op(ast::BinOp::BooleanAnd, expr, cmp)),
None => Some(cmp),
}
}
fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
let pat = make::ext::simple_ident_pat(make::name(&pat_name));
let name_ref = make::name_ref(field_name);
make::record_pat_field(name_ref, pat.into())
}
fn gen_record_pat(record_name: ast::Path, fields: Vec<ast::RecordPatField>) -> ast::RecordPat {
let list = make::record_pat_field_list(fields);
make::record_pat_with_fields(record_name, list)
}
fn gen_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
make::ext::path_from_idents(["Self", &variant.name()?.to_string()])
}
fn gen_tuple_field(field_name: &String) -> ast::Pat {
ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
}
// FIXME: return `None` if the trait carries a generic type; we can only
// generate this code `Self` for the time being.
let body = match adt {
// `Hash` cannot be derived for unions, so no default impl can be provided.
ast::Adt::Union(_) => return None,
ast::Adt::Enum(enum_) => {
// => std::mem::discriminant(self) == std::mem::discriminant(other)
let lhs_name = make::expr_path(make::ext::ident_path("self"));
let lhs = make::expr_call(make_discriminant()?, make::arg_list(Some(lhs_name.clone())));
let rhs_name = make::expr_path(make::ext::ident_path("other"));
let rhs = make::expr_call(make_discriminant()?, make::arg_list(Some(rhs_name.clone())));
let eq_check = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
let mut case_count = 0;
let mut arms = vec![];
for variant in enum_.variant_list()?.variants() {
case_count += 1;
match variant.field_list() {
// => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
Some(ast::FieldList::RecordFieldList(list)) => {
let mut expr = None;
let mut l_fields = vec![];
let mut r_fields = vec![];
for field in list.fields() {
let field_name = field.name()?.to_string();
let l_name = &format!("l_{}", field_name);
l_fields.push(gen_record_pat_field(&field_name, &l_name));
let r_name = &format!("r_{}", field_name);
r_fields.push(gen_record_pat_field(&field_name, &r_name));
let lhs = make::expr_path(make::ext::ident_path(l_name));
let rhs = make::expr_path(make::ext::ident_path(r_name));
let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
expr = gen_eq_chain(expr, cmp);
}
let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
let tuple = make::tuple_pat(vec![left.into(), right.into()]);
if let Some(expr) = expr {
arms.push(make::match_arm(Some(tuple.into()), None, expr));
}
}
Some(ast::FieldList::TupleFieldList(list)) => {
let mut expr = None;
let mut l_fields = vec![];
let mut r_fields = vec![];
for (i, _) in list.fields().enumerate() {
let field_name = format!("{}", i);
let l_name = format!("l{}", field_name);
l_fields.push(gen_tuple_field(&l_name));
let r_name = format!("r{}", field_name);
r_fields.push(gen_tuple_field(&r_name));
let lhs = make::expr_path(make::ext::ident_path(&l_name));
let rhs = make::expr_path(make::ext::ident_path(&r_name));
let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
expr = gen_eq_chain(expr, cmp);
}
let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
let tuple = make::tuple_pat(vec![left.into(), right.into()]);
if let Some(expr) = expr {
arms.push(make::match_arm(Some(tuple.into()), None, expr));
}
}
None => continue,
}
}
let expr = match arms.len() {
0 => eq_check,
_ => {
if case_count > arms.len() {
let lhs = make::wildcard_pat().into();
arms.push(make::match_arm(Some(lhs), None, eq_check));
}
let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
make::expr_match(match_target, list)
}
};
make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
}
ast::Adt::Struct(strukt) => match strukt.field_list() {
Some(ast::FieldList::RecordFieldList(field_list)) => {
let mut expr = None;
for field in field_list.fields() {
let lhs = make::expr_path(make::ext::ident_path("self"));
let lhs = make::expr_field(lhs, &field.name()?.to_string());
let rhs = make::expr_path(make::ext::ident_path("other"));
let rhs = make::expr_field(rhs, &field.name()?.to_string());
let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
expr = gen_eq_chain(expr, cmp);
}
make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let mut expr = None;
for (i, _) in field_list.fields().enumerate() {
let idx = format!("{}", i);
let lhs = make::expr_path(make::ext::ident_path("self"));
let lhs = make::expr_field(lhs, &idx);
let rhs = make::expr_path(make::ext::ident_path("other"));
let rhs = make::expr_field(rhs, &idx);
let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
expr = gen_eq_chain(expr, cmp);
}
make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
}
// No fields in the body means there's nothing to hash.
None => {
let expr = make::expr_literal("true").into();
make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
}
},
};
ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
Some(())
}
fn make_discriminant() -> Option<ast::Expr> {
Some(make::expr_path(make::ext::path_from_idents(["core", "mem", "discriminant"])?))
}