diff --git a/Cargo.toml b/Cargo.toml index a5eea2a..3b7c730 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,3 +6,12 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +clap = "4.3.17" +derive_builder = "0.12.0" +regex = "1.9.1" +tree-sitter = "0.20.10" +tree-sitter-grep = { git = "https://github.com/helixbass/tree-sitter-grep", rev = "3d4682c" } +tree-sitter-rust = "0.20.3" + +[[bin]] +name = "tree-sitter-lint" diff --git a/src/args.rs b/src/args.rs new file mode 100644 index 0000000..41acdde --- /dev/null +++ b/src/args.rs @@ -0,0 +1,4 @@ +use clap::Parser; + +#[derive(Parser)] +pub struct Args {} diff --git a/src/bin/tree-sitter-lint.rs b/src/bin/tree-sitter-lint.rs new file mode 100644 index 0000000..030f32f --- /dev/null +++ b/src/bin/tree-sitter-lint.rs @@ -0,0 +1,7 @@ +use clap::Parser; +use tree_sitter_lint::{run, Args}; + +fn main() { + let args = Args::parse(); + run(args); +} diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..32a0e0e --- /dev/null +++ b/src/context.rs @@ -0,0 +1,57 @@ +use std::{ + path::Path, + sync::atomic::{AtomicBool, Ordering}, +}; + +use tree_sitter::Language; + +use crate::{rule::ResolvedRule, violation::Violation}; + +pub struct Context { + pub language: Language, +} + +impl Context { + pub fn new(language: Language) -> Self { + Self { language } + } +} + +pub struct QueryMatchContext<'a> { + pub path: &'a Path, + pub file_contents: &'a [u8], + pub rule: &'a ResolvedRule<'a>, + reported_any_violations: &'a AtomicBool, +} + +impl<'a> QueryMatchContext<'a> { + pub fn new( + path: &'a Path, + file_contents: &'a [u8], + rule: &'a ResolvedRule, + reported_any_violations: &'a AtomicBool, + ) -> Self { + Self { + path, + file_contents, + rule, + reported_any_violations, + } + } + + pub fn report(&self, violation: Violation) { + self.reported_any_violations.store(true, Ordering::Relaxed); + print_violation(&violation, self); + } +} + +fn print_violation(violation: &Violation, query_match_context: &QueryMatchContext) { + println!( + "{:?}:{}:{} {} {}", + query_match_context.path, + violation.node.range().start_point.row + 1, + violation.node.range().start_point.column + 1, + violation.message, + query_match_context.rule.name, + ); +} diff --git a/src/lib.rs b/src/lib.rs index 7d12d9a..f68228b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,180 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right +mod args; +mod context; +mod rule; +mod violation; + +use std::{ + borrow::Cow, + process, + sync::atomic::{AtomicBool, Ordering}, +}; + +pub use args::Args; +use clap::Parser; +use context::QueryMatchContext; +use rule::{ResolvedRule, Rule, RuleBuilder, RuleListenerBuilder}; +use tree_sitter::Query; +use violation::ViolationBuilder; + +use crate::context::Context; + +#[macro_export] +macro_rules! regex { + ($re:expr $(,)?) => {{ + static RE: std::sync::OnceLock = std::sync::OnceLock::new(); + RE.get_or_init(|| regex::Regex::new($re).unwrap()) + }}; +} + +const CAPTURE_NAME_FOR_TREE_SITTER_GREP: &str = "_tree_sitter_lint_capture"; +const CAPTURE_NAME_FOR_TREE_SITTER_GREP_WITH_LEADING_AT: &str = "@_tree_sitter_lint_capture"; + +pub fn run(_args: Args) { + let language = tree_sitter_rust::language(); + let context = Context::new(language); + let resolved_rules = get_rules() + .into_iter() + .map(|rule| rule.resolve(&context)) + .collect::>(); + let aggregated_queries = AggregatedQueries::new(&resolved_rules, &context); + let tree_sitter_grep_args = tree_sitter_grep::Args::parse_from([ + "tree_sitter_grep", + "-q", + &aggregated_queries.query_text, + "-l", + "rust", + "--capture", + CAPTURE_NAME_FOR_TREE_SITTER_GREP, + ]); + let reported_any_violations = AtomicBool::new(false); + tree_sitter_grep::run_with_callback( + tree_sitter_grep_args, + |capture_info, file_contents, path| { + let (rule_index, rule_listener_index) = + aggregated_queries.pattern_index_lookup[capture_info.pattern_index]; + let rule = &resolved_rules[rule_index]; + let listener = &rule.listeners[rule_listener_index]; + (listener.on_query_match)( + &capture_info.node, + &QueryMatchContext::new(path, file_contents, rule, &reported_any_violations), + ); + }, + ) + .unwrap(); + if reported_any_violations.load(Ordering::Relaxed) { + process::exit(1); + } else { + process::exit(0); + } } -#[cfg(test)] -mod tests { - use super::*; +type RuleIndex = usize; +type RuleListenerIndex = usize; - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); +struct AggregatedQueries { + pattern_index_lookup: Vec<(RuleIndex, RuleListenerIndex)>, + #[allow(dead_code)] + query: Query, + query_text: String, +} + +impl AggregatedQueries { + pub fn new(resolved_rules: &[ResolvedRule], context: &Context) -> Self { + let mut pattern_index_lookup: Vec<(RuleIndex, RuleListenerIndex)> = Default::default(); + let mut aggregated_query_text = String::new(); + for (rule_index, resolved_rule) in resolved_rules.into_iter().enumerate() { + for (rule_listener_index, rule_listener) in resolved_rule.listeners.iter().enumerate() { + for _ in 0..rule_listener.query.pattern_count() { + pattern_index_lookup.push((rule_index, rule_listener_index)); + } + let use_capture_name = + &rule_listener.query.capture_names()[rule_listener.capture_index as usize]; + let query_text_with_unified_capture_name = + regex!(&format!(r#"@{use_capture_name}\b"#)).replace_all( + &rule_listener.query_text, + CAPTURE_NAME_FOR_TREE_SITTER_GREP_WITH_LEADING_AT, + ); + assert!( + matches!(query_text_with_unified_capture_name, Cow::Owned(_),), + "Didn't find any instances of the capture name to replace" + ); + aggregated_query_text.push_str(&query_text_with_unified_capture_name); + aggregated_query_text.push_str("\n\n"); + } + } + let query = Query::new(context.language, &aggregated_query_text).unwrap(); + assert!(query.pattern_count() == pattern_index_lookup.len()); + Self { + pattern_index_lookup, + query, + query_text: aggregated_query_text, + } } } + +fn get_rules() -> Vec { + vec![no_default_default_rule(), no_lazy_static_rule()] +} + +fn no_default_default_rule() -> Rule { + RuleBuilder::default() + .name("no_default_default") + .create(|_context| { + vec![RuleListenerBuilder::default() + .query( + r#"( + (call_expression + function: + (scoped_identifier + path: + (identifier) @first (#eq? @first "Default") + name: + (identifier) @second (#eq? @second "default") + ) + ) @c + )"#, + ) + .capture_name("c") + .on_query_match(|node, query_match_context| { + query_match_context.report( + ViolationBuilder::default() + .message(r#"Use '_d()' instead of 'Default::default()'"#) + .node(node) + .build() + .unwrap(), + ); + }) + .build() + .unwrap()] + }) + .build() + .unwrap() +} + +fn no_lazy_static_rule() -> Rule { + RuleBuilder::default() + .name("no_lazy_static") + .create(|_context| { + vec![RuleListenerBuilder::default() + .query( + r#"( + (macro_invocation + macro: (identifier) @c (#eq? @c "lazy_static") + ) + )"#, + ) + .on_query_match(|node, query_match_context| { + query_match_context.report( + ViolationBuilder::default() + .message(r#"Prefer 'OnceCell::*::Lazy' to 'lazy_static!()'"#) + .node(node) + .build() + .unwrap(), + ); + }) + .build() + .unwrap()] + }) + .build() + .unwrap() +} diff --git a/src/rule.rs b/src/rule.rs new file mode 100644 index 0000000..dcf5aed --- /dev/null +++ b/src/rule.rs @@ -0,0 +1,100 @@ +use std::sync::Arc; + +use derive_builder::Builder; +use tree_sitter::{Node, Query}; + +use crate::context::{Context, QueryMatchContext}; + +#[derive(Builder)] +#[builder(setter(into))] +pub struct Rule { + pub name: String, + #[builder(setter(custom))] + pub create: Arc Vec>, +} + +impl Rule { + pub fn resolve(self, context: &Context) -> ResolvedRule<'_> { + let Rule { name, create } = self; + + ResolvedRule::new( + name, + create(context) + .into_iter() + .map(|rule_listener| rule_listener.resolve(context)) + .collect(), + ) + } +} + +impl RuleBuilder { + pub fn create( + &mut self, + callback: impl Fn(&Context) -> Vec + 'static, + ) -> &mut Self { + self.create = Some(Arc::new(callback)); + self + } +} + +pub struct ResolvedRule<'context> { + pub name: String, + pub listeners: Vec>, +} + +impl<'context> ResolvedRule<'context> { + pub fn new(name: String, listeners: Vec>) -> Self { + Self { name, listeners } + } +} + +#[derive(Builder)] +#[builder(setter(into, strip_option))] +pub struct RuleListener<'on_query_match> { + pub query: String, + #[builder(default)] + pub capture_name: Option, + #[builder(setter(custom))] + pub on_query_match: Arc, +} + +impl<'on_query_match> RuleListener<'on_query_match> { + pub fn resolve(self, context: &Context) -> ResolvedRuleListener<'on_query_match> { + let RuleListener { + query: query_text, + capture_name, + on_query_match, + } = self; + let query = Query::new(context.language, &query_text).unwrap(); + let capture_index = match capture_name { + None => match query.capture_names().len() { + 0 => panic!("Expected capture"), + _ => 0, + }, + Some(capture_name) => query.capture_index_for_name(&capture_name).unwrap(), + }; + ResolvedRuleListener { + query, + query_text, + capture_index, + on_query_match, + } + } +} + +impl<'on_query_match> RuleListenerBuilder<'on_query_match> { + pub fn on_query_match( + &mut self, + callback: impl Fn(&Node, &QueryMatchContext) + 'on_query_match + Send + Sync, + ) -> &mut Self { + self.on_query_match = Some(Arc::new(callback)); + self + } +} + +pub struct ResolvedRuleListener<'on_query_match> { + pub query: Query, + pub query_text: String, + pub capture_index: u32, + pub on_query_match: Arc, +} diff --git a/src/violation.rs b/src/violation.rs new file mode 100644 index 0000000..0627a7f --- /dev/null +++ b/src/violation.rs @@ -0,0 +1,9 @@ +use derive_builder::Builder; +use tree_sitter::Node; + +#[derive(Builder)] +#[builder(setter(into))] +pub struct Violation<'a> { + pub message: String, + pub node: &'a Node<'a>, +}