Make incorrect case diagnostic work inside of functions

This commit is contained in:
Igor Aleksanov 2020-10-04 07:39:35 +03:00
parent 9ec1741b65
commit b42562b5de
4 changed files with 280 additions and 33 deletions

View file

@ -95,6 +95,12 @@ impl ItemScope {
self.impls.iter().copied()
}
pub fn values(
&self,
) -> impl Iterator<Item = (ModuleDefId, Visibility)> + ExactSizeIterator + '_ {
self.values.values().copied()
}
pub fn visibility_of(&self, def: ModuleDefId) -> Option<Visibility> {
self.name_of(ItemInNs::Types(def))
.or_else(|| self.name_of(ItemInNs::Values(def)))

View file

@ -281,7 +281,7 @@ impl Diagnostic for IncorrectCase {
fn message(&self) -> String {
format!(
"{} `{}` should have a {} name, e.g. `{}`",
"{} `{}` should have {} name, e.g. `{}`",
self.ident_type,
self.ident_text,
self.expected_case.to_string(),
@ -339,6 +339,8 @@ mod tests {
let impl_data = self.impl_data(impl_id);
for item in impl_data.items.iter() {
if let AssocItemId::FunctionId(f) = item {
let mut sink = DiagnosticSinkBuilder::new().build(&mut cb);
validate_module_item(self, ModuleDefId::FunctionId(*f), &mut sink);
fns.push(*f)
}
}

View file

@ -5,23 +5,13 @@
//! - enum fields (e.g. `enum Foo { Variant { field: u8 } }`)
//! - function/method arguments (e.g. `fn foo(arg: u8)`)
// TODO: Temporary, to not see warnings until module is somewhat complete.
// If you see these lines in the pull request, feel free to call me stupid :P.
#![allow(dead_code, unused_imports, unused_variables)]
mod str_helpers;
use std::sync::Arc;
use hir_def::{
adt::VariantData,
body::Body,
db::DefDatabase,
expr::{Expr, ExprId, UnaryOp},
item_tree::ItemTreeNode,
resolver::{resolver_for_expr, ResolveValueResult, ValueNs},
expr::{Pat, PatId},
src::HasSource,
AdtId, EnumId, FunctionId, Lookup, ModuleDefId, StructId,
AdtId, ConstId, EnumId, FunctionId, Lookup, ModuleDefId, StaticId, StructId,
};
use hir_expand::{
diagnostics::DiagnosticSink,
@ -35,8 +25,6 @@ use syntax::{
use crate::{
db::HirDatabase,
diagnostics::{decl_check::str_helpers::*, CaseType, IncorrectCase},
lower::CallableDefId,
ApplicationTy, InferenceResult, Ty, TypeCtor,
};
pub(super) struct DeclValidator<'a, 'b: 'a> {
@ -64,12 +52,25 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
match self.owner {
ModuleDefId::FunctionId(func) => self.validate_func(db, func),
ModuleDefId::AdtId(adt) => self.validate_adt(db, adt),
ModuleDefId::ConstId(const_id) => self.validate_const(db, const_id),
ModuleDefId::StaticId(static_id) => self.validate_static(db, static_id),
_ => return,
}
}
fn validate_adt(&mut self, db: &dyn HirDatabase, adt: AdtId) {
match adt {
AdtId::StructId(struct_id) => self.validate_struct(db, struct_id),
AdtId::EnumId(enum_id) => self.validate_enum(db, enum_id),
AdtId::UnionId(_) => {
// Unions aren't yet supported by this validator.
}
}
}
fn validate_func(&mut self, db: &dyn HirDatabase, func: FunctionId) {
let data = db.function_data(func);
let body = db.body(func.into());
// 1. Check the function name.
let function_name = data.name.to_string();
@ -87,11 +88,18 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
// 2. Check the param names.
let mut fn_param_replacements = Vec::new();
for param_name in data.param_names.iter().cloned().filter_map(|i| i) {
for pat_id in body.params.iter().cloned() {
let pat = &body[pat_id];
let param_name = match pat {
Pat::Bind { name, .. } => name,
_ => continue,
};
let name = param_name.to_string();
if let Some(new_name) = to_lower_snake_case(&name) {
let replacement = Replacement {
current_name: param_name,
current_name: param_name.clone(),
suggested_text: new_name,
expected_case: CaseType::LowerSnakeCase,
};
@ -99,13 +107,45 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
}
}
// 3. If there is at least one element to spawn a warning on, go to the source map and generate a warning.
// 3. Check the patterns inside the function body.
let mut pats_replacements = Vec::new();
for (pat_idx, pat) in body.pats.iter() {
if body.params.contains(&pat_idx) {
// We aren't interested in function parameters, we've processed them above.
continue;
}
let bind_name = match pat {
Pat::Bind { name, .. } => name,
_ => continue,
};
let name = bind_name.to_string();
if let Some(new_name) = to_lower_snake_case(&name) {
let replacement = Replacement {
current_name: bind_name.clone(),
suggested_text: new_name,
expected_case: CaseType::LowerSnakeCase,
};
pats_replacements.push((pat_idx, replacement));
}
}
// 4. If there is at least one element to spawn a warning on, go to the source map and generate a warning.
self.create_incorrect_case_diagnostic_for_func(
func,
db,
fn_name_replacement,
fn_param_replacements,
)
);
self.create_incorrect_case_diagnostic_for_variables(func, db, pats_replacements);
// 5. Recursively validate inner scope items, such as static variables and constants.
for (item_id, _) in body.item_scope.values() {
let mut validator = DeclValidator::new(item_id, self.sink);
validator.validate_item(db);
}
}
/// Given the information about incorrect names in the function declaration, looks up into the source code
@ -125,6 +165,7 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
let fn_loc = func.lookup(db.upcast());
let fn_src = fn_loc.source(db.upcast());
// 1. Diagnostic for function name.
if let Some(replacement) = fn_name_replacement {
let ast_ptr = if let Some(name) = fn_src.value.name() {
name
@ -150,6 +191,7 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
self.sink.push(diagnostic);
}
// 2. Diagnostics for function params.
let fn_params_list = match fn_src.value.param_list() {
Some(params) => params,
None => {
@ -197,12 +239,38 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
}
}
fn validate_adt(&mut self, db: &dyn HirDatabase, adt: AdtId) {
match adt {
AdtId::StructId(struct_id) => self.validate_struct(db, struct_id),
AdtId::EnumId(enum_id) => self.validate_enum(db, enum_id),
AdtId::UnionId(_) => {
// Unions aren't yet supported by this validator.
/// Given the information about incorrect variable names, looks up into the source code
/// for exact locations and adds diagnostics into the sink.
fn create_incorrect_case_diagnostic_for_variables(
&mut self,
func: FunctionId,
db: &dyn HirDatabase,
pats_replacements: Vec<(PatId, Replacement)>,
) {
// XXX: only look at source_map if we do have missing fields
if pats_replacements.is_empty() {
return;
}
let (_, source_map) = db.body_with_source_map(func.into());
for (id, replacement) in pats_replacements {
if let Ok(source_ptr) = source_map.pat_syntax(id) {
if let Some(expr) = source_ptr.value.as_ref().left() {
let root = source_ptr.file_syntax(db.upcast());
if let ast::Pat::IdentPat(ident_pat) = expr.to_node(&root) {
let diagnostic = IncorrectCase {
file: source_ptr.file_id,
ident_type: "Variable".to_string(),
ident: AstPtr::new(&ident_pat).into(),
expected_case: replacement.expected_case,
ident_text: replacement.current_name.to_string(),
suggested_text: replacement.suggested_text,
};
self.sink.push(diagnostic);
}
}
}
}
}
@ -246,7 +314,7 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
db,
struct_name_replacement,
struct_fields_replacements,
)
);
}
/// Given the information about incorrect names in the struct declaration, looks up into the source code
@ -464,6 +532,86 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
self.sink.push(diagnostic);
}
}
fn validate_const(&mut self, db: &dyn HirDatabase, const_id: ConstId) {
let data = db.const_data(const_id);
let name = match &data.name {
Some(name) => name,
None => return,
};
let const_name = name.to_string();
let replacement = if let Some(new_name) = to_upper_snake_case(&const_name) {
Replacement {
current_name: name.clone(),
suggested_text: new_name,
expected_case: CaseType::UpperSnakeCase,
}
} else {
// Nothing to do here.
return;
};
let const_loc = const_id.lookup(db.upcast());
let const_src = const_loc.source(db.upcast());
let ast_ptr = match const_src.value.name() {
Some(name) => name,
None => return,
};
let diagnostic = IncorrectCase {
file: const_src.file_id,
ident_type: "Constant".to_string(),
ident: AstPtr::new(&ast_ptr).into(),
expected_case: replacement.expected_case,
ident_text: replacement.current_name.to_string(),
suggested_text: replacement.suggested_text,
};
self.sink.push(diagnostic);
}
fn validate_static(&mut self, db: &dyn HirDatabase, static_id: StaticId) {
let data = db.static_data(static_id);
let name = match &data.name {
Some(name) => name,
None => return,
};
let static_name = name.to_string();
let replacement = if let Some(new_name) = to_upper_snake_case(&static_name) {
Replacement {
current_name: name.clone(),
suggested_text: new_name,
expected_case: CaseType::UpperSnakeCase,
}
} else {
// Nothing to do here.
return;
};
let static_loc = static_id.lookup(db.upcast());
let static_src = static_loc.source(db.upcast());
let ast_ptr = match static_src.value.name() {
Some(name) => name,
None => return,
};
let diagnostic = IncorrectCase {
file: static_src.file_id,
ident_type: "Static variable".to_string(),
ident: AstPtr::new(&ast_ptr).into(),
expected_case: replacement.expected_case,
ident_text: replacement.current_name.to_string(),
suggested_text: replacement.suggested_text,
};
self.sink.push(diagnostic);
}
}
fn names_equal(left: Option<ast::Name>, right: &Name) -> bool {
@ -491,7 +639,7 @@ mod tests {
check_diagnostics(
r#"
fn NonSnakeCaseName() {}
// ^^^^^^^^^^^^^^^^ Function `NonSnakeCaseName` should have a snake_case name, e.g. `non_snake_case_name`
// ^^^^^^^^^^^^^^^^ Function `NonSnakeCaseName` should have snake_case name, e.g. `non_snake_case_name`
"#,
);
}
@ -501,10 +649,24 @@ fn NonSnakeCaseName() {}
check_diagnostics(
r#"
fn foo(SomeParam: u8) {}
// ^^^^^^^^^ Argument `SomeParam` should have a snake_case name, e.g. `some_param`
// ^^^^^^^^^ Argument `SomeParam` should have snake_case name, e.g. `some_param`
fn foo2(ok_param: &str, CAPS_PARAM: u8) {}
// ^^^^^^^^^^ Argument `CAPS_PARAM` should have a snake_case name, e.g. `caps_param`
// ^^^^^^^^^^ Argument `CAPS_PARAM` should have snake_case name, e.g. `caps_param`
"#,
);
}
#[test]
fn incorrect_variable_names() {
check_diagnostics(
r#"
fn foo() {
let SOME_VALUE = 10;
// ^^^^^^^^^^ Variable `SOME_VALUE` should have a snake_case name, e.g. `some_value`
let AnotherValue = 20;
// ^^^^^^^^^^^^ Variable `AnotherValue` should have snake_case name, e.g. `another_value`
}
"#,
);
}
@ -514,7 +676,7 @@ fn foo2(ok_param: &str, CAPS_PARAM: u8) {}
check_diagnostics(
r#"
struct non_camel_case_name {}
// ^^^^^^^^^^^^^^^^^^^ Structure `non_camel_case_name` should have a CamelCase name, e.g. `NonCamelCaseName`
// ^^^^^^^^^^^^^^^^^^^ Structure `non_camel_case_name` should have CamelCase name, e.g. `NonCamelCaseName`
"#,
);
}
@ -524,7 +686,7 @@ struct non_camel_case_name {}
check_diagnostics(
r#"
struct SomeStruct { SomeField: u8 }
// ^^^^^^^^^ Field `SomeField` should have a snake_case name, e.g. `some_field`
// ^^^^^^^^^ Field `SomeField` should have snake_case name, e.g. `some_field`
"#,
);
}
@ -534,7 +696,7 @@ struct SomeStruct { SomeField: u8 }
check_diagnostics(
r#"
enum some_enum { Val(u8) }
// ^^^^^^^^^ Enum `some_enum` should have a CamelCase name, e.g. `SomeEnum`
// ^^^^^^^^^ Enum `some_enum` should have CamelCase name, e.g. `SomeEnum`
"#,
);
}
@ -544,7 +706,58 @@ enum some_enum { Val(u8) }
check_diagnostics(
r#"
enum SomeEnum { SOME_VARIANT(u8) }
// ^^^^^^^^^^^^ Variant `SOME_VARIANT` should have a CamelCase name, e.g. `SomeVariant`
// ^^^^^^^^^^^^ Variant `SOME_VARIANT` should have CamelCase name, e.g. `SomeVariant`
"#,
);
}
#[test]
fn incorrect_const_name() {
check_diagnostics(
r#"
const some_weird_const: u8 = 10;
// ^^^^^^^^^^^^^^^^ Constant `some_weird_const` should have UPPER_SNAKE_CASE name, e.g. `SOME_WEIRD_CONST`
fn func() {
const someConstInFunc: &str = "hi there";
// ^^^^^^^^^^^^^^^ Constant `someConstInFunc` should have UPPER_SNAKE_CASE name, e.g. `SOME_CONST_IN_FUNC`
}
"#,
);
}
#[test]
fn incorrect_static_name() {
check_diagnostics(
r#"
static some_weird_const: u8 = 10;
// ^^^^^^^^^^^^^^^^ Static variable `some_weird_const` should have UPPER_SNAKE_CASE name, e.g. `SOME_WEIRD_CONST`
fn func() {
static someConstInFunc: &str = "hi there";
// ^^^^^^^^^^^^^^^ Static variable `someConstInFunc` should have UPPER_SNAKE_CASE name, e.g. `SOME_CONST_IN_FUNC`
}
"#,
);
}
#[test]
fn fn_inside_impl_struct() {
check_diagnostics(
r#"
struct someStruct;
// ^^^^^^^^^^ Structure `someStruct` should have CamelCase name, e.g. `SomeStruct`
impl someStruct {
fn SomeFunc(&self) {
// ^^^^^^^^ Function `SomeFunc` should have snake_case name, e.g. `some_func`
static someConstInFunc: &str = "hi there";
// ^^^^^^^^^^^^^^^ Static variable `someConstInFunc` should have UPPER_SNAKE_CASE name, e.g. `SOME_CONST_IN_FUNC`
let WHY_VAR_IS_CAPS = 10;
// ^^^^^^^^^^^^^^^ Variable `WHY_VAR_IS_CAPS` should have snake_case name, e.g. `why_var_is_caps`
}
}
"#,
);
}

View file

@ -877,6 +877,32 @@ pub fn SomeFn<|>(val: u8) -> u8 {
pub fn some_fn(val: u8) -> u8 {
if val != 0 { some_fn(val - 1) } else { val }
}
"#,
);
check_fixes(
r#"
fn some_fn() {
let whatAWeird_Formatting<|> = 10;
another_func(whatAWeird_Formatting);
}
"#,
r#"
fn some_fn() {
let what_a_weird_formatting = 10;
another_func(what_a_weird_formatting);
}
"#,
);
}
#[test]
fn test_uppercase_const_no_diagnostics() {
check_no_diagnostics(
r#"
fn foo() {
const ANOTHER_ITEM<|>: &str = "some_item";
}
"#,
);
}