diff --git a/Cargo.toml b/Cargo.toml index 56dbf82..35ba177 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "tree-sitter-grep" +name = "tree_sitter_lint_tree-sitter-grep" version = "0.1.0" edition = "2021" license = "Unlicense OR MIT" @@ -8,6 +8,7 @@ authors = [ "Peter Stuart " ] description = """ +(not-yet-landed version used by tree-sitter-lint) tree-sitter-grep is a grep-like search tool that recursively searches the current directory for a tree-sitter query pattern. Like ripgrep, it respects @@ -25,6 +26,7 @@ rust-version = "1.70" bstr = "1.1.0" bytecount = "0.6" clap = { version = "4.3.0", features = ["derive", "wrap_help"] } +derive_builder = "0.12.0" encoding_rs = "0.8.14" encoding_rs_io = "0.1.6" ignore = { package = "tree_sitter_grep_ignore", git = "https://github.com/helixbass/ripgrep", rev = "669ebd3", version = "0.4.20-dev.0" } @@ -34,10 +36,13 @@ log = "0.4.5" memchr = "2.1" memmap = { package = "memmap2", version = "0.5.3" } once_cell = "1.18.0" +ouroboros = "0.17.2" proc_macros = { package = "tree_sitter_grep_proc_macros", path = "proc_macros", version = "0.1.0" } rayon = "1.7.0" regex = "1.8.2" +ropey = "1.6.0" serde = { version = "1.0.77", features = ["derive"] } +streaming-iterator = "0.1.9" strum_macros = "0.25.1" termcolor = "1.2.0" thiserror = "1.0.43" @@ -52,7 +57,8 @@ tree-sitter-elm = "5.6.4" tree-sitter-go = "0.19.1" tree-sitter-html = "0.19.0" tree-sitter-java = "0.20.0" -tree-sitter-javascript = "0.20.0" +# tree-sitter-javascript = "0.20.0" +tree-sitter-javascript = { git = "https://github.com/tree-sitter/tree-sitter-javascript", rev = "f1e5a09b", version = "0.20.1" } tree-sitter-json = "0.19.0" tree-sitter-kotlin = "0.2.11" tree-sitter-lua = "0.0.18" @@ -60,11 +66,14 @@ tree-sitter-objc = "1.1.0" tree-sitter-python = "0.20.2" tree-sitter-query = "0.1.0" tree-sitter-ruby = "0.20.0" -tree-sitter-rust = { package = "tree_sitter_grep_tree-sitter-rust", git = "https://github.com/helixbass/tree-sitter-rust", rev = "781a8d9", version = "0.20.3-dev.0" } +tree-sitter-rust = { package = "tree_sitter_grep_tree-sitter-rust", git = "https://github.com/helixbass/tree-sitter-rust", rev = "6146443", version = "0.20.3-dev.0" } tree-sitter-swift = "0.3.6" tree-sitter-toml = "0.20.0" tree-sitter-typescript = "0.20.2" +[patch.crates-io] +tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "c16b90d" } + [[bin]] name = "tree-sitter-grep" @@ -73,6 +82,7 @@ assert_cmd = "2.0.11" escargot = "0.5.7" predicates = "3.0.3" shlex = "1.1.0" +speculoos = "0.11.0" [features] default = ["bytecount/runtime-dispatch-simd"] diff --git a/examples/filter_before_line_number.rs b/examples/filter_before_line_number.rs index ae05948..3eae9f5 100644 --- a/examples/filter_before_line_number.rs +++ b/examples/filter_before_line_number.rs @@ -5,7 +5,7 @@ use std::{ use libc::c_char; use tree_sitter::Node; -use tree_sitter_grep::PluginInitializeReturn; +use tree_sitter_lint_tree_sitter_grep::PluginInitializeReturn; static ROW_NUMBER: AtomicUsize = AtomicUsize::new(0); diff --git a/examples/print_match_text.rs b/examples/print_match_text.rs index 3806033..08c5ed5 100644 --- a/examples/print_match_text.rs +++ b/examples/print_match_text.rs @@ -1,12 +1,19 @@ use clap::Parser; -use tree_sitter_grep::{run_with_callback, Args}; +use tree_sitter_lint_tree_sitter_grep::{run_with_callback, Args}; fn main() { let args = Args::parse_from(["tree_sitter_grep", "-q", "(function_item) @f"]); - run_with_callback(args, |node, file_contents, path| { + run_with_callback(args, |query_match, file_contents, path| { println!( "Found match in {path:?}: {}", - std::str::from_utf8(&file_contents[node.byte_range()]).unwrap(), + std::str::from_utf8( + &file_contents[query_match + .nodes_for_capture_index(0) + .next() + .unwrap() + .byte_range()] + ) + .unwrap(), ); }) .unwrap(); diff --git a/proc_macros/src/lib.rs b/proc_macros/src/lib.rs index 40fce98..1854090 100644 --- a/proc_macros/src/lib.rs +++ b/proc_macros/src/lib.rs @@ -133,7 +133,7 @@ fn get_token_enum_definition( variants_with_attributes: &[ExprPath], ) -> proc_macro2::TokenStream { quote! { - #[derive(Copy, Clone, Debug, Eq, PartialEq, clap::ValueEnum, strum_macros::Display)] + #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, clap::ValueEnum, strum_macros::Display)] pub enum #name { #(#variants_with_attributes),* } @@ -335,8 +335,8 @@ fn get_all_variants_collection_definition( ) -> proc_macro2::TokenStream { quote! { pub static #all_variants_collection_name: #collection_type_name<#name> = { - use SupportedLanguage::*; - BySupportedLanguage([ + use #name::*; + #collection_type_name([ #(#variants),* ]) }; diff --git a/rustfmt.toml b/rustfmt.toml index cd90377..5528c07 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -3,4 +3,5 @@ format_macro_bodies = true format_macro_matchers = true group_imports = "StdExternalCrate" imports_granularity = "Crate" -wrap_comments = true +edition = "2021" +# wrap_comments = true diff --git a/src/args.rs b/src/args.rs index b625291..2064bc8 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,13 +1,16 @@ use std::{ + collections::HashMap, fs, path::{Path, PathBuf}, sync::{Arc, Mutex}, }; use clap::{ArgGroup, Parser}; +use derive_builder::Builder; use ignore::{types::Types, WalkBuilder, WalkParallel}; use rayon::iter::IterBridge; use termcolor::BufferWriter; +use tree_sitter::Query; use crate::{ language::SupportedLanguage, @@ -18,12 +21,13 @@ use crate::{ }, searcher::{Searcher, SearcherBuilder}, use_printer::Printer, - Error, NonFatalError, + Error, NonFatalError, SupportedLanguageLanguage, }; const ALL_NODES_QUERY: &str = "(_) @node"; -#[derive(Parser)] +#[derive(Builder, Clone, Default, Parser)] +#[builder(default, setter(strip_option, into))] #[clap(group( ArgGroup::new("query_or_filter") .multiple(true) @@ -37,13 +41,16 @@ pub struct Args { /// /// This conflicts with the --query option. #[arg(short = 'Q', long = "query-file", conflicts_with = "query_text")] - pub path_to_query_file: Option, + path_to_query_file: Option, /// The source text of a tree-sitter query. /// /// This conflicts with the --query-file option. #[arg(short, long = "query", conflicts_with = "path_to_query_file")] - pub query_text: Option, + query_text: Option, + + #[clap(skip)] + query_per_language: Option, /// The name of the tree-sitter query capture (without leading "@") whose /// matching nodes will be output. @@ -174,7 +181,16 @@ impl Args { } pub(crate) fn get_project_file_walker_types(&self) -> Types { - get_project_file_walker_types(self.language) + get_project_file_walker_types(self.language.map(|language| vec![language]).or_else(|| { + self.query_per_language.as_ref().map(|query_per_language| { + query_per_language + .keys() + .map(|supported_language_language| { + supported_language_language.supported_language() + }) + .collect() + }) + })) } pub(crate) fn get_project_file_walker(&self) -> WalkParallel { @@ -199,18 +215,83 @@ impl Args { Ok(get_loaded_filter(self.filter.as_deref(), self.filter_arg.as_deref())?.map(Arc::new)) } - pub(crate) fn get_loaded_query_text(&self) -> Result { + pub(crate) fn get_loaded_query_text_per_language( + &self, + ) -> Result { Ok( - match (self.path_to_query_file.as_ref(), self.query_text.as_ref()) { - (Some(path_to_query_file), None) => fs::read_to_string(path_to_query_file) + match ( + self.path_to_query_file.as_ref(), + self.query_text.as_ref(), + self.query_per_language.as_ref(), + ) { + (Some(path_to_query_file), None, None) => fs::read_to_string(path_to_query_file) .map_err(|source| Error::QueryFileReadError { source, path_to_query_file: path_to_query_file.clone(), - })?, - (None, Some(query_text)) => query_text.clone(), - (None, None) => ALL_NODES_QUERY.to_owned(), + })? + .into(), + (None, Some(query_text), None) => query_text.clone().into(), + (None, None, Some(query_per_language)) => query_per_language.clone().into(), + (None, None, None) => ALL_NODES_QUERY.to_owned().into(), _ => unreachable!(), }, ) } } + +impl ArgsBuilder { + pub fn maybe_language(&mut self, language: Option) -> &mut Self { + self.language = Some(language); + self + } +} + +pub type QueryPerLanguage = HashMap>; + +pub enum QueryOrQueryTextPerLanguage { + SingleQueryText(String), + PerLanguage(QueryPerLanguage), +} + +impl QueryOrQueryTextPerLanguage { + pub fn get_query_or_query_text_for_language( + &self, + language: SupportedLanguageLanguage, + ) -> QueryOrQueryText { + match self { + QueryOrQueryTextPerLanguage::SingleQueryText(query_text) => (&**query_text).into(), + QueryOrQueryTextPerLanguage::PerLanguage(per_language) => { + per_language.get(&language).unwrap().clone().into() + } + } + } +} + +impl From for QueryOrQueryTextPerLanguage { + fn from(value: String) -> Self { + Self::SingleQueryText(value) + } +} + +impl From for QueryOrQueryTextPerLanguage { + fn from(value: QueryPerLanguage) -> Self { + Self::PerLanguage(value) + } +} + +pub enum QueryOrQueryText<'a> { + QueryText(&'a str), + Query(Arc), +} + +impl<'a> From<&'a str> for QueryOrQueryText<'a> { + fn from(value: &'a str) -> Self { + Self::QueryText(value) + } +} + +impl<'a> From> for QueryOrQueryText<'a> { + fn from(value: Arc) -> Self { + Self::Query(value) + } +} diff --git a/src/bin/tree-sitter-grep.rs b/src/bin/tree-sitter-grep.rs index b250691..87bed28 100644 --- a/src/bin/tree-sitter-grep.rs +++ b/src/bin/tree-sitter-grep.rs @@ -1,7 +1,7 @@ use std::process; use clap::Parser; -use tree_sitter_grep::{run_print, Args, RunStatus}; +use tree_sitter_lint_tree_sitter_grep::{run_print, Args, RunStatus}; pub fn main() { let args = Args::parse(); diff --git a/src/language.rs b/src/language.rs index d4cf908..3d2a187 100644 --- a/src/language.rs +++ b/src/language.rs @@ -1,6 +1,7 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, ops::{Deref, Index}, + path::Path, }; use once_cell::sync::Lazy; @@ -42,42 +43,220 @@ fixed_map! { ], } +fixed_map! { + name => SupportedLanguageLanguage, + variants => [ + C, + #[value(name = "c++")] + #[strum(serialize = "C++")] + Cpp, + #[strum(serialize = "C#")] + CSharp, + #[strum(serialize = "CSS")] + Css, + Dockerfile, + Elisp, + Elm, + Go, + #[strum(serialize = "HTML")] + Html, + Java, + Javascript, + #[strum(serialize = "JSON")] + Json, + Kotlin, + Lua, + #[strum(serialize = "Objective-C")] + ObjectiveC, + Python, + Ruby, + Rust, + Swift, + Toml, + TreeSitterQuery, + Tsx, + Typescript, + ], +} + impl SupportedLanguage { - pub fn language(&self) -> Language { - SUPPORTED_LANGUAGE_LANGUAGES[*self] + pub fn language(&self, path: Option<&Path>) -> Language { + self.supported_language_language(path).language() + } + + pub fn supported_language_language(&self, path: Option<&Path>) -> SupportedLanguageLanguage { + match &SUPPORTED_LANGUAGE_LANGUAGES[*self] { + SingleLanguageOrLanguageFromPath::SingleLanguage(language) => *language, + SingleLanguageOrLanguageFromPath::LanguageFromPath(language_from_path) => { + language_from_path.from_path(path) + } + } + } + + pub fn all_supported_language_languages(&self) -> &'static [SupportedLanguageLanguage] { + &SUPPORTED_LANGUAGE_ALL_LANGUAGES[*self] } pub fn name_for_ignore_select(&self) -> &'static str { SUPPORTED_LANGUAGE_NAMES_FOR_IGNORE_SELECT[*self] } + + pub fn comment_kinds(&self) -> &'static HashSet<&'static str> { + &SUPPORTED_LANGUAGE_COMMENT_KINDS[*self] + } } -static SUPPORTED_LANGUAGE_LANGUAGES: Lazy> = Lazy::new(|| { - by_supported_language!( - Rust => tree_sitter_rust::language(), - Typescript => tree_sitter_typescript::language_tsx(), - Javascript => tree_sitter_javascript::language(), - Swift => tree_sitter_swift::language(), - ObjectiveC => tree_sitter_objc::language(), - Toml => tree_sitter_toml::language(), - Python => tree_sitter_python::language(), - Ruby => tree_sitter_ruby::language(), - C => tree_sitter_c::language(), - Cpp => tree_sitter_cpp::language(), - Go => tree_sitter_go::language(), - Java => tree_sitter_java::language(), - CSharp => tree_sitter_c_sharp::language(), - Kotlin => tree_sitter_kotlin::language(), - Elisp => tree_sitter_elisp::language(), - Elm => tree_sitter_elm::language(), - Dockerfile => tree_sitter_dockerfile::language(), - Html => tree_sitter_html::language(), - TreeSitterQuery => tree_sitter_query::language(), - Json => tree_sitter_json::language(), - Css => tree_sitter_css::language(), - Lua => tree_sitter_lua::language(), - ) -}); +impl SupportedLanguageLanguage { + pub fn language(&self) -> Language { + SUPPORTED_LANGUAGE_LANGUAGE_LANGUAGES[*self] + } + + pub fn supported_language(&self) -> SupportedLanguage { + SUPPORTED_LANGUAGE_LANGUAGE_SUPPORTED_LANGUAGES[*self] + } +} + +enum SingleLanguageOrLanguageFromPath { + SingleLanguage(SupportedLanguageLanguage), + LanguageFromPath(Box), +} + +impl From for SingleLanguageOrLanguageFromPath { + fn from(value: SupportedLanguageLanguage) -> Self { + Self::SingleLanguage(value) + } +} + +trait LanguageFromPath: Send + Sync { + #[allow(clippy::wrong_self_convention)] + fn from_path(&self, path: Option<&Path>) -> SupportedLanguageLanguage; +} + +struct TypescriptLanguageFromPath; + +impl LanguageFromPath for TypescriptLanguageFromPath { + fn from_path(&self, path: Option<&Path>) -> SupportedLanguageLanguage { + match path.and_then(|path| path.extension()) { + Some(extension) if "tsx" == extension => SupportedLanguageLanguage::Tsx, + _ => SupportedLanguageLanguage::Typescript, + } + } +} + +static SUPPORTED_LANGUAGE_LANGUAGES: Lazy> = + Lazy::new(|| { + by_supported_language!( + Rust => SupportedLanguageLanguage::Rust.into(), + Typescript => SingleLanguageOrLanguageFromPath::LanguageFromPath(Box::new(TypescriptLanguageFromPath)), + Javascript => SupportedLanguageLanguage::Javascript.into(), + Swift => SupportedLanguageLanguage::Swift.into(), + ObjectiveC => SupportedLanguageLanguage::ObjectiveC.into(), + Toml => SupportedLanguageLanguage::Toml.into(), + Python => SupportedLanguageLanguage::Python.into(), + Ruby => SupportedLanguageLanguage::Ruby.into(), + C => SupportedLanguageLanguage::C.into(), + Cpp => SupportedLanguageLanguage::Cpp.into(), + Go => SupportedLanguageLanguage::Go.into(), + Java => SupportedLanguageLanguage::Java.into(), + CSharp => SupportedLanguageLanguage::CSharp.into(), + Kotlin => SupportedLanguageLanguage::Kotlin.into(), + Elisp => SupportedLanguageLanguage::Elisp.into(), + Elm => SupportedLanguageLanguage::Elm.into(), + Dockerfile => SupportedLanguageLanguage::Dockerfile.into(), + Html => SupportedLanguageLanguage::Html.into(), + TreeSitterQuery => SupportedLanguageLanguage::TreeSitterQuery.into(), + Json => SupportedLanguageLanguage::Json.into(), + Css => SupportedLanguageLanguage::Css.into(), + Lua => SupportedLanguageLanguage::Lua.into(), + ) + }); + +static SUPPORTED_LANGUAGE_ALL_LANGUAGES: Lazy>> = + Lazy::new(|| { + by_supported_language!( + Rust => vec![SupportedLanguageLanguage::Rust], + Typescript => vec![SupportedLanguageLanguage::Tsx, SupportedLanguageLanguage::Typescript], + Javascript => vec![SupportedLanguageLanguage::Javascript], + Swift => vec![SupportedLanguageLanguage::Swift], + ObjectiveC => vec![SupportedLanguageLanguage::ObjectiveC], + Toml => vec![SupportedLanguageLanguage::Toml], + Python => vec![SupportedLanguageLanguage::Python], + Ruby => vec![SupportedLanguageLanguage::Ruby], + C => vec![SupportedLanguageLanguage::C], + Cpp => vec![SupportedLanguageLanguage::Cpp], + Go => vec![SupportedLanguageLanguage::Go], + Java => vec![SupportedLanguageLanguage::Java], + CSharp => vec![SupportedLanguageLanguage::CSharp], + Kotlin => vec![SupportedLanguageLanguage::Kotlin], + Elisp => vec![SupportedLanguageLanguage::Elisp], + Elm => vec![SupportedLanguageLanguage::Elm], + Dockerfile => vec![SupportedLanguageLanguage::Dockerfile], + Html => vec![SupportedLanguageLanguage::Html], + TreeSitterQuery => vec![SupportedLanguageLanguage::TreeSitterQuery], + Json => vec![SupportedLanguageLanguage::Json], + Css => vec![SupportedLanguageLanguage::Css], + Lua => vec![SupportedLanguageLanguage::Lua], + ) + }); + +static SUPPORTED_LANGUAGE_LANGUAGE_LANGUAGES: Lazy> = + Lazy::new(|| { + by_supported_language_language!( + Rust => tree_sitter_rust::language(), + Typescript => tree_sitter_typescript::language_typescript(), + Tsx => tree_sitter_typescript::language_tsx(), + Javascript => tree_sitter_javascript::language(), + Swift => tree_sitter_swift::language(), + ObjectiveC => tree_sitter_objc::language(), + Toml => tree_sitter_toml::language(), + Python => tree_sitter_python::language(), + Ruby => tree_sitter_ruby::language(), + C => tree_sitter_c::language(), + Cpp => tree_sitter_cpp::language(), + Go => tree_sitter_go::language(), + Java => tree_sitter_java::language(), + CSharp => tree_sitter_c_sharp::language(), + Kotlin => tree_sitter_kotlin::language(), + Elisp => tree_sitter_elisp::language(), + Elm => tree_sitter_elm::language(), + Dockerfile => tree_sitter_dockerfile::language(), + Html => tree_sitter_html::language(), + TreeSitterQuery => tree_sitter_query::language(), + Json => tree_sitter_json::language(), + Css => tree_sitter_css::language(), + Lua => tree_sitter_lua::language(), + ) + }); + + +static SUPPORTED_LANGUAGE_LANGUAGE_SUPPORTED_LANGUAGES: Lazy> = + Lazy::new(|| { + by_supported_language_language!( + Rust => SupportedLanguage::Rust, + Typescript => SupportedLanguage::Typescript, + Tsx => SupportedLanguage::Typescript, + Javascript => SupportedLanguage::Javascript, + Swift => SupportedLanguage::Swift, + ObjectiveC => SupportedLanguage::ObjectiveC, + Toml => SupportedLanguage::Toml, + Python => SupportedLanguage::Python, + Ruby => SupportedLanguage::Ruby, + C => SupportedLanguage::C, + Cpp => SupportedLanguage::Cpp, + Go => SupportedLanguage::Go, + Java => SupportedLanguage::Java, + CSharp => SupportedLanguage::CSharp, + Kotlin => SupportedLanguage::Kotlin, + Elisp => SupportedLanguage::Elisp, + Elm => SupportedLanguage::Elm, + Dockerfile => SupportedLanguage::Dockerfile, + Html => SupportedLanguage::Html, + TreeSitterQuery => SupportedLanguage::TreeSitterQuery, + Json => SupportedLanguage::Json, + Css => SupportedLanguage::Css, + Lua => SupportedLanguage::Lua, + ) + }); static SUPPORTED_LANGUAGE_NAMES_FOR_IGNORE_SELECT: BySupportedLanguage<&'static str> = by_supported_language!( Rust => "rust", @@ -117,3 +296,56 @@ pub static ALL_SUPPORTED_LANGUAGES_BY_NAME_FOR_IGNORE_SELECT: Lazy< }) .collect() }); + +static SUPPORTED_LANGUAGE_COMMENT_KINDS: Lazy>> = + Lazy::new(|| { + by_supported_language!( + Rust => ["line_comment", "block_comment"].into(), + Typescript => ["comment"].into(), + Javascript => ["comment"].into(), + Swift => ["comment"].into(), + ObjectiveC => ["comment"].into(), + Toml => ["comment"].into(), + Python => ["comment"].into(), + Ruby => ["comment"].into(), + C => ["comment"].into(), + Cpp => ["comment"].into(), + Go => ["comment"].into(), + Java => ["comment"].into(), + CSharp => ["comment"].into(), + Kotlin => ["comment"].into(), + Elisp => ["comment"].into(), + Elm => ["comment"].into(), + Dockerfile => ["comment"].into(), + Html => ["comment"].into(), + TreeSitterQuery => ["comment"].into(), + Json => ["comment"].into(), + Css => ["comment"].into(), + Lua => ["comment"].into(), + ) + }); + +#[cfg(test)] +mod tests { + use speculoos::prelude::*; + + use super::*; + + #[test] + fn test_supported_language_language_simple() { + assert_that!(&SupportedLanguage::Rust.language(Some("foo.rs".as_ref()))) + .is_equal_to(tree_sitter_rust::language()); + assert_that!(&SupportedLanguage::Rust.language(None)) + .is_equal_to(tree_sitter_rust::language()); + } + + #[test] + fn test_supported_language_language_typescript() { + assert_that!(&SupportedLanguage::Typescript.language(Some("foo.tsx".as_ref()))) + .is_equal_to(tree_sitter_typescript::language_tsx()); + assert_that!(&SupportedLanguage::Typescript.language(Some("foo.ts".as_ref()))) + .is_equal_to(tree_sitter_typescript::language_typescript()); + assert_that!(&SupportedLanguage::Typescript.language(None)) + .is_equal_to(tree_sitter_typescript::language_typescript()); + } +} diff --git a/src/lib.rs b/src/lib.rs index a374ff0..2a2f781 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,11 +9,12 @@ use std::{ }, }; +use args::QueryOrQueryText; use ignore::DirEntry; use rayon::prelude::*; use termcolor::{BufferWriter, ColorChoice}; use thiserror::Error; -use tree_sitter::{Node, Query, QueryError}; +use tree_sitter::{Query, QueryError, QueryMatch, Tree}; mod args; mod language; @@ -31,14 +32,23 @@ mod treesitter; mod use_printer; mod use_searcher; -pub use args::Args; -use language::{BySupportedLanguage, SupportedLanguage}; +pub use args::{Args, ArgsBuilder}; +use language::BySupportedLanguageLanguage; +pub use language::{SupportedLanguage, SupportedLanguageLanguage}; pub use plugin::PluginInitializeReturn; use query_context::QueryContext; use treesitter::maybe_get_query; +pub use treesitter::{ + get_captures, get_captures_for_enclosing_node, get_matches, get_parser, CaptureInfo, Parseable, + RopeOrSlice, +}; use use_printer::get_printer; use use_searcher::get_searcher; +pub extern crate ropey; +pub extern crate streaming_iterator; +pub extern crate tree_sitter; + #[derive(Debug, Error)] pub enum Error { #[error("couldn't read query file {path_to_query_file:?}")] @@ -65,7 +75,7 @@ pub enum Error { } } )] - NoSuccessfulQueryParsing(Vec<(SupportedLanguage, QueryError)>), + NoSuccessfulQueryParsing(Vec<(SupportedLanguageLanguage, QueryError)>), #[error("query must include at least one capture (\"@whatever\")")] NoCaptureInQuery, #[error("invalid capture name '{capture_name}'")] @@ -74,6 +84,8 @@ pub enum Error { FilterPluginExpectedArgument, #[error("plugin couldn't parse argument {filter_arg:?}")] FilterPluginCouldntParseArgument { filter_arg: String }, + #[error("language is required when passing a slice")] + LanguageMissingForSlice, } #[derive(Clone, Debug, Error)] @@ -96,7 +108,7 @@ pub enum NonFatalError { )] AmbiguousLanguageForFile { path: PathBuf, - languages: Vec, + languages: Vec, }, #[error("No files were searched")] NothingSearched, @@ -107,7 +119,7 @@ pub enum NonFatalError { }, } -#[derive(Clone)] +#[derive(Clone, Debug)] enum CaptureIndexError { NoCaptureInQuery, InvalidCaptureName { capture_name: String }, @@ -124,31 +136,6 @@ impl From for Error { } } -#[derive(Default)] -struct CaptureIndex(OnceLock>); - -impl CaptureIndex { - pub fn get_or_init( - &self, - query: &Query, - capture_name: Option<&str>, - ) -> Result { - self.0 - .get_or_init(|| match capture_name { - None => match query.capture_names().len() { - 0 => Err(CaptureIndexError::NoCaptureInQuery), - _ => Ok(0), - }, - Some(capture_name) => query.capture_index_for_name(capture_name).ok_or_else(|| { - CaptureIndexError::InvalidCaptureName { - capture_name: capture_name.to_owned(), - } - }), - }) - .clone() - } -} - fn join_with_or(list: &[TItem]) -> String { let mut ret: String = Default::default(); for (index, item) in list.iter().enumerate() { @@ -162,17 +149,72 @@ fn join_with_or(list: &[TItem]) -> String { ret } +type CaptureIndex = u32; + +#[derive(Debug)] +enum QueryOrCaptureIndexError { + QueryError(QueryError), + CaptureIndexError(CaptureIndexError), +} + +impl From for QueryOrCaptureIndexError { + fn from(value: QueryError) -> Self { + Self::QueryError(value) + } +} + +impl From for QueryOrCaptureIndexError { + fn from(value: CaptureIndexError) -> Self { + Self::CaptureIndexError(value) + } +} + +#[allow(clippy::type_complexity)] #[derive(Default)] -struct CachedQueries(BySupportedLanguage, QueryError>>>); +struct CachedQueries( + BySupportedLanguageLanguage< + OnceLock, CaptureIndex), QueryOrCaptureIndexError>>, + >, +); impl CachedQueries { - fn get_and_cache_query_for_language( + fn get_and_cache_query_for_language<'a>( &self, - query_text: &str, - language: SupportedLanguage, - ) -> Option> { - self.0[language] - .get_or_init(|| maybe_get_query(query_text, language.language()).map(Arc::new)) + query_or_query_text: impl Into>, + supported_language_language: SupportedLanguageLanguage, + capture_name: Option<&str>, + ) -> Option<(Arc, CaptureIndex)> { + let query_or_query_text = query_or_query_text.into(); + self.0[supported_language_language] + .get_or_init(|| { + match query_or_query_text { + QueryOrQueryText::QueryText(query_text) => { + maybe_get_query(query_text, supported_language_language.language()) + .map(Arc::new) + .map_err(Into::into) + } + QueryOrQueryText::Query(query) => Ok(query), + } + .and_then( + |query| -> Result<(Arc, CaptureIndex), QueryOrCaptureIndexError> { + match capture_name { + None => match query.capture_names().len() { + 0 => Err(CaptureIndexError::NoCaptureInQuery.into()), + _ => Ok(0), + }, + Some(capture_name) => { + query.capture_index_for_name(capture_name).ok_or_else(|| { + CaptureIndexError::InvalidCaptureName { + capture_name: capture_name.to_owned(), + } + .into() + }) + } + } + .map(|capture_index| (query, capture_index)) + }, + ) + }) .as_ref() .ok() .cloned() @@ -200,13 +242,44 @@ impl CachedQueries { !attempted_parsings.is_empty(), "Should've tried to parse in at least one language or else should've already failed on no candidate files" ); - return Err(Error::NoSuccessfulQueryParsing(attempted_parsings)); + if let Some((_, capture_index_error)) = + attempted_parsings + .iter() + .find(|(_, query_or_capture_index_error)| { + matches!( + query_or_capture_index_error, + QueryOrCaptureIndexError::CaptureIndexError(_) + ) + }) + { + match capture_index_error { + QueryOrCaptureIndexError::CaptureIndexError(capture_index_error) => { + return Err(capture_index_error.clone().into()) + } + _ => unreachable!(), + } + } + return Err(Error::NoSuccessfulQueryParsing( + attempted_parsings + .into_iter() + .map(|(language, query_or_capture_index_error)| { + ( + language, + match query_or_capture_index_error { + QueryOrCaptureIndexError::QueryError(query_error) => query_error, + _ => unreachable!(), + }, + ) + }) + .collect(), + )); } Ok(()) } } +#[derive(Debug)] pub struct RunStatus { pub matched: bool, pub non_fatal_errors: Vec, @@ -292,7 +365,7 @@ pub fn run_print(args: Args) -> Result { pub fn run_with_callback( args: Args, - callback: impl Fn(Node, &[u8], &Path) + Sync, + callback: impl Fn(&QueryMatch, &[u8], &Path) + Sync, ) -> Result { run_for_context( args, @@ -307,8 +380,8 @@ pub fn run_with_callback( .search_path_callback::<_, io::Error>( query_context, path, - |node: Node, file_contents: &[u8], path: &Path| { - callback(node, file_contents, path); + |query_match: &QueryMatch, file_contents: &[u8], path: &Path| { + callback(query_match, file_contents, path); matched.store(true, Ordering::SeqCst); }, ) @@ -322,10 +395,9 @@ fn run_for_context( context: TContext, search_file: impl Fn(&TContext, &Args, &Path, QueryContext, &AtomicBool) + Sync, ) -> Result { - let query_text = args.get_loaded_query_text()?; + let query_text_per_language = args.get_loaded_query_text_per_language()?; let filter = args.get_loaded_filter()?; let cached_queries: CachedQueries = Default::default(); - let capture_index = CaptureIndex::default(); let matched = AtomicBool::new(false); let searched = AtomicBool::new(false); let non_fatal_errors: Arc>> = Default::default(); @@ -335,32 +407,42 @@ fn run_for_context( non_fatal_errors.clone(), |project_file_dir_entry, matched_languages| { searched.store(true, Ordering::SeqCst); - let language = match args.language { + let path = project_file_dir_entry.path(); + let supported_language_language = match args.language { Some(specified_language) => { if !matched_languages.contains(&specified_language) { return NonFatalError::ExplicitPathArgumentNotOfSpecifiedType { - path: project_file_dir_entry.path().to_owned(), + path: path.to_owned(), specified_language, } .into(); } - specified_language + specified_language.supported_language_language(Some(path)) } None => match matched_languages.len() { 0 => { return NonFatalError::ExplicitPathArgumentNotOfKnownType { - path: project_file_dir_entry.path().to_owned(), + path: path.to_owned(), } .into(); } - 1 => matched_languages[0], + 1 => matched_languages[0].supported_language_language(Some(path)), _ => { let successfully_parsed_query_languages = matched_languages .iter() .filter_map(|&matched_language| { + let matched_supported_language_language = + matched_language.supported_language_language(Some(path)); cached_queries - .get_and_cache_query_for_language(&query_text, matched_language) - .map(|_| matched_language) + .get_and_cache_query_for_language( + query_text_per_language + .get_query_or_query_text_for_language( + matched_supported_language_language, + ), + matched_supported_language_language, + args.capture_name.as_deref(), + ) + .map(|_| matched_supported_language_language) }) .collect::>(); match successfully_parsed_query_languages.len() { @@ -370,7 +452,7 @@ fn run_for_context( 1 => successfully_parsed_query_languages[0], _ => { return NonFatalError::AmbiguousLanguageForFile { - path: project_file_dir_entry.path().to_owned(), + path: path.to_owned(), languages: successfully_parsed_query_languages, } .into(); @@ -379,19 +461,24 @@ fn run_for_context( } }, }; - let query = match cached_queries.get_and_cache_query_for_language(&query_text, language) - { + let (query, capture_index) = match cached_queries.get_and_cache_query_for_language( + query_text_per_language.get_query_or_query_text_for_language(supported_language_language), + supported_language_language, + args.capture_name.as_deref(), + ) { Some(query) => query, None => return Ok(SingleFileSearchNonFailure::QueryNotParseableForFile), }; - let capture_index = capture_index.get_or_init(&query, args.capture_name.as_deref())?; - let path = - format_relative_path(project_file_dir_entry.path(), args.is_using_default_paths()); + let relative_path = format_relative_path(path, args.is_using_default_paths()); - let query_context = - QueryContext::new(query, capture_index, language.language(), filter.clone()); + let query_context = QueryContext::new( + query, + capture_index, + supported_language_language.language(), + filter.clone(), + ); - search_file(&context, &args, path, query_context, &matched); + search_file(&context, &args, relative_path, query_context, &matched); Ok(SingleFileSearchNonFailure::RanQuery) }, @@ -412,6 +499,118 @@ fn run_for_context( }) } +pub fn run_with_single_per_file_callback( + args: Args, + per_file_callback: impl Fn(&DirEntry, SupportedLanguageLanguage, &[u8], &Tree, &Arc) + Sync, +) -> Result { + let query_text_per_language = args.get_loaded_query_text_per_language()?; + let filter = args.get_loaded_filter()?; + let cached_queries: CachedQueries = Default::default(); + let non_fatal_errors: Arc>> = Default::default(); + + for_each_project_file( + &args, + non_fatal_errors.clone(), + |project_file_dir_entry, matched_languages| { + let path = project_file_dir_entry.path(); + let supported_language_language = match args.language { + Some(specified_language) => { + if !matched_languages.contains(&specified_language) { + return NonFatalError::ExplicitPathArgumentNotOfSpecifiedType { + path: path.to_owned(), + specified_language, + } + .into(); + } + specified_language.supported_language_language(Some(path)) + } + None => match matched_languages.len() { + 0 => { + return NonFatalError::ExplicitPathArgumentNotOfKnownType { + path: path.to_owned(), + } + .into(); + } + 1 => matched_languages[0].supported_language_language(Some(path)), + _ => { + let successfully_parsed_query_languages = matched_languages + .iter() + .filter_map(|&matched_language| { + let matched_supported_language_language = matched_language.supported_language_language(Some(path)); + cached_queries + .get_and_cache_query_for_language( + query_text_per_language + .get_query_or_query_text_for_language(matched_supported_language_language), + matched_supported_language_language, + args.capture_name.as_deref(), + ) + .map(|_| matched_supported_language_language) + }) + .collect::>(); + match successfully_parsed_query_languages.len() { + 0 => { + return Ok(SingleFileSearchNonFailure::QueryNotParseableForFile); + } + 1 => successfully_parsed_query_languages[0], + _ => { + return NonFatalError::AmbiguousLanguageForFile { + path: path.to_owned(), + languages: successfully_parsed_query_languages, + } + .into(); + } + } + } + }, + }; + let (query, capture_index) = match cached_queries.get_and_cache_query_for_language( + query_text_per_language.get_query_or_query_text_for_language(supported_language_language), + supported_language_language, + args.capture_name.as_deref(), + ) { + Some(query) => query, + None => return Ok(SingleFileSearchNonFailure::QueryNotParseableForFile), + }; + let relative_path = format_relative_path(path, args.is_using_default_paths()); + + let query_context = QueryContext::new( + query, + capture_index, + supported_language_language.language(), + filter.clone(), + ); + + let searcher = get_searcher(&args); + let mut searcher = searcher.borrow_mut(); + let file_contents = searcher + .load_file_contents::<_, io::Error>(relative_path) + .unwrap(); + let tree = (&*file_contents) + .parse(&mut get_parser(supported_language_language.language()), None) + .unwrap(); + per_file_callback( + &project_file_dir_entry, + supported_language_language, + &file_contents, + &tree, + &query_context.query, + ); + + Ok(SingleFileSearchNonFailure::RanQuery) + }, + )?; + + let non_fatal_errors = non_fatal_errors.lock().unwrap().clone(); + if non_fatal_errors.is_empty() { + cached_queries.error_if_no_successful_query_parsing()?; + } + + Ok(RunStatus { + matched: false, + non_fatal_errors, + }) +} + fn for_each_project_file( args: &Args, non_fatal_errors: Arc>>, diff --git a/src/project_file_walker.rs b/src/project_file_walker.rs index 1294f10..79289fa 100644 --- a/src/project_file_walker.rs +++ b/src/project_file_walker.rs @@ -95,11 +95,15 @@ impl Iterator for WalkParallelIterator { } } -pub(crate) fn get_project_file_walker_types(language: Option) -> Types { +pub(crate) fn get_project_file_walker_types( + languages: Option>, +) -> Types { let mut types_builder = TypesBuilder::new(); types_builder.add_defaults(); - if let Some(language) = language { - types_builder.select(language.name_for_ignore_select()); + if let Some(languages) = languages { + for language in languages { + types_builder.select(language.name_for_ignore_select()); + } } else { for language in ALL_SUPPORTED_LANGUAGES.values() { types_builder.select(language.name_for_ignore_select()); diff --git a/src/query_context.rs b/src/query_context.rs index d788956..c79ff0b 100644 --- a/src/query_context.rs +++ b/src/query_context.rs @@ -4,6 +4,7 @@ use tree_sitter::{Language, Query}; use crate::plugin::Filterer; +#[derive(Clone)] pub struct QueryContext { pub query: Arc, pub capture_index: u32, diff --git a/src/searcher/glue.rs b/src/searcher/glue.rs index 3b0c7f6..1a1eba1 100644 --- a/src/searcher/glue.rs +++ b/src/searcher/glue.rs @@ -1,13 +1,14 @@ // derived from https://github.com/BurntSushi/ripgrep/blob/master/crates/searcher/src/searcher/glue.rs -use tree_sitter::{Node, QueryCursor}; +use streaming_iterator::StreamingIterator; use crate::{ lines::{self, LineStep}, query_context::QueryContext, searcher::{core::Core, Config, Range, Searcher}, sink::Sink, - treesitter::get_parser, + treesitter::get_captures, + CaptureInfo, }; #[derive(Debug, Default)] @@ -75,33 +76,17 @@ impl<'s, S: Sink> MultiLine<'s, S> { pub fn run(mut self) -> Result<(), S::Error> { if self.core.begin()? { let mut keepgoing = true; - let mut query_cursor = QueryCursor::new(); - let tree = get_parser(self.core.query_context().language) - .parse(self.slice, None) - .unwrap(); - let query = self.core.query_context().query.clone(); - let capture_index = self.core.query_context().capture_index; - let filter = self.core.query_context().filter.clone(); - let mut matches = query_cursor - .captures(&query, tree.root_node(), self.slice) - .filter_map(|(match_, found_capture_index)| { - let found_capture_index = found_capture_index as u32; - if found_capture_index != capture_index { - return None; - } - let mut nodes_for_this_capture = match_.nodes_for_capture_index(capture_index); - let single_captured_node = nodes_for_this_capture.next().unwrap(); - assert!( - nodes_for_this_capture.next().is_none(), - "I guess .captures() always wraps up the single capture like this?" - ); - match filter.as_ref() { - None => Some(single_captured_node), - Some(filter) => filter - .call(&single_captured_node) - .then_some(single_captured_node), - } - }); + let query_context = self.core.query_context(); + let query = query_context.query.clone(); + let filter = query_context.filter.clone(); + let mut matches = get_captures( + query_context.language, + self.slice, + &query, + query_context.capture_index, + filter.as_deref(), + None, + ); while !self.slice[self.core.pos()..].is_empty() && keepgoing { keepgoing = self.sink(&mut matches)?; } @@ -132,7 +117,7 @@ impl<'s, S: Sink> MultiLine<'s, S> { fn sink<'tree>( &mut self, - matches: &mut impl Iterator>, + matches: &mut impl StreamingIterator>, ) -> Result { if self.config.invert_match { return self.sink_matched_inverted(matches); @@ -173,7 +158,7 @@ impl<'s, S: Sink> MultiLine<'s, S> { fn sink_matched_inverted<'tree>( &mut self, - matches: &mut impl Iterator>, + matches: &mut impl StreamingIterator>, ) -> Result { assert!(self.config.invert_match); @@ -241,9 +226,12 @@ impl<'s, S: Sink> MultiLine<'s, S> { fn find<'tree>( &mut self, - matches: &mut impl Iterator>, + matches: &mut impl StreamingIterator>, ) -> Result, S::Error> { - Ok(matches.next().as_ref().map(Into::into)) + Ok(matches + .next() + .as_ref() + .map(|capture_info| (&capture_info.node).into())) } fn advance(&mut self, range: &Range) { diff --git a/src/searcher/mod.rs b/src/searcher/mod.rs index 1da30af..a1a5f29 100644 --- a/src/searcher/mod.rs +++ b/src/searcher/mod.rs @@ -1,24 +1,28 @@ // derived from https://github.com/BurntSushi/ripgrep/blob/master/crates/searcher/src/searcher/mod.rs use std::{ - cell::RefCell, + cell::{Ref, RefCell}, cmp, fmt, fs::File, io::{self, Read}, + ops, path::Path, }; use encoding_rs_io::DecodeReaderBytesBuilder; -use tree_sitter::{Node, QueryCursor}; +use memmap::Mmap; +use streaming_iterator::StreamingIterator; +use tree_sitter::{QueryMatch, Tree}; pub use self::mmap::MmapChoice; use crate::{ + get_matches, line_buffer::{alloc_error, DEFAULT_BUFFER_CAPACITY}, matcher::{LineTerminator, Match}, query_context::QueryContext, searcher::glue::MultiLine, sink::{Sink, SinkError}, - treesitter::get_parser, + RopeOrSlice, }; mod core; @@ -214,11 +218,30 @@ impl Searcher { self.search_file_maybe_path(query_context, Some(path), &file, write_to) } + pub fn load_file_contents( + &mut self, + path: P, + ) -> Result + where + P: AsRef, + { + let path = path.as_ref(); + let file = File::open(path).map_err(TError::error_io)?; + + if let Some(mmap) = self.config.mmap.open(&file, Some(path)) { + return Ok(mmap.into()); + } + + self.fill_multi_line_buffer_from_file(&file) + .map_err(TError::error_io)?; + return Ok(self.multi_line_buffer.borrow().into()); + } + pub fn search_path_callback( &mut self, query_context: QueryContext, path: P, - callback: impl Fn(Node, &[u8], &Path), + callback: impl FnMut(&QueryMatch, &[u8], &Path), ) -> Result<(), TError> where P: AsRef, @@ -338,7 +361,7 @@ impl Searcher { &mut self, query_context: QueryContext, slice: &[u8], - callback: impl Fn(Node, &[u8], &Path), + callback: impl FnMut(&QueryMatch, &[u8], &Path), path: &Path, ) -> Result<(), ConfigError> { self.check_config()?; @@ -349,43 +372,36 @@ impl Searcher { Ok(()) } + pub fn search_slice_callback_no_path<'a, 'text, 'tree>( + &mut self, + query_context: QueryContext, + // slice: impl TextProvider<'a> + Parseable + 'a, + slice: impl Into>, + tree: Option<&'tree Tree>, + mut callback: impl FnMut(&QueryMatch), + ) -> Result<(), ConfigError> { + self.check_config()?; + + log::trace!("slice reader: searching via multiline strategy"); + get_matches(query_context.language, slice, &query_context.query, tree).for_each( + |query_match| { + callback(query_match); + }, + ); + + Ok(()) + } + fn run_with_callback( &self, query_context: QueryContext, slice: &[u8], - callback: impl Fn(Node, &[u8], &Path), + mut callback: impl FnMut(&QueryMatch, &[u8], &Path), path: &Path, ) { - let mut query_cursor = QueryCursor::new(); - let tree = get_parser(query_context.language) - .parse(slice, None) - .unwrap(); - let query = &query_context.query; - let capture_index = query_context.capture_index; - let filter = &query_context.filter; - query_cursor - .captures(query, tree.root_node(), slice) - .filter_map(|(match_, found_capture_index)| { - let found_capture_index = found_capture_index as u32; - if found_capture_index != capture_index { - return None; - } - let mut nodes_for_this_capture = match_.nodes_for_capture_index(capture_index); - let single_captured_node = nodes_for_this_capture.next().unwrap(); - assert!( - nodes_for_this_capture.next().is_none(), - "I guess .captures() always wraps up the single capture like this?" - ); - match filter.as_ref() { - None => Some(single_captured_node), - Some(filter) => filter - .call(&single_captured_node) - .then_some(single_captured_node), - } - }) - .for_each(|node| { - callback(node, slice, path); - }); + get_matches(query_context.language, slice, &query_context.query, None).for_each(|match_| { + callback(match_, slice, path); + }); } fn check_config(&self) -> Result<(), ConfigError> { @@ -491,3 +507,31 @@ impl Searcher { } } } + +pub enum MmapOrRefByteVec<'a> { + Mmap(Mmap), + RefByteVec(Ref<'a, Vec>), +} + +impl<'a> From for MmapOrRefByteVec<'a> { + fn from(value: Mmap) -> Self { + Self::Mmap(value) + } +} + +impl<'a> From>> for MmapOrRefByteVec<'a> { + fn from(value: Ref<'a, Vec>) -> Self { + Self::RefByteVec(value) + } +} + +impl<'a> ops::Deref for MmapOrRefByteVec<'a> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + match self { + Self::Mmap(value) => value, + Self::RefByteVec(value) => value, + } + } +} diff --git a/src/treesitter.rs b/src/treesitter.rs index 5575a04..97c15d7 100644 --- a/src/treesitter.rs +++ b/src/treesitter.rs @@ -1,8 +1,18 @@ -use tree_sitter::{Language, Node, Parser, Query, QueryError}; +#![allow(clippy::too_many_arguments)] -use crate::matcher::Match; +use std::{borrow::Cow, fmt, iter, mem}; -pub(crate) fn get_parser(language: Language) -> Parser { +use ouroboros::self_referencing; +use ropey::{iter::Chunks, Rope, RopeSlice}; +use streaming_iterator::StreamingIterator; +use tree_sitter::{ + Language, Node, Parser, Query, QueryCaptures, QueryCursor, QueryError, QueryMatch, + QueryMatches, TextProvider, Tree, +}; + +use crate::{matcher::Match, plugin::Filterer}; + +pub fn get_parser(language: Language) -> Parser { let mut parser = Parser::new(); parser .set_language(language) @@ -21,3 +31,388 @@ impl From<&'_ Node<'_>> for Match { Self::new(range.start_byte, range.end_byte) } } + +pub trait Parseable { + fn parse(&self, parser: &mut Parser, old_tree: Option<&Tree>) -> Option; +} + +impl<'a> Parseable for &'a [u8] { + fn parse(&self, parser: &mut Parser, old_tree: Option<&Tree>) -> Option { + parser.parse(self, old_tree) + } +} + +impl<'a> Parseable for &'a Rope { + fn parse(&self, parser: &mut Parser, old_tree: Option<&Tree>) -> Option { + parser.parse_with( + &mut |byte_offset, _| { + let (chunk, chunk_start_byte_index, _, _) = self.chunk_at_byte(byte_offset); + &chunk[byte_offset - chunk_start_byte_index..] + }, + old_tree, + ) + } +} + +#[derive(Copy, Clone)] +pub enum RopeOrSlice<'a> { + Slice(&'a [u8]), + Rope(&'a Rope), +} + +impl<'a> TextProvider<&'a [u8]> for RopeOrSlice<'a> { + type I = RopeOrSliceTextProviderIterator<'a>; + + fn text(&mut self, node: Node) -> Self::I { + match self { + Self::Slice(slice) => { + RopeOrSliceTextProviderIterator::Slice(iter::once(&slice[node.byte_range()])) + } + Self::Rope(rope) => { + let rope_slice = rope.byte_slice(node.byte_range()); + RopeOrSliceTextProviderIterator::Rope(RopeOrSliceRopeTextProviderIterator::new( + rope_slice, + |rope_slice| rope_slice.chunks(), + )) + } + } + } +} + +impl<'a> TextProvider<&'a [u8]> for &'a RopeOrSlice<'a> { + type I = RopeOrSliceTextProviderIterator<'a>; + + fn text(&mut self, node: Node) -> Self::I { + match self { + RopeOrSlice::Slice(slice) => { + RopeOrSliceTextProviderIterator::Slice(iter::once(&slice[node.byte_range()])) + } + RopeOrSlice::Rope(rope) => { + let rope_slice = rope.byte_slice(node.byte_range()); + RopeOrSliceTextProviderIterator::Rope(RopeOrSliceRopeTextProviderIterator::new( + rope_slice, + |rope_slice| rope_slice.chunks(), + )) + } + } + } +} + +impl<'a> Parseable for RopeOrSlice<'a> { + fn parse(&self, parser: &mut Parser, old_tree: Option<&Tree>) -> Option { + match self { + Self::Slice(slice) => slice.parse(parser, old_tree), + Self::Rope(rope) => rope.parse(parser, old_tree), + } + } +} + +impl<'a> Parseable for &'a RopeOrSlice<'a> { + fn parse(&self, parser: &mut Parser, old_tree: Option<&Tree>) -> Option { + match self { + RopeOrSlice::Slice(slice) => slice.parse(parser, old_tree), + RopeOrSlice::Rope(rope) => rope.parse(parser, old_tree), + } + } +} + +impl<'a> fmt::Debug for RopeOrSlice<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Slice(arg0) => f + .debug_tuple("Slice") + .field(&std::str::from_utf8(arg0)) + .finish(), + Self::Rope(arg0) => f.debug_tuple("Rope").field(arg0).finish(), + } + } +} + +impl<'a> From<&'a [u8]> for RopeOrSlice<'a> { + fn from(value: &'a [u8]) -> Self { + Self::Slice(value) + } +} + +impl<'a> From<&'a Rope> for RopeOrSlice<'a> { + fn from(value: &'a Rope) -> Self { + Self::Rope(value) + } +} + +impl<'a> From<&'a str> for RopeOrSlice<'a> { + fn from(value: &'a str) -> Self { + Self::Slice(value.as_bytes()) + } +} + +impl<'a> From> for String { + fn from(value: RopeOrSlice<'a>) -> Self { + match value { + // TODO: should this use TryFrom instead to expose + // this fallibility? + RopeOrSlice::Slice(value) => std::str::from_utf8(value).unwrap().to_owned(), + RopeOrSlice::Rope(value) => value.into(), + } + } +} + +pub enum RopeOrSliceTextProviderIterator<'a> { + Slice(iter::Once<&'a [u8]>), + Rope(RopeOrSliceRopeTextProviderIterator<'a>), +} + +impl<'a> Iterator for RopeOrSliceTextProviderIterator<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + match self { + Self::Slice(slice_iterator) => slice_iterator.next(), + Self::Rope(rope_iterator) => rope_iterator.next().map(str::as_bytes), + } + } +} + +#[self_referencing] +pub struct RopeOrSliceRopeTextProviderIterator<'a> { + rope_slice: RopeSlice<'a>, + + #[borrows(rope_slice)] + chunks_iterator: Chunks<'a>, +} + +impl<'a> Iterator for RopeOrSliceRopeTextProviderIterator<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + self.with_chunks_iterator_mut(|chunks_iterator| chunks_iterator.next()) + } +} + +// I believe this type can't be Copy/Clone in order for the +// `get_captures()` unsafe stuff to be sound +pub struct CaptureInfo<'a> { + pub node: Node<'a>, + pub pattern_index: usize, +} + +#[self_referencing] +pub struct Captures<'a, 'text: 'a, 'tree: 'a> { + text: RopeOrSlice<'text>, + query_cursor: QueryCursor, + query: &'a Query, + filter: Option<&'a Filterer>, + tree: Cow<'tree, Tree>, + capture_index: u32, + #[borrows(text, mut query_cursor, query, tree)] + #[covariant] + captures_iterator: QueryCaptures<'this, 'this, RopeOrSlice<'this>, &'this [u8]>, + #[borrows(tree)] + #[covariant] + next_capture: Option>, +} + +pub fn get_captures<'a, 'text, 'tree>( + language: Language, + // text: impl TextProvider<'a> + Parseable, + text: impl Into>, + query: &'a Query, + capture_index: u32, + filter: Option<&'a Filterer>, + tree: Option<&'tree Tree>, +) -> Captures<'a, 'text, 'tree> { + let text = text.into(); + let query_cursor = QueryCursor::new(); + let tree: Cow<'tree, Tree> = tree.map_or_else( + || Cow::Owned(text.parse(&mut get_parser(language), None).unwrap()), + Cow::Borrowed, + ); + Captures::new( + text, + query_cursor, + query, + filter, + tree, + capture_index, + |text, query_cursor, query, tree| query_cursor.captures(query, tree.root_node(), *text), + |_| None, + ) +} + +impl<'a, 'text, 'tree> StreamingIterator for Captures<'a, 'text, 'tree> { + type Item = CaptureInfo<'tree>; + + fn advance(&mut self) { + self.with_mut(|all_fields| { + for (match_, index_into_query_match_captures) in all_fields.captures_iterator.by_ref() { + let this_capture = &match_.captures[index_into_query_match_captures]; + if this_capture.index != *all_fields.capture_index { + continue; + } + let single_captured_node = this_capture.node; + if all_fields + .filter + .as_ref() + .map_or(true, |filter| filter.call(&single_captured_node)) + { + *all_fields.next_capture = Some(CaptureInfo { + node: single_captured_node, + pattern_index: match_.pattern_index, + }); + return; + } + } + *all_fields.next_capture = None; + }); + } + + fn get<'this>(&'this self) -> Option<&'this Self::Item> { + let next_capture = self.borrow_next_capture(); + // SAFETY: I think this is ok as long as CaptureInfo isn't + // Copy/Clone? + // Since at that point there's no way for the "inner" + // CaptureInfo's contents to "outlive" the returned reference? + // Did this because otherwise was running into not being able + // to express that the "real" Item type for this trait (I think) + // should be CaptureInfo<'this>, not CaptureInfo<'a> + let next_capture: &'this Option> = + unsafe { mem::transmute(next_capture) }; + next_capture.as_ref() + } +} + +#[self_referencing] +pub struct CapturesForEnclosingNode<'a, 'text: 'a, 'tree: 'a> { + text: RopeOrSlice<'text>, + query_cursor: QueryCursor, + query: &'a Query, + filter: Option<&'a Filterer>, + enclosing_node: Node<'tree>, + capture_index: u32, + #[borrows(text, mut query_cursor, query, enclosing_node)] + #[covariant] + captures_iterator: QueryCaptures<'this, 'this, RopeOrSlice<'this>, &'this [u8]>, + #[borrows(enclosing_node)] + #[covariant] + next_capture: Option>, +} + +pub fn get_captures_for_enclosing_node<'a, 'text, 'tree>( + // text: impl TextProvider<'a> + Parseable, + text: impl Into>, + query: &'a Query, + capture_index: u32, + filter: Option<&'a Filterer>, + enclosing_node: Node<'tree>, +) -> CapturesForEnclosingNode<'a, 'text, 'tree> { + let text = text.into(); + let query_cursor = QueryCursor::new(); + CapturesForEnclosingNode::new( + text, + query_cursor, + query, + filter, + enclosing_node, + capture_index, + |text, query_cursor, query, enclosing_node| { + query_cursor.captures(query, *enclosing_node, *text) + }, + |_| None, + ) +} + +impl<'a, 'text, 'tree> StreamingIterator for CapturesForEnclosingNode<'a, 'text, 'tree> { + type Item = CaptureInfo<'tree>; + + fn advance(&mut self) { + self.with_mut(|all_fields| { + for (match_, index_into_query_match_captures) in all_fields.captures_iterator.by_ref() { + let this_capture = &match_.captures[index_into_query_match_captures]; + if this_capture.index != *all_fields.capture_index { + continue; + } + let single_captured_node = this_capture.node; + if all_fields + .filter + .as_ref() + .map_or(true, |filter| filter.call(&single_captured_node)) + { + *all_fields.next_capture = Some(CaptureInfo { + node: single_captured_node, + pattern_index: match_.pattern_index, + }); + return; + } + } + *all_fields.next_capture = None; + }); + } + + fn get<'this>(&'this self) -> Option<&'this Self::Item> { + let next_capture = self.borrow_next_capture(); + // SAFETY: I think this is ok as long as CaptureInfo isn't + // Copy/Clone? + // Since at that point there's no way for the "inner" + // CaptureInfo's contents to "outlive" the returned reference? + // Did this because otherwise was running into not being able + // to express that the "real" Item type for this trait (I think) + // should be CaptureInfo<'this>, not CaptureInfo<'a> + let next_capture: &'this Option> = + unsafe { mem::transmute(next_capture) }; + next_capture.as_ref() + } +} + +#[self_referencing] +pub struct Matches<'a, 'text: 'a, 'tree: 'a> { + text: RopeOrSlice<'text>, + query_cursor: QueryCursor, + query: &'a Query, + tree: Cow<'tree, Tree>, + #[borrows(text, mut query_cursor, query, tree)] + #[covariant] + matches_iterator: QueryMatches<'this, 'this, RopeOrSlice<'this>, &'this [u8]>, + #[borrows(tree)] + #[covariant] + next_match: Option>, +} + +pub fn get_matches<'a, 'text, 'tree>( + language: Language, + text: impl Into>, + query: &'a Query, + tree: Option<&'tree Tree>, +) -> Matches<'a, 'text, 'tree> { + let text = text.into(); + let query_cursor = QueryCursor::new(); + let tree: Cow<'tree, Tree> = tree.map_or_else( + || Cow::Owned(text.parse(&mut get_parser(language), None).unwrap()), + Cow::Borrowed, + ); + Matches::new( + text, + query_cursor, + query, + tree, + |text, query_cursor, query, tree| query_cursor.matches(query, tree.root_node(), *text), + |_| None, + ) +} + +impl<'a, 'text, 'tree> StreamingIterator for Matches<'a, 'text, 'tree> { + type Item = QueryMatch<'a, 'tree>; + + fn advance(&mut self) { + self.with_mut(|all_fields| { + *all_fields.next_match = all_fields.matches_iterator.next(); + }); + } + + fn get<'this>(&'this self) -> Option<&'this Self::Item> { + let next_match = self.borrow_next_match(); + // SAFETY: Not as sure on this one? + let next_match: &'this Option> = + unsafe { mem::transmute(next_match) }; + next_match.as_ref() + } +} diff --git a/src/use_searcher.rs b/src/use_searcher.rs index 10d824a..ac477b5 100644 --- a/src/use_searcher.rs +++ b/src/use_searcher.rs @@ -1,22 +1,16 @@ -use std::{ - cell::{OnceCell, RefCell}, - ptr, - rc::Rc, -}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; use crate::{searcher::Searcher, Args}; thread_local! { - static SEARCHER: OnceCell<(Rc>, *const Args)> = Default::default(); + static SEARCHER_PER_ARGS_INSTANCE: RefCell>>> = Default::default(); } pub(crate) fn get_searcher(args: &Args) -> Rc> { - SEARCHER.with(|searcher| { - let (searcher, args_when_initialized) = - searcher.get_or_init(|| (Rc::new(RefCell::new(args.get_searcher())), args)); - assert!( - ptr::eq(*args_when_initialized, args), - "Using multiple instances of args not supported" - ); - searcher.clone() + SEARCHER_PER_ARGS_INSTANCE.with(|searcher_per_args_instance| { + searcher_per_args_instance + .borrow_mut() + .entry(args) + .or_insert_with(|| Rc::new(RefCell::new(args.get_searcher()))) + .clone() }) } diff --git a/tests/fixtures/typescript_project_with_tsx_and_ts/foo.tsx b/tests/fixtures/typescript_project_with_tsx_and_ts/foo.tsx new file mode 100644 index 0000000..4f7bbdd --- /dev/null +++ b/tests/fixtures/typescript_project_with_tsx_and_ts/foo.tsx @@ -0,0 +1 @@ +const a =
whee
; diff --git a/tests/fixtures/typescript_project_with_tsx_and_ts/hello.ts b/tests/fixtures/typescript_project_with_tsx_and_ts/hello.ts new file mode 100644 index 0000000..867813e --- /dev/null +++ b/tests/fixtures/typescript_project_with_tsx_and_ts/hello.ts @@ -0,0 +1 @@ +const x =
3; diff --git a/tests/languages.rs b/tests/languages.rs index 11e305d..772ab59 100644 --- a/tests/languages.rs +++ b/tests/languages.rs @@ -438,3 +438,69 @@ fn test_lua_auto_language() { "#, ); } + +#[test] +fn test_typescript_tsx_specific_query() { + assert_sorted_output( + "typescript_project_with_tsx_and_ts", + r#" + $ tree-sitter-grep -q '(jsx_element) @c' --language typescript + foo.tsx:1:const a =
whee
; + "#, + ); +} + +#[test] +fn test_typescript_tsx_specific_query_auto_language() { + assert_sorted_output( + "typescript_project_with_tsx_and_ts", + r#" + $ tree-sitter-grep -q '(jsx_element) @c' + foo.tsx:1:const a =
whee
; + "#, + ); +} + +#[test] +fn test_typescript_only_should_match_ts_parsing() { + assert_sorted_output( + "typescript_project_with_tsx_and_ts", + r#" + $ tree-sitter-grep -q '(type_assertion) @c' --language typescript + hello.ts:1:const x =
3; + "#, + ); +} + +#[test] +fn test_typescript_only_should_match_ts_parsing_auto_language() { + assert_sorted_output( + "typescript_project_with_tsx_and_ts", + r#" + $ tree-sitter-grep -q '(type_assertion) @c' + hello.ts:1:const x =
3; + "#, + ); +} + +#[test] +fn test_typescript_invalid_query_for_ts_or_tsx() { + assert_failure_output( + "typescript_project_with_tsx_and_ts", + r#" + $ tree-sitter-grep -q '(foo) @c' --language typescript + error: couldn't parse query for Tsx or Typescript + "#, + ); +} + +#[test] +fn test_typescript_invalid_query_for_ts_or_tsx_auto_language() { + assert_failure_output( + "typescript_project_with_tsx_and_ts", + r#" + $ tree-sitter-grep -q '(foo) @c' + error: couldn't parse query for Tsx or Typescript + "#, + ); +} diff --git a/tests/output.rs b/tests/output.rs index ef7b911..d1c1add 100644 --- a/tests/output.rs +++ b/tests/output.rs @@ -997,7 +997,7 @@ fn test_couldnt_parse_more_than_two_candidate_auto_detected_languages() { "mixed_project", r#" $ tree-sitter-grep -q '(function_itemz) @f' - error: couldn't parse query for Javascript, Rust, or Typescript + error: couldn't parse query for Javascript, Rust, or Tsx "#, ); } @@ -1008,7 +1008,7 @@ fn test_couldnt_parse_two_candidate_auto_detected_languages() { "mixed_project", r#" $ tree-sitter-grep -q '(function_itemz) @f' javascript_src/ typescript_src/ - error: couldn't parse query for Javascript or Typescript + error: couldn't parse query for Javascript or Tsx "#, ); }