This commit is contained in:
Richard Feldman
2025-04-15 16:15:21 -04:00
parent 68aadebed9
commit 7dcd76995d

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use collections::IndexMap;
use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task};
@@ -28,44 +28,139 @@ pub struct SearchToolInput {
/// To find all Markdown files, use "**/*.md"
/// To find files in a specific directory, use "src/zed.dev/**"
/// </example>
pub path_glob: Option<String>,
pub path: Option<String>,
/// When specified, this filters the output based on the contents of the files or code symbols.
///
/// - If the "output" parameter is "symbols", then this search query be sent to the language server to filter which the code symbols (such as identifiers, types, etc.) will be included in the output.
/// - If the "output" parameter is "text", then this query will be interpreted as a regex, adn only text snippets matching that regex will be included.
/// - If the "output" parameter is "paths", then this query will be interpreted as a regex, adn only files whose text contents match that regex will be included.
/// - If the "output" parameter is "symbols", then this search query will be sent to a language server to filter which the code symbols (such as identifiers, types, etc.) will be included in the output.
/// - If the "output" parameter is "text", then this query will be interpreted as a regex, and only text snippets matching that regex will be included.
/// - If the "output" parameter is "paths", then this query will be interpreted as a regex, and only files whose text contents match that regex will be included.
#[serde(default)]
pub query: Option<String>,
/// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
/// Whether the query should match case-sensitively. Defaults to false (case-insensitive).
#[serde(default)]
pub contents_regex_case_sensitive: bool,
pub query_case_sensitive: bool,
/// The desired format for the output.
pub output: Output,
/// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning.
/// Optional position (1-based index) to start reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
/// For paginated results, this represents the starting item (1-based).
///
/// Defaults to 1.
#[serde(default)]
pub offset: u32,
pub start: Option<u32>,
/// Optional position (1-based index) to end reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
/// For paginated results, this represents how many items to include (starting from the start position).
///
/// Defaults to reading until the end of the file or directory, or a reasonable limit for paginated results.
#[serde(default)]
pub end: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, Copy, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Output {
/// Output matching file paths only.
/// If no path_glob is specified, outputs all the paths in the project.
/// If no path is specified, this outputs all the paths in the project matching
/// the query (or all paths in the project, if there is no query specified), limited
/// based on start and/or end (if they are specified).
///
/// <example>
/// To list the paths of all Rust files in the project:
/// {
/// "path": "**/*.rs",
/// "output": "paths"
/// }
/// </example>
///
/// <example>
/// To list the paths of all files in the project which contain the string "TODO":
/// {
/// "query": "TODO",
/// "output": "paths"
/// }
/// </example>
Paths,
/// Output matching arbitrary text regions within files, including their line numbers.
/// If no path_glob is specified, outputs text found in any file in the project.
/// If no path is specified, this outputs text found in every file in the project
/// matching the query (a query should always be specified when using "output": "text"
/// and no path). If no query is specified, but a path is specified, reads the entire
/// contents of that path.
///
/// Output is always limited based on start and/or end (if they are specified).
///
/// <example>
/// To find all occurrences of "TODO" in all files (including paths and line numbers):
/// {
/// "query": "TODO",
/// "output": "text"
/// }
///
/// To read the first 5 lines of an individual file:
/// {
/// "path": "path/to/file.txt",
/// "output": "text"
/// "end": 5
/// }
///
/// To read all the entries in a directory:
/// {
/// "path": "path/to/directory/",
/// "output": "text"
/// }
/// </example>
Text,
/// Output matching code symbols (such as identifiers, types, etc.) within files, including their line numbers.
/// If no path_glob is specified, outputs symbols found across the entire project.
/// If no path is specified, outputs symbols found across the entire project.
///
/// <example>
/// To find all functions with "search" in their name:
/// {
/// "query": "search",
/// "output": "symbols"
/// }
/// </example>
Symbols,
/// Output error and warning diagnostics for files matching the `path` glob.
/// If no path is specified, outputs a summary of diagnostics found across the entire project.
/// If query is specified, it is treated as a regex, and only shows individual diagnostics
/// which match that regex.
///
/// <example>
/// To find all diagnostics in Rust files:
/// {
/// "path_glob": "**/*.rs",
/// "output": "diagnostics"
/// }
/// </example>
///
/// <example>
/// To find diagnostics containing the word "unused":
/// {
/// "query": "unused",
/// "output": "diagnostics"
/// }
/// </example>
///
/// <example>
/// To find a summary of all errors and warnings in the project:
/// {
/// "output": "diagnostics"
/// }
/// </example>
Diagnostics,
}
// Different search modes have different pagination limits
const PATHS_RESULTS_PER_PAGE: usize = 50;
const TEXT_RESULTS_PER_PAGE: usize = 20;
const SYMBOLS_LINES_PER_PAGE: u32 = 1000;
@@ -89,7 +184,7 @@ impl Tool for SearchTool {
IconName::SearchCode
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<SearchToolInput>(format)
}
@@ -97,8 +192,8 @@ impl Tool for SearchTool {
match serde_json::from_value::<SearchToolInput>(input.clone()) {
Ok(input) => {
// Don't show any pattern if not specified
let path_pattern = input.path_glob.as_deref().map(MarkdownString::inline_code);
let case_info = if input.contents_regex_case_sensitive {
let path_pattern = input.path.as_deref().map(MarkdownString::inline_code);
let case_info = if input.query_case_sensitive {
" (case-sensitive)"
} else {
""
@@ -163,29 +258,175 @@ impl Tool for SearchTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<SearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
) -> ToolResult {
let output = match serde_json::from_value::<SearchToolInput>(input) {
Ok(input) => match input.output {
Output::Paths => {
let todo = todo!(); // TODO process paths
}
Output::Text => text_search(
project,
input.path,
action_log,
input.query,
input.query_case_sensitive,
input.start,
input.end,
cx,
),
Output::Symbols => match &input.path {
Some(path_glob) => {
let todo = todo!(); // TODO go get all the symbols in all those files
}
None => cx.spawn(async move |cx| {
project_symbols(
project,
&input.query.unwrap_or_default(),
input.start,
input.end,
cx,
)
.await
}),
},
Output::Diagnostics => todo!(),
},
Err(err) => Task::ready(Err(anyhow!(err))),
};
match input.output {
Output::Paths => search_paths(input, project, cx),
Output::Text => search_text(input, project, action_log.clone(), cx),
Output::Symbols => search_symbols(input, project, action_log, cx),
ToolResult { output }
}
}
fn text_search(
project: Entity<Project>,
path_glob: Option<String>,
action_log: Entity<ActionLog>,
query: Option<String>,
case_sensitive: bool,
start: Option<u32>,
end: Option<u32>,
cx: &mut App,
) -> Task<Result<String>> {
const MATCH_WHOLE_WORD: bool = false;
const INCLUDE_IGNORED: bool = false;
let files_to_exclude = PathMatcher::default();
let Ok(files_to_include) = PathMatcher::new(path_glob) else {
return Task::ready(Err(anyhow!(
"Invalid path glob: {}",
path_glob.unwrap_or_default()
)));
};
// If a query regex is provided, create a search query for filtering files by content.
// If it's not a valid regex, assume the model wanted an exact match.
match query {
Some(query_str) => {
let Ok(query) = SearchQuery::regex(
&query_str,
MATCH_WHOLE_WORD,
case_sensitive,
INCLUDE_IGNORED,
false,
files_to_include.clone(),
files_to_exclude.clone(),
None, // buffers
)
.or_else(move |_| {
SearchQuery::text(
&query_str,
MATCH_WHOLE_WORD,
case_sensitive,
INCLUDE_IGNORED,
files_to_include,
files_to_exclude,
None, // buffers
)
}) else {
return Task::ready(Err(anyhow!("Invalid query regex: {query_str}")));
};
let results = project.update(cx, |project, cx| project.search(query, cx));
cx.spawn(async move|cx| {
futures::pin_mut!(results);
let mut filtered_paths = Vec::new();
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
if !ranges.is_empty() {
if let Some(path) = buffer.read_with(cx, |buffer, cx| {
buffer
.file()
.map(|file| file.full_path(cx).to_string_lossy().to_string())
})? {
filtered_paths.push(path);
}
}
}
if filtered_paths.is_empty() {
return Ok(
match path_glob {
Some(path_glob) => {
format!("No paths in the project had paths matching the glob {path_glob:?} and contents matching {query_str:?}")
}
None => {
format!("No paths in the project had contents matching {query_str:?}")
}
}
);
}
// Sort to group entries in the same directory together
filtered_paths.sort();
let total_matches = filtered_paths.len();
let response = if total_matches > PATHS_RESULTS_PER_PAGE + input.offset as usize {
let paginated_matches: Vec<_> = filtered_paths
.into_iter()
.skip(input.offset as usize)
.take(PATHS_RESULTS_PER_PAGE)
.collect();
format!(
"Found {} paths matching the content regex. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}",
total_matches,
input.offset + 1,
input.offset as usize + paginated_matches.len(),
paginated_matches.join("\n")
)
} else {
let displayed_matches: Vec<_> = filtered_paths
.into_iter()
.skip(input.offset as usize)
.collect();
format!(
"Found {} paths matching the content regex:\n\n{}",
total_matches,
displayed_matches.join("\n")
)
};
Ok(response)
})
}
None => {
let todo = todo!(); // TODO don't actually do a search, just filter all the paths.
}
}
}
fn search_paths(
fn output_paths(
input: SearchToolInput,
project: Entity<Project>,
cx: &mut App,
) -> Task<Result<String>> {
// Clone the path_glob to avoid borrowing issues with the async closure
let path_glob_option = input.path_glob.clone();
let path_glob_option = input.path.clone();
let query_option = input.query.clone();
let case_sensitive = input.contents_regex_case_sensitive;
let case_sensitive = input.query_case_sensitive;
// Create the path matcher based on the provided glob pattern or use a matcher that matches everything
let path_matcher = if let Some(glob) = path_glob_option.as_deref() {
@@ -207,7 +448,7 @@ fn search_paths(
};
// If a query regex is provided, create a search query for filtering files by content
let regex_query = if let Some(regex) = &query_option {
let results = if let Some(regex) = &query_option {
match SearchQuery::regex(
regex,
false,
@@ -235,63 +476,7 @@ fn search_paths(
let path_glob_for_error = path_glob_option.clone();
// If we need to filter by content, use the search functionality
if let Some(query) = regex_query {
let results = project.update(cx, |project, cx| project.search(query, cx));
return cx.spawn(async move |cx| {
futures::pin_mut!(results);
let mut filtered_paths = Vec::new();
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
if !ranges.is_empty() {
if let Some(path) = buffer.read_with(cx, |buffer, cx| {
buffer.file().map(|file| file.full_path(cx).to_string_lossy().to_string())
})? {
filtered_paths.push(path);
}
}
}
if filtered_paths.is_empty() {
return Ok(format!("No paths in the project matched the glob {:?} and content regex {:?}",
path_glob_for_error, query_option));
}
// Sort to group entries in the same directory together
filtered_paths.sort();
let total_matches = filtered_paths.len();
let response = if total_matches > PATHS_RESULTS_PER_PAGE + input.offset as usize {
let paginated_matches: Vec<_> = filtered_paths
.into_iter()
.skip(input.offset as usize)
.take(PATHS_RESULTS_PER_PAGE)
.collect();
format!(
"Found {} paths matching the content regex. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}",
total_matches,
input.offset + 1,
input.offset as usize + paginated_matches.len(),
paginated_matches.join("\n")
)
} else {
let displayed_matches: Vec<_> = filtered_paths
.into_iter()
.skip(input.offset as usize)
.collect();
format!(
"Found {} paths matching the content regex:\n\n{}",
total_matches,
displayed_matches.join("\n")
)
};
Ok(response)
});
}
if let Some(query) = regex_query {}
// If no content regex, just filter by path glob as before
cx.background_executor().spawn(async move {
@@ -352,7 +537,7 @@ fn search_paths(
})
}
fn search_text(
fn output_text(
input: SearchToolInput,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -365,11 +550,11 @@ fn search_text(
// If no query is provided and path_glob points to a specific file, read the file contents
if input.query.is_none()
&& input.path_glob.as_ref().map_or(false, |glob| {
&& input.path.as_ref().map_or(false, |glob| {
!glob.contains('*') && !glob.contains('?') && !glob.contains('[')
})
{
let file_path = input.path_glob.unwrap();
let file_path = input.path.unwrap();
return cx.spawn(async move |cx| {
let Some(project_path) = project.read_with(cx, |project, cx| {
@@ -408,7 +593,7 @@ fn search_text(
};
// Create a query based on the path glob or use a matcher that matches everything
let path_matcher = if let Some(glob) = input.path_glob.as_deref() {
let path_matcher = if let Some(glob) = input.path.as_deref() {
match PathMatcher::new([glob]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob pattern: {}", err))),
@@ -429,7 +614,7 @@ fn search_text(
let query = match SearchQuery::regex(
dbg!(&search_regex),
false,
input.contents_regex_case_sensitive,
input.query_case_sensitive,
false,
false,
path_matcher,
@@ -529,46 +714,6 @@ fn search_text(
})
}
fn search_symbols(
input: SearchToolInput,
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
// Check if path_glob is a specific file path
if let Some(path) = &input.path_glob {
// If the glob pattern doesn't contain wildcards, assume it's a specific file path
if !path.contains('*') && !path.contains('?') && !path.contains('[') {
let path_string = path.clone();
return cx.spawn(async move |cx| {
render_file_outline(
project,
path_string,
action_log,
input.query,
input.offset,
cx,
)
.await
});
}
}
// Otherwise, get project-wide symbols filtered by path_glob
let path_matcher = if let Some(glob) = input.path_glob.as_deref() {
match PathMatcher::new([glob]) {
Ok(matcher) => Some(matcher),
Err(err) => return Task::ready(Err(anyhow!("Invalid glob pattern: {}", err))),
}
} else {
None
};
cx.spawn(async move |cx| {
project_symbols(project, path_matcher, &input.query.unwrap_or_default(), cx).await
})
}
async fn render_file_outline(
project: Entity<Project>,
path: String,
@@ -629,26 +774,25 @@ async fn render_file_outline(
async fn project_symbols(
project: Entity<Project>,
path_matcher: Option<PathMatcher>,
query: &str,
start: Option<u32>,
end: Option<u32>,
cx: &mut AsyncApp,
) -> anyhow::Result<String> {
let symbols = project
.update(cx, |project, cx| project.symbols(query, cx))?
.await?;
if symbols.is_empty() {
return Err(anyhow!("No symbols found in project."));
// We report a different error later on if there was a query.
if symbols.is_empty() && query.is_empty() {
return Err(anyhow!(
"The language server found no code symbols in this project."
));
}
let mut symbols_by_path: IndexMap<PathBuf, Vec<&Symbol>> = IndexMap::default();
let mut symbols_by_path: IndexMap<PathBuf, Vec<Symbol>> = IndexMap::default();
for symbol in symbols.iter().filter(|symbol| {
path_matcher
.as_ref()
.map(|matcher| matcher.is_match(&symbol.path.path))
.unwrap_or(true)
}) {
for symbol in symbols {
if let Some(worktree_path) = project.read_with(cx, |project, cx| {
project
.worktree_for_id(symbol.path.worktree_id, cx)
@@ -659,11 +803,20 @@ async fn project_symbols(
}
}
// If no symbols matched the filter, return early
if symbols_by_path.is_empty() {
return Ok("No symbols found matching the criteria.".to_string());
Err(anyhow!(
"The language server found no code symbols in this project when filtering by query {query:?}."
))
} else {
render_symbols_by_path(symbols_by_path, project, cx).await
}
}
async fn render_symbols_by_path(
symbols_by_path: impl IntoIterator<Item = (PathBuf, Vec<Symbol>)>,
project: Entity<Project>,
cx: &mut AsyncApp,
) -> Result<String> {
let mut symbols_rendered: usize = 0;
let mut output = String::new();
let mut lines_shown = 0;